mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
1 Commits
fix/copilo
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b410bdd6e0 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -195,3 +195,4 @@ test.db
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
test-results/
|
||||
|
||||
@@ -18,6 +18,7 @@ from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.message_dedup import acquire_dedup_lock
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -190,8 +191,6 @@ class SessionDetailResponse(BaseModel):
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
has_more_messages: bool = False
|
||||
oldest_sequence: int | None = None
|
||||
newest_sequence: int | None = None
|
||||
forward_paginated: bool = False
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
@@ -456,113 +455,52 @@ async def update_session_title_route(
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=200,
|
||||
description="Maximum number of messages to return.",
|
||||
),
|
||||
before_sequence: int | None = Query(
|
||||
default=None,
|
||||
ge=0,
|
||||
description=(
|
||||
"Backward pagination cursor. Return messages with sequence number "
|
||||
"strictly less than this value. Used by active-session load-more. "
|
||||
"Mutually exclusive with after_sequence."
|
||||
),
|
||||
),
|
||||
after_sequence: int | None = Query(
|
||||
default=None,
|
||||
ge=0,
|
||||
description=(
|
||||
"Forward pagination cursor. Return messages with sequence number "
|
||||
"strictly greater than this value. Used by completed-session load-more. "
|
||||
"Mutually exclusive with before_sequence."
|
||||
),
|
||||
),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
before_sequence: int | None = Query(default=None, ge=0),
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Supports cursor-based pagination via ``limit``, ``before_sequence``, and
|
||||
``after_sequence``. The two cursor parameters are mutually exclusive.
|
||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||
When no pagination params are provided, returns the most recent messages.
|
||||
|
||||
On the initial load (no cursor provided) of a completed session, messages
|
||||
are returned in forward order starting from sequence 0 so the user always
|
||||
sees their initial prompt. Active sessions use the legacy newest-first
|
||||
order so streaming context is preserved.
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of messages to return (1-200, default 50).
|
||||
before_sequence: Return messages with sequence < this value (cursor).
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including
|
||||
active_stream info and pagination metadata.
|
||||
"""
|
||||
if before_sequence is not None and after_sequence is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="before_sequence and after_sequence are mutually exclusive",
|
||||
)
|
||||
|
||||
is_initial_load = before_sequence is None and after_sequence is None
|
||||
|
||||
# Check active stream before the DB query on initial loads so we can
|
||||
# choose the correct pagination direction (forward for completed sessions,
|
||||
# newest-first for active ones).
|
||||
active_session = None
|
||||
last_message_id = None
|
||||
if is_initial_load:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
# Completed sessions on initial load start from sequence 0 so the user's
|
||||
# initial prompt is always visible. Active sessions keep the legacy
|
||||
# newest-first behavior to preserve streaming context.
|
||||
from_start = is_initial_load and active_session is None
|
||||
forward_paginated = from_start or after_sequence is not None
|
||||
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id,
|
||||
limit,
|
||||
before_sequence=before_sequence,
|
||||
after_sequence=after_sequence,
|
||||
from_start=from_start,
|
||||
user_id=user_id,
|
||||
session_id, limit, before_sequence, user_id=user_id
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
# Close the TOCTOU window: if the session was active at pre-check, re-verify
|
||||
# after the DB fetch. The session may have completed between the two awaits,
|
||||
# which would have caused messages to be fetched newest-first even though the
|
||||
# session is now complete. Re-fetch from seq 0 so the initial prompt is
|
||||
# always visible.
|
||||
if is_initial_load and active_session is not None:
|
||||
post_active, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
if post_active is None:
|
||||
active_session = None
|
||||
last_message_id = None
|
||||
from_start = True
|
||||
forward_paginated = True
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id,
|
||||
limit,
|
||||
before_sequence=None,
|
||||
after_sequence=None,
|
||||
from_start=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [
|
||||
_strip_injected_context(message.model_dump()) for message in page.messages
|
||||
]
|
||||
|
||||
# Only check active stream on initial load (not on "load more" requests)
|
||||
active_stream_info = None
|
||||
if active_session and last_message_id is not None:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
if before_sequence is None:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
if not is_initial_load:
|
||||
if before_sequence is not None:
|
||||
return SessionDetailResponse(
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
@@ -572,8 +510,6 @@ async def get_session(
|
||||
active_stream=None,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
newest_sequence=page.newest_sequence,
|
||||
forward_paginated=forward_paginated,
|
||||
total_prompt_tokens=0,
|
||||
total_completion_tokens=0,
|
||||
)
|
||||
@@ -590,8 +526,6 @@ async def get_session(
|
||||
active_stream=active_stream_info,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
newest_sequence=page.newest_sequence,
|
||||
forward_paginated=forward_paginated,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=page.session.metadata,
|
||||
@@ -912,6 +846,9 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
# Capture the original message text BEFORE any mutation (attachment enrichment)
|
||||
# so the idempotency hash is stable across retries.
|
||||
original_message = request.message
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
@@ -940,36 +877,58 @@ async def stream_chat_post(
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message returns None when a duplicate is
|
||||
# detected — in that case skip enqueue to avoid processing the message twice.
|
||||
is_duplicate_message = False
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
# ── Idempotency guard ────────────────────────────────────────────────────
|
||||
# Blocks duplicate executor tasks from concurrent/retried POSTs.
|
||||
# See backend/copilot/message_dedup.py for the full lifecycle description.
|
||||
dedup_lock = None
|
||||
if request.is_user_message:
|
||||
dedup_lock = await acquire_dedup_lock(
|
||||
session_id, original_message, sanitized_file_ids
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
is_duplicate_message = (
|
||||
await append_and_save_message(session_id, message)
|
||||
) is None
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
if not is_duplicate_message and request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
if dedup_lock is None and (original_message or sanitized_file_ids):
|
||||
|
||||
async def _empty_sse() -> AsyncGenerator[str, None]:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_empty_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support.
|
||||
# For duplicate messages, skip create_session entirely so the infra-retry
|
||||
# client subscribes to the *existing* turn's Redis stream and receives the
|
||||
# in-progress executor output rather than an empty stream.
|
||||
turn_id = ""
|
||||
if not is_duplicate_message:
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
#
|
||||
# If any of these operations raises, release the dedup lock before propagating
|
||||
# so subsequent retries are not blocked for 30 s.
|
||||
try:
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
@@ -987,6 +946,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@@ -998,10 +958,10 @@ async def stream_chat_post(
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
|
||||
)
|
||||
except Exception:
|
||||
if dedup_lock:
|
||||
await dedup_lock.release()
|
||||
raise
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
@@ -1025,6 +985,12 @@ async def stream_chat_post(
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
# True for every exit path except GeneratorExit (client disconnect).
|
||||
# On disconnect the backend turn is still running — releasing the lock
|
||||
# there would reopen the infra-retry duplicate window. The 30 s TTL
|
||||
# is the fallback. All other exits (normal finish, early return, error)
|
||||
# should release so the user can re-send the same message.
|
||||
release_dedup_lock_on_exit = True
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
@@ -1036,7 +1002,7 @@ async def stream_chat_post(
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
return
|
||||
return # finally releases dedup_lock
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
@@ -1078,7 +1044,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
break # finally releases dedup_lock
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
@@ -1094,6 +1060,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
release_dedup_lock_on_exit = False
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
@@ -1108,7 +1075,10 @@ async def stream_chat_post(
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
# finally releases dedup_lock
|
||||
finally:
|
||||
if dedup_lock and release_dedup_lock_on_exit:
|
||||
await dedup_lock.release()
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
|
||||
@@ -133,12 +133,21 @@ def test_stream_chat_rejects_too_many_file_ids():
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
def _mock_stream_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
redis_set_returns: object = True,
|
||||
):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing RabbitMQ.
|
||||
validation and enrichment logic without needing Redis/RabbitMQ.
|
||||
|
||||
Args:
|
||||
redis_set_returns: Value returned by the mocked Redis ``set`` call.
|
||||
``True`` (default) simulates a fresh key (new message);
|
||||
``None`` simulates a collision (duplicate blocked).
|
||||
|
||||
Returns:
|
||||
A namespace with ``save`` and ``enqueue`` mock objects so
|
||||
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
|
||||
callers can make additional assertions about side-effects.
|
||||
"""
|
||||
import types
|
||||
@@ -149,7 +158,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
)
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
@@ -165,9 +174,15 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
return types.SimpleNamespace(
|
||||
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=redis_set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
|
||||
return ns
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
@@ -196,29 +211,6 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── Duplicate message dedup ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_skips_enqueue_for_duplicate_message(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
"""When append_and_save_message returns None (duplicate detected),
|
||||
enqueue_copilot_turn and stream_registry.create_session must NOT be called
|
||||
to avoid double-processing and to prevent overwriting the active stream's
|
||||
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
|
||||
mocks = _mock_stream_internals(mocker)
|
||||
# Override save to return None — signalling a duplicate
|
||||
mocks.save.return_value = None
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mocks.enqueue.assert_not_called()
|
||||
mocks.registry.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -714,6 +706,237 @@ class TestStripInjectedContext:
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""A second POST with the same message within the 30-s window must return
|
||||
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
|
||||
turn complete without creating a ghost response."""
|
||||
# redis_set_returns=None simulates a collision: the NX key already exists.
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-dup/stream",
|
||||
json={"message": "duplicate message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
|
||||
assert '"finish"' in body
|
||||
assert "[DONE]" in body
|
||||
# The empty SSE response must include the AI SDK protocol header so the
|
||||
# frontend treats it as a valid stream and marks the turn complete.
|
||||
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
|
||||
# The duplicate guard must prevent save/enqueue side effects.
|
||||
ns.save.assert_not_called()
|
||||
ns.enqueue.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_first_post_proceeds_normally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The first POST (Redis NX key set successfully) must proceed through the
|
||||
normal streaming path — no early return."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-new/stream",
|
||||
json={"message": "first message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Redis set must have been called once with the NX flag.
|
||||
ns.redis.set.assert_called_once()
|
||||
call_kwargs = ns.redis.set.call_args
|
||||
assert call_kwargs.kwargs.get("nx") is True
|
||||
|
||||
|
||||
def test_stream_chat_dedup_skipped_for_non_user_messages(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""System/assistant messages (is_user_message=False) bypass the dedup
|
||||
guard — they are injected programmatically and must always be processed."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-sys/stream",
|
||||
json={"message": "system context", "is_user_message": False},
|
||||
)
|
||||
|
||||
# Even though redis_set_returns=None (would block a user message),
|
||||
# the endpoint must proceed because is_user_message=False.
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup hash must be computed from the original request message,
|
||||
not the mutated version that has the [Attached files] block appended.
|
||||
A file_id is sent so the route actually appends the [Attached files] block,
|
||||
exercising the mutation path — the hash must still match the original text."""
|
||||
import hashlib
|
||||
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
# Mock workspace + prisma so the attachment block is actually appended.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
fake_file = type(
|
||||
"F",
|
||||
(),
|
||||
{
|
||||
"id": file_id,
|
||||
"name": "doc.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"sizeBytes": 1024,
|
||||
},
|
||||
)()
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-hash/stream",
|
||||
json={
|
||||
"message": "plain message",
|
||||
"is_user_message": True,
|
||||
"file_ids": [file_id],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_called_once()
|
||||
call_args = ns.redis.set.call_args
|
||||
dedup_key = call_args.args[0]
|
||||
|
||||
# Hash must use the original message + sorted file IDs, not the mutated text.
|
||||
expected_hash = hashlib.sha256(
|
||||
f"sess-hash:plain message:{file_id}".encode()
|
||||
).hexdigest()[:16]
|
||||
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
|
||||
assert dedup_key == expected_key, (
|
||||
f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — "
|
||||
"hash may be using mutated message or wrong inputs"
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_after_stream_finish(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup Redis key must be deleted after the turn completes (when
|
||||
subscriber_queue is None the route yields StreamFinish immediately and
|
||||
should release the key so the user can re-send the same message)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
# Set up all internals manually so we can control subscribe_to_session.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-finish/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
assert '"finish"' in body
|
||||
# The dedup key must be released so intentional re-sends are allowed.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The route must not crash when the dedup Redis delete fails on the
|
||||
subscriber_queue-is-None early-finish path (except Exception: pass)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
# Make the delete raise so the except-pass branch is exercised.
|
||||
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
# Should not raise even though delete fails.
|
||||
response = client.post(
|
||||
"/sessions/sess-finish-err/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert '"finish"' in response.text
|
||||
# delete must have been attempted — the except-pass branch silenced the error.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
@@ -757,146 +980,3 @@ def test_disconnect_stream_returns_404_when_session_missing(
|
||||
|
||||
assert response.status_code == 404
|
||||
mock_disconnect.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── GET /sessions/{session_id} — forward/backward pagination ──────────────────
|
||||
|
||||
|
||||
def _make_paginated_messages(
|
||||
mocker: pytest_mock.MockerFixture, *, has_more: bool = False
|
||||
):
|
||||
"""Return a mock PaginatedMessages and configure the DB patch."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.db import PaginatedMessages
|
||||
from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata
|
||||
|
||||
now = datetime.now(UTC)
|
||||
session_info = ChatSessionInfo(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
metadata=ChatSessionMetadata(),
|
||||
)
|
||||
page = PaginatedMessages(
|
||||
messages=[ChatMessage(role="user", content="hello", sequence=0)],
|
||||
has_more=has_more,
|
||||
oldest_sequence=0,
|
||||
newest_sequence=0,
|
||||
session=session_info,
|
||||
)
|
||||
mock_paginate = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_messages_paginated",
|
||||
new_callable=AsyncMock,
|
||||
return_value=page,
|
||||
)
|
||||
return page, mock_paginate
|
||||
|
||||
|
||||
def test_get_session_completed_returns_forward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Completed sessions (no active stream) return forward_paginated=True."""
|
||||
_make_paginated_messages(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["forward_paginated"] is True
|
||||
assert data["newest_sequence"] == 0
|
||||
|
||||
|
||||
def test_get_session_active_returns_backward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Active sessions (with running stream) return forward_paginated=False."""
|
||||
from backend.copilot.stream_registry import ActiveSession
|
||||
|
||||
_make_paginated_messages(mocker)
|
||||
active = MagicMock(spec=ActiveSession)
|
||||
active.turn_id = "turn-1"
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(active, "msg-1"),
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["forward_paginated"] is False
|
||||
assert data["active_stream"] is not None
|
||||
assert data["active_stream"]["turn_id"] == "turn-1"
|
||||
|
||||
|
||||
def test_get_session_after_sequence_returns_forward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""after_sequence param returns forward_paginated=True; no stream check needed."""
|
||||
_, mock_paginate = _make_paginated_messages(mocker)
|
||||
|
||||
response = client.get("/sessions/sess-1?after_sequence=10")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["forward_paginated"] is True
|
||||
call_kwargs = mock_paginate.call_args
|
||||
assert call_kwargs.kwargs.get("after_sequence") == 10
|
||||
assert call_kwargs.kwargs.get("before_sequence") is None
|
||||
|
||||
|
||||
def test_get_session_both_cursors_returns_400(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending both before_sequence and after_sequence returns 400."""
|
||||
response = client.get("/sessions/sess-1?before_sequence=5&after_sequence=10")
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_get_session_toctou_refetch_when_session_completes_mid_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Race condition: session was active at pre-check but completes before DB fetch.
|
||||
|
||||
The route should detect the race via a post-fetch re-check, then re-fetch
|
||||
from seq 0 so the initial prompt is always visible.
|
||||
"""
|
||||
from backend.copilot.stream_registry import ActiveSession
|
||||
|
||||
page, mock_paginate = _make_paginated_messages(mocker)
|
||||
active = MagicMock(spec=ActiveSession)
|
||||
active.turn_id = "turn-1"
|
||||
|
||||
# First call: session appears active. Second call: session has completed.
|
||||
mock_get_active = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[(active, "msg-1"), (None, None)],
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Post-race: session is now completed → forward_paginated=True, no stream
|
||||
assert data["forward_paginated"] is True
|
||||
assert data["active_stream"] is None
|
||||
# The DB was queried twice: once newest-first, once from_start=True
|
||||
assert mock_paginate.call_count == 2
|
||||
assert mock_get_active.call_count == 2
|
||||
second_call = mock_paginate.call_args_list[1]
|
||||
assert second_call.kwargs.get("from_start") is True
|
||||
|
||||
@@ -10,11 +10,9 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatMessageWhereInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
FindManyChatMessageArgsFromChatSession,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -32,8 +30,6 @@ from .model import get_chat_session as get_chat_session_cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BOUNDARY_SCAN_LIMIT = 10
|
||||
|
||||
|
||||
class PaginatedMessages(BaseModel):
|
||||
"""Result of a paginated message query."""
|
||||
@@ -41,7 +37,6 @@ class PaginatedMessages(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
has_more: bool
|
||||
oldest_sequence: int | None
|
||||
newest_sequence: int | None
|
||||
session: ChatSessionInfo
|
||||
|
||||
|
||||
@@ -66,48 +61,32 @@ async def get_chat_messages_paginated(
|
||||
session_id: str,
|
||||
limit: int = 50,
|
||||
before_sequence: int | None = None,
|
||||
after_sequence: int | None = None,
|
||||
from_start: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> PaginatedMessages | None:
|
||||
"""Get paginated messages for a session.
|
||||
"""Get paginated messages for a session, newest first.
|
||||
|
||||
Three modes:
|
||||
Verifies session existence (and ownership when ``user_id`` is provided)
|
||||
in parallel with the message query. Returns ``None`` when the session
|
||||
is not found or does not belong to the user.
|
||||
|
||||
- ``before_sequence`` set: backward pagination (DESC), returns messages
|
||||
with sequence < ``before_sequence``. Used for active sessions or manual
|
||||
backward navigation.
|
||||
- ``from_start=True`` or ``after_sequence`` set: forward pagination (ASC).
|
||||
Returns messages from sequence 0 (``from_start``) or after
|
||||
``after_sequence``. Used on initial load of completed sessions and for
|
||||
loading subsequent forward pages.
|
||||
- Both cursors ``None`` and ``from_start=False``: newest-first (DESC
|
||||
without filter). Used for active sessions on initial load.
|
||||
|
||||
Verifies session existence (and ownership when ``user_id`` is provided).
|
||||
Returns ``None`` when the session is not found or does not belong to the
|
||||
user.
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
limit: Max messages to return.
|
||||
before_sequence: Cursor — return messages with sequence < this value.
|
||||
user_id: If provided, filters via ``Session.userId`` so only the
|
||||
session owner's messages are returned (acts as an ownership guard).
|
||||
"""
|
||||
# Build session-existence / ownership check
|
||||
session_where: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
session_where["userId"] = user_id
|
||||
|
||||
forward = from_start or after_sequence is not None
|
||||
|
||||
# Build message include — fetch paginated messages in the same query.
|
||||
# Note: when both from_start=True and after_sequence is not None, the
|
||||
# after_sequence filter takes precedence (the elif branch below is skipped).
|
||||
# This combination is not reachable via the HTTP route (mutual exclusion is
|
||||
# enforced there), so we rely on the documented priority here without an
|
||||
# additional assertion.
|
||||
msg_include: FindManyChatMessageArgsFromChatSession = {
|
||||
"order_by": {"sequence": "asc" if forward else "desc"},
|
||||
# Build message include — fetch paginated messages in the same query
|
||||
msg_include: dict[str, Any] = {
|
||||
"order_by": {"sequence": "desc"},
|
||||
"take": limit + 1,
|
||||
}
|
||||
if after_sequence is not None:
|
||||
msg_include["where"] = {"sequence": {"gt": after_sequence}}
|
||||
elif before_sequence is not None:
|
||||
if before_sequence is not None:
|
||||
msg_include["where"] = {"sequence": {"lt": before_sequence}}
|
||||
|
||||
# Single query: session existence/ownership + paginated messages
|
||||
@@ -125,96 +104,57 @@ async def get_chat_messages_paginated(
|
||||
has_more = len(results) > limit
|
||||
results = results[:limit]
|
||||
|
||||
if not forward:
|
||||
# Backward mode: DB returned DESC; reverse to ascending order.
|
||||
results.reverse()
|
||||
# Reverse to ascending order
|
||||
results.reverse()
|
||||
|
||||
# Tool-call boundary fix: if the oldest message is a tool message,
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
if results and results[0].role == "tool":
|
||||
boundary_where: ChatMessageWhereInput = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
# Tool-call boundary fix: if the oldest message is a tool message,
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
_BOUNDARY_SCAN_LIMIT = 10
|
||||
if results and results[0].role == "tool":
|
||||
boundary_where: dict[str, Any] = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
# Only mark has_more if the expanded boundary isn't the
|
||||
# very start of the conversation (sequence 0).
|
||||
if boundary_msgs[0].sequence > 0:
|
||||
has_more = True
|
||||
else:
|
||||
# Forward mode: DB returned ASC.
|
||||
# Tool-call tail boundary fix: if the last message in this page is a
|
||||
# tool message, the NEXT forward page would start after it and begin
|
||||
# mid-tool-group — the owning assistant message is on this page but
|
||||
# the following tool results are on the next page.
|
||||
# Trim the current page so it ends on the owning assistant message,
|
||||
# which keeps tool groups intact across page boundaries.
|
||||
if results and results[-1].role == "tool":
|
||||
# Walk backward through results to find the last non-tool message.
|
||||
trim_idx = len(results) - 1
|
||||
while trim_idx >= 0 and results[trim_idx].role == "tool":
|
||||
trim_idx -= 1
|
||||
|
||||
if trim_idx >= 0:
|
||||
# Trim results so the page ends at the owning assistant.
|
||||
# Mark has_more=True so the client knows to fetch the rest.
|
||||
results = results[: trim_idx + 1]
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
# Only mark has_more if the expanded boundary isn't the
|
||||
# very start of the conversation (sequence 0).
|
||||
if boundary_msgs[0].sequence > 0:
|
||||
has_more = True
|
||||
else:
|
||||
# Entire page is tool messages with no visible owner — log and
|
||||
# keep as-is so the caller is not stuck with an empty page.
|
||||
logger.warning(
|
||||
"Forward tail boundary: entire page is tool messages "
|
||||
"for session=%s, no owning assistant found (%d msgs)",
|
||||
session_id,
|
||||
len(results),
|
||||
)
|
||||
|
||||
messages = [ChatMessage.from_db(m) for m in results]
|
||||
# oldest_sequence is only meaningful in backward mode (used as backward
|
||||
# pagination cursor). In forward mode the page always starts near seq 0
|
||||
# and clients should use newest_sequence as the forward cursor instead.
|
||||
# Return None in forward mode so clients don't accidentally treat it as a
|
||||
# backward cursor on a forward-paginated session.
|
||||
oldest_sequence = messages[0].sequence if (messages and not forward) else None
|
||||
# newest_sequence is only meaningful in forward mode; in backward mode it
|
||||
# points to the last message of the page (not the session's newest message)
|
||||
# which is not a valid forward cursor. Return None in backward mode so
|
||||
# clients don't accidentally use it as one.
|
||||
newest_sequence = messages[-1].sequence if (messages and forward) else None
|
||||
oldest_sequence = messages[0].sequence if messages else None
|
||||
|
||||
return PaginatedMessages(
|
||||
messages=messages,
|
||||
has_more=has_more,
|
||||
oldest_sequence=oldest_sequence,
|
||||
newest_sequence=newest_sequence,
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
@@ -608,6 +548,46 @@ async def update_message_content_by_sequence(
|
||||
return False
|
||||
|
||||
|
||||
async def update_message_tool_calls(
|
||||
session_id: str,
|
||||
sequence: int,
|
||||
tool_calls: list[dict],
|
||||
) -> bool:
|
||||
"""Patch the toolCalls column of an already-saved assistant message.
|
||||
|
||||
Called when StreamToolInputAvailable arrives after an intermediate flush
|
||||
saved the assistant message with tool_calls=None. The DB save is
|
||||
append-only (uses get_next_sequence), so the already-persisted row must
|
||||
be updated in-place to reflect the tool_calls that arrived later.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
sequence: The 0-based sequence number of the assistant message to patch.
|
||||
tool_calls: The full list of tool call dicts to set on the row.
|
||||
|
||||
Returns:
|
||||
True if the row was found and updated, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await PrismaChatMessage.prisma().update_many(
|
||||
where={"sessionId": session_id, "sequence": sequence},
|
||||
data={"toolCalls": SafeJson(tool_calls)},
|
||||
)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"update_message_tool_calls: no row found for session {session_id}, "
|
||||
f"sequence {sequence}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"update_message_tool_calls failed for session {session_id}, "
|
||||
f"sequence {sequence}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
"""Set durationMs on the last assistant message in a session.
|
||||
|
||||
|
||||
@@ -175,187 +175,6 @@ async def test_no_where_on_messages_without_before_sequence(
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
# ---------- Forward pagination (from_start / after_sequence) ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_start_uses_asc_order_no_where(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""from_start=True queries messages in ASC order with no where filter."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert include["Messages"]["order_by"] == {"sequence": "asc"}
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_start_returns_messages_ascending(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""from_start=True returns messages in ascending sequence order."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [0, 1, 2]
|
||||
assert (
|
||||
page.oldest_sequence is None
|
||||
) # None in forward mode — not a valid backward cursor
|
||||
assert page.newest_sequence == 2
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_start_has_more_when_results_exceed_limit(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""from_start=True sets has_more when DB returns more than limit items."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is True
|
||||
assert [m.sequence for m in page.messages] == [0, 1]
|
||||
assert page.newest_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_sequence_uses_gt_filter_asc_order(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""after_sequence adds a sequence > N where clause and uses ASC order."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(11), _make_msg(12)],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert include["Messages"]["order_by"] == {"sequence": "asc"}
|
||||
assert include["Messages"]["where"] == {"sequence": {"gt": 10}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_sequence_returns_messages_in_order(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""after_sequence returns only messages with sequence > cursor, ascending."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(11), _make_msg(12), _make_msg(13)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [11, 12, 13]
|
||||
assert (
|
||||
page.oldest_sequence is None
|
||||
) # None in forward mode — not a valid backward cursor
|
||||
assert page.newest_sequence == 13
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_newest_sequence_none_for_backward_mode(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""newest_sequence is None in backward mode — it is not a valid forward cursor."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(5), _make_msg(4), _make_msg(3)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is not None
|
||||
assert page.newest_sequence is None
|
||||
assert page.oldest_sequence == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_mode_no_boundary_expansion(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pagination never triggers backward boundary expansion."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0, role="tool"), _make_msg(1, role="tool")],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
assert find_many.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_tail_boundary_trims_trailing_tool_messages(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pages that end with tool messages are trimmed to the owning
|
||||
assistant so the next after_sequence page doesn't start mid-tool-group."""
|
||||
find_first, _ = mock_db
|
||||
# DB returns 4 messages ASC: assistant at 0, tool at 1, tool at 2, tool at 3
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(0, role="assistant"),
|
||||
_make_msg(1, role="tool"),
|
||||
_make_msg(2, role="tool"),
|
||||
_make_msg(3, role="tool"),
|
||||
],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
# Page should be trimmed to end at the assistant message
|
||||
assert [m.sequence for m in page.messages] == [0]
|
||||
assert page.newest_sequence == 0
|
||||
# has_more must be True so the client fetches the tool messages on next page
|
||||
assert page.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_tail_boundary_no_trim_when_last_not_tool(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pages that end with a non-tool message are not trimmed."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(0, role="user"),
|
||||
_make_msg(1, role="assistant"),
|
||||
_make_msg(2, role="tool"),
|
||||
_make_msg(3, role="assistant"),
|
||||
],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [0, 1, 2, 3]
|
||||
assert page.newest_sequence == 3
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_id_filter_applied_to_session_where(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
|
||||
71
autogpt_platform/backend/backend/copilot/message_dedup.py
Normal file
71
autogpt_platform/backend/backend/copilot/message_dedup.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Per-request idempotency lock for the /stream endpoint.
|
||||
|
||||
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
|
||||
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
|
||||
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
|
||||
``None`` when the key already exists (duplicate request).
|
||||
2. ``release()`` — deletes the key. Must be called on turn completion or turn
|
||||
error so the next legitimate send is never blocked.
|
||||
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
|
||||
the backend turn is still running, and releasing would reopen the duplicate
|
||||
window for infra-level retries. The 30 s TTL is the safety net.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_KEY_PREFIX = "chat:msg_dedup"
|
||||
_TTL_SECONDS = 30
|
||||
|
||||
|
||||
class _DedupLock:
|
||||
def __init__(self, key: str, redis) -> None:
|
||||
self._key = key
|
||||
self._redis = redis
|
||||
|
||||
async def release(self) -> None:
|
||||
"""Best-effort key deletion. The TTL handles failures silently."""
|
||||
try:
|
||||
await self._redis.delete(self._key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def acquire_dedup_lock(
|
||||
session_id: str,
|
||||
message: str | None,
|
||||
file_ids: list[str] | None,
|
||||
) -> _DedupLock | None:
|
||||
"""Acquire the idempotency lock for this (session, message, files) tuple.
|
||||
|
||||
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
|
||||
Returns ``None`` when a duplicate is detected (lock already held).
|
||||
Returns ``None`` when there is nothing to deduplicate (no message, no files).
|
||||
"""
|
||||
if not message and not file_ids:
|
||||
return None
|
||||
|
||||
sorted_ids = ":".join(sorted(file_ids or []))
|
||||
content_hash = hashlib.sha256(
|
||||
f"{session_id}:{message or ''}:{sorted_ids}".encode()
|
||||
).hexdigest()[:16]
|
||||
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
|
||||
|
||||
redis = await get_redis_async()
|
||||
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"[STREAM] Duplicate user message blocked for session {session_id}, "
|
||||
f"hash={content_hash} — returning empty SSE",
|
||||
)
|
||||
return None
|
||||
|
||||
return _DedupLock(key, redis)
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Unit tests for backend.copilot.message_dedup."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
|
||||
|
||||
|
||||
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
return mock_redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_when_no_message_no_files(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Nothing to deduplicate — no Redis call made, None returned."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
result = await acquire_dedup_lock("sess-1", None, None)
|
||||
assert result is None
|
||||
mock_redis.set.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_lock_on_first_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""First request acquires the lock and returns a _DedupLock."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
mock_redis.set.assert_called_once()
|
||||
key_arg = mock_redis.set.call_args.args[0]
|
||||
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_on_duplicate(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Duplicate request (NX fails) returns None to signal the caller."""
|
||||
_patch_redis(mocker, set_returns=None)
|
||||
result = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_key_stable_across_file_order(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""File IDs are sorted before hashing so order doesn't affect the key."""
|
||||
mock_redis_1 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
|
||||
key_ab = mock_redis_1.set.call_args.args[0]
|
||||
|
||||
mock_redis_2 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
|
||||
key_ba = mock_redis_2.set.call_args.args[0]
|
||||
|
||||
assert key_ab == key_ba
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_deletes_key(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() calls Redis delete exactly once."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release()
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_swallows_redis_error(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() must not raise even when Redis delete fails."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release() # must not raise
|
||||
mock_redis.delete.assert_called_once()
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, AsyncIterator, Self, cast
|
||||
from typing import Any, Self, cast
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
@@ -521,7 +522,10 @@ async def upsert_chat_session(
|
||||
callers are aware of the persistence failure.
|
||||
RedisError: If the cache write fails (after successful DB write).
|
||||
"""
|
||||
async with _get_session_lock(session.session_id) as _:
|
||||
# Acquire session-specific lock to prevent concurrent upserts
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
|
||||
@@ -647,50 +651,20 @@ async def _save_session_to_db(
|
||||
msg.sequence = existing_message_count + i
|
||||
|
||||
|
||||
async def append_and_save_message(
|
||||
session_id: str, message: ChatMessage
|
||||
) -> ChatSession | None:
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
"""Atomically append a message to a session and persist it.
|
||||
|
||||
Returns the updated session, or None if the message was detected as a
|
||||
duplicate (idempotency guard). Callers must check for None and skip any
|
||||
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
|
||||
|
||||
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
|
||||
The idempotency check below provides a last-resort guard when the lock degrades.
|
||||
Acquires the session lock, re-fetches the latest session state,
|
||||
appends the message, and saves — preventing message loss when
|
||||
concurrent requests modify the same session.
|
||||
"""
|
||||
async with _get_session_lock(session_id) as lock_acquired:
|
||||
# When the lock degraded (Redis down or 2s timeout), bypass cache for
|
||||
# the idempotency check. Stale cache could let two concurrent writers
|
||||
# both see the old state, pass the check, and write the same message.
|
||||
if lock_acquired:
|
||||
session = await get_chat_session(session_id)
|
||||
else:
|
||||
session = await _get_session_from_db(session_id)
|
||||
lock = await _get_session_lock(session_id)
|
||||
|
||||
async with lock:
|
||||
session = await get_chat_session(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# Idempotency: skip if the trailing block of same-role messages already
|
||||
# contains this content. Uses is_message_duplicate which checks all
|
||||
# consecutive trailing messages of the same role, not just [-1].
|
||||
#
|
||||
# This collapses infra/nginx retries whether they land on the same pod
|
||||
# (serialised by the Redis lock) or a different pod.
|
||||
#
|
||||
# Legit same-text messages are distinguished by the assistant turn
|
||||
# between them: if the user said "yes", got a response, and says
|
||||
# "yes" again, session.messages[-1] is the assistant reply, so the
|
||||
# role check fails and the second message goes through normally.
|
||||
#
|
||||
# Edge case: if a turn dies without writing any assistant message,
|
||||
# the user's next send of the same text is blocked here permanently.
|
||||
# The fix is to ensure failed turns always write an error/timeout
|
||||
# assistant message so the session always ends on an assistant turn.
|
||||
if message.content is not None and is_message_duplicate(
|
||||
session.messages, message.role, message.content
|
||||
):
|
||||
return None # duplicate — caller should skip enqueue
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
@@ -705,9 +679,6 @@ async def append_and_save_message(
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||
# Invalidate the stale entry so future reads fall back to DB,
|
||||
# preventing a retry from bypassing the idempotency check above.
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
return session
|
||||
|
||||
@@ -793,6 +764,10 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
@@ -857,38 +832,25 @@ async def update_session_title(
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
|
||||
"""Distributed Redis lock for a session, usable as an async context manager.
|
||||
|
||||
Yields True if the lock was acquired, False if it timed out or Redis was
|
||||
unavailable. Callers should treat False as a degraded mode and prefer fresh
|
||||
DB reads over cache to avoid acting on stale state.
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
|
||||
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
|
||||
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
_lock_key = f"copilot:session_lock:{session_id}"
|
||||
lock = None
|
||||
acquired = False
|
||||
try:
|
||||
_redis = await get_redis_async()
|
||||
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
|
||||
acquired = await lock.acquire(blocking=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
"Could not acquire session lock for %s within 2s", session_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
|
||||
|
||||
try:
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired and lock is not None:
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception:
|
||||
pass # TTL will expire the key
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
@@ -11,13 +11,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
append_and_save_message,
|
||||
get_chat_session,
|
||||
is_message_duplicate,
|
||||
maybe_append_user_message,
|
||||
@@ -576,345 +574,3 @@ def test_maybe_append_assistant_skips_duplicate():
|
||||
result = maybe_append_user_message(session, "dup", is_user_message=False)
|
||||
assert result is False
|
||||
assert len(session.messages) == 2
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# append_and_save_message #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
|
||||
s = ChatSession.new(user_id="u1", dry_run=False)
|
||||
s.messages = list(msgs)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_returns_none_for_duplicate(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message returns None when the trailing message is a duplicate."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="user", content="hello"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
|
||||
result = await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="hello")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_appends_new_message(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message appends a non-duplicate message and returns the session."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=2)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="second message")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
assert result is not None
|
||||
assert result.messages[-1].content == "second message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_raises_when_session_not_found(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message raises ValueError when the session does not exist."""
|
||||
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await append_and_save_message(
|
||||
"missing-session-id", ChatMessage(role="user", content="hi")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_uses_db_when_lock_degraded(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mock_get_from_db = mocker.patch(
|
||||
"backend.copilot.model._get_session_from_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
# DB path was used (not cache-first)
|
||||
mock_get_from_db.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_raises_database_error_on_save_failure(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=RuntimeError("db down"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="new msg")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=RuntimeError("redis write failed"),
|
||||
)
|
||||
mock_invalidate = mocker.patch(
|
||||
"backend.copilot.model.invalidate_session_cache",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
result = await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="new msg")
|
||||
)
|
||||
# DB write succeeded, cache invalidation was called
|
||||
mock_invalidate.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_uses_db_when_redis_unavailable(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=ConnectionError("redis down"),
|
||||
)
|
||||
mock_get_from_db = mocker.patch(
|
||||
"backend.copilot.model._get_session_from_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
mock_get_from_db.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_lock_release_failure_is_ignored(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock(
|
||||
side_effect=RuntimeError("release failed")
|
||||
)
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
assert result is not None
|
||||
|
||||
@@ -44,6 +44,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from ..db import update_message_tool_calls
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
@@ -2262,6 +2263,30 @@ async def _run_stream_attempt(
|
||||
if dispatched is not None:
|
||||
yield dispatched
|
||||
|
||||
# If tool calls arrived this batch AND the assistant message was
|
||||
# already flushed to DB (sequence is set), patch the existing row
|
||||
# so tool_calls are not lost. The append-only save (start_sequence)
|
||||
# in _save_session_to_db never re-saves already-persisted rows, so
|
||||
# without this patch the assistant row keeps tool_calls=null.
|
||||
if acc.assistant_response.sequence is not None and any(
|
||||
isinstance(r, StreamToolInputAvailable) for r in adapter_responses
|
||||
):
|
||||
try:
|
||||
await asyncio.shield(
|
||||
update_message_tool_calls(
|
||||
ctx.session.session_id,
|
||||
acc.assistant_response.sequence,
|
||||
acc.accumulated_tool_calls,
|
||||
)
|
||||
)
|
||||
except Exception as patch_err:
|
||||
logger.warning(
|
||||
"%s tool_calls DB patch failed (sequence=%d): %s",
|
||||
ctx.log_prefix,
|
||||
acc.assistant_response.sequence,
|
||||
patch_err,
|
||||
)
|
||||
|
||||
# Append assistant entry AFTER convert_message so that
|
||||
# any stashed tool results from the previous turn are
|
||||
# recorded first, preserving the required API order:
|
||||
|
||||
@@ -20,7 +20,11 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import StreamStartStep, StreamTextDelta
|
||||
from backend.copilot.response_model import (
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator
|
||||
|
||||
_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
@@ -215,3 +219,100 @@ class TestPreCreateAssistantMessage:
|
||||
_simulate_pre_create(acc, ctx)
|
||||
|
||||
assert len(ctx.session.messages) == 0
|
||||
|
||||
|
||||
class TestToolCallsLostAfterIntermediateFlush:
|
||||
"""Regression tests for the bug where tool_calls are lost when an
|
||||
intermediate flush saves the assistant message before StreamToolInputAvailable
|
||||
arrives.
|
||||
|
||||
Sequence that triggers the bug:
|
||||
1. StreamTextDelta → assistant message appended with tool_calls=None
|
||||
2. Intermediate flush fires (time/count threshold) → DB row written with tool_calls=null
|
||||
and acc.assistant_response.sequence is set (back-filled)
|
||||
3. StreamToolInputAvailable → acc.assistant_response.tool_calls mutated in-memory
|
||||
4. Final save: append-only — assistant row already in DB, tool_calls never updated
|
||||
|
||||
Fix: when StreamToolInputAvailable arrives and acc.assistant_response.sequence
|
||||
is not None, issue a DB UPDATE to patch toolCalls on the existing row.
|
||||
"""
|
||||
|
||||
def test_text_delta_then_tool_input_sets_tool_calls_on_message(self) -> None:
|
||||
"""After text arrives then tool input arrives, acc.assistant_response.tool_calls
|
||||
should be populated regardless of flush state."""
|
||||
session = _make_session()
|
||||
ctx = _make_ctx(session)
|
||||
state = _make_state()
|
||||
acc = _StreamAccumulator(
|
||||
assistant_response=ChatMessage(role="assistant", content=""),
|
||||
accumulated_tool_calls=[],
|
||||
)
|
||||
|
||||
# Step 1: text delta arrives, message appended
|
||||
_dispatch_response(
|
||||
StreamTextDelta(id="t1", delta="Let me run that for you."),
|
||||
acc,
|
||||
ctx,
|
||||
state,
|
||||
False,
|
||||
"[test]",
|
||||
)
|
||||
assert acc.has_appended_assistant
|
||||
assert session.messages[-1].tool_calls is None
|
||||
|
||||
# Step 2: simulate intermediate flush back-filling the sequence
|
||||
acc.assistant_response.sequence = 1 # back-filled by _save_session_to_db
|
||||
|
||||
# Step 3: tool input arrives
|
||||
_dispatch_response(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId="call_abc",
|
||||
toolName="bash_exec",
|
||||
input={"command": "ls"},
|
||||
),
|
||||
acc,
|
||||
ctx,
|
||||
state,
|
||||
False,
|
||||
"[test]",
|
||||
)
|
||||
|
||||
# tool_calls should be set in memory
|
||||
assert acc.assistant_response.tool_calls is not None
|
||||
assert len(acc.assistant_response.tool_calls) == 1
|
||||
assert acc.assistant_response.tool_calls[0]["id"] == "call_abc"
|
||||
|
||||
def test_sequence_set_when_flush_occurred_before_tool_input(self) -> None:
|
||||
"""When sequence is back-filled (flush happened) before tool calls arrive,
|
||||
it is detectable so the caller can issue a DB patch."""
|
||||
acc = _StreamAccumulator(
|
||||
assistant_response=ChatMessage(role="assistant", content="hello"),
|
||||
accumulated_tool_calls=[],
|
||||
has_appended_assistant=True,
|
||||
)
|
||||
# Simulate flush back-fill
|
||||
acc.assistant_response.sequence = 3
|
||||
|
||||
ctx = _make_ctx()
|
||||
state = _make_state()
|
||||
|
||||
_dispatch_response(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId="call_xyz",
|
||||
toolName="run_block",
|
||||
input={},
|
||||
),
|
||||
acc,
|
||||
ctx,
|
||||
state,
|
||||
False,
|
||||
"[test]",
|
||||
)
|
||||
|
||||
# Caller should detect this condition and issue a DB patch
|
||||
needs_db_patch = acc.assistant_response.sequence is not None and bool(
|
||||
acc.accumulated_tool_calls
|
||||
)
|
||||
assert (
|
||||
needs_db_patch
|
||||
), "Expected needs_db_patch=True when flush happened before tool calls arrived"
|
||||
|
||||
@@ -423,33 +423,20 @@ async def subscribe_to_session(
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||
)
|
||||
|
||||
# RACE CONDITION FIX: If session not found, retry with backoff.
|
||||
# Duplicate requests skip create_session and subscribe immediately; the
|
||||
# original request's create_session (a Redis hset) may not have completed
|
||||
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
|
||||
# original request before the hset even starts.
|
||||
# RACE CONDITION FIX: If session not found, retry once after small delay
|
||||
# This handles the case where subscribe_to_session is called immediately
|
||||
# after create_session but before Redis propagates the write
|
||||
if not meta:
|
||||
_max_retries = 3
|
||||
_retry_delay = 0.1 # 100ms per attempt
|
||||
for attempt in range(_max_retries):
|
||||
logger.warning(
|
||||
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
|
||||
f"retrying after {int(_retry_delay * 1000)}ms",
|
||||
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
|
||||
)
|
||||
await asyncio.sleep(_retry_delay)
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if meta:
|
||||
logger.info(
|
||||
f"[TIMING] Session found after {attempt + 1} retries",
|
||||
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
await asyncio.sleep(0.05) # 50ms
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if not meta:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
|
||||
f"({elapsed:.1f}ms total)",
|
||||
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -459,6 +446,10 @@ async def subscribe_to_session(
|
||||
},
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"[TIMING] Session found after retry",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
session_status = meta.get("status", "")
|
||||
|
||||
@@ -93,7 +93,6 @@ export function CopilotPage() {
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
forwardPaginated,
|
||||
// Mobile drawer
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
@@ -218,7 +217,6 @@ export function CopilotPage() {
|
||||
hasMoreMessages={hasMoreMessages}
|
||||
isLoadingMore={isLoadingMore}
|
||||
onLoadMore={loadMore}
|
||||
forwardPaginated={forwardPaginated}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
historicalDurations={historicalDurations}
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useChatSession } from "../useChatSession";
|
||||
|
||||
const mockUseGetV2GetSession = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
useGetV2GetSession: (...args: unknown[]) => mockUseGetV2GetSession(...args),
|
||||
usePostV2CreateSession: () => ({ mutateAsync: vi.fn(), isPending: false }),
|
||||
getGetV2GetSessionQueryKey: (id: string) => ["session", id],
|
||||
getGetV2ListSessionsQueryKey: () => ["sessions"],
|
||||
}));
|
||||
|
||||
vi.mock("@tanstack/react-query", () => ({
|
||||
useQueryClient: () => ({
|
||||
invalidateQueries: vi.fn(),
|
||||
setQueryData: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("nuqs", () => ({
|
||||
parseAsString: { withDefault: (v: unknown) => v },
|
||||
useQueryState: () => ["sess-1", vi.fn()],
|
||||
}));
|
||||
|
||||
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
|
||||
convertChatSessionMessagesToUiMessages: vi.fn(() => ({
|
||||
messages: [],
|
||||
historicalDurations: new Map(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("../helpers", () => ({
|
||||
resolveSessionDryRun: vi.fn(() => false),
|
||||
}));
|
||||
|
||||
vi.mock("@sentry/nextjs", () => ({
|
||||
captureException: vi.fn(),
|
||||
}));
|
||||
|
||||
function makeQueryResult(data: object | null) {
|
||||
return {
|
||||
data: data ? { status: 200, data } : undefined,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
isFetching: false,
|
||||
refetch: vi.fn(),
|
||||
};
|
||||
}
|
||||
|
||||
describe("useChatSession — newestSequence and forwardPaginated", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("returns null / false when no session data", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(makeQueryResult(null));
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.newestSequence).toBeNull();
|
||||
expect(result.current.forwardPaginated).toBe(false);
|
||||
});
|
||||
|
||||
it("returns newestSequence from session data", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: 99,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.newestSequence).toBe(99);
|
||||
});
|
||||
|
||||
it("returns null for newestSequence when field is missing", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: false,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: null,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.newestSequence).toBeNull();
|
||||
});
|
||||
|
||||
it("returns forwardPaginated=true when session is forward-paginated", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: 49,
|
||||
forward_paginated: true,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.forwardPaginated).toBe(true);
|
||||
});
|
||||
|
||||
it("returns forwardPaginated=false when session is backward-paginated", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 50,
|
||||
newest_sequence: 99,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.forwardPaginated).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -1,202 +0,0 @@
|
||||
import { act, renderHook, waitFor } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useCopilotPage } from "../useCopilotPage";
|
||||
|
||||
const mockUseChatSession = vi.fn();
|
||||
const mockUseCopilotStream = vi.fn();
|
||||
const mockUseLoadMoreMessages = vi.fn();
|
||||
|
||||
vi.mock("../useChatSession", () => ({
|
||||
useChatSession: (...args: unknown[]) => mockUseChatSession(...args),
|
||||
}));
|
||||
vi.mock("../useCopilotStream", () => ({
|
||||
useCopilotStream: (...args: unknown[]) => mockUseCopilotStream(...args),
|
||||
}));
|
||||
vi.mock("../useLoadMoreMessages", () => ({
|
||||
useLoadMoreMessages: (...args: unknown[]) => mockUseLoadMoreMessages(...args),
|
||||
}));
|
||||
vi.mock("../useCopilotNotifications", () => ({
|
||||
useCopilotNotifications: () => undefined,
|
||||
}));
|
||||
vi.mock("../useWorkflowImportAutoSubmit", () => ({
|
||||
useWorkflowImportAutoSubmit: () => undefined,
|
||||
}));
|
||||
vi.mock("../store", () => ({
|
||||
useCopilotUIStore: () => ({
|
||||
sessionToDelete: null,
|
||||
setSessionToDelete: vi.fn(),
|
||||
isDrawerOpen: false,
|
||||
setDrawerOpen: vi.fn(),
|
||||
copilotChatMode: "chat",
|
||||
copilotLlmModel: null,
|
||||
isDryRun: false,
|
||||
}),
|
||||
}));
|
||||
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
|
||||
concatWithAssistantMerge: (a: unknown[], b: unknown[]) => [...a, ...b],
|
||||
}));
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
useDeleteV2DeleteSession: () => ({ mutate: vi.fn(), isPending: false }),
|
||||
useGetV2ListSessions: () => ({ data: undefined, isLoading: false }),
|
||||
getGetV2ListSessionsQueryKey: () => ["sessions"],
|
||||
}));
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
toast: vi.fn(),
|
||||
}));
|
||||
vi.mock("@/lib/direct-upload", () => ({
|
||||
uploadFileDirect: vi.fn(),
|
||||
}));
|
||||
vi.mock("@/lib/hooks/useBreakpoint", () => ({
|
||||
useBreakpoint: () => "lg",
|
||||
}));
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: () => ({ isUserLoading: false, isLoggedIn: true }),
|
||||
}));
|
||||
vi.mock("@tanstack/react-query", () => ({
|
||||
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
|
||||
}));
|
||||
vi.mock("@/services/feature-flags/use-get-flag", () => ({
|
||||
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
|
||||
useGetFlag: () => false,
|
||||
}));
|
||||
|
||||
function makeBaseChatSession(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
sessionId: "sess-1",
|
||||
setSessionId: vi.fn(),
|
||||
hydratedMessages: [],
|
||||
rawSessionMessages: [],
|
||||
historicalDurations: new Map(),
|
||||
hasActiveStream: false,
|
||||
hasMoreMessages: false,
|
||||
oldestSequence: null,
|
||||
newestSequence: null,
|
||||
forwardPaginated: false,
|
||||
isLoadingSession: false,
|
||||
isSessionError: false,
|
||||
createSession: vi.fn(),
|
||||
isCreatingSession: false,
|
||||
refetchSession: vi.fn(),
|
||||
sessionDryRun: false,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeBaseCopilotStream(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
messages: [],
|
||||
sendMessage: vi.fn(),
|
||||
stop: vi.fn(),
|
||||
status: "ready",
|
||||
error: undefined,
|
||||
isReconnecting: false,
|
||||
isSyncing: false,
|
||||
isUserStoppingRef: { current: false },
|
||||
rateLimitMessage: null,
|
||||
dismissRateLimit: vi.fn(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeBaseLoadMore(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
pagedMessages: [],
|
||||
hasMore: false,
|
||||
isLoadingMore: false,
|
||||
loadMore: vi.fn(),
|
||||
resetPaged: vi.fn(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("useCopilotPage — forwardPaginated message ordering", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("prepends pagedMessages before currentMessages when forwardPaginated=false", () => {
|
||||
const pagedMsg = { id: "paged", role: "user" };
|
||||
const currentMsg = { id: "current", role: "assistant" };
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: false }),
|
||||
);
|
||||
mockUseCopilotStream.mockReturnValue(
|
||||
makeBaseCopilotStream({ messages: [currentMsg] }),
|
||||
);
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useCopilotPage());
|
||||
|
||||
// Backward: pagedMessages (older) come first
|
||||
expect(result.current.messages[0]).toEqual(pagedMsg);
|
||||
expect(result.current.messages[1]).toEqual(currentMsg);
|
||||
});
|
||||
|
||||
it("appends pagedMessages after currentMessages when forwardPaginated=true", () => {
|
||||
const pagedMsg = { id: "paged", role: "assistant" };
|
||||
const currentMsg = { id: "current", role: "user" };
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: true }),
|
||||
);
|
||||
mockUseCopilotStream.mockReturnValue(
|
||||
makeBaseCopilotStream({ messages: [currentMsg] }),
|
||||
);
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useCopilotPage());
|
||||
|
||||
// Forward: currentMessages (beginning of session) come first
|
||||
expect(result.current.messages[0]).toEqual(currentMsg);
|
||||
expect(result.current.messages[1]).toEqual(pagedMsg);
|
||||
});
|
||||
|
||||
it("calls resetPaged when forwardPaginated transitions false→true with paged messages", async () => {
|
||||
const mockResetPaged = vi.fn();
|
||||
const pagedMsg = { id: "paged", role: "user" };
|
||||
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: false }),
|
||||
);
|
||||
mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream());
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({
|
||||
pagedMessages: [pagedMsg],
|
||||
resetPaged: mockResetPaged,
|
||||
}),
|
||||
);
|
||||
|
||||
const { rerender } = renderHook(() => useCopilotPage());
|
||||
|
||||
// Simulate session completing — forwardPaginated flips to true
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: true }),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
rerender();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockResetPaged).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it("does not call resetPaged when forwardPaginated is already true on mount", () => {
|
||||
const mockResetPaged = vi.fn();
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: true }),
|
||||
);
|
||||
mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream());
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({ pagedMessages: [], resetPaged: mockResetPaged }),
|
||||
);
|
||||
|
||||
renderHook(() => useCopilotPage());
|
||||
|
||||
expect(mockResetPaged).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -1,568 +0,0 @@
|
||||
import { act, renderHook, waitFor } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useLoadMoreMessages } from "../useLoadMoreMessages";
|
||||
|
||||
const mockGetV2GetSession = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
getV2GetSession: (...args: unknown[]) => mockGetV2GetSession(...args),
|
||||
}));
|
||||
|
||||
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
|
||||
convertChatSessionMessagesToUiMessages: vi.fn(() => ({ messages: [] })),
|
||||
extractToolOutputsFromRaw: vi.fn(() => []),
|
||||
}));
|
||||
|
||||
const BASE_ARGS = {
|
||||
sessionId: "sess-1",
|
||||
initialOldestSequence: 0,
|
||||
initialNewestSequence: 49,
|
||||
initialHasMore: true,
|
||||
forwardPaginated: true,
|
||||
initialPageRawMessages: [],
|
||||
};
|
||||
|
||||
function makeSuccessResponse(overrides: {
|
||||
messages?: unknown[];
|
||||
has_more_messages?: boolean;
|
||||
oldest_sequence?: number;
|
||||
newest_sequence?: number;
|
||||
}) {
|
||||
return {
|
||||
status: 200,
|
||||
data: {
|
||||
messages: overrides.messages ?? [],
|
||||
has_more_messages: overrides.has_more_messages ?? false,
|
||||
oldest_sequence: overrides.oldest_sequence ?? 0,
|
||||
newest_sequence: overrides.newest_sequence ?? 49,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe("useLoadMoreMessages", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("initialises with empty pagedMessages and correct cursors", () => {
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("resetPaged clears paged state and sets hasMore=false during transition", () => {
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
// hasMore must be false during transition to prevent forward loadMore
|
||||
// from firing on the now-active session before forwardPaginated updates.
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("resetPaged exposes a fresh loadMore via incremented epoch", () => {
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
// Just verify resetPaged is callable and doesn't throw.
|
||||
expect(() => {
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
it("resets all state on sessionId change", () => {
|
||||
const { result, rerender } = renderHook(
|
||||
(props) => useLoadMoreMessages(props),
|
||||
{ initialProps: BASE_ARGS },
|
||||
);
|
||||
|
||||
rerender({
|
||||
...BASE_ARGS,
|
||||
sessionId: "sess-2",
|
||||
initialOldestSequence: 10,
|
||||
initialNewestSequence: 59,
|
||||
initialHasMore: false,
|
||||
});
|
||||
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
describe("loadMore — forward pagination", () => {
|
||||
it("calls getV2GetSession with after_sequence and updates newestSequence", async () => {
|
||||
const rawMsg = { role: "user", content: "hi", sequence: 50 };
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [rawMsg],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 49 }),
|
||||
);
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("sets hasMore=false when response has no more messages", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false }),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("is a no-op when hasMore is false", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
initialHasMore: false,
|
||||
forwardPaginated: true,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — backward pagination", () => {
|
||||
it("calls getV2GetSession with before_sequence", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [{ role: "user", content: "old", sequence: 0 }],
|
||||
has_more_messages: false,
|
||||
oldest_sequence: 0,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: 50,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ before_sequence: 50 }),
|
||||
);
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — error handling", () => {
|
||||
it("does not set hasMore=false on first error", async () => {
|
||||
mockGetV2GetSession.mockRejectedValueOnce(new Error("network error"));
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// First error — hasMore still true
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) errors", async () => {
|
||||
mockGetV2GetSession.mockRejectedValue(new Error("network error"));
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
// Reset the in-flight guard between calls
|
||||
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
|
||||
}
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("ignores non-200 response and increments error count", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce({ status: 500, data: {} });
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// One error, not yet at threshold — hasMore still true
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) non-200 responses", async () => {
|
||||
mockGetV2GetSession.mockResolvedValue({ status: 503, data: {} });
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
|
||||
}
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("discards in-flight error when epoch changes mid-flight (resetPaged called)", async () => {
|
||||
let rejectRequest!: (e: Error) => void;
|
||||
mockGetV2GetSession.mockReturnValueOnce(
|
||||
new Promise((_, rej) => {
|
||||
rejectRequest = rej;
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
act(() => {
|
||||
result.current.loadMore();
|
||||
});
|
||||
|
||||
// Reset epoch mid-flight
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
|
||||
// Reject the in-flight request — stale error should be discarded
|
||||
await act(async () => {
|
||||
rejectRequest(new Error("network error"));
|
||||
});
|
||||
|
||||
// State unchanged: no hasMore=false, no errorCount, isLoadingMore cleared
|
||||
expect(result.current.hasMore).toBe(false); // false from resetPaged
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — forward pagination cursor advancement", () => {
|
||||
it("advances newestSequence after a successful forward load", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [{ role: "user", content: "hi", sequence: 50 }],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// A second loadMore should use after_sequence: 99 (advanced cursor)
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 99 }),
|
||||
);
|
||||
});
|
||||
|
||||
it("does not regress newestSequence when parent refetches after pages loaded", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [{ role: "user", content: "msg", sequence: 50 }],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
(props) => useLoadMoreMessages(props),
|
||||
{ initialProps: { ...BASE_ARGS, forwardPaginated: true } },
|
||||
);
|
||||
|
||||
// Load one page — newestSequence advances to 99
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Parent refetches with a lower newest_sequence (49) — should NOT regress cursor
|
||||
rerender({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: 49,
|
||||
});
|
||||
|
||||
// Next loadMore should still use the advanced cursor (99)
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 99 }),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
|
||||
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
|
||||
// Single load of 2001 messages exceeds the limit in one shot.
|
||||
// This avoids relying on cross-render closure staleness: estimatedTotal =
|
||||
// pagedRawMessages.length (0, fresh) + 2001 = 2001 >= 2000 → hasMore=false.
|
||||
const args = { ...BASE_ARGS, forwardPaginated: false };
|
||||
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 2001 }, (_, i) => ({
|
||||
role: "user",
|
||||
content: `msg ${i}`,
|
||||
sequence: i,
|
||||
})),
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(args));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("forward truncation keeps first MAX_OLDER_MESSAGES items (not last)", async () => {
|
||||
// 1990 messages already paged; load 20 more forward — total 2010 > 2000.
|
||||
// Forward truncation must keep slice(0, 2000), not slice(-2000),
|
||||
// to preserve the beginning of the conversation.
|
||||
const forwardNearLimitArgs = {
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: 49,
|
||||
initialOldestSequence: 0,
|
||||
initialHasMore: true,
|
||||
};
|
||||
|
||||
const { result } = renderHook((props) => useLoadMoreMessages(props), {
|
||||
initialProps: forwardNearLimitArgs,
|
||||
});
|
||||
|
||||
// First load: 1990 messages — advances newestSequence to 2039
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 1990 }, (_, i) => ({
|
||||
role: "assistant",
|
||||
content: `msg ${i + 50}`,
|
||||
sequence: i + 50,
|
||||
})),
|
||||
has_more_messages: true,
|
||||
newest_sequence: 2039,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Second load: 20 more messages pushes total to 2010 > 2000.
|
||||
// Truncation keeps seq 50..2049 (2000 items); discards seq 2050..2059 (10 items).
|
||||
// Even though the server says has_more_messages=false, hasMore stays true
|
||||
// because there are discarded items that need to be re-fetched.
|
||||
// The cursor (newestSequence) advances to 2049 — the last kept item's sequence.
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 20 }, (_, i) => ({
|
||||
role: "assistant",
|
||||
content: `msg ${i + 2040}`,
|
||||
sequence: i + 2040,
|
||||
})),
|
||||
has_more_messages: false,
|
||||
newest_sequence: 2059,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Truncation occurred (2010 > 2000): hasMore=true so discarded items can be fetched.
|
||||
// Cursor advances to last kept item (seq 2049), not the server's newest (2059).
|
||||
await waitFor(() => expect(result.current.hasMore).toBe(true));
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — null cursor guard", () => {
|
||||
it("is a no-op when newestSequence is null (forwardPaginated=true)", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: null,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("is a no-op when oldestSequence is null (forwardPaginated=false)", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: null,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("pagedMessages — initialPageRawMessages extraToolOutputs", () => {
|
||||
it("calls extractToolOutputsFromRaw for backward pagination with non-empty initialPageRawMessages", async () => {
|
||||
const { extractToolOutputsFromRaw } = await import(
|
||||
"../helpers/convertChatSessionToUiMessages"
|
||||
);
|
||||
|
||||
const rawMsg = { role: "user", content: "old", sequence: 0 };
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [rawMsg],
|
||||
has_more_messages: false,
|
||||
oldest_sequence: 0,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: 50,
|
||||
initialPageRawMessages: [{ role: "assistant", content: "response" }],
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(extractToolOutputsFromRaw).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("does NOT call extractToolOutputsFromRaw for forward pagination", async () => {
|
||||
const { extractToolOutputsFromRaw } = await import(
|
||||
"../helpers/convertChatSessionToUiMessages"
|
||||
);
|
||||
|
||||
const rawMsg = { role: "assistant", content: "hi", sequence: 50 };
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [rawMsg],
|
||||
has_more_messages: false,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialPageRawMessages: [{ role: "user", content: "hello" }],
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(extractToolOutputsFromRaw).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — epoch / stale-response guard", () => {
|
||||
it("discards response when epoch changes during flight (resetPaged called)", async () => {
|
||||
let resolveRequest!: (v: unknown) => void;
|
||||
mockGetV2GetSession.mockReturnValueOnce(
|
||||
new Promise((res) => {
|
||||
resolveRequest = res;
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
// Start the request without awaiting
|
||||
act(() => {
|
||||
result.current.loadMore();
|
||||
});
|
||||
|
||||
// Reset epoch mid-flight
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
|
||||
// Now resolve the in-flight request
|
||||
await act(async () => {
|
||||
resolveRequest(
|
||||
makeSuccessResponse({ messages: [{ role: "user", content: "hi" }] }),
|
||||
);
|
||||
});
|
||||
|
||||
// Response discarded — pagedMessages stays empty, isLoadingMore stays false
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -30,7 +30,6 @@ export interface ChatContainerProps {
|
||||
hasMoreMessages?: boolean;
|
||||
isLoadingMore?: boolean;
|
||||
onLoadMore?: () => void;
|
||||
forwardPaginated?: boolean;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
@@ -55,7 +54,6 @@ export const ChatContainer = ({
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
onLoadMore,
|
||||
forwardPaginated,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
historicalDurations,
|
||||
@@ -110,7 +108,6 @@ export const ChatContainer = ({
|
||||
hasMoreMessages={hasMoreMessages}
|
||||
isLoadingMore={isLoadingMore}
|
||||
onLoadMore={onLoadMore}
|
||||
forwardPaginated={forwardPaginated}
|
||||
onRetry={handleRetry}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
|
||||
@@ -43,10 +43,6 @@ interface Props {
|
||||
hasMoreMessages?: boolean;
|
||||
isLoadingMore?: boolean;
|
||||
onLoadMore?: () => void;
|
||||
/** When true the load-more sentinel is placed at the bottom (forward
|
||||
* pagination for completed sessions). When false it is at the top
|
||||
* (backward pagination for active sessions). */
|
||||
forwardPaginated?: boolean;
|
||||
onRetry?: () => void;
|
||||
historicalDurations?: Map<string, number>;
|
||||
}
|
||||
@@ -140,25 +136,11 @@ export function LoadMoreSentinel({
|
||||
isLoading,
|
||||
messageCount,
|
||||
onLoadMore,
|
||||
rootMargin = "200px 0px 0px 0px",
|
||||
adjustScroll = true,
|
||||
forwardPaginated = false,
|
||||
}: {
|
||||
hasMore: boolean;
|
||||
isLoading: boolean;
|
||||
messageCount: number;
|
||||
onLoadMore: () => void;
|
||||
/** IntersectionObserver rootMargin. Top sentinel uses "200px 0px 0px 0px"
|
||||
* (pre-trigger when approaching from above); bottom sentinel should use
|
||||
* "0px 0px 200px 0px" (pre-trigger when approaching from below). */
|
||||
rootMargin?: string;
|
||||
/** Whether to adjust scrollTop after load to preserve visual position.
|
||||
* True for backward pagination (prepend above); false for forward
|
||||
* pagination (append below) where no adjustment is needed. */
|
||||
adjustScroll?: boolean;
|
||||
/** When true the button reads "Load newer messages" (forward pagination).
|
||||
* When false (default) it reads "Load older messages". */
|
||||
forwardPaginated?: boolean;
|
||||
}) {
|
||||
const sentinelRef = useRef<HTMLDivElement>(null);
|
||||
const onLoadMoreRef = useRef(onLoadMore);
|
||||
@@ -203,11 +185,11 @@ export function LoadMoreSentinel({
|
||||
if (autoFillRoundsRef.current >= MAX_AUTO_FILL_ROUNDS) return;
|
||||
captureAndLoad(true);
|
||||
},
|
||||
{ rootMargin },
|
||||
{ rootMargin: "200px 0px 0px 0px" },
|
||||
);
|
||||
observer.observe(sentinelRef.current);
|
||||
return () => observer.disconnect();
|
||||
}, [hasMore, isLoading, rootMargin, scrollRef]);
|
||||
}, [hasMore, isLoading, scrollRef]);
|
||||
|
||||
// After React commits new DOM nodes (prepended messages), adjust
|
||||
// scrollTop so the user stays at the same visual position.
|
||||
@@ -220,9 +202,7 @@ export function LoadMoreSentinel({
|
||||
scrollSnapshotRef.current;
|
||||
if (!el || prevHeight === 0) return;
|
||||
const delta = el.scrollHeight - prevHeight;
|
||||
// Only restore scroll position for backward pagination (content prepended
|
||||
// above). Forward pagination appends below — no adjustment needed.
|
||||
if (adjustScroll && delta > 0) {
|
||||
if (delta > 0) {
|
||||
el.scrollTop = prevTop + delta;
|
||||
}
|
||||
// Reset the auto-fill backoff whenever the container becomes
|
||||
@@ -236,7 +216,7 @@ export function LoadMoreSentinel({
|
||||
}
|
||||
scrollSnapshotRef.current = { scrollHeight: 0, scrollTop: 0 };
|
||||
autoTriggeredRef.current = false;
|
||||
}, [adjustScroll, messageCount, scrollRef]);
|
||||
}, [messageCount, scrollRef]);
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -255,7 +235,7 @@ export function LoadMoreSentinel({
|
||||
size="small"
|
||||
onClick={() => captureAndLoad(false)}
|
||||
>
|
||||
{forwardPaginated ? "Load newer messages" : "Load older messages"}
|
||||
Load older messages
|
||||
</Button>
|
||||
)
|
||||
)}
|
||||
@@ -272,7 +252,6 @@ export function ChatMessagesContainer({
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
onLoadMore,
|
||||
forwardPaginated,
|
||||
onRetry,
|
||||
historicalDurations,
|
||||
}: Props) {
|
||||
@@ -351,7 +330,7 @@ export function ChatMessagesContainer({
|
||||
}
|
||||
>
|
||||
<ConversationContent className="flex min-h-full flex-1 flex-col gap-6 px-3 py-6">
|
||||
{hasMoreMessages && onLoadMore && !forwardPaginated && (
|
||||
{hasMoreMessages && onLoadMore && (
|
||||
<LoadMoreSentinel
|
||||
hasMore={hasMoreMessages}
|
||||
isLoading={!!isLoadingMore}
|
||||
@@ -510,17 +489,6 @@ export function ChatMessagesContainer({
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
{hasMoreMessages && onLoadMore && forwardPaginated && (
|
||||
<LoadMoreSentinel
|
||||
hasMore={hasMoreMessages}
|
||||
isLoading={!!isLoadingMore}
|
||||
messageCount={messages.length}
|
||||
onLoadMore={onLoadMore}
|
||||
rootMargin="0px 0px 200px 0px"
|
||||
adjustScroll={false}
|
||||
forwardPaginated
|
||||
/>
|
||||
)}
|
||||
</ConversationContent>
|
||||
<ConversationScrollButton />
|
||||
</Conversation>
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { ChatMessagesContainer } from "../ChatMessagesContainer";
|
||||
|
||||
const mockScrollEl = {
|
||||
scrollHeight: 100,
|
||||
scrollTop: 0,
|
||||
clientHeight: 500,
|
||||
};
|
||||
|
||||
vi.mock("use-stick-to-bottom", () => ({
|
||||
useStickToBottomContext: () => ({ scrollRef: { current: mockScrollEl } }),
|
||||
Conversation: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
ConversationContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
ConversationScrollButton: () => null,
|
||||
}));
|
||||
|
||||
vi.mock("@/components/ai-elements/conversation", () => ({
|
||||
Conversation: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
ConversationContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
ConversationScrollButton: () => null,
|
||||
}));
|
||||
|
||||
vi.mock("@/components/ai-elements/message", () => ({
|
||||
Message: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
MessageContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
MessageActions: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("../components/AssistantMessageActions", () => ({
|
||||
AssistantMessageActions: () => null,
|
||||
}));
|
||||
vi.mock("../components/CopyButton", () => ({ CopyButton: () => null }));
|
||||
vi.mock("../components/CollapsedToolGroup", () => ({
|
||||
CollapsedToolGroup: () => null,
|
||||
}));
|
||||
vi.mock("../components/MessageAttachments", () => ({
|
||||
MessageAttachments: () => null,
|
||||
}));
|
||||
vi.mock("../components/MessagePartRenderer", () => ({
|
||||
MessagePartRenderer: () => null,
|
||||
}));
|
||||
vi.mock("../components/ReasoningCollapse", () => ({
|
||||
ReasoningCollapse: () => null,
|
||||
}));
|
||||
vi.mock("../components/ThinkingIndicator", () => ({
|
||||
ThinkingIndicator: () => null,
|
||||
}));
|
||||
vi.mock("../../JobStatsBar/TurnStatsBar", () => ({
|
||||
TurnStatsBar: () => null,
|
||||
}));
|
||||
vi.mock("../../JobStatsBar/useElapsedTimer", () => ({
|
||||
useElapsedTimer: () => ({ elapsedSeconds: 0 }),
|
||||
}));
|
||||
vi.mock("../../CopilotPendingReviews/CopilotPendingReviews", () => ({
|
||||
CopilotPendingReviews: () => null,
|
||||
}));
|
||||
vi.mock("../helpers", () => ({
|
||||
buildRenderSegments: () => [],
|
||||
getTurnMessages: () => [],
|
||||
parseSpecialMarkers: () => ({ markerType: null }),
|
||||
splitReasoningAndResponse: (parts: unknown[]) => ({
|
||||
reasoningParts: [],
|
||||
responseParts: parts,
|
||||
}),
|
||||
}));
|
||||
|
||||
type ObserverCallback = (entries: { isIntersecting: boolean }[]) => void;
|
||||
class MockIntersectionObserver {
|
||||
static lastCallback: ObserverCallback | null = null;
|
||||
private callback: ObserverCallback;
|
||||
constructor(cb: ObserverCallback) {
|
||||
this.callback = cb;
|
||||
MockIntersectionObserver.lastCallback = cb;
|
||||
}
|
||||
observe() {}
|
||||
disconnect() {}
|
||||
unobserve() {}
|
||||
takeRecords() {
|
||||
return [];
|
||||
}
|
||||
root = null;
|
||||
rootMargin = "";
|
||||
thresholds = [];
|
||||
}
|
||||
|
||||
const BASE_PROPS = {
|
||||
messages: [],
|
||||
status: "ready" as const,
|
||||
error: undefined,
|
||||
isLoading: false,
|
||||
sessionID: "sess-1",
|
||||
hasMoreMessages: true,
|
||||
isLoadingMore: false,
|
||||
onLoadMore: vi.fn(),
|
||||
onRetry: vi.fn(),
|
||||
};
|
||||
|
||||
describe("ChatMessagesContainer", () => {
|
||||
beforeEach(() => {
|
||||
mockScrollEl.scrollHeight = 100;
|
||||
mockScrollEl.scrollTop = 0;
|
||||
mockScrollEl.clientHeight = 500;
|
||||
MockIntersectionObserver.lastCallback = null;
|
||||
vi.stubGlobal("IntersectionObserver", MockIntersectionObserver);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("renders top sentinel when forwardPaginated is false (backward pagination)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={false} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load older messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders top sentinel when forwardPaginated is undefined (default, backward)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load older messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders bottom sentinel when forwardPaginated is true (forward pagination)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={true} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load newer messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("hides sentinel when hasMoreMessages is false", () => {
|
||||
render(
|
||||
<ChatMessagesContainer
|
||||
{...BASE_PROPS}
|
||||
hasMoreMessages={false}
|
||||
forwardPaginated={true}
|
||||
/>,
|
||||
);
|
||||
expect(
|
||||
screen.queryByRole("button", { name: /load older messages/i }),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("hides sentinel when onLoadMore is not provided", () => {
|
||||
render(
|
||||
<ChatMessagesContainer
|
||||
{...BASE_PROPS}
|
||||
onLoadMore={undefined}
|
||||
forwardPaginated={true}
|
||||
/>,
|
||||
);
|
||||
expect(
|
||||
screen.queryByRole("button", { name: /load older messages/i }),
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -172,36 +172,6 @@ describe("LoadMoreSentinel", () => {
|
||||
expect(mockScrollEl.scrollTop).toBe(200);
|
||||
});
|
||||
|
||||
it("does NOT adjust scroll when adjustScroll=false (forward pagination)", () => {
|
||||
mockScrollEl.scrollHeight = 100;
|
||||
mockScrollEl.scrollTop = 50;
|
||||
const onLoadMore = vi.fn();
|
||||
const { rerender } = render(
|
||||
<LoadMoreSentinel
|
||||
hasMore={true}
|
||||
isLoading={false}
|
||||
messageCount={5}
|
||||
onLoadMore={onLoadMore}
|
||||
adjustScroll={false}
|
||||
/>,
|
||||
);
|
||||
// Fire observer to capture snapshot.
|
||||
MockIntersectionObserver.lastCallback?.([{ isIntersecting: true }]);
|
||||
// Simulate DOM growing from appended newer messages (forward load-more).
|
||||
mockScrollEl.scrollHeight = 300;
|
||||
rerender(
|
||||
<LoadMoreSentinel
|
||||
hasMore={true}
|
||||
isLoading={false}
|
||||
messageCount={10}
|
||||
onLoadMore={onLoadMore}
|
||||
adjustScroll={false}
|
||||
/>,
|
||||
);
|
||||
// scrollTop should remain unchanged — no jump for forward pagination.
|
||||
expect(mockScrollEl.scrollTop).toBe(50);
|
||||
});
|
||||
|
||||
it("ignores same-frame duplicate triggers until isLoading transitions", () => {
|
||||
const onLoadMore = vi.fn();
|
||||
const { rerender } = render(
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { convertChatSessionMessagesToUiMessages } from "../convertChatSessionToUiMessages";
|
||||
|
||||
const SESSION_ID = "sess-test";
|
||||
|
||||
describe("convertChatSessionMessagesToUiMessages", () => {
|
||||
it("does not drop user messages with null content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "user", content: null, sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(1);
|
||||
expect(result.messages[0].role).toBe("user");
|
||||
});
|
||||
|
||||
it("does not drop user messages with empty string content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "user", content: "", sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(1);
|
||||
expect(result.messages[0].role).toBe("user");
|
||||
});
|
||||
|
||||
it("still drops non-user messages with null content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "assistant", content: null, sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("still drops non-user messages with empty string content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "assistant", content: "", sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("includes user message with normal content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "user", content: "hello", sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(1);
|
||||
expect(result.messages[0].role).toBe("user");
|
||||
});
|
||||
});
|
||||
@@ -253,11 +253,6 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
}
|
||||
}
|
||||
|
||||
// User messages must always be rendered, even with empty content, so the
|
||||
// initial prompt is visible when reloading a session.
|
||||
if (parts.length === 0 && msg.role === "user") {
|
||||
parts.push({ type: "text", text: "", state: "done" });
|
||||
}
|
||||
if (parts.length === 0) return;
|
||||
|
||||
// Merge consecutive assistant messages into a single UIMessage
|
||||
|
||||
@@ -86,16 +86,6 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
|
||||
return sessionQuery.data.data.oldest_sequence ?? null;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const newestSequence = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200) return null;
|
||||
return sessionQuery.data.data.newest_sequence ?? null;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const forwardPaginated = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200) return false;
|
||||
return !!sessionQuery.data.data.forward_paginated;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
|
||||
// array reference every render. Re-derives only when query data changes.
|
||||
// When the session is complete (no active stream), mark dangling tool
|
||||
@@ -195,8 +185,6 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
|
||||
hasActiveStream,
|
||||
hasMoreMessages,
|
||||
oldestSequence,
|
||||
newestSequence,
|
||||
forwardPaginated,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
isSessionError: sessionQuery.isError,
|
||||
createSession,
|
||||
|
||||
@@ -56,8 +56,6 @@ export function useCopilotPage() {
|
||||
hasActiveStream,
|
||||
hasMoreMessages,
|
||||
oldestSequence,
|
||||
newestSequence,
|
||||
forwardPaginated,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
createSession,
|
||||
@@ -86,26 +84,18 @@ export function useCopilotPage() {
|
||||
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
|
||||
});
|
||||
|
||||
const { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged } =
|
||||
const { olderMessages, hasMore, isLoadingMore, loadMore } =
|
||||
useLoadMoreMessages({
|
||||
sessionId,
|
||||
initialOldestSequence: oldestSequence,
|
||||
initialNewestSequence: newestSequence,
|
||||
initialHasMore: hasMoreMessages,
|
||||
forwardPaginated,
|
||||
initialPageRawMessages: rawSessionMessages,
|
||||
});
|
||||
|
||||
// Combine paginated messages with current page messages, merging consecutive
|
||||
// assistant UIMessages at the page boundary so reasoning + response parts
|
||||
// stay in a single bubble.
|
||||
// Forward pagination (completed sessions): current page is the beginning,
|
||||
// paged messages are newer pages appended after.
|
||||
// Backward pagination (active sessions): paged messages are older history
|
||||
// prepended before the current page.
|
||||
const messages = forwardPaginated
|
||||
? concatWithAssistantMerge(currentMessages, pagedMessages)
|
||||
: concatWithAssistantMerge(pagedMessages, currentMessages);
|
||||
// Combine older (paginated) messages with current page messages,
|
||||
// merging consecutive assistant UIMessages at the page boundary so
|
||||
// reasoning + response parts stay in a single bubble.
|
||||
const messages = concatWithAssistantMerge(olderMessages, currentMessages);
|
||||
|
||||
useCopilotNotifications(sessionId);
|
||||
|
||||
@@ -180,23 +170,6 @@ export function useCopilotPage() {
|
||||
}
|
||||
}, [sessionId, pendingMessage, sendMessage]);
|
||||
|
||||
// --- Clear backward-paginated messages when session completes ---
|
||||
// When a session transitions from active (forwardPaginated=false) to complete
|
||||
// (forwardPaginated=true), any backward-paginated older messages would be
|
||||
// appended after currentMessages instead of before, causing chronological
|
||||
// disorder. Reset paged state so the completed session renders cleanly.
|
||||
const prevForwardPaginatedRef = useRef(forwardPaginated);
|
||||
useEffect(() => {
|
||||
if (
|
||||
!prevForwardPaginatedRef.current &&
|
||||
forwardPaginated &&
|
||||
pagedMessages.length > 0
|
||||
) {
|
||||
resetPaged();
|
||||
}
|
||||
prevForwardPaginatedRef.current = forwardPaginated;
|
||||
}, [forwardPaginated, pagedMessages.length, resetPaged]);
|
||||
|
||||
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
|
||||
useWorkflowImportAutoSubmit({
|
||||
createSession,
|
||||
@@ -278,15 +251,6 @@ export function useCopilotPage() {
|
||||
isUserStoppingRef.current = false;
|
||||
|
||||
if (sessionId) {
|
||||
// When continuing a completed session that had forward-paginated history
|
||||
// loaded, the paged messages would appear in wrong position relative to
|
||||
// the new streaming turn (pagedMessages are newer pages, so they'd end
|
||||
// up after the streaming turn). Reset paged state so ordering is correct
|
||||
// during streaming; the user can reload history afterward if needed.
|
||||
if (forwardPaginated && pagedMessages.length > 0) {
|
||||
resetPaged();
|
||||
}
|
||||
|
||||
if (files && files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
try {
|
||||
@@ -433,7 +397,6 @@ export function useCopilotPage() {
|
||||
hasMoreMessages: hasMore,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
forwardPaginated,
|
||||
// Mobile drawer
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -9,11 +9,7 @@ import {
|
||||
interface UseLoadMoreMessagesArgs {
|
||||
sessionId: string | null;
|
||||
initialOldestSequence: number | null;
|
||||
initialNewestSequence: number | null;
|
||||
initialHasMore: boolean;
|
||||
/** True when the initial page was loaded from sequence 0 forward (completed
|
||||
* sessions). False when loaded newest-first (active sessions). */
|
||||
forwardPaginated: boolean;
|
||||
/** Raw messages from the initial page, used for cross-page tool output matching. */
|
||||
initialPageRawMessages: unknown[];
|
||||
}
|
||||
@@ -24,21 +20,16 @@ const MAX_OLDER_MESSAGES = 2000;
|
||||
export function useLoadMoreMessages({
|
||||
sessionId,
|
||||
initialOldestSequence,
|
||||
initialNewestSequence,
|
||||
initialHasMore,
|
||||
forwardPaginated,
|
||||
initialPageRawMessages,
|
||||
}: UseLoadMoreMessagesArgs) {
|
||||
// Accumulated raw messages from all extra pages (ascending order).
|
||||
// Store accumulated raw messages from all older pages (in ascending order).
|
||||
// Re-converting them all together ensures tool outputs are matched across
|
||||
// inter-page boundaries.
|
||||
const [pagedRawMessages, setPagedRawMessages] = useState<unknown[]>([]);
|
||||
const [olderRawMessages, setOlderRawMessages] = useState<unknown[]>([]);
|
||||
const [oldestSequence, setOldestSequence] = useState<number | null>(
|
||||
initialOldestSequence,
|
||||
);
|
||||
const [newestSequence, setNewestSequence] = useState<number | null>(
|
||||
initialNewestSequence,
|
||||
);
|
||||
const [hasMore, setHasMore] = useState(initialHasMore);
|
||||
const [isLoadingMore, setIsLoadingMore] = useState(false);
|
||||
const isLoadingMoreRef = useRef(false);
|
||||
@@ -55,7 +46,7 @@ export function useLoadMoreMessages({
|
||||
// The parent's `initialOldestSequence` drifts forward every time the
|
||||
// session query refetches (e.g. after a stream completes — see
|
||||
// `useCopilotStream` invalidation on `streaming → ready`). If we
|
||||
// wiped `pagedRawMessages` every time that happened, users who had
|
||||
// wiped `olderRawMessages` every time that happened, users who had
|
||||
// scrolled back would lose their loaded history on each new turn and
|
||||
// subsequent `loadMore` calls would fetch messages that overlap with
|
||||
// the AI SDK's retained state in `currentMessages`, producing visible
|
||||
@@ -72,9 +63,8 @@ export function useLoadMoreMessages({
|
||||
// Session changed — full reset
|
||||
prevSessionIdRef.current = sessionId;
|
||||
prevInitialOldestRef.current = initialOldestSequence;
|
||||
setPagedRawMessages([]);
|
||||
setOlderRawMessages([]);
|
||||
setOldestSequence(initialOldestSequence);
|
||||
setNewestSequence(initialNewestSequence);
|
||||
setHasMore(initialHasMore);
|
||||
setIsLoadingMore(false);
|
||||
isLoadingMoreRef.current = false;
|
||||
@@ -85,64 +75,49 @@ export function useLoadMoreMessages({
|
||||
|
||||
prevInitialOldestRef.current = initialOldestSequence;
|
||||
|
||||
// If we haven't paged yet, mirror the parent so the first
|
||||
// If we haven't paged back yet, mirror the parent so the first
|
||||
// `loadMore` starts from the correct cursor.
|
||||
//
|
||||
// When paged messages exist (pagedRawMessages.length > 0) we intentionally
|
||||
// do NOT update `hasMore` or `newestSequence` from the parent. A parent
|
||||
// refetch (e.g. after a new turn completes) may carry a fresh
|
||||
// `initialHasMore=true` or a larger `initialNewestSequence`, but those
|
||||
// reflect the *initial* page window, not the forward-paged window we have
|
||||
// already advanced into. Overwriting the local cursor here would cause the
|
||||
// next `loadMore` to re-fetch pages we already have. The local cursor is
|
||||
// advanced correctly inside `loadMore` itself via `setNewestSequence`.
|
||||
if (pagedRawMessages.length === 0) {
|
||||
if (olderRawMessages.length === 0) {
|
||||
setOldestSequence(initialOldestSequence);
|
||||
// Only regress the forward cursor if we haven't paged ahead yet —
|
||||
// otherwise a parent refetch would reset a cursor we already advanced.
|
||||
setNewestSequence((prev) =>
|
||||
prev !== null && prev > (initialNewestSequence ?? -1)
|
||||
? prev
|
||||
: initialNewestSequence,
|
||||
);
|
||||
setHasMore(initialHasMore);
|
||||
}
|
||||
}, [sessionId, initialOldestSequence, initialNewestSequence, initialHasMore]);
|
||||
}, [sessionId, initialOldestSequence, initialHasMore]);
|
||||
|
||||
// Convert all accumulated raw messages in one pass so tool outputs
|
||||
// are matched across inter-page boundaries.
|
||||
// For backward pagination only: include initial page tool outputs so older
|
||||
// paged pages can match tool calls whose outputs landed in the initial page.
|
||||
// For forward pagination this is unnecessary — tool calls in newer paged
|
||||
// pages cannot have their outputs in the older initial page.
|
||||
const pagedMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
|
||||
// are matched across inter-page boundaries. Initial page tool outputs
|
||||
// are included via extraToolOutputs to handle the boundary between
|
||||
// the last older page and the initial/streaming page.
|
||||
const olderMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
|
||||
useMemo(() => {
|
||||
if (!sessionId || pagedRawMessages.length === 0) return [];
|
||||
if (!sessionId || olderRawMessages.length === 0) return [];
|
||||
const extraToolOutputs =
|
||||
!forwardPaginated && initialPageRawMessages.length > 0
|
||||
initialPageRawMessages.length > 0
|
||||
? extractToolOutputsFromRaw(initialPageRawMessages)
|
||||
: undefined;
|
||||
return convertChatSessionMessagesToUiMessages(
|
||||
sessionId,
|
||||
pagedRawMessages,
|
||||
olderRawMessages,
|
||||
{ isComplete: true, extraToolOutputs },
|
||||
).messages;
|
||||
}, [sessionId, pagedRawMessages, initialPageRawMessages, forwardPaginated]);
|
||||
}, [sessionId, olderRawMessages, initialPageRawMessages]);
|
||||
|
||||
async function loadMore() {
|
||||
if (!sessionId || !hasMore || isLoadingMoreRef.current) return;
|
||||
|
||||
const cursor = forwardPaginated ? newestSequence : oldestSequence;
|
||||
if (cursor === null) return;
|
||||
if (
|
||||
!sessionId ||
|
||||
!hasMore ||
|
||||
isLoadingMoreRef.current ||
|
||||
oldestSequence === null
|
||||
)
|
||||
return;
|
||||
|
||||
const requestEpoch = epochRef.current;
|
||||
isLoadingMoreRef.current = true;
|
||||
setIsLoadingMore(true);
|
||||
try {
|
||||
const params = forwardPaginated
|
||||
? { limit: 50, after_sequence: cursor }
|
||||
: { limit: 50, before_sequence: cursor };
|
||||
const response = await getV2GetSession(sessionId, params);
|
||||
const response = await getV2GetSession(sessionId, {
|
||||
limit: 50,
|
||||
before_sequence: oldestSequence,
|
||||
});
|
||||
|
||||
// Discard response if session/pagination was reset while awaiting
|
||||
if (epochRef.current !== requestEpoch) return;
|
||||
@@ -161,66 +136,18 @@ export function useLoadMoreMessages({
|
||||
consecutiveErrorsRef.current = 0;
|
||||
|
||||
const newRaw = (response.data.messages ?? []) as unknown[];
|
||||
// Estimate total after merge using the closure-captured pagedRawMessages.length.
|
||||
// This is a safe approximation: worst case it's one page stale (one extra load
|
||||
// allowed), but it avoids the React-18-batching pitfall where a functional
|
||||
// updater's mutations are not visible until the next render.
|
||||
const estimatedTotal = pagedRawMessages.length + newRaw.length;
|
||||
setPagedRawMessages((prev) => {
|
||||
// Forward: append to end. Backward: prepend to start.
|
||||
const merged = forwardPaginated
|
||||
? [...prev, ...newRaw]
|
||||
: [...newRaw, ...prev];
|
||||
setOlderRawMessages((prev) => {
|
||||
const merged = [...newRaw, ...prev];
|
||||
if (merged.length > MAX_OLDER_MESSAGES) {
|
||||
// Backward: discard the oldest (front) items — user has scrolled far
|
||||
// back and we shed the furthest history.
|
||||
// Forward: discard the newest (tail) items — we only ever fetch
|
||||
// forward, so the tail is the most recently appended page; shedding
|
||||
// it means the sentinel stalls, which is safer than discarding the
|
||||
// beginning of the conversation the user is here to read.
|
||||
return forwardPaginated
|
||||
? merged.slice(0, MAX_OLDER_MESSAGES)
|
||||
: merged.slice(merged.length - MAX_OLDER_MESSAGES);
|
||||
return merged.slice(merged.length - MAX_OLDER_MESSAGES);
|
||||
}
|
||||
return merged;
|
||||
});
|
||||
|
||||
if (forwardPaginated) {
|
||||
const willTruncateForward = estimatedTotal > MAX_OLDER_MESSAGES;
|
||||
if (willTruncateForward) {
|
||||
// Truncation shed the newest tail. Advance the cursor to the last KEPT
|
||||
// item's sequence so the sentinel re-fetches the discarded items next
|
||||
// time rather than jumping past them.
|
||||
// lastKeptIdx: index within newRaw of the last item that survives.
|
||||
// prev contributes pagedRawMessages.length items; total kept = MAX.
|
||||
const lastKeptIdx = MAX_OLDER_MESSAGES - 1 - pagedRawMessages.length;
|
||||
if (lastKeptIdx >= 0 && lastKeptIdx < newRaw.length) {
|
||||
const lastKeptMsg = newRaw[lastKeptIdx] as { sequence?: number };
|
||||
if (typeof lastKeptMsg?.sequence === "number") {
|
||||
setNewestSequence(lastKeptMsg.sequence);
|
||||
setHasMore(true); // Discarded items still exist — keep sentinel active
|
||||
} else {
|
||||
// Sequence unavailable — fall back; truncated items will be lost
|
||||
setNewestSequence(response.data.newest_sequence ?? null);
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
}
|
||||
} else {
|
||||
// All of newRaw was dropped (already at MAX_OLDER_MESSAGES cap).
|
||||
// Stop to avoid an infinite re-fetch loop at the display cap.
|
||||
setHasMore(false);
|
||||
}
|
||||
} else {
|
||||
setNewestSequence(response.data.newest_sequence ?? null);
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
}
|
||||
setOldestSequence(response.data.oldest_sequence ?? null);
|
||||
if (newRaw.length + olderRawMessages.length >= MAX_OLDER_MESSAGES) {
|
||||
setHasMore(false);
|
||||
} else {
|
||||
setOldestSequence(response.data.oldest_sequence ?? null);
|
||||
if (estimatedTotal >= MAX_OLDER_MESSAGES) {
|
||||
// Backward: accumulated MAX_OLDER_MESSAGES — stop to avoid unbounded memory.
|
||||
setHasMore(false);
|
||||
} else {
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
}
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
}
|
||||
} catch (error) {
|
||||
if (epochRef.current !== requestEpoch) return;
|
||||
@@ -237,22 +164,5 @@ export function useLoadMoreMessages({
|
||||
}
|
||||
}
|
||||
|
||||
function resetPaged() {
|
||||
setPagedRawMessages([]);
|
||||
setOldestSequence(initialOldestSequence);
|
||||
setNewestSequence(initialNewestSequence);
|
||||
// Set hasMore=false during the session-transition window so no loadMore
|
||||
// fires with forward pagination (after_sequence) on the now-active session.
|
||||
// The useEffect will restore hasMore from the parent after the refetch
|
||||
// completes and forwardPaginated switches to false.
|
||||
setHasMore(false);
|
||||
// Clear the loading state so the spinner doesn't stay stuck if a loadMore
|
||||
// was in flight when resetPaged was called.
|
||||
setIsLoadingMore(false);
|
||||
isLoadingMoreRef.current = false;
|
||||
consecutiveErrorsRef.current = 0;
|
||||
epochRef.current += 1;
|
||||
}
|
||||
|
||||
return { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged };
|
||||
return { olderMessages, hasMore, isLoadingMore, loadMore };
|
||||
}
|
||||
|
||||
@@ -1498,7 +1498,7 @@
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Session",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nSupports cursor-based pagination via ``limit``, ``before_sequence``, and\n``after_sequence``. The two cursor parameters are mutually exclusive.\n\nOn the initial load (no cursor provided) of a completed session, messages\nare returned in forward order starting from sequence 0 so the user always\nsees their initial prompt. Active sessions use the legacy newest-first\norder so streaming context is preserved.",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nSupports cursor-based pagination via ``limit`` and ``before_sequence``.\nWhen no pagination params are provided, returns the most recent messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The authenticated user's ID.\n limit: Maximum number of messages to return (1-200, default 50).\n before_sequence: Return messages with sequence < this value (cursor).\n\nReturns:\n SessionDetailResponse: Details for the requested session, including\n active_stream info and pagination metadata.",
|
||||
"operationId": "getV2GetSession",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1516,11 +1516,9 @@
|
||||
"type": "integer",
|
||||
"maximum": 200,
|
||||
"minimum": 1,
|
||||
"description": "Maximum number of messages to return.",
|
||||
"default": 50,
|
||||
"title": "Limit"
|
||||
},
|
||||
"description": "Maximum number of messages to return."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "before_sequence",
|
||||
@@ -1531,24 +1529,8 @@
|
||||
{ "type": "integer", "minimum": 0 },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"description": "Backward pagination cursor. Return messages with sequence number strictly less than this value. Used by active-session load-more. Mutually exclusive with after_sequence.",
|
||||
"title": "Before Sequence"
|
||||
},
|
||||
"description": "Backward pagination cursor. Return messages with sequence number strictly less than this value. Used by active-session load-more. Mutually exclusive with after_sequence."
|
||||
},
|
||||
{
|
||||
"name": "after_sequence",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "integer", "minimum": 0 },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"description": "Forward pagination cursor. Return messages with sequence number strictly greater than this value. Used by completed-session load-more. Mutually exclusive with before_sequence.",
|
||||
"title": "After Sequence"
|
||||
},
|
||||
"description": "Forward pagination cursor. Return messages with sequence number strictly greater than this value. Used by completed-session load-more. Mutually exclusive with before_sequence."
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -13291,15 +13273,6 @@
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Oldest Sequence"
|
||||
},
|
||||
"newest_sequence": {
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Newest Sequence"
|
||||
},
|
||||
"forward_paginated": {
|
||||
"type": "boolean",
|
||||
"title": "Forward Paginated",
|
||||
"default": false
|
||||
},
|
||||
"total_prompt_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Prompt Tokens",
|
||||
|
||||
Reference in New Issue
Block a user