Compare commits

..

1 Commits

Author SHA1 Message Date
Zamil Majdy
b410bdd6e0 fix(backend/copilot): patch toolCalls DB row when flush races ahead of StreamToolInputAvailable
The intermediate flush introduced in #12604 is append-only: rows already in
the DB are never re-saved.  When a flush fires between StreamTextDelta and
StreamToolInputAvailable, the assistant row is written with toolCalls=null
and the tool calls that arrive later are only applied in-memory — the DB
row is never updated.  The frontend silently drops tool calls when the column
is null, making them invisible in the UI.

Fix: after dispatching each adapter_responses batch, if StreamToolInputAvailable
was present AND acc.assistant_response.sequence is already set (flush happened),
issue a targeted UPDATE via update_message_tool_calls() to patch toolCalls on
the existing row.  asyncio.shield() keeps the patch from being cancelled on
GeneratorExit.

Adds regression tests in session_persistence_test.py covering the text→flush→
tool-input sequence.
2026-04-16 20:48:16 +07:00
26 changed files with 836 additions and 2448 deletions

1
.gitignore vendored
View File

@@ -195,3 +195,4 @@ test.db
# Implementation plans (generated by AI agents)
plans/
.claude/worktrees/
test-results/

View File

@@ -18,6 +18,7 @@ from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.message_dedup import acquire_dedup_lock
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -190,8 +191,6 @@ class SessionDetailResponse(BaseModel):
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
has_more_messages: bool = False
oldest_sequence: int | None = None
newest_sequence: int | None = None
forward_paginated: bool = False
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
@@ -456,113 +455,52 @@ async def update_session_title_route(
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(
default=50,
ge=1,
le=200,
description="Maximum number of messages to return.",
),
before_sequence: int | None = Query(
default=None,
ge=0,
description=(
"Backward pagination cursor. Return messages with sequence number "
"strictly less than this value. Used by active-session load-more. "
"Mutually exclusive with after_sequence."
),
),
after_sequence: int | None = Query(
default=None,
ge=0,
description=(
"Forward pagination cursor. Return messages with sequence number "
"strictly greater than this value. Used by completed-session load-more. "
"Mutually exclusive with before_sequence."
),
),
limit: int = Query(default=50, ge=1, le=200),
before_sequence: int | None = Query(default=None, ge=0),
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
Supports cursor-based pagination via ``limit``, ``before_sequence``, and
``after_sequence``. The two cursor parameters are mutually exclusive.
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
On the initial load (no cursor provided) of a completed session, messages
are returned in forward order starting from sequence 0 so the user always
sees their initial prompt. Active sessions use the legacy newest-first
order so streaming context is preserved.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The authenticated user's ID.
limit: Maximum number of messages to return (1-200, default 50).
before_sequence: Return messages with sequence < this value (cursor).
Returns:
SessionDetailResponse: Details for the requested session, including
active_stream info and pagination metadata.
"""
if before_sequence is not None and after_sequence is not None:
raise HTTPException(
status_code=400,
detail="before_sequence and after_sequence are mutually exclusive",
)
is_initial_load = before_sequence is None and after_sequence is None
# Check active stream before the DB query on initial loads so we can
# choose the correct pagination direction (forward for completed sessions,
# newest-first for active ones).
active_session = None
last_message_id = None
if is_initial_load:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
# Completed sessions on initial load start from sequence 0 so the user's
# initial prompt is always visible. Active sessions keep the legacy
# newest-first behavior to preserve streaming context.
from_start = is_initial_load and active_session is None
forward_paginated = from_start or after_sequence is not None
page = await get_chat_messages_paginated(
session_id,
limit,
before_sequence=before_sequence,
after_sequence=after_sequence,
from_start=from_start,
user_id=user_id,
session_id, limit, before_sequence, user_id=user_id
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
# Close the TOCTOU window: if the session was active at pre-check, re-verify
# after the DB fetch. The session may have completed between the two awaits,
# which would have caused messages to be fetched newest-first even though the
# session is now complete. Re-fetch from seq 0 so the initial prompt is
# always visible.
if is_initial_load and active_session is not None:
post_active, _ = await stream_registry.get_active_session(session_id, user_id)
if post_active is None:
active_session = None
last_message_id = None
from_start = True
forward_paginated = True
page = await get_chat_messages_paginated(
session_id,
limit,
before_sequence=None,
after_sequence=None,
from_start=True,
user_id=user_id,
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
messages = [
_strip_injected_context(message.model_dump()) for message in page.messages
]
# Only check active stream on initial load (not on "load more" requests)
active_stream_info = None
if active_session and last_message_id is not None:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
if before_sequence is None:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
# Skip session metadata on "load more" — frontend only needs messages
if not is_initial_load:
if before_sequence is not None:
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
@@ -572,8 +510,6 @@ async def get_session(
active_stream=None,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
newest_sequence=page.newest_sequence,
forward_paginated=forward_paginated,
total_prompt_tokens=0,
total_completion_tokens=0,
)
@@ -590,8 +526,6 @@ async def get_session(
active_stream=active_stream_info,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
newest_sequence=page.newest_sequence,
forward_paginated=forward_paginated,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=page.session.metadata,
@@ -912,6 +846,9 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
# Capture the original message text BEFORE any mutation (attachment enrichment)
# so the idempotency hash is stable across retries.
original_message = request.message
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
@@ -940,36 +877,58 @@ async def stream_chat_post(
)
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message returns None when a duplicate is
# detected — in that case skip enqueue to avoid processing the message twice.
is_duplicate_message = False
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
# ── Idempotency guard ────────────────────────────────────────────────────
# Blocks duplicate executor tasks from concurrent/retried POSTs.
# See backend/copilot/message_dedup.py for the full lifecycle description.
dedup_lock = None
if request.is_user_message:
dedup_lock = await acquire_dedup_lock(
session_id, original_message, sanitized_file_ids
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
is_duplicate_message = (
await append_and_save_message(session_id, message)
) is None
logger.info(f"[STREAM] User message saved for session {session_id}")
if not is_duplicate_message and request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
if dedup_lock is None and (original_message or sanitized_file_ids):
async def _empty_sse() -> AsyncGenerator[str, None]:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return StreamingResponse(
_empty_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
},
)
# Create a task in the stream registry for reconnection support.
# For duplicate messages, skip create_session entirely so the infra-retry
# client subscribes to the *existing* turn's Redis stream and receives the
# in-progress executor output rather than an empty stream.
turn_id = ""
if not is_duplicate_message:
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
#
# If any of these operations raises, release the dedup lock before propagating
# so subsequent retries are not blocked for 30 s.
try:
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
@@ -987,6 +946,7 @@ async def stream_chat_post(
}
},
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
@@ -998,10 +958,10 @@ async def stream_chat_post(
mode=request.mode,
model=request.model,
)
else:
logger.info(
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
)
except Exception:
if dedup_lock:
await dedup_lock.release()
raise
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
@@ -1025,6 +985,12 @@ async def stream_chat_post(
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
# True for every exit path except GeneratorExit (client disconnect).
# On disconnect the backend turn is still running — releasing the lock
# there would reopen the infra-retry duplicate window. The 30 s TTL
# is the fallback. All other exits (normal finish, early return, error)
# should release so the user can re-send the same message.
release_dedup_lock_on_exit = True
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
@@ -1036,7 +1002,7 @@ async def stream_chat_post(
if subscriber_queue is None:
yield StreamFinish().to_sse()
return
return # finally releases dedup_lock
# Read from the subscriber queue and yield to SSE
logger.info(
@@ -1078,7 +1044,7 @@ async def stream_chat_post(
}
},
)
break
break # finally releases dedup_lock
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
@@ -1094,6 +1060,7 @@ async def stream_chat_post(
}
},
)
release_dedup_lock_on_exit = False
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
@@ -1108,7 +1075,10 @@ async def stream_chat_post(
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
# finally releases dedup_lock
finally:
if dedup_lock and release_dedup_lock_on_exit:
await dedup_lock.release()
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:

View File

@@ -133,12 +133,21 @@ def test_stream_chat_rejects_too_many_file_ids():
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
def _mock_stream_internals(
mocker: pytest_mock.MockerFixture,
*,
redis_set_returns: object = True,
):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing RabbitMQ.
validation and enrichment logic without needing Redis/RabbitMQ.
Args:
redis_set_returns: Value returned by the mocked Redis ``set`` call.
``True`` (default) simulates a fresh key (new message);
``None`` simulates a collision (duplicate blocked).
Returns:
A namespace with ``save`` and ``enqueue`` mock objects so
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
callers can make additional assertions about side-effects.
"""
import types
@@ -149,7 +158,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
)
mock_save = mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
@@ -165,9 +174,15 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
return types.SimpleNamespace(
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=redis_set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
return ns
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
@@ -196,29 +211,6 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
assert response.status_code == 200
# ─── Duplicate message dedup ──────────────────────────────────────────
def test_stream_chat_skips_enqueue_for_duplicate_message(
mocker: pytest_mock.MockerFixture,
):
"""When append_and_save_message returns None (duplicate detected),
enqueue_copilot_turn and stream_registry.create_session must NOT be called
to avoid double-processing and to prevent overwriting the active stream's
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
mocks = _mock_stream_internals(mocker)
# Override save to return None — signalling a duplicate
mocks.save.return_value = None
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 200
mocks.enqueue.assert_not_called()
mocks.registry.create_session.assert_not_called()
# ─── UUID format filtering ─────────────────────────────────────────────
@@ -714,6 +706,237 @@ class TestStripInjectedContext:
assert result["content"] == "hello"
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
mocker: pytest_mock.MockerFixture,
) -> None:
"""A second POST with the same message within the 30-s window must return
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
turn complete without creating a ghost response."""
# redis_set_returns=None simulates a collision: the NX key already exists.
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-dup/stream",
json={"message": "duplicate message", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
assert '"finish"' in body
assert "[DONE]" in body
# The empty SSE response must include the AI SDK protocol header so the
# frontend treats it as a valid stream and marks the turn complete.
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
# The duplicate guard must prevent save/enqueue side effects.
ns.save.assert_not_called()
ns.enqueue.assert_not_called()
def test_stream_chat_first_post_proceeds_normally(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The first POST (Redis NX key set successfully) must proceed through the
normal streaming path — no early return."""
ns = _mock_stream_internals(mocker, redis_set_returns=True)
response = client.post(
"/sessions/sess-new/stream",
json={"message": "first message", "is_user_message": True},
)
assert response.status_code == 200
# Redis set must have been called once with the NX flag.
ns.redis.set.assert_called_once()
call_kwargs = ns.redis.set.call_args
assert call_kwargs.kwargs.get("nx") is True
def test_stream_chat_dedup_skipped_for_non_user_messages(
mocker: pytest_mock.MockerFixture,
) -> None:
"""System/assistant messages (is_user_message=False) bypass the dedup
guard — they are injected programmatically and must always be processed."""
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-sys/stream",
json={"message": "system context", "is_user_message": False},
)
# Even though redis_set_returns=None (would block a user message),
# the endpoint must proceed because is_user_message=False.
assert response.status_code == 200
ns.redis.set.assert_not_called()
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup hash must be computed from the original request message,
not the mutated version that has the [Attached files] block appended.
A file_id is sent so the route actually appends the [Attached files] block,
exercising the mutation path — the hash must still match the original text."""
import hashlib
ns = _mock_stream_internals(mocker, redis_set_returns=True)
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
# Mock workspace + prisma so the attachment block is actually appended.
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
fake_file = type(
"F",
(),
{
"id": file_id,
"name": "doc.pdf",
"mimeType": "application/pdf",
"sizeBytes": 1024,
},
)()
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-hash/stream",
json={
"message": "plain message",
"is_user_message": True,
"file_ids": [file_id],
},
)
assert response.status_code == 200
ns.redis.set.assert_called_once()
call_args = ns.redis.set.call_args
dedup_key = call_args.args[0]
# Hash must use the original message + sorted file IDs, not the mutated text.
expected_hash = hashlib.sha256(
f"sess-hash:plain message:{file_id}".encode()
).hexdigest()[:16]
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
assert dedup_key == expected_key, (
f"Dedup key {dedup_key!r} does not match expected {expected_key!r}"
"hash may be using mutated message or wrong inputs"
)
def test_stream_chat_dedup_key_released_after_stream_finish(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup Redis key must be deleted after the turn completes (when
subscriber_queue is None the route yields StreamFinish immediately and
should release the key so the user can re-send the same message)."""
from unittest.mock import AsyncMock as _AsyncMock
# Set up all internals manually so we can control subscribe_to_session.
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
response = client.post(
"/sessions/sess-finish/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
assert '"finish"' in body
# The dedup key must be released so intentional re-sends are allowed.
mock_redis.delete.assert_called_once()
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The route must not crash when the dedup Redis delete fails on the
subscriber_queue-is-None early-finish path (except Exception: pass)."""
from unittest.mock import AsyncMock as _AsyncMock
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
# Make the delete raise so the except-pass branch is exercised.
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
# Should not raise even though delete fails.
response = client.post(
"/sessions/sess-finish-err/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
assert '"finish"' in response.text
# delete must have been attempted — the except-pass branch silenced the error.
mock_redis.delete.assert_called_once()
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
@@ -757,146 +980,3 @@ def test_disconnect_stream_returns_404_when_session_missing(
assert response.status_code == 404
mock_disconnect.assert_not_awaited()
# ─── GET /sessions/{session_id} — forward/backward pagination ──────────────────
def _make_paginated_messages(
mocker: pytest_mock.MockerFixture, *, has_more: bool = False
):
"""Return a mock PaginatedMessages and configure the DB patch."""
from datetime import UTC, datetime
from backend.copilot.db import PaginatedMessages
from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata
now = datetime.now(UTC)
session_info = ChatSessionInfo(
session_id="sess-1",
user_id=TEST_USER_ID,
usage=[],
started_at=now,
updated_at=now,
metadata=ChatSessionMetadata(),
)
page = PaginatedMessages(
messages=[ChatMessage(role="user", content="hello", sequence=0)],
has_more=has_more,
oldest_sequence=0,
newest_sequence=0,
session=session_info,
)
mock_paginate = mocker.patch(
"backend.api.features.chat.routes.get_chat_messages_paginated",
new_callable=AsyncMock,
return_value=page,
)
return page, mock_paginate
def test_get_session_completed_returns_forward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Completed sessions (no active stream) return forward_paginated=True."""
_make_paginated_messages(mocker)
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(None, None),
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
assert data["forward_paginated"] is True
assert data["newest_sequence"] == 0
def test_get_session_active_returns_backward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Active sessions (with running stream) return forward_paginated=False."""
from backend.copilot.stream_registry import ActiveSession
_make_paginated_messages(mocker)
active = MagicMock(spec=ActiveSession)
active.turn_id = "turn-1"
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(active, "msg-1"),
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
assert data["forward_paginated"] is False
assert data["active_stream"] is not None
assert data["active_stream"]["turn_id"] == "turn-1"
def test_get_session_after_sequence_returns_forward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""after_sequence param returns forward_paginated=True; no stream check needed."""
_, mock_paginate = _make_paginated_messages(mocker)
response = client.get("/sessions/sess-1?after_sequence=10")
assert response.status_code == 200
data = response.json()
assert data["forward_paginated"] is True
call_kwargs = mock_paginate.call_args
assert call_kwargs.kwargs.get("after_sequence") == 10
assert call_kwargs.kwargs.get("before_sequence") is None
def test_get_session_both_cursors_returns_400(
test_user_id: str,
) -> None:
"""Sending both before_sequence and after_sequence returns 400."""
response = client.get("/sessions/sess-1?before_sequence=5&after_sequence=10")
assert response.status_code == 400
def test_get_session_toctou_refetch_when_session_completes_mid_request(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Race condition: session was active at pre-check but completes before DB fetch.
The route should detect the race via a post-fetch re-check, then re-fetch
from seq 0 so the initial prompt is always visible.
"""
from backend.copilot.stream_registry import ActiveSession
page, mock_paginate = _make_paginated_messages(mocker)
active = MagicMock(spec=ActiveSession)
active.turn_id = "turn-1"
# First call: session appears active. Second call: session has completed.
mock_get_active = mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
side_effect=[(active, "msg-1"), (None, None)],
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
# Post-race: session is now completed → forward_paginated=True, no stream
assert data["forward_paginated"] is True
assert data["active_stream"] is None
# The DB was queried twice: once newest-first, once from_start=True
assert mock_paginate.call_count == 2
assert mock_get_active.call_count == 2
second_call = mock_paginate.call_args_list[1]
assert second_call.kwargs.get("from_start") is True

View File

@@ -10,11 +10,9 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatMessageWhereInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
FindManyChatMessageArgsFromChatSession,
)
from pydantic import BaseModel
@@ -32,8 +30,6 @@ from .model import get_chat_session as get_chat_session_cached
logger = logging.getLogger(__name__)
_BOUNDARY_SCAN_LIMIT = 10
class PaginatedMessages(BaseModel):
"""Result of a paginated message query."""
@@ -41,7 +37,6 @@ class PaginatedMessages(BaseModel):
messages: list[ChatMessage]
has_more: bool
oldest_sequence: int | None
newest_sequence: int | None
session: ChatSessionInfo
@@ -66,48 +61,32 @@ async def get_chat_messages_paginated(
session_id: str,
limit: int = 50,
before_sequence: int | None = None,
after_sequence: int | None = None,
from_start: bool = False,
user_id: str | None = None,
) -> PaginatedMessages | None:
"""Get paginated messages for a session.
"""Get paginated messages for a session, newest first.
Three modes:
Verifies session existence (and ownership when ``user_id`` is provided)
in parallel with the message query. Returns ``None`` when the session
is not found or does not belong to the user.
- ``before_sequence`` set: backward pagination (DESC), returns messages
with sequence < ``before_sequence``. Used for active sessions or manual
backward navigation.
- ``from_start=True`` or ``after_sequence`` set: forward pagination (ASC).
Returns messages from sequence 0 (``from_start``) or after
``after_sequence``. Used on initial load of completed sessions and for
loading subsequent forward pages.
- Both cursors ``None`` and ``from_start=False``: newest-first (DESC
without filter). Used for active sessions on initial load.
Verifies session existence (and ownership when ``user_id`` is provided).
Returns ``None`` when the session is not found or does not belong to the
user.
Args:
session_id: The chat session ID.
limit: Max messages to return.
before_sequence: Cursor — return messages with sequence < this value.
user_id: If provided, filters via ``Session.userId`` so only the
session owner's messages are returned (acts as an ownership guard).
"""
# Build session-existence / ownership check
session_where: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
session_where["userId"] = user_id
forward = from_start or after_sequence is not None
# Build message include — fetch paginated messages in the same query.
# Note: when both from_start=True and after_sequence is not None, the
# after_sequence filter takes precedence (the elif branch below is skipped).
# This combination is not reachable via the HTTP route (mutual exclusion is
# enforced there), so we rely on the documented priority here without an
# additional assertion.
msg_include: FindManyChatMessageArgsFromChatSession = {
"order_by": {"sequence": "asc" if forward else "desc"},
# Build message include — fetch paginated messages in the same query
msg_include: dict[str, Any] = {
"order_by": {"sequence": "desc"},
"take": limit + 1,
}
if after_sequence is not None:
msg_include["where"] = {"sequence": {"gt": after_sequence}}
elif before_sequence is not None:
if before_sequence is not None:
msg_include["where"] = {"sequence": {"lt": before_sequence}}
# Single query: session existence/ownership + paginated messages
@@ -125,96 +104,57 @@ async def get_chat_messages_paginated(
has_more = len(results) > limit
results = results[:limit]
if not forward:
# Backward mode: DB returned DESC; reverse to ascending order.
results.reverse()
# Reverse to ascending order
results.reverse()
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
if results and results[0].role == "tool":
boundary_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
_BOUNDARY_SCAN_LIMIT = 10
if results and results[0].role == "tool":
boundary_where: dict[str, Any] = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
else:
# Forward mode: DB returned ASC.
# Tool-call tail boundary fix: if the last message in this page is a
# tool message, the NEXT forward page would start after it and begin
# mid-tool-group — the owning assistant message is on this page but
# the following tool results are on the next page.
# Trim the current page so it ends on the owning assistant message,
# which keeps tool groups intact across page boundaries.
if results and results[-1].role == "tool":
# Walk backward through results to find the last non-tool message.
trim_idx = len(results) - 1
while trim_idx >= 0 and results[trim_idx].role == "tool":
trim_idx -= 1
if trim_idx >= 0:
# Trim results so the page ends at the owning assistant.
# Mark has_more=True so the client knows to fetch the rest.
results = results[: trim_idx + 1]
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
else:
# Entire page is tool messages with no visible owner — log and
# keep as-is so the caller is not stuck with an empty page.
logger.warning(
"Forward tail boundary: entire page is tool messages "
"for session=%s, no owning assistant found (%d msgs)",
session_id,
len(results),
)
messages = [ChatMessage.from_db(m) for m in results]
# oldest_sequence is only meaningful in backward mode (used as backward
# pagination cursor). In forward mode the page always starts near seq 0
# and clients should use newest_sequence as the forward cursor instead.
# Return None in forward mode so clients don't accidentally treat it as a
# backward cursor on a forward-paginated session.
oldest_sequence = messages[0].sequence if (messages and not forward) else None
# newest_sequence is only meaningful in forward mode; in backward mode it
# points to the last message of the page (not the session's newest message)
# which is not a valid forward cursor. Return None in backward mode so
# clients don't accidentally use it as one.
newest_sequence = messages[-1].sequence if (messages and forward) else None
oldest_sequence = messages[0].sequence if messages else None
return PaginatedMessages(
messages=messages,
has_more=has_more,
oldest_sequence=oldest_sequence,
newest_sequence=newest_sequence,
session=session_info,
)
@@ -608,6 +548,46 @@ async def update_message_content_by_sequence(
return False
async def update_message_tool_calls(
session_id: str,
sequence: int,
tool_calls: list[dict],
) -> bool:
"""Patch the toolCalls column of an already-saved assistant message.
Called when StreamToolInputAvailable arrives after an intermediate flush
saved the assistant message with tool_calls=None. The DB save is
append-only (uses get_next_sequence), so the already-persisted row must
be updated in-place to reflect the tool_calls that arrived later.
Args:
session_id: The chat session ID.
sequence: The 0-based sequence number of the assistant message to patch.
tool_calls: The full list of tool call dicts to set on the row.
Returns:
True if the row was found and updated, False otherwise.
"""
try:
result = await PrismaChatMessage.prisma().update_many(
where={"sessionId": session_id, "sequence": sequence},
data={"toolCalls": SafeJson(tool_calls)},
)
if result == 0:
logger.warning(
f"update_message_tool_calls: no row found for session {session_id}, "
f"sequence {sequence}"
)
return False
return True
except Exception as e:
logger.error(
f"update_message_tool_calls failed for session {session_id}, "
f"sequence {sequence}: {e}"
)
return False
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
"""Set durationMs on the last assistant message in a session.

View File

@@ -175,187 +175,6 @@ async def test_no_where_on_messages_without_before_sequence(
assert "where" not in include["Messages"]
# ---------- Forward pagination (from_start / after_sequence) ----------
@pytest.mark.asyncio
async def test_from_start_uses_asc_order_no_where(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""from_start=True queries messages in ASC order with no where filter."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["order_by"] == {"sequence": "asc"}
assert "where" not in include["Messages"]
@pytest.mark.asyncio
async def test_from_start_returns_messages_ascending(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""from_start=True returns messages in ascending sequence order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
assert page is not None
assert [m.sequence for m in page.messages] == [0, 1, 2]
assert (
page.oldest_sequence is None
) # None in forward mode — not a valid backward cursor
assert page.newest_sequence == 2
assert page.has_more is False
@pytest.mark.asyncio
async def test_from_start_has_more_when_results_exceed_limit(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""from_start=True sets has_more when DB returns more than limit items."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=2, from_start=True)
assert page is not None
assert page.has_more is True
assert [m.sequence for m in page.messages] == [0, 1]
assert page.newest_sequence == 1
@pytest.mark.asyncio
async def test_after_sequence_uses_gt_filter_asc_order(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""after_sequence adds a sequence > N where clause and uses ASC order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(11), _make_msg(12)],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["order_by"] == {"sequence": "asc"}
assert include["Messages"]["where"] == {"sequence": {"gt": 10}}
@pytest.mark.asyncio
async def test_after_sequence_returns_messages_in_order(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""after_sequence returns only messages with sequence > cursor, ascending."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(11), _make_msg(12), _make_msg(13)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
assert page is not None
assert [m.sequence for m in page.messages] == [11, 12, 13]
assert (
page.oldest_sequence is None
) # None in forward mode — not a valid backward cursor
assert page.newest_sequence == 13
assert page.has_more is False
@pytest.mark.asyncio
async def test_newest_sequence_none_for_backward_mode(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""newest_sequence is None in backward mode — it is not a valid forward cursor."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5), _make_msg(4), _make_msg(3)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.newest_sequence is None
assert page.oldest_sequence == 3
@pytest.mark.asyncio
async def test_forward_mode_no_boundary_expansion(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pagination never triggers backward boundary expansion."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0, role="tool"), _make_msg(1, role="tool")],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
assert find_many.call_count == 0
@pytest.mark.asyncio
async def test_forward_tail_boundary_trims_trailing_tool_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pages that end with tool messages are trimmed to the owning
assistant so the next after_sequence page doesn't start mid-tool-group."""
find_first, _ = mock_db
# DB returns 4 messages ASC: assistant at 0, tool at 1, tool at 2, tool at 3
find_first.return_value = _make_session(
messages=[
_make_msg(0, role="assistant"),
_make_msg(1, role="tool"),
_make_msg(2, role="tool"),
_make_msg(3, role="tool"),
],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
assert page is not None
# Page should be trimmed to end at the assistant message
assert [m.sequence for m in page.messages] == [0]
assert page.newest_sequence == 0
# has_more must be True so the client fetches the tool messages on next page
assert page.has_more is True
@pytest.mark.asyncio
async def test_forward_tail_boundary_no_trim_when_last_not_tool(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pages that end with a non-tool message are not trimmed."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[
_make_msg(0, role="user"),
_make_msg(1, role="assistant"),
_make_msg(2, role="tool"),
_make_msg(3, role="assistant"),
],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
assert page is not None
assert [m.sequence for m in page.messages] == [0, 1, 2, 3]
assert page.newest_sequence == 3
assert page.has_more is False
@pytest.mark.asyncio
async def test_user_id_filter_applied_to_session_where(
mock_db: tuple[AsyncMock, AsyncMock],

View File

@@ -0,0 +1,71 @@
"""Per-request idempotency lock for the /stream endpoint.
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
Lifecycle
---------
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
``None`` when the key already exists (duplicate request).
2. ``release()`` — deletes the key. Must be called on turn completion or turn
error so the next legitimate send is never blocked.
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
the backend turn is still running, and releasing would reopen the duplicate
window for infra-level retries. The 30 s TTL is the safety net.
"""
import hashlib
import logging
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_KEY_PREFIX = "chat:msg_dedup"
_TTL_SECONDS = 30
class _DedupLock:
def __init__(self, key: str, redis) -> None:
self._key = key
self._redis = redis
async def release(self) -> None:
"""Best-effort key deletion. The TTL handles failures silently."""
try:
await self._redis.delete(self._key)
except Exception:
pass
async def acquire_dedup_lock(
session_id: str,
message: str | None,
file_ids: list[str] | None,
) -> _DedupLock | None:
"""Acquire the idempotency lock for this (session, message, files) tuple.
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
Returns ``None`` when a duplicate is detected (lock already held).
Returns ``None`` when there is nothing to deduplicate (no message, no files).
"""
if not message and not file_ids:
return None
sorted_ids = ":".join(sorted(file_ids or []))
content_hash = hashlib.sha256(
f"{session_id}:{message or ''}:{sorted_ids}".encode()
).hexdigest()[:16]
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
redis = await get_redis_async()
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
if not acquired:
logger.warning(
f"[STREAM] Duplicate user message blocked for session {session_id}, "
f"hash={content_hash} — returning empty SSE",
)
return None
return _DedupLock(key, redis)

View File

@@ -0,0 +1,94 @@
"""Unit tests for backend.copilot.message_dedup."""
from unittest.mock import AsyncMock
import pytest
import pytest_mock
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
return mock_redis
@pytest.mark.asyncio
async def test_acquire_returns_none_when_no_message_no_files(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Nothing to deduplicate — no Redis call made, None returned."""
mock_redis = _patch_redis(mocker, set_returns=True)
result = await acquire_dedup_lock("sess-1", None, None)
assert result is None
mock_redis.set.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_returns_lock_on_first_request(
mocker: pytest_mock.MockerFixture,
) -> None:
"""First request acquires the lock and returns a _DedupLock."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
mock_redis.set.assert_called_once()
key_arg = mock_redis.set.call_args.args[0]
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
@pytest.mark.asyncio
async def test_acquire_returns_none_on_duplicate(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Duplicate request (NX fails) returns None to signal the caller."""
_patch_redis(mocker, set_returns=None)
result = await acquire_dedup_lock("sess-1", "hello", None)
assert result is None
@pytest.mark.asyncio
async def test_acquire_key_stable_across_file_order(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs are sorted before hashing so order doesn't affect the key."""
mock_redis_1 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
key_ab = mock_redis_1.set.call_args.args[0]
mock_redis_2 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
key_ba = mock_redis_2.set.call_args.args[0]
assert key_ab == key_ba
@pytest.mark.asyncio
async def test_release_deletes_key(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() calls Redis delete exactly once."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release()
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_release_swallows_redis_error(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() must not raise even when Redis delete fails."""
mock_redis = _patch_redis(mocker, set_returns=True)
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release() # must not raise
mock_redis.delete.assert_called_once()

View File

@@ -1,8 +1,9 @@
import asyncio
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any, AsyncIterator, Self, cast
from typing import Any, Self, cast
from weakref import WeakValueDictionary
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@@ -521,7 +522,10 @@ async def upsert_chat_session(
callers are aware of the persistence failure.
RedisError: If the cache write fails (after successful DB write).
"""
async with _get_session_lock(session.session_id) as _:
# Acquire session-specific lock to prevent concurrent upserts
lock = await _get_session_lock(session.session_id)
async with lock:
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
@@ -647,50 +651,20 @@ async def _save_session_to_db(
msg.sequence = existing_message_count + i
async def append_and_save_message(
session_id: str, message: ChatMessage
) -> ChatSession | None:
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Returns the updated session, or None if the message was detected as a
duplicate (idempotency guard). Callers must check for None and skip any
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
The idempotency check below provides a last-resort guard when the lock degrades.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
async with _get_session_lock(session_id) as lock_acquired:
# When the lock degraded (Redis down or 2s timeout), bypass cache for
# the idempotency check. Stale cache could let two concurrent writers
# both see the old state, pass the check, and write the same message.
if lock_acquired:
session = await get_chat_session(session_id)
else:
session = await _get_session_from_db(session_id)
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
# Idempotency: skip if the trailing block of same-role messages already
# contains this content. Uses is_message_duplicate which checks all
# consecutive trailing messages of the same role, not just [-1].
#
# This collapses infra/nginx retries whether they land on the same pod
# (serialised by the Redis lock) or a different pod.
#
# Legit same-text messages are distinguished by the assistant turn
# between them: if the user said "yes", got a response, and says
# "yes" again, session.messages[-1] is the assistant reply, so the
# role check fails and the second message goes through normally.
#
# Edge case: if a turn dies without writing any assistant message,
# the user's next send of the same text is blocked here permanently.
# The fix is to ensure failed turns always write an error/timeout
# assistant message so the session always ends on an assistant turn.
if message.content is not None and is_message_duplicate(
session.messages, message.role, message.content
):
return None # duplicate — caller should skip enqueue
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
@@ -705,9 +679,6 @@ async def append_and_save_message(
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
# Invalidate the stale entry so future reads fall back to DB,
# preventing a retry from bypassing the idempotency check above.
await invalidate_session_cache(session_id)
return session
@@ -793,6 +764,10 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
@@ -857,38 +832,25 @@ async def update_session_title(
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
@asynccontextmanager
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
"""Distributed Redis lock for a session, usable as an async context manager.
Yields True if the lock was acquired, False if it timed out or Redis was
unavailable. Callers should treat False as a degraded mode and prefer fresh
DB reads over cache to avoid acting on stale state.
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
"""
_lock_key = f"copilot:session_lock:{session_id}"
lock = None
acquired = False
try:
_redis = await get_redis_async()
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
acquired = await lock.acquire(blocking=True)
if not acquired:
logger.warning(
"Could not acquire session lock for %s within 2s", session_id
)
except Exception as e:
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
try:
yield acquired
finally:
if acquired and lock is not None:
try:
await lock.release()
except Exception:
pass # TTL will expire the key
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock

View File

@@ -11,13 +11,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from pytest_mock import MockerFixture
from .model import (
ChatMessage,
ChatSession,
Usage,
append_and_save_message,
get_chat_session,
is_message_duplicate,
maybe_append_user_message,
@@ -576,345 +574,3 @@ def test_maybe_append_assistant_skips_duplicate():
result = maybe_append_user_message(session, "dup", is_user_message=False)
assert result is False
assert len(session.messages) == 2
# --------------------------------------------------------------------------- #
# append_and_save_message #
# --------------------------------------------------------------------------- #
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
s = ChatSession.new(user_id="u1", dry_run=False)
s.messages = list(msgs)
return s
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_returns_none_for_duplicate(
mocker: MockerFixture,
) -> None:
"""append_and_save_message returns None when the trailing message is a duplicate."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="hello")
)
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_appends_new_message(
mocker: MockerFixture,
) -> None:
"""append_and_save_message appends a non-duplicate message and returns the session."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=2)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="second message")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None
assert result.messages[-1].content == "second message"
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_when_session_not_found(
mocker: MockerFixture,
) -> None:
"""append_and_save_message raises ValueError when the session does not exist."""
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=None,
)
with pytest.raises(ValueError, match="not found"):
await append_and_save_message(
"missing-session-id", ChatMessage(role="user", content="hi")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_lock_degraded(
mocker: MockerFixture,
) -> None:
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
# DB path was used (not cache-first)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_database_error_on_save_failure(
mocker: MockerFixture,
) -> None:
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
from backend.util.exceptions import DatabaseError
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("db down"),
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
with pytest.raises(DatabaseError):
await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
mocker: MockerFixture,
) -> None:
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("redis write failed"),
)
mock_invalidate = mocker.patch(
"backend.copilot.model.invalidate_session_cache",
new_callable=mocker.AsyncMock,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
# DB write succeeded, cache invalidation was called
mock_invalidate.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_redis_unavailable(
mocker: MockerFixture,
) -> None:
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
side_effect=ConnectionError("redis down"),
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_lock_release_failure_is_ignored(
mocker: MockerFixture,
) -> None:
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock(
side_effect=RuntimeError("release failed")
)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None

View File

@@ -44,6 +44,7 @@ from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from ..config import ChatConfig, CopilotLlmModel, CopilotMode
from ..db import update_message_tool_calls
from ..constants import (
COPILOT_ERROR_PREFIX,
COPILOT_RETRYABLE_ERROR_PREFIX,
@@ -2262,6 +2263,30 @@ async def _run_stream_attempt(
if dispatched is not None:
yield dispatched
# If tool calls arrived this batch AND the assistant message was
# already flushed to DB (sequence is set), patch the existing row
# so tool_calls are not lost. The append-only save (start_sequence)
# in _save_session_to_db never re-saves already-persisted rows, so
# without this patch the assistant row keeps tool_calls=null.
if acc.assistant_response.sequence is not None and any(
isinstance(r, StreamToolInputAvailable) for r in adapter_responses
):
try:
await asyncio.shield(
update_message_tool_calls(
ctx.session.session_id,
acc.assistant_response.sequence,
acc.accumulated_tool_calls,
)
)
except Exception as patch_err:
logger.warning(
"%s tool_calls DB patch failed (sequence=%d): %s",
ctx.log_prefix,
acc.assistant_response.sequence,
patch_err,
)
# Append assistant entry AFTER convert_message so that
# any stashed tool results from the previous turn are
# recorded first, preserving the required API order:

View File

@@ -20,7 +20,11 @@ from datetime import datetime, timezone
from unittest.mock import MagicMock
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import StreamStartStep, StreamTextDelta
from backend.copilot.response_model import (
StreamStartStep,
StreamTextDelta,
StreamToolInputAvailable,
)
from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator
_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc)
@@ -215,3 +219,100 @@ class TestPreCreateAssistantMessage:
_simulate_pre_create(acc, ctx)
assert len(ctx.session.messages) == 0
class TestToolCallsLostAfterIntermediateFlush:
"""Regression tests for the bug where tool_calls are lost when an
intermediate flush saves the assistant message before StreamToolInputAvailable
arrives.
Sequence that triggers the bug:
1. StreamTextDelta → assistant message appended with tool_calls=None
2. Intermediate flush fires (time/count threshold) → DB row written with tool_calls=null
and acc.assistant_response.sequence is set (back-filled)
3. StreamToolInputAvailable → acc.assistant_response.tool_calls mutated in-memory
4. Final save: append-only — assistant row already in DB, tool_calls never updated
Fix: when StreamToolInputAvailable arrives and acc.assistant_response.sequence
is not None, issue a DB UPDATE to patch toolCalls on the existing row.
"""
def test_text_delta_then_tool_input_sets_tool_calls_on_message(self) -> None:
"""After text arrives then tool input arrives, acc.assistant_response.tool_calls
should be populated regardless of flush state."""
session = _make_session()
ctx = _make_ctx(session)
state = _make_state()
acc = _StreamAccumulator(
assistant_response=ChatMessage(role="assistant", content=""),
accumulated_tool_calls=[],
)
# Step 1: text delta arrives, message appended
_dispatch_response(
StreamTextDelta(id="t1", delta="Let me run that for you."),
acc,
ctx,
state,
False,
"[test]",
)
assert acc.has_appended_assistant
assert session.messages[-1].tool_calls is None
# Step 2: simulate intermediate flush back-filling the sequence
acc.assistant_response.sequence = 1 # back-filled by _save_session_to_db
# Step 3: tool input arrives
_dispatch_response(
StreamToolInputAvailable(
toolCallId="call_abc",
toolName="bash_exec",
input={"command": "ls"},
),
acc,
ctx,
state,
False,
"[test]",
)
# tool_calls should be set in memory
assert acc.assistant_response.tool_calls is not None
assert len(acc.assistant_response.tool_calls) == 1
assert acc.assistant_response.tool_calls[0]["id"] == "call_abc"
def test_sequence_set_when_flush_occurred_before_tool_input(self) -> None:
"""When sequence is back-filled (flush happened) before tool calls arrive,
it is detectable so the caller can issue a DB patch."""
acc = _StreamAccumulator(
assistant_response=ChatMessage(role="assistant", content="hello"),
accumulated_tool_calls=[],
has_appended_assistant=True,
)
# Simulate flush back-fill
acc.assistant_response.sequence = 3
ctx = _make_ctx()
state = _make_state()
_dispatch_response(
StreamToolInputAvailable(
toolCallId="call_xyz",
toolName="run_block",
input={},
),
acc,
ctx,
state,
False,
"[test]",
)
# Caller should detect this condition and issue a DB patch
needs_db_patch = acc.assistant_response.sequence is not None and bool(
acc.accumulated_tool_calls
)
assert (
needs_db_patch
), "Expected needs_db_patch=True when flush happened before tool calls arrived"

View File

@@ -423,33 +423,20 @@ async def subscribe_to_session(
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
# RACE CONDITION FIX: If session not found, retry with backoff.
# Duplicate requests skip create_session and subscribe immediately; the
# original request's create_session (a Redis hset) may not have completed
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
# original request before the hset even starts.
# RACE CONDITION FIX: If session not found, retry once after small delay
# This handles the case where subscribe_to_session is called immediately
# after create_session but before Redis propagates the write
if not meta:
_max_retries = 3
_retry_delay = 0.1 # 100ms per attempt
for attempt in range(_max_retries):
logger.warning(
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
f"retrying after {int(_retry_delay * 1000)}ms",
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
)
await asyncio.sleep(_retry_delay)
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
logger.info(
f"[TIMING] Session found after {attempt + 1} retries",
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
)
break
else:
logger.warning(
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
extra={"json_fields": {**log_meta}},
)
await asyncio.sleep(0.05) # 50ms
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
f"({elapsed:.1f}ms total)",
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
extra={
"json_fields": {
**log_meta,
@@ -459,6 +446,10 @@ async def subscribe_to_session(
},
)
return None
logger.info(
"[TIMING] Session found after retry",
extra={"json_fields": {**log_meta}},
)
# Note: Redis client uses decode_responses=True, so keys are strings
session_status = meta.get("status", "")

View File

@@ -93,7 +93,6 @@ export function CopilotPage() {
hasMoreMessages,
isLoadingMore,
loadMore,
forwardPaginated,
// Mobile drawer
isMobile,
isDrawerOpen,
@@ -218,7 +217,6 @@ export function CopilotPage() {
hasMoreMessages={hasMoreMessages}
isLoadingMore={isLoadingMore}
onLoadMore={loadMore}
forwardPaginated={forwardPaginated}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
historicalDurations={historicalDurations}

View File

@@ -1,122 +0,0 @@
import { renderHook } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useChatSession } from "../useChatSession";
const mockUseGetV2GetSession = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useGetV2GetSession: (...args: unknown[]) => mockUseGetV2GetSession(...args),
usePostV2CreateSession: () => ({ mutateAsync: vi.fn(), isPending: false }),
getGetV2GetSessionQueryKey: (id: string) => ["session", id],
getGetV2ListSessionsQueryKey: () => ["sessions"],
}));
vi.mock("@tanstack/react-query", () => ({
useQueryClient: () => ({
invalidateQueries: vi.fn(),
setQueryData: vi.fn(),
}),
}));
vi.mock("nuqs", () => ({
parseAsString: { withDefault: (v: unknown) => v },
useQueryState: () => ["sess-1", vi.fn()],
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
convertChatSessionMessagesToUiMessages: vi.fn(() => ({
messages: [],
historicalDurations: new Map(),
})),
}));
vi.mock("../helpers", () => ({
resolveSessionDryRun: vi.fn(() => false),
}));
vi.mock("@sentry/nextjs", () => ({
captureException: vi.fn(),
}));
function makeQueryResult(data: object | null) {
return {
data: data ? { status: 200, data } : undefined,
isLoading: false,
isError: false,
isFetching: false,
refetch: vi.fn(),
};
}
describe("useChatSession — newestSequence and forwardPaginated", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("returns null / false when no session data", () => {
mockUseGetV2GetSession.mockReturnValue(makeQueryResult(null));
const { result } = renderHook(() => useChatSession());
expect(result.current.newestSequence).toBeNull();
expect(result.current.forwardPaginated).toBe(false);
});
it("returns newestSequence from session data", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 0,
newest_sequence: 99,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.newestSequence).toBe(99);
});
it("returns null for newestSequence when field is missing", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: false,
oldest_sequence: 0,
newest_sequence: null,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.newestSequence).toBeNull();
});
it("returns forwardPaginated=true when session is forward-paginated", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 0,
newest_sequence: 49,
forward_paginated: true,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.forwardPaginated).toBe(true);
});
it("returns forwardPaginated=false when session is backward-paginated", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 50,
newest_sequence: 99,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.forwardPaginated).toBe(false);
});
});

View File

@@ -1,202 +0,0 @@
import { act, renderHook, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useCopilotPage } from "../useCopilotPage";
const mockUseChatSession = vi.fn();
const mockUseCopilotStream = vi.fn();
const mockUseLoadMoreMessages = vi.fn();
vi.mock("../useChatSession", () => ({
useChatSession: (...args: unknown[]) => mockUseChatSession(...args),
}));
vi.mock("../useCopilotStream", () => ({
useCopilotStream: (...args: unknown[]) => mockUseCopilotStream(...args),
}));
vi.mock("../useLoadMoreMessages", () => ({
useLoadMoreMessages: (...args: unknown[]) => mockUseLoadMoreMessages(...args),
}));
vi.mock("../useCopilotNotifications", () => ({
useCopilotNotifications: () => undefined,
}));
vi.mock("../useWorkflowImportAutoSubmit", () => ({
useWorkflowImportAutoSubmit: () => undefined,
}));
vi.mock("../store", () => ({
useCopilotUIStore: () => ({
sessionToDelete: null,
setSessionToDelete: vi.fn(),
isDrawerOpen: false,
setDrawerOpen: vi.fn(),
copilotChatMode: "chat",
copilotLlmModel: null,
isDryRun: false,
}),
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
concatWithAssistantMerge: (a: unknown[], b: unknown[]) => [...a, ...b],
}));
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useDeleteV2DeleteSession: () => ({ mutate: vi.fn(), isPending: false }),
useGetV2ListSessions: () => ({ data: undefined, isLoading: false }),
getGetV2ListSessionsQueryKey: () => ["sessions"],
}));
vi.mock("@/components/molecules/Toast/use-toast", () => ({
toast: vi.fn(),
}));
vi.mock("@/lib/direct-upload", () => ({
uploadFileDirect: vi.fn(),
}));
vi.mock("@/lib/hooks/useBreakpoint", () => ({
useBreakpoint: () => "lg",
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: () => ({ isUserLoading: false, isLoggedIn: true }),
}));
vi.mock("@tanstack/react-query", () => ({
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
}));
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
useGetFlag: () => false,
}));
function makeBaseChatSession(overrides: Record<string, unknown> = {}) {
return {
sessionId: "sess-1",
setSessionId: vi.fn(),
hydratedMessages: [],
rawSessionMessages: [],
historicalDurations: new Map(),
hasActiveStream: false,
hasMoreMessages: false,
oldestSequence: null,
newestSequence: null,
forwardPaginated: false,
isLoadingSession: false,
isSessionError: false,
createSession: vi.fn(),
isCreatingSession: false,
refetchSession: vi.fn(),
sessionDryRun: false,
...overrides,
};
}
function makeBaseCopilotStream(overrides: Record<string, unknown> = {}) {
return {
messages: [],
sendMessage: vi.fn(),
stop: vi.fn(),
status: "ready",
error: undefined,
isReconnecting: false,
isSyncing: false,
isUserStoppingRef: { current: false },
rateLimitMessage: null,
dismissRateLimit: vi.fn(),
...overrides,
};
}
function makeBaseLoadMore(overrides: Record<string, unknown> = {}) {
return {
pagedMessages: [],
hasMore: false,
isLoadingMore: false,
loadMore: vi.fn(),
resetPaged: vi.fn(),
...overrides,
};
}
describe("useCopilotPage — forwardPaginated message ordering", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("prepends pagedMessages before currentMessages when forwardPaginated=false", () => {
const pagedMsg = { id: "paged", role: "user" };
const currentMsg = { id: "current", role: "assistant" };
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: false }),
);
mockUseCopilotStream.mockReturnValue(
makeBaseCopilotStream({ messages: [currentMsg] }),
);
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
);
const { result } = renderHook(() => useCopilotPage());
// Backward: pagedMessages (older) come first
expect(result.current.messages[0]).toEqual(pagedMsg);
expect(result.current.messages[1]).toEqual(currentMsg);
});
it("appends pagedMessages after currentMessages when forwardPaginated=true", () => {
const pagedMsg = { id: "paged", role: "assistant" };
const currentMsg = { id: "current", role: "user" };
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: true }),
);
mockUseCopilotStream.mockReturnValue(
makeBaseCopilotStream({ messages: [currentMsg] }),
);
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
);
const { result } = renderHook(() => useCopilotPage());
// Forward: currentMessages (beginning of session) come first
expect(result.current.messages[0]).toEqual(currentMsg);
expect(result.current.messages[1]).toEqual(pagedMsg);
});
it("calls resetPaged when forwardPaginated transitions false→true with paged messages", async () => {
const mockResetPaged = vi.fn();
const pagedMsg = { id: "paged", role: "user" };
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: false }),
);
mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream());
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({
pagedMessages: [pagedMsg],
resetPaged: mockResetPaged,
}),
);
const { rerender } = renderHook(() => useCopilotPage());
// Simulate session completing — forwardPaginated flips to true
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: true }),
);
act(() => {
rerender();
});
await waitFor(() => {
expect(mockResetPaged).toHaveBeenCalled();
});
});
it("does not call resetPaged when forwardPaginated is already true on mount", () => {
const mockResetPaged = vi.fn();
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: true }),
);
mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream());
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({ pagedMessages: [], resetPaged: mockResetPaged }),
);
renderHook(() => useCopilotPage());
expect(mockResetPaged).not.toHaveBeenCalled();
});
});

View File

@@ -1,568 +0,0 @@
import { act, renderHook, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useLoadMoreMessages } from "../useLoadMoreMessages";
const mockGetV2GetSession = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
getV2GetSession: (...args: unknown[]) => mockGetV2GetSession(...args),
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
convertChatSessionMessagesToUiMessages: vi.fn(() => ({ messages: [] })),
extractToolOutputsFromRaw: vi.fn(() => []),
}));
const BASE_ARGS = {
sessionId: "sess-1",
initialOldestSequence: 0,
initialNewestSequence: 49,
initialHasMore: true,
forwardPaginated: true,
initialPageRawMessages: [],
};
function makeSuccessResponse(overrides: {
messages?: unknown[];
has_more_messages?: boolean;
oldest_sequence?: number;
newest_sequence?: number;
}) {
return {
status: 200,
data: {
messages: overrides.messages ?? [],
has_more_messages: overrides.has_more_messages ?? false,
oldest_sequence: overrides.oldest_sequence ?? 0,
newest_sequence: overrides.newest_sequence ?? 49,
},
};
}
describe("useLoadMoreMessages", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("initialises with empty pagedMessages and correct cursors", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("resetPaged clears paged state and sets hasMore=false during transition", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
act(() => {
result.current.resetPaged();
});
expect(result.current.pagedMessages).toHaveLength(0);
// hasMore must be false during transition to prevent forward loadMore
// from firing on the now-active session before forwardPaginated updates.
expect(result.current.hasMore).toBe(false);
expect(result.current.isLoadingMore).toBe(false);
});
it("resetPaged exposes a fresh loadMore via incremented epoch", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
// Just verify resetPaged is callable and doesn't throw.
expect(() => {
act(() => {
result.current.resetPaged();
});
}).not.toThrow();
});
it("resets all state on sessionId change", () => {
const { result, rerender } = renderHook(
(props) => useLoadMoreMessages(props),
{ initialProps: BASE_ARGS },
);
rerender({
...BASE_ARGS,
sessionId: "sess-2",
initialOldestSequence: 10,
initialNewestSequence: 59,
initialHasMore: false,
});
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.hasMore).toBe(false);
expect(result.current.isLoadingMore).toBe(false);
});
describe("loadMore — forward pagination", () => {
it("calls getV2GetSession with after_sequence and updates newestSequence", async () => {
const rawMsg = { role: "user", content: "hi", sequence: 50 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 49 }),
);
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("sets hasMore=false when response has no more messages", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false }),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(false);
});
it("is a no-op when hasMore is false", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
initialHasMore: false,
forwardPaginated: true,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
});
describe("loadMore — backward pagination", () => {
it("calls getV2GetSession with before_sequence", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [{ role: "user", content: "old", sequence: 0 }],
has_more_messages: false,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: 50,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenCalledWith(
"sess-1",
expect.objectContaining({ before_sequence: 50 }),
);
expect(result.current.hasMore).toBe(false);
});
});
describe("loadMore — error handling", () => {
it("does not set hasMore=false on first error", async () => {
mockGetV2GetSession.mockRejectedValueOnce(new Error("network error"));
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
// First error — hasMore still true
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) errors", async () => {
mockGetV2GetSession.mockRejectedValue(new Error("network error"));
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
for (let i = 0; i < 3; i++) {
await act(async () => {
await result.current.loadMore();
});
// Reset the in-flight guard between calls
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
}
expect(result.current.hasMore).toBe(false);
});
it("ignores non-200 response and increments error count", async () => {
mockGetV2GetSession.mockResolvedValueOnce({ status: 500, data: {} });
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
// One error, not yet at threshold — hasMore still true
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) non-200 responses", async () => {
mockGetV2GetSession.mockResolvedValue({ status: 503, data: {} });
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
for (let i = 0; i < 3; i++) {
await act(async () => {
await result.current.loadMore();
});
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
}
expect(result.current.hasMore).toBe(false);
});
it("discards in-flight error when epoch changes mid-flight (resetPaged called)", async () => {
let rejectRequest!: (e: Error) => void;
mockGetV2GetSession.mockReturnValueOnce(
new Promise((_, rej) => {
rejectRequest = rej;
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
act(() => {
result.current.loadMore();
});
// Reset epoch mid-flight
act(() => {
result.current.resetPaged();
});
// Reject the in-flight request — stale error should be discarded
await act(async () => {
rejectRequest(new Error("network error"));
});
// State unchanged: no hasMore=false, no errorCount, isLoadingMore cleared
expect(result.current.hasMore).toBe(false); // false from resetPaged
expect(result.current.isLoadingMore).toBe(false);
});
});
describe("loadMore — forward pagination cursor advancement", () => {
it("advances newestSequence after a successful forward load", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [{ role: "user", content: "hi", sequence: 50 }],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
// A second loadMore should use after_sequence: 99 (advanced cursor)
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 99 }),
);
});
it("does not regress newestSequence when parent refetches after pages loaded", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [{ role: "user", content: "msg", sequence: 50 }],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result, rerender } = renderHook(
(props) => useLoadMoreMessages(props),
{ initialProps: { ...BASE_ARGS, forwardPaginated: true } },
);
// Load one page — newestSequence advances to 99
await act(async () => {
await result.current.loadMore();
});
// Parent refetches with a lower newest_sequence (49) — should NOT regress cursor
rerender({
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: 49,
});
// Next loadMore should still use the advanced cursor (99)
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 99 }),
);
});
});
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
// Single load of 2001 messages exceeds the limit in one shot.
// This avoids relying on cross-render closure staleness: estimatedTotal =
// pagedRawMessages.length (0, fresh) + 2001 = 2001 >= 2000 → hasMore=false.
const args = { ...BASE_ARGS, forwardPaginated: false };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 2001 }, (_, i) => ({
role: "user",
content: `msg ${i}`,
sequence: i,
})),
has_more_messages: true,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() => useLoadMoreMessages(args));
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(false);
});
it("forward truncation keeps first MAX_OLDER_MESSAGES items (not last)", async () => {
// 1990 messages already paged; load 20 more forward — total 2010 > 2000.
// Forward truncation must keep slice(0, 2000), not slice(-2000),
// to preserve the beginning of the conversation.
const forwardNearLimitArgs = {
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: 49,
initialOldestSequence: 0,
initialHasMore: true,
};
const { result } = renderHook((props) => useLoadMoreMessages(props), {
initialProps: forwardNearLimitArgs,
});
// First load: 1990 messages — advances newestSequence to 2039
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 1990 }, (_, i) => ({
role: "assistant",
content: `msg ${i + 50}`,
sequence: i + 50,
})),
has_more_messages: true,
newest_sequence: 2039,
}),
);
await act(async () => {
await result.current.loadMore();
});
// Second load: 20 more messages pushes total to 2010 > 2000.
// Truncation keeps seq 50..2049 (2000 items); discards seq 2050..2059 (10 items).
// Even though the server says has_more_messages=false, hasMore stays true
// because there are discarded items that need to be re-fetched.
// The cursor (newestSequence) advances to 2049 — the last kept item's sequence.
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 20 }, (_, i) => ({
role: "assistant",
content: `msg ${i + 2040}`,
sequence: i + 2040,
})),
has_more_messages: false,
newest_sequence: 2059,
}),
);
await act(async () => {
await result.current.loadMore();
});
// Truncation occurred (2010 > 2000): hasMore=true so discarded items can be fetched.
// Cursor advances to last kept item (seq 2049), not the server's newest (2059).
await waitFor(() => expect(result.current.hasMore).toBe(true));
});
});
describe("loadMore — null cursor guard", () => {
it("is a no-op when newestSequence is null (forwardPaginated=true)", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: null,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
it("is a no-op when oldestSequence is null (forwardPaginated=false)", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: null,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
});
describe("pagedMessages — initialPageRawMessages extraToolOutputs", () => {
it("calls extractToolOutputsFromRaw for backward pagination with non-empty initialPageRawMessages", async () => {
const { extractToolOutputsFromRaw } = await import(
"../helpers/convertChatSessionToUiMessages"
);
const rawMsg = { role: "user", content: "old", sequence: 0 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: false,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: 50,
initialPageRawMessages: [{ role: "assistant", content: "response" }],
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(extractToolOutputsFromRaw).toHaveBeenCalled();
});
it("does NOT call extractToolOutputsFromRaw for forward pagination", async () => {
const { extractToolOutputsFromRaw } = await import(
"../helpers/convertChatSessionToUiMessages"
);
const rawMsg = { role: "assistant", content: "hi", sequence: 50 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: false,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: true,
initialPageRawMessages: [{ role: "user", content: "hello" }],
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(extractToolOutputsFromRaw).not.toHaveBeenCalled();
});
});
describe("loadMore — epoch / stale-response guard", () => {
it("discards response when epoch changes during flight (resetPaged called)", async () => {
let resolveRequest!: (v: unknown) => void;
mockGetV2GetSession.mockReturnValueOnce(
new Promise((res) => {
resolveRequest = res;
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
// Start the request without awaiting
act(() => {
result.current.loadMore();
});
// Reset epoch mid-flight
act(() => {
result.current.resetPaged();
});
// Now resolve the in-flight request
await act(async () => {
resolveRequest(
makeSuccessResponse({ messages: [{ role: "user", content: "hi" }] }),
);
});
// Response discarded — pagedMessages stays empty, isLoadingMore stays false
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.isLoadingMore).toBe(false);
});
});
});

View File

@@ -30,7 +30,6 @@ export interface ChatContainerProps {
hasMoreMessages?: boolean;
isLoadingMore?: boolean;
onLoadMore?: () => void;
forwardPaginated?: boolean;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -55,7 +54,6 @@ export const ChatContainer = ({
hasMoreMessages,
isLoadingMore,
onLoadMore,
forwardPaginated,
droppedFiles,
onDroppedFilesConsumed,
historicalDurations,
@@ -110,7 +108,6 @@ export const ChatContainer = ({
hasMoreMessages={hasMoreMessages}
isLoadingMore={isLoadingMore}
onLoadMore={onLoadMore}
forwardPaginated={forwardPaginated}
onRetry={handleRetry}
historicalDurations={historicalDurations}
/>

View File

@@ -43,10 +43,6 @@ interface Props {
hasMoreMessages?: boolean;
isLoadingMore?: boolean;
onLoadMore?: () => void;
/** When true the load-more sentinel is placed at the bottom (forward
* pagination for completed sessions). When false it is at the top
* (backward pagination for active sessions). */
forwardPaginated?: boolean;
onRetry?: () => void;
historicalDurations?: Map<string, number>;
}
@@ -140,25 +136,11 @@ export function LoadMoreSentinel({
isLoading,
messageCount,
onLoadMore,
rootMargin = "200px 0px 0px 0px",
adjustScroll = true,
forwardPaginated = false,
}: {
hasMore: boolean;
isLoading: boolean;
messageCount: number;
onLoadMore: () => void;
/** IntersectionObserver rootMargin. Top sentinel uses "200px 0px 0px 0px"
* (pre-trigger when approaching from above); bottom sentinel should use
* "0px 0px 200px 0px" (pre-trigger when approaching from below). */
rootMargin?: string;
/** Whether to adjust scrollTop after load to preserve visual position.
* True for backward pagination (prepend above); false for forward
* pagination (append below) where no adjustment is needed. */
adjustScroll?: boolean;
/** When true the button reads "Load newer messages" (forward pagination).
* When false (default) it reads "Load older messages". */
forwardPaginated?: boolean;
}) {
const sentinelRef = useRef<HTMLDivElement>(null);
const onLoadMoreRef = useRef(onLoadMore);
@@ -203,11 +185,11 @@ export function LoadMoreSentinel({
if (autoFillRoundsRef.current >= MAX_AUTO_FILL_ROUNDS) return;
captureAndLoad(true);
},
{ rootMargin },
{ rootMargin: "200px 0px 0px 0px" },
);
observer.observe(sentinelRef.current);
return () => observer.disconnect();
}, [hasMore, isLoading, rootMargin, scrollRef]);
}, [hasMore, isLoading, scrollRef]);
// After React commits new DOM nodes (prepended messages), adjust
// scrollTop so the user stays at the same visual position.
@@ -220,9 +202,7 @@ export function LoadMoreSentinel({
scrollSnapshotRef.current;
if (!el || prevHeight === 0) return;
const delta = el.scrollHeight - prevHeight;
// Only restore scroll position for backward pagination (content prepended
// above). Forward pagination appends below — no adjustment needed.
if (adjustScroll && delta > 0) {
if (delta > 0) {
el.scrollTop = prevTop + delta;
}
// Reset the auto-fill backoff whenever the container becomes
@@ -236,7 +216,7 @@ export function LoadMoreSentinel({
}
scrollSnapshotRef.current = { scrollHeight: 0, scrollTop: 0 };
autoTriggeredRef.current = false;
}, [adjustScroll, messageCount, scrollRef]);
}, [messageCount, scrollRef]);
return (
<div
@@ -255,7 +235,7 @@ export function LoadMoreSentinel({
size="small"
onClick={() => captureAndLoad(false)}
>
{forwardPaginated ? "Load newer messages" : "Load older messages"}
Load older messages
</Button>
)
)}
@@ -272,7 +252,6 @@ export function ChatMessagesContainer({
hasMoreMessages,
isLoadingMore,
onLoadMore,
forwardPaginated,
onRetry,
historicalDurations,
}: Props) {
@@ -351,7 +330,7 @@ export function ChatMessagesContainer({
}
>
<ConversationContent className="flex min-h-full flex-1 flex-col gap-6 px-3 py-6">
{hasMoreMessages && onLoadMore && !forwardPaginated && (
{hasMoreMessages && onLoadMore && (
<LoadMoreSentinel
hasMore={hasMoreMessages}
isLoading={!!isLoadingMore}
@@ -510,17 +489,6 @@ export function ChatMessagesContainer({
</pre>
</details>
)}
{hasMoreMessages && onLoadMore && forwardPaginated && (
<LoadMoreSentinel
hasMore={hasMoreMessages}
isLoading={!!isLoadingMore}
messageCount={messages.length}
onLoadMore={onLoadMore}
rootMargin="0px 0px 200px 0px"
adjustScroll={false}
forwardPaginated
/>
)}
</ConversationContent>
<ConversationScrollButton />
</Conversation>

View File

@@ -1,173 +0,0 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { ChatMessagesContainer } from "../ChatMessagesContainer";
const mockScrollEl = {
scrollHeight: 100,
scrollTop: 0,
clientHeight: 500,
};
vi.mock("use-stick-to-bottom", () => ({
useStickToBottomContext: () => ({ scrollRef: { current: mockScrollEl } }),
Conversation: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationScrollButton: () => null,
}));
vi.mock("@/components/ai-elements/conversation", () => ({
Conversation: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationScrollButton: () => null,
}));
vi.mock("@/components/ai-elements/message", () => ({
Message: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
MessageContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
MessageActions: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
vi.mock("../components/AssistantMessageActions", () => ({
AssistantMessageActions: () => null,
}));
vi.mock("../components/CopyButton", () => ({ CopyButton: () => null }));
vi.mock("../components/CollapsedToolGroup", () => ({
CollapsedToolGroup: () => null,
}));
vi.mock("../components/MessageAttachments", () => ({
MessageAttachments: () => null,
}));
vi.mock("../components/MessagePartRenderer", () => ({
MessagePartRenderer: () => null,
}));
vi.mock("../components/ReasoningCollapse", () => ({
ReasoningCollapse: () => null,
}));
vi.mock("../components/ThinkingIndicator", () => ({
ThinkingIndicator: () => null,
}));
vi.mock("../../JobStatsBar/TurnStatsBar", () => ({
TurnStatsBar: () => null,
}));
vi.mock("../../JobStatsBar/useElapsedTimer", () => ({
useElapsedTimer: () => ({ elapsedSeconds: 0 }),
}));
vi.mock("../../CopilotPendingReviews/CopilotPendingReviews", () => ({
CopilotPendingReviews: () => null,
}));
vi.mock("../helpers", () => ({
buildRenderSegments: () => [],
getTurnMessages: () => [],
parseSpecialMarkers: () => ({ markerType: null }),
splitReasoningAndResponse: (parts: unknown[]) => ({
reasoningParts: [],
responseParts: parts,
}),
}));
type ObserverCallback = (entries: { isIntersecting: boolean }[]) => void;
class MockIntersectionObserver {
static lastCallback: ObserverCallback | null = null;
private callback: ObserverCallback;
constructor(cb: ObserverCallback) {
this.callback = cb;
MockIntersectionObserver.lastCallback = cb;
}
observe() {}
disconnect() {}
unobserve() {}
takeRecords() {
return [];
}
root = null;
rootMargin = "";
thresholds = [];
}
const BASE_PROPS = {
messages: [],
status: "ready" as const,
error: undefined,
isLoading: false,
sessionID: "sess-1",
hasMoreMessages: true,
isLoadingMore: false,
onLoadMore: vi.fn(),
onRetry: vi.fn(),
};
describe("ChatMessagesContainer", () => {
beforeEach(() => {
mockScrollEl.scrollHeight = 100;
mockScrollEl.scrollTop = 0;
mockScrollEl.clientHeight = 500;
MockIntersectionObserver.lastCallback = null;
vi.stubGlobal("IntersectionObserver", MockIntersectionObserver);
});
afterEach(() => {
cleanup();
vi.unstubAllGlobals();
});
it("renders top sentinel when forwardPaginated is false (backward pagination)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={false} />);
expect(
screen.getByRole("button", { name: /load older messages/i }),
).toBeDefined();
});
it("renders top sentinel when forwardPaginated is undefined (default, backward)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} />);
expect(
screen.getByRole("button", { name: /load older messages/i }),
).toBeDefined();
});
it("renders bottom sentinel when forwardPaginated is true (forward pagination)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={true} />);
expect(
screen.getByRole("button", { name: /load newer messages/i }),
).toBeDefined();
});
it("hides sentinel when hasMoreMessages is false", () => {
render(
<ChatMessagesContainer
{...BASE_PROPS}
hasMoreMessages={false}
forwardPaginated={true}
/>,
);
expect(
screen.queryByRole("button", { name: /load older messages/i }),
).toBeNull();
});
it("hides sentinel when onLoadMore is not provided", () => {
render(
<ChatMessagesContainer
{...BASE_PROPS}
onLoadMore={undefined}
forwardPaginated={true}
/>,
);
expect(
screen.queryByRole("button", { name: /load older messages/i }),
).toBeNull();
});
});

View File

@@ -172,36 +172,6 @@ describe("LoadMoreSentinel", () => {
expect(mockScrollEl.scrollTop).toBe(200);
});
it("does NOT adjust scroll when adjustScroll=false (forward pagination)", () => {
mockScrollEl.scrollHeight = 100;
mockScrollEl.scrollTop = 50;
const onLoadMore = vi.fn();
const { rerender } = render(
<LoadMoreSentinel
hasMore={true}
isLoading={false}
messageCount={5}
onLoadMore={onLoadMore}
adjustScroll={false}
/>,
);
// Fire observer to capture snapshot.
MockIntersectionObserver.lastCallback?.([{ isIntersecting: true }]);
// Simulate DOM growing from appended newer messages (forward load-more).
mockScrollEl.scrollHeight = 300;
rerender(
<LoadMoreSentinel
hasMore={true}
isLoading={false}
messageCount={10}
onLoadMore={onLoadMore}
adjustScroll={false}
/>,
);
// scrollTop should remain unchanged — no jump for forward pagination.
expect(mockScrollEl.scrollTop).toBe(50);
});
it("ignores same-frame duplicate triggers until isLoading transitions", () => {
const onLoadMore = vi.fn();
const { rerender } = render(

View File

@@ -1,59 +0,0 @@
import { describe, expect, it } from "vitest";
import { convertChatSessionMessagesToUiMessages } from "../convertChatSessionToUiMessages";
const SESSION_ID = "sess-test";
describe("convertChatSessionMessagesToUiMessages", () => {
it("does not drop user messages with null content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: null, sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
it("does not drop user messages with empty string content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: "", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
it("still drops non-user messages with null content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "assistant", content: null, sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(0);
});
it("still drops non-user messages with empty string content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "assistant", content: "", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(0);
});
it("includes user message with normal content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: "hello", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
});

View File

@@ -253,11 +253,6 @@ export function convertChatSessionMessagesToUiMessages(
}
}
// User messages must always be rendered, even with empty content, so the
// initial prompt is visible when reloading a session.
if (parts.length === 0 && msg.role === "user") {
parts.push({ type: "text", text: "", state: "done" });
}
if (parts.length === 0) return;
// Merge consecutive assistant messages into a single UIMessage

View File

@@ -86,16 +86,6 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
return sessionQuery.data.data.oldest_sequence ?? null;
}, [sessionQuery.data]);
const newestSequence = useMemo(() => {
if (sessionQuery.data?.status !== 200) return null;
return sessionQuery.data.data.newest_sequence ?? null;
}, [sessionQuery.data]);
const forwardPaginated = useMemo(() => {
if (sessionQuery.data?.status !== 200) return false;
return !!sessionQuery.data.data.forward_paginated;
}, [sessionQuery.data]);
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
// array reference every render. Re-derives only when query data changes.
// When the session is complete (no active stream), mark dangling tool
@@ -195,8 +185,6 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
hasActiveStream,
hasMoreMessages,
oldestSequence,
newestSequence,
forwardPaginated,
isLoadingSession: sessionQuery.isLoading,
isSessionError: sessionQuery.isError,
createSession,

View File

@@ -56,8 +56,6 @@ export function useCopilotPage() {
hasActiveStream,
hasMoreMessages,
oldestSequence,
newestSequence,
forwardPaginated,
isLoadingSession,
isSessionError,
createSession,
@@ -86,26 +84,18 @@ export function useCopilotPage() {
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
});
const { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged } =
const { olderMessages, hasMore, isLoadingMore, loadMore } =
useLoadMoreMessages({
sessionId,
initialOldestSequence: oldestSequence,
initialNewestSequence: newestSequence,
initialHasMore: hasMoreMessages,
forwardPaginated,
initialPageRawMessages: rawSessionMessages,
});
// Combine paginated messages with current page messages, merging consecutive
// assistant UIMessages at the page boundary so reasoning + response parts
// stay in a single bubble.
// Forward pagination (completed sessions): current page is the beginning,
// paged messages are newer pages appended after.
// Backward pagination (active sessions): paged messages are older history
// prepended before the current page.
const messages = forwardPaginated
? concatWithAssistantMerge(currentMessages, pagedMessages)
: concatWithAssistantMerge(pagedMessages, currentMessages);
// Combine older (paginated) messages with current page messages,
// merging consecutive assistant UIMessages at the page boundary so
// reasoning + response parts stay in a single bubble.
const messages = concatWithAssistantMerge(olderMessages, currentMessages);
useCopilotNotifications(sessionId);
@@ -180,23 +170,6 @@ export function useCopilotPage() {
}
}, [sessionId, pendingMessage, sendMessage]);
// --- Clear backward-paginated messages when session completes ---
// When a session transitions from active (forwardPaginated=false) to complete
// (forwardPaginated=true), any backward-paginated older messages would be
// appended after currentMessages instead of before, causing chronological
// disorder. Reset paged state so the completed session renders cleanly.
const prevForwardPaginatedRef = useRef(forwardPaginated);
useEffect(() => {
if (
!prevForwardPaginatedRef.current &&
forwardPaginated &&
pagedMessages.length > 0
) {
resetPaged();
}
prevForwardPaginatedRef.current = forwardPaginated;
}, [forwardPaginated, pagedMessages.length, resetPaged]);
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
useWorkflowImportAutoSubmit({
createSession,
@@ -278,15 +251,6 @@ export function useCopilotPage() {
isUserStoppingRef.current = false;
if (sessionId) {
// When continuing a completed session that had forward-paginated history
// loaded, the paged messages would appear in wrong position relative to
// the new streaming turn (pagedMessages are newer pages, so they'd end
// up after the streaming turn). Reset paged state so ordering is correct
// during streaming; the user can reload history afterward if needed.
if (forwardPaginated && pagedMessages.length > 0) {
resetPaged();
}
if (files && files.length > 0) {
setIsUploadingFiles(true);
try {
@@ -433,7 +397,6 @@ export function useCopilotPage() {
hasMoreMessages: hasMore,
isLoadingMore,
loadMore,
forwardPaginated,
// Mobile drawer
isMobile,
isDrawerOpen,

View File

@@ -9,11 +9,7 @@ import {
interface UseLoadMoreMessagesArgs {
sessionId: string | null;
initialOldestSequence: number | null;
initialNewestSequence: number | null;
initialHasMore: boolean;
/** True when the initial page was loaded from sequence 0 forward (completed
* sessions). False when loaded newest-first (active sessions). */
forwardPaginated: boolean;
/** Raw messages from the initial page, used for cross-page tool output matching. */
initialPageRawMessages: unknown[];
}
@@ -24,21 +20,16 @@ const MAX_OLDER_MESSAGES = 2000;
export function useLoadMoreMessages({
sessionId,
initialOldestSequence,
initialNewestSequence,
initialHasMore,
forwardPaginated,
initialPageRawMessages,
}: UseLoadMoreMessagesArgs) {
// Accumulated raw messages from all extra pages (ascending order).
// Store accumulated raw messages from all older pages (in ascending order).
// Re-converting them all together ensures tool outputs are matched across
// inter-page boundaries.
const [pagedRawMessages, setPagedRawMessages] = useState<unknown[]>([]);
const [olderRawMessages, setOlderRawMessages] = useState<unknown[]>([]);
const [oldestSequence, setOldestSequence] = useState<number | null>(
initialOldestSequence,
);
const [newestSequence, setNewestSequence] = useState<number | null>(
initialNewestSequence,
);
const [hasMore, setHasMore] = useState(initialHasMore);
const [isLoadingMore, setIsLoadingMore] = useState(false);
const isLoadingMoreRef = useRef(false);
@@ -55,7 +46,7 @@ export function useLoadMoreMessages({
// The parent's `initialOldestSequence` drifts forward every time the
// session query refetches (e.g. after a stream completes — see
// `useCopilotStream` invalidation on `streaming → ready`). If we
// wiped `pagedRawMessages` every time that happened, users who had
// wiped `olderRawMessages` every time that happened, users who had
// scrolled back would lose their loaded history on each new turn and
// subsequent `loadMore` calls would fetch messages that overlap with
// the AI SDK's retained state in `currentMessages`, producing visible
@@ -72,9 +63,8 @@ export function useLoadMoreMessages({
// Session changed — full reset
prevSessionIdRef.current = sessionId;
prevInitialOldestRef.current = initialOldestSequence;
setPagedRawMessages([]);
setOlderRawMessages([]);
setOldestSequence(initialOldestSequence);
setNewestSequence(initialNewestSequence);
setHasMore(initialHasMore);
setIsLoadingMore(false);
isLoadingMoreRef.current = false;
@@ -85,64 +75,49 @@ export function useLoadMoreMessages({
prevInitialOldestRef.current = initialOldestSequence;
// If we haven't paged yet, mirror the parent so the first
// If we haven't paged back yet, mirror the parent so the first
// `loadMore` starts from the correct cursor.
//
// When paged messages exist (pagedRawMessages.length > 0) we intentionally
// do NOT update `hasMore` or `newestSequence` from the parent. A parent
// refetch (e.g. after a new turn completes) may carry a fresh
// `initialHasMore=true` or a larger `initialNewestSequence`, but those
// reflect the *initial* page window, not the forward-paged window we have
// already advanced into. Overwriting the local cursor here would cause the
// next `loadMore` to re-fetch pages we already have. The local cursor is
// advanced correctly inside `loadMore` itself via `setNewestSequence`.
if (pagedRawMessages.length === 0) {
if (olderRawMessages.length === 0) {
setOldestSequence(initialOldestSequence);
// Only regress the forward cursor if we haven't paged ahead yet —
// otherwise a parent refetch would reset a cursor we already advanced.
setNewestSequence((prev) =>
prev !== null && prev > (initialNewestSequence ?? -1)
? prev
: initialNewestSequence,
);
setHasMore(initialHasMore);
}
}, [sessionId, initialOldestSequence, initialNewestSequence, initialHasMore]);
}, [sessionId, initialOldestSequence, initialHasMore]);
// Convert all accumulated raw messages in one pass so tool outputs
// are matched across inter-page boundaries.
// For backward pagination only: include initial page tool outputs so older
// paged pages can match tool calls whose outputs landed in the initial page.
// For forward pagination this is unnecessary — tool calls in newer paged
// pages cannot have their outputs in the older initial page.
const pagedMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
// are matched across inter-page boundaries. Initial page tool outputs
// are included via extraToolOutputs to handle the boundary between
// the last older page and the initial/streaming page.
const olderMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
useMemo(() => {
if (!sessionId || pagedRawMessages.length === 0) return [];
if (!sessionId || olderRawMessages.length === 0) return [];
const extraToolOutputs =
!forwardPaginated && initialPageRawMessages.length > 0
initialPageRawMessages.length > 0
? extractToolOutputsFromRaw(initialPageRawMessages)
: undefined;
return convertChatSessionMessagesToUiMessages(
sessionId,
pagedRawMessages,
olderRawMessages,
{ isComplete: true, extraToolOutputs },
).messages;
}, [sessionId, pagedRawMessages, initialPageRawMessages, forwardPaginated]);
}, [sessionId, olderRawMessages, initialPageRawMessages]);
async function loadMore() {
if (!sessionId || !hasMore || isLoadingMoreRef.current) return;
const cursor = forwardPaginated ? newestSequence : oldestSequence;
if (cursor === null) return;
if (
!sessionId ||
!hasMore ||
isLoadingMoreRef.current ||
oldestSequence === null
)
return;
const requestEpoch = epochRef.current;
isLoadingMoreRef.current = true;
setIsLoadingMore(true);
try {
const params = forwardPaginated
? { limit: 50, after_sequence: cursor }
: { limit: 50, before_sequence: cursor };
const response = await getV2GetSession(sessionId, params);
const response = await getV2GetSession(sessionId, {
limit: 50,
before_sequence: oldestSequence,
});
// Discard response if session/pagination was reset while awaiting
if (epochRef.current !== requestEpoch) return;
@@ -161,66 +136,18 @@ export function useLoadMoreMessages({
consecutiveErrorsRef.current = 0;
const newRaw = (response.data.messages ?? []) as unknown[];
// Estimate total after merge using the closure-captured pagedRawMessages.length.
// This is a safe approximation: worst case it's one page stale (one extra load
// allowed), but it avoids the React-18-batching pitfall where a functional
// updater's mutations are not visible until the next render.
const estimatedTotal = pagedRawMessages.length + newRaw.length;
setPagedRawMessages((prev) => {
// Forward: append to end. Backward: prepend to start.
const merged = forwardPaginated
? [...prev, ...newRaw]
: [...newRaw, ...prev];
setOlderRawMessages((prev) => {
const merged = [...newRaw, ...prev];
if (merged.length > MAX_OLDER_MESSAGES) {
// Backward: discard the oldest (front) items — user has scrolled far
// back and we shed the furthest history.
// Forward: discard the newest (tail) items — we only ever fetch
// forward, so the tail is the most recently appended page; shedding
// it means the sentinel stalls, which is safer than discarding the
// beginning of the conversation the user is here to read.
return forwardPaginated
? merged.slice(0, MAX_OLDER_MESSAGES)
: merged.slice(merged.length - MAX_OLDER_MESSAGES);
return merged.slice(merged.length - MAX_OLDER_MESSAGES);
}
return merged;
});
if (forwardPaginated) {
const willTruncateForward = estimatedTotal > MAX_OLDER_MESSAGES;
if (willTruncateForward) {
// Truncation shed the newest tail. Advance the cursor to the last KEPT
// item's sequence so the sentinel re-fetches the discarded items next
// time rather than jumping past them.
// lastKeptIdx: index within newRaw of the last item that survives.
// prev contributes pagedRawMessages.length items; total kept = MAX.
const lastKeptIdx = MAX_OLDER_MESSAGES - 1 - pagedRawMessages.length;
if (lastKeptIdx >= 0 && lastKeptIdx < newRaw.length) {
const lastKeptMsg = newRaw[lastKeptIdx] as { sequence?: number };
if (typeof lastKeptMsg?.sequence === "number") {
setNewestSequence(lastKeptMsg.sequence);
setHasMore(true); // Discarded items still exist — keep sentinel active
} else {
// Sequence unavailable — fall back; truncated items will be lost
setNewestSequence(response.data.newest_sequence ?? null);
setHasMore(!!response.data.has_more_messages);
}
} else {
// All of newRaw was dropped (already at MAX_OLDER_MESSAGES cap).
// Stop to avoid an infinite re-fetch loop at the display cap.
setHasMore(false);
}
} else {
setNewestSequence(response.data.newest_sequence ?? null);
setHasMore(!!response.data.has_more_messages);
}
setOldestSequence(response.data.oldest_sequence ?? null);
if (newRaw.length + olderRawMessages.length >= MAX_OLDER_MESSAGES) {
setHasMore(false);
} else {
setOldestSequence(response.data.oldest_sequence ?? null);
if (estimatedTotal >= MAX_OLDER_MESSAGES) {
// Backward: accumulated MAX_OLDER_MESSAGES — stop to avoid unbounded memory.
setHasMore(false);
} else {
setHasMore(!!response.data.has_more_messages);
}
setHasMore(!!response.data.has_more_messages);
}
} catch (error) {
if (epochRef.current !== requestEpoch) return;
@@ -237,22 +164,5 @@ export function useLoadMoreMessages({
}
}
function resetPaged() {
setPagedRawMessages([]);
setOldestSequence(initialOldestSequence);
setNewestSequence(initialNewestSequence);
// Set hasMore=false during the session-transition window so no loadMore
// fires with forward pagination (after_sequence) on the now-active session.
// The useEffect will restore hasMore from the parent after the refetch
// completes and forwardPaginated switches to false.
setHasMore(false);
// Clear the loading state so the spinner doesn't stay stuck if a loadMore
// was in flight when resetPaged was called.
setIsLoadingMore(false);
isLoadingMoreRef.current = false;
consecutiveErrorsRef.current = 0;
epochRef.current += 1;
}
return { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged };
return { olderMessages, hasMore, isLoadingMore, loadMore };
}

View File

@@ -1498,7 +1498,7 @@
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Session",
"description": "Retrieve the details of a specific chat session.\n\nSupports cursor-based pagination via ``limit``, ``before_sequence``, and\n``after_sequence``. The two cursor parameters are mutually exclusive.\n\nOn the initial load (no cursor provided) of a completed session, messages\nare returned in forward order starting from sequence 0 so the user always\nsees their initial prompt. Active sessions use the legacy newest-first\norder so streaming context is preserved.",
"description": "Retrieve the details of a specific chat session.\n\nSupports cursor-based pagination via ``limit`` and ``before_sequence``.\nWhen no pagination params are provided, returns the most recent messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The authenticated user's ID.\n limit: Maximum number of messages to return (1-200, default 50).\n before_sequence: Return messages with sequence < this value (cursor).\n\nReturns:\n SessionDetailResponse: Details for the requested session, including\n active_stream info and pagination metadata.",
"operationId": "getV2GetSession",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1516,11 +1516,9 @@
"type": "integer",
"maximum": 200,
"minimum": 1,
"description": "Maximum number of messages to return.",
"default": 50,
"title": "Limit"
},
"description": "Maximum number of messages to return."
}
},
{
"name": "before_sequence",
@@ -1531,24 +1529,8 @@
{ "type": "integer", "minimum": 0 },
{ "type": "null" }
],
"description": "Backward pagination cursor. Return messages with sequence number strictly less than this value. Used by active-session load-more. Mutually exclusive with after_sequence.",
"title": "Before Sequence"
},
"description": "Backward pagination cursor. Return messages with sequence number strictly less than this value. Used by active-session load-more. Mutually exclusive with after_sequence."
},
{
"name": "after_sequence",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{ "type": "integer", "minimum": 0 },
{ "type": "null" }
],
"description": "Forward pagination cursor. Return messages with sequence number strictly greater than this value. Used by completed-session load-more. Mutually exclusive with before_sequence.",
"title": "After Sequence"
},
"description": "Forward pagination cursor. Return messages with sequence number strictly greater than this value. Used by completed-session load-more. Mutually exclusive with before_sequence."
}
}
],
"responses": {
@@ -13291,15 +13273,6 @@
"anyOf": [{ "type": "integer" }, { "type": "null" }],
"title": "Oldest Sequence"
},
"newest_sequence": {
"anyOf": [{ "type": "integer" }, { "type": "null" }],
"title": "Newest Sequence"
},
"forward_paginated": {
"type": "boolean",
"title": "Forward Paginated",
"default": false
},
"total_prompt_tokens": {
"type": "integer",
"title": "Total Prompt Tokens",