Merge remote-tracking branch 'origin/dev' into feat/task-decomposition-copilot

# Conflicts:
#	autogpt_platform/backend/backend/copilot/model.py
This commit is contained in:
anvyle
2026-04-16 19:05:24 +02:00
115 changed files with 11115 additions and 2612 deletions

View File

@@ -18,7 +18,6 @@ from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.message_dedup import acquire_dedup_lock
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -192,6 +191,8 @@ class SessionDetailResponse(BaseModel):
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
has_more_messages: bool = False
oldest_sequence: int | None = None
newest_sequence: int | None = None
forward_paginated: bool = False
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
@@ -456,52 +457,113 @@ async def update_session_title_route(
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=200),
before_sequence: int | None = Query(default=None, ge=0),
limit: int = Query(
default=50,
ge=1,
le=200,
description="Maximum number of messages to return.",
),
before_sequence: int | None = Query(
default=None,
ge=0,
description=(
"Backward pagination cursor. Return messages with sequence number "
"strictly less than this value. Used by active-session load-more. "
"Mutually exclusive with after_sequence."
),
),
after_sequence: int | None = Query(
default=None,
ge=0,
description=(
"Forward pagination cursor. Return messages with sequence number "
"strictly greater than this value. Used by completed-session load-more. "
"Mutually exclusive with before_sequence."
),
),
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
Supports cursor-based pagination via ``limit``, ``before_sequence``, and
``after_sequence``. The two cursor parameters are mutually exclusive.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The authenticated user's ID.
limit: Maximum number of messages to return (1-200, default 50).
before_sequence: Return messages with sequence < this value (cursor).
Returns:
SessionDetailResponse: Details for the requested session, including
active_stream info and pagination metadata.
On the initial load (no cursor provided) of a completed session, messages
are returned in forward order starting from sequence 0 so the user always
sees their initial prompt. Active sessions use the legacy newest-first
order so streaming context is preserved.
"""
if before_sequence is not None and after_sequence is not None:
raise HTTPException(
status_code=400,
detail="before_sequence and after_sequence are mutually exclusive",
)
is_initial_load = before_sequence is None and after_sequence is None
# Check active stream before the DB query on initial loads so we can
# choose the correct pagination direction (forward for completed sessions,
# newest-first for active ones).
active_session = None
last_message_id = None
if is_initial_load:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
# Completed sessions on initial load start from sequence 0 so the user's
# initial prompt is always visible. Active sessions keep the legacy
# newest-first behavior to preserve streaming context.
from_start = is_initial_load and active_session is None
forward_paginated = from_start or after_sequence is not None
page = await get_chat_messages_paginated(
session_id, limit, before_sequence, user_id=user_id
session_id,
limit,
before_sequence=before_sequence,
after_sequence=after_sequence,
from_start=from_start,
user_id=user_id,
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
# Close the TOCTOU window: if the session was active at pre-check, re-verify
# after the DB fetch. The session may have completed between the two awaits,
# which would have caused messages to be fetched newest-first even though the
# session is now complete. Re-fetch from seq 0 so the initial prompt is
# always visible.
if is_initial_load and active_session is not None:
post_active, _ = await stream_registry.get_active_session(session_id, user_id)
if post_active is None:
active_session = None
last_message_id = None
from_start = True
forward_paginated = True
page = await get_chat_messages_paginated(
session_id,
limit,
before_sequence=None,
after_sequence=None,
from_start=True,
user_id=user_id,
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
messages = [
_strip_injected_context(message.model_dump()) for message in page.messages
]
# Only check active stream on initial load (not on "load more" requests)
active_stream_info = None
if before_sequence is None:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
if active_session and last_message_id is not None:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
# Skip session metadata on "load more" — frontend only needs messages
if before_sequence is not None:
if not is_initial_load:
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
@@ -511,6 +573,8 @@ async def get_session(
active_stream=None,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
newest_sequence=page.newest_sequence,
forward_paginated=forward_paginated,
total_prompt_tokens=0,
total_completion_tokens=0,
)
@@ -527,6 +591,8 @@ async def get_session(
active_stream=active_stream_info,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
newest_sequence=page.newest_sequence,
forward_paginated=forward_paginated,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=page.session.metadata,
@@ -872,9 +938,6 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
# Capture the original message text BEFORE any mutation (attachment enrichment)
# so the idempotency hash is stable across retries.
original_message = request.message
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
@@ -903,58 +966,36 @@ async def stream_chat_post(
)
request.message += files_block
# ── Idempotency guard ────────────────────────────────────────────────────
# Blocks duplicate executor tasks from concurrent/retried POSTs.
# See backend/copilot/message_dedup.py for the full lifecycle description.
dedup_lock = None
if request.is_user_message:
dedup_lock = await acquire_dedup_lock(
session_id, original_message, sanitized_file_ids
)
if dedup_lock is None and (original_message or sanitized_file_ids):
async def _empty_sse() -> AsyncGenerator[str, None]:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return StreamingResponse(
_empty_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
#
# If any of these operations raises, release the dedup lock before propagating
# so subsequent retries are not blocked for 30 s.
try:
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
# saved yet. append_and_save_message returns None when a duplicate is
# detected — in that case skip enqueue to avoid processing the message twice.
is_duplicate_message = False
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
is_duplicate_message = (
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
) is None
logger.info(f"[STREAM] User message saved for session {session_id}")
if not is_duplicate_message and request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
# Create a task in the stream registry for reconnection support
# Create a task in the stream registry for reconnection support.
# For duplicate messages, skip create_session entirely so the infra-retry
# client subscribes to the *existing* turn's Redis stream and receives the
# in-progress executor output rather than an empty stream.
turn_id = ""
if not is_duplicate_message:
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
@@ -972,7 +1013,6 @@ async def stream_chat_post(
}
},
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
@@ -984,10 +1024,10 @@ async def stream_chat_post(
mode=request.mode,
model=request.model,
)
except Exception:
if dedup_lock:
await dedup_lock.release()
raise
else:
logger.info(
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
@@ -1011,12 +1051,6 @@ async def stream_chat_post(
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
# True for every exit path except GeneratorExit (client disconnect).
# On disconnect the backend turn is still running — releasing the lock
# there would reopen the infra-retry duplicate window. The 30 s TTL
# is the fallback. All other exits (normal finish, early return, error)
# should release so the user can re-send the same message.
release_dedup_lock_on_exit = True
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
@@ -1028,7 +1062,7 @@ async def stream_chat_post(
if subscriber_queue is None:
yield StreamFinish().to_sse()
return # finally releases dedup_lock
return
# Read from the subscriber queue and yield to SSE
logger.info(
@@ -1070,7 +1104,7 @@ async def stream_chat_post(
}
},
)
break # finally releases dedup_lock
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
@@ -1086,7 +1120,6 @@ async def stream_chat_post(
}
},
)
release_dedup_lock_on_exit = False
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
@@ -1101,10 +1134,7 @@ async def stream_chat_post(
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
# finally releases dedup_lock
finally:
if dedup_lock and release_dedup_lock_on_exit:
await dedup_lock.release()
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:

View File

@@ -133,21 +133,12 @@ def test_stream_chat_rejects_too_many_file_ids():
assert response.status_code == 422
def _mock_stream_internals(
mocker: pytest_mock.MockerFixture,
*,
redis_set_returns: object = True,
):
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ.
Args:
redis_set_returns: Value returned by the mocked Redis ``set`` call.
``True`` (default) simulates a fresh key (new message);
``None`` simulates a collision (duplicate blocked).
validation and enrichment logic without needing RabbitMQ.
Returns:
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
A namespace with ``save`` and ``enqueue`` mock objects so
callers can make additional assertions about side-effects.
"""
import types
@@ -158,7 +149,7 @@ def _mock_stream_internals(
)
mock_save = mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
@@ -174,15 +165,9 @@ def _mock_stream_internals(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=redis_set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
return types.SimpleNamespace(
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
)
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
return ns
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
@@ -211,6 +196,29 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
assert response.status_code == 200
# ─── Duplicate message dedup ──────────────────────────────────────────
def test_stream_chat_skips_enqueue_for_duplicate_message(
mocker: pytest_mock.MockerFixture,
):
"""When append_and_save_message returns None (duplicate detected),
enqueue_copilot_turn and stream_registry.create_session must NOT be called
to avoid double-processing and to prevent overwriting the active stream's
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
mocks = _mock_stream_internals(mocker)
# Override save to return None — signalling a duplicate
mocks.save.return_value = None
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 200
mocks.enqueue.assert_not_called()
mocks.registry.create_session.assert_not_called()
# ─── UUID format filtering ─────────────────────────────────────────────
@@ -706,237 +714,6 @@ class TestStripInjectedContext:
assert result["content"] == "hello"
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
mocker: pytest_mock.MockerFixture,
) -> None:
"""A second POST with the same message within the 30-s window must return
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
turn complete without creating a ghost response."""
# redis_set_returns=None simulates a collision: the NX key already exists.
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-dup/stream",
json={"message": "duplicate message", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
assert '"finish"' in body
assert "[DONE]" in body
# The empty SSE response must include the AI SDK protocol header so the
# frontend treats it as a valid stream and marks the turn complete.
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
# The duplicate guard must prevent save/enqueue side effects.
ns.save.assert_not_called()
ns.enqueue.assert_not_called()
def test_stream_chat_first_post_proceeds_normally(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The first POST (Redis NX key set successfully) must proceed through the
normal streaming path — no early return."""
ns = _mock_stream_internals(mocker, redis_set_returns=True)
response = client.post(
"/sessions/sess-new/stream",
json={"message": "first message", "is_user_message": True},
)
assert response.status_code == 200
# Redis set must have been called once with the NX flag.
ns.redis.set.assert_called_once()
call_kwargs = ns.redis.set.call_args
assert call_kwargs.kwargs.get("nx") is True
def test_stream_chat_dedup_skipped_for_non_user_messages(
mocker: pytest_mock.MockerFixture,
) -> None:
"""System/assistant messages (is_user_message=False) bypass the dedup
guard — they are injected programmatically and must always be processed."""
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-sys/stream",
json={"message": "system context", "is_user_message": False},
)
# Even though redis_set_returns=None (would block a user message),
# the endpoint must proceed because is_user_message=False.
assert response.status_code == 200
ns.redis.set.assert_not_called()
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup hash must be computed from the original request message,
not the mutated version that has the [Attached files] block appended.
A file_id is sent so the route actually appends the [Attached files] block,
exercising the mutation path — the hash must still match the original text."""
import hashlib
ns = _mock_stream_internals(mocker, redis_set_returns=True)
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
# Mock workspace + prisma so the attachment block is actually appended.
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
fake_file = type(
"F",
(),
{
"id": file_id,
"name": "doc.pdf",
"mimeType": "application/pdf",
"sizeBytes": 1024,
},
)()
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-hash/stream",
json={
"message": "plain message",
"is_user_message": True,
"file_ids": [file_id],
},
)
assert response.status_code == 200
ns.redis.set.assert_called_once()
call_args = ns.redis.set.call_args
dedup_key = call_args.args[0]
# Hash must use the original message + sorted file IDs, not the mutated text.
expected_hash = hashlib.sha256(
f"sess-hash:plain message:{file_id}".encode()
).hexdigest()[:16]
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
assert dedup_key == expected_key, (
f"Dedup key {dedup_key!r} does not match expected {expected_key!r}"
"hash may be using mutated message or wrong inputs"
)
def test_stream_chat_dedup_key_released_after_stream_finish(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup Redis key must be deleted after the turn completes (when
subscriber_queue is None the route yields StreamFinish immediately and
should release the key so the user can re-send the same message)."""
from unittest.mock import AsyncMock as _AsyncMock
# Set up all internals manually so we can control subscribe_to_session.
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
response = client.post(
"/sessions/sess-finish/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
assert '"finish"' in body
# The dedup key must be released so intentional re-sends are allowed.
mock_redis.delete.assert_called_once()
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The route must not crash when the dedup Redis delete fails on the
subscriber_queue-is-None early-finish path (except Exception: pass)."""
from unittest.mock import AsyncMock as _AsyncMock
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
# Make the delete raise so the except-pass branch is exercised.
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
# Should not raise even though delete fails.
response = client.post(
"/sessions/sess-finish-err/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
assert '"finish"' in response.text
# delete must have been attempted — the except-pass branch silenced the error.
mock_redis.delete.assert_called_once()
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
@@ -980,3 +757,146 @@ def test_disconnect_stream_returns_404_when_session_missing(
assert response.status_code == 404
mock_disconnect.assert_not_awaited()
# ─── GET /sessions/{session_id} — forward/backward pagination ──────────────────
def _make_paginated_messages(
mocker: pytest_mock.MockerFixture, *, has_more: bool = False
):
"""Return a mock PaginatedMessages and configure the DB patch."""
from datetime import UTC, datetime
from backend.copilot.db import PaginatedMessages
from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata
now = datetime.now(UTC)
session_info = ChatSessionInfo(
session_id="sess-1",
user_id=TEST_USER_ID,
usage=[],
started_at=now,
updated_at=now,
metadata=ChatSessionMetadata(),
)
page = PaginatedMessages(
messages=[ChatMessage(role="user", content="hello", sequence=0)],
has_more=has_more,
oldest_sequence=0,
newest_sequence=0,
session=session_info,
)
mock_paginate = mocker.patch(
"backend.api.features.chat.routes.get_chat_messages_paginated",
new_callable=AsyncMock,
return_value=page,
)
return page, mock_paginate
def test_get_session_completed_returns_forward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Completed sessions (no active stream) return forward_paginated=True."""
_make_paginated_messages(mocker)
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(None, None),
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
assert data["forward_paginated"] is True
assert data["newest_sequence"] == 0
def test_get_session_active_returns_backward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Active sessions (with running stream) return forward_paginated=False."""
from backend.copilot.stream_registry import ActiveSession
_make_paginated_messages(mocker)
active = MagicMock(spec=ActiveSession)
active.turn_id = "turn-1"
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(active, "msg-1"),
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
assert data["forward_paginated"] is False
assert data["active_stream"] is not None
assert data["active_stream"]["turn_id"] == "turn-1"
def test_get_session_after_sequence_returns_forward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""after_sequence param returns forward_paginated=True; no stream check needed."""
_, mock_paginate = _make_paginated_messages(mocker)
response = client.get("/sessions/sess-1?after_sequence=10")
assert response.status_code == 200
data = response.json()
assert data["forward_paginated"] is True
call_kwargs = mock_paginate.call_args
assert call_kwargs.kwargs.get("after_sequence") == 10
assert call_kwargs.kwargs.get("before_sequence") is None
def test_get_session_both_cursors_returns_400(
test_user_id: str,
) -> None:
"""Sending both before_sequence and after_sequence returns 400."""
response = client.get("/sessions/sess-1?before_sequence=5&after_sequence=10")
assert response.status_code == 400
def test_get_session_toctou_refetch_when_session_completes_mid_request(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Race condition: session was active at pre-check but completes before DB fetch.
The route should detect the race via a post-fetch re-check, then re-fetch
from seq 0 so the initial prompt is always visible.
"""
from backend.copilot.stream_registry import ActiveSession
page, mock_paginate = _make_paginated_messages(mocker)
active = MagicMock(spec=ActiveSession)
active.turn_id = "turn-1"
# First call: session appears active. Second call: session has completed.
mock_get_active = mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
side_effect=[(active, "msg-1"), (None, None)],
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
# Post-race: session is now completed → forward_paginated=True, no stream
assert data["forward_paginated"] is True
assert data["active_stream"] is None
# The DB was queried twice: once newest-first, once from_start=True
assert mock_paginate.call_count == 2
assert mock_get_active.call_count == 2
second_call = mock_paginate.call_args_list[1]
assert second_call.kwargs.get("from_start") is True

View File

@@ -43,6 +43,25 @@ config = Config()
integration_creds_manager = IntegrationCredentialsManager()
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
"""Fetch execution counts per graph in a single batched query."""
if not graph_ids:
return {}
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
by=["agentGraphId"],
where={
"userId": user_id,
"agentGraphId": {"in": graph_ids},
"isDeleted": False,
},
count=True,
)
return {
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
for row in rows
}
async def list_library_agents(
user_id: str,
search_term: Optional[str] = None,
@@ -137,12 +156,18 @@ async def list_library_agents(
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts = await _fetch_execution_counts(user_id, graph_ids)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(agent)
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -214,12 +239,18 @@ async def list_favorite_library_agents(
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
)
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts = await _fetch_execution_counts(user_id, graph_ids)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(agent)
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error

View File

@@ -65,6 +65,11 @@ async def test_get_library_agents(mocker):
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
# Call function
result = await db.list_library_agents("test-user")
@@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert():
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False
@pytest.mark.asyncio
async def test_list_favorite_library_agents(mocker):
mock_library_agents = [
prisma.models.LibraryAgent(
id="fav1",
userId="test-user",
agentGraphId="agent-fav",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=True,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-fav",
version=1,
name="Favorite Agent",
description="My Favorite",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
)
result = await db.list_favorite_library_agents("test-user")
assert len(result.agents) == 1
assert result.agents[0].id == "fav1"
assert result.agents[0].name == "Favorite Agent"
assert result.agents[0].graph_id == "agent-fav"
assert result.pagination.total_items == 1
assert result.pagination.total_pages == 1
assert result.pagination.current_page == 1
assert result.pagination.page_size == 50
@pytest.mark.asyncio
async def test_list_library_agents_skips_failed_agent(mocker):
"""Agents that fail parsing should be skipped — covers the except branch."""
mock_library_agents = [
prisma.models.LibraryAgent(
id="ua-bad",
userId="test-user",
agentGraphId="agent-bad",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-bad",
version=1,
name="Bad Agent",
description="",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
mocker.patch(
"backend.api.features.library.model.LibraryAgent.from_db",
side_effect=Exception("parse error"),
)
result = await db.list_library_agents("test-user")
assert len(result.agents) == 0
assert result.pagination.total_items == 1
@pytest.mark.asyncio
async def test_fetch_execution_counts_empty_graph_ids():
result = await db._fetch_execution_counts("user-1", [])
assert result == {}
@pytest.mark.asyncio
async def test_fetch_execution_counts_uses_group_by(mocker):
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
mock_prisma.return_value.group_by = mocker.AsyncMock(
return_value=[
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
]
)
result = await db._fetch_execution_counts(
"user-1", ["graph-1", "graph-2", "graph-3"]
)
assert result == {"graph-1": 5, "graph-2": 2}
mock_prisma.return_value.group_by.assert_called_once_with(
by=["agentGraphId"],
where={
"userId": "user-1",
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
"isDeleted": False,
},
count=True,
)

View File

@@ -223,6 +223,7 @@ class LibraryAgent(pydantic.BaseModel):
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
store_listing: Optional[prisma.models.StoreListing] = None,
profile: Optional[prisma.models.Profile] = None,
execution_count_override: Optional[int] = None,
) -> "LibraryAgent":
"""
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
@@ -258,10 +259,14 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = len(executions)
execution_count = (
execution_count_override
if execution_count_override is not None
else len(executions)
)
success_rate: float | None = None
avg_correctness_score: float | None = None
if execution_count > 0:
if executions and execution_count > 0:
success_count = sum(
1
for e in executions

View File

@@ -1,11 +1,66 @@
import datetime
import prisma.enums
import prisma.models
import pytest
from . import model as library_model
def _make_library_agent(
*,
graph_id: str = "g1",
executions: list | None = None,
) -> prisma.models.LibraryAgent:
return prisma.models.LibraryAgent(
id="la1",
userId="u1",
agentGraphId=graph_id,
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=True,
isDeleted=False,
isArchived=False,
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id=graph_id,
version=1,
name="Agent",
description="Desc",
userId="u1",
isActive=True,
createdAt=datetime.datetime.now(),
Executions=executions,
),
)
def test_from_db_execution_count_override_covers_success_rate():
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
now = datetime.datetime.now(datetime.timezone.utc)
exec1 = prisma.models.AgentGraphExecution(
id="exec-1",
agentGraphId="g1",
agentGraphVersion=1,
userId="u1",
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
createdAt=now,
updatedAt=now,
isDeleted=False,
isShared=False,
)
agent = _make_library_agent(executions=[exec1])
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
assert result.execution_count == 1
assert result.success_rate is not None
assert result.success_rate == 100.0
@pytest.mark.asyncio
async def test_agent_preset_from_db(test_user_id: str):
# Create mock DB agent

View File

@@ -5,7 +5,8 @@ import time
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Any, Literal, Sequence, get_args
from typing import Annotated, Any, Literal, Sequence, cast, get_args
from urllib.parse import urlparse
import pydantic
import stripe
@@ -54,8 +55,11 @@ from backend.data.credit import (
cancel_stripe_subscription,
create_subscription_checkout,
get_auto_top_up,
get_proration_credit_cents,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
modify_stripe_subscription_for_tier,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
@@ -699,9 +703,72 @@ class SubscriptionCheckoutResponse(BaseModel):
class SubscriptionStatusResponse(BaseModel):
tier: str
monthly_cost: int
tier_costs: dict[str, int]
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
monthly_cost: int # amount in cents (Stripe convention)
tier_costs: dict[str, int] # tier name -> amount in cents
proration_credit_cents: int # unused portion of current sub to convert on upgrade
def _validate_checkout_redirect_url(url: str) -> bool:
"""Return True if `url` matches the configured frontend origin.
Prevents open-redirect: attackers must not be able to supply arbitrary
success_url/cancel_url that Stripe will redirect users to after checkout.
Pre-parse rejection rules (applied before urlparse):
- Backslashes (``\\``) are normalised differently across parsers/browsers.
- Control characters (U+0000U+001F) are not valid in URLs and may confuse
some URL-parsing implementations.
"""
# Reject characters that can confuse URL parsers before any parsing.
if "\\" in url:
return False
if any(ord(c) < 0x20 for c in url):
return False
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
if not allowed:
# No configured origin — refuse to validate rather than allow arbitrary URLs.
return False
try:
parsed = urlparse(url)
allowed_parsed = urlparse(allowed)
except ValueError:
return False
if parsed.scheme not in ("http", "https"):
return False
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
# can trick browsers into connecting to a different host than displayed.
# ``@`` in query/fragment is harmless and must be allowed.
if "@" in parsed.netloc:
return False
return (
parsed.scheme == allowed_parsed.scheme
and parsed.netloc == allowed_parsed.netloc
)
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
async def _get_stripe_price_amount(price_id: str) -> int | None:
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
of caching the ``None`` sentinel so the next request retries Stripe instead
of being served a stale "no price" for the rest of the TTL window. Callers
should treat ``None`` as an unknown price and fall back to 0.
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
every GET /credits/subscription page load and reduces quota consumption.
"""
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
return price.unit_amount or 0
except stripe.StripeError:
logger.warning(
"Failed to retrieve Stripe price %s — returning None (not cached)",
price_id,
)
return None
@v1_router.get(
@@ -722,21 +789,26 @@ async def get_subscription_status(
*[get_subscription_price_id(t) for t in paid_tiers]
)
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
for t, price_id in zip(paid_tiers, price_ids):
cost = 0
if price_id:
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
cost = price.unit_amount or 0
except stripe.StripeError:
pass
tier_costs: dict[str, int] = {
SubscriptionTier.FREE.value: 0,
SubscriptionTier.ENTERPRISE.value: 0,
}
async def _cost(pid: str | None) -> int:
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
for t, cost in zip(paid_tiers, costs):
tier_costs[t.value] = cost
current_monthly_cost = tier_costs.get(tier.value, 0)
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
return SubscriptionStatusResponse(
tier=tier.value,
monthly_cost=tier_costs.get(tier.value, 0),
monthly_cost=current_monthly_cost,
tier_costs=tier_costs,
proration_credit_cents=proration_credit,
)
@@ -766,24 +838,125 @@ async def update_subscription_tier(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
# keeps their tier for the time they already paid for. The DB tier is NOT
# updated here when a subscription exists — the customer.subscription.deleted
# webhook fires at period end and downgrades to FREE then.
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
# tier), cancel_stripe_subscription returns False and we update the DB tier
# immediately since no webhook will ever fire.
# When payment is disabled entirely, update the DB tier directly.
if tier == SubscriptionTier.FREE:
if payment_enabled:
await cancel_stripe_subscription(user_id)
try:
had_subscription = await cancel_stripe_subscription(user_id)
except stripe.StripeError as e:
# Log full Stripe error server-side but return a generic message
# to the client — raw Stripe errors can leak customer/sub IDs and
# infrastructure config details.
logger.exception(
"Stripe error cancelling subscription for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=502,
detail=(
"Unable to cancel your subscription right now. "
"Please try again or contact support."
),
)
if not had_subscription:
# No active Stripe subscription found — the user was on an
# admin-granted tier. Update DB immediately since the
# subscription.deleted webhook will never fire.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
# Beta users (payment not enabled) → update tier directly without Stripe.
# Paid tier changes require payment to be enabled — block self-service upgrades
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
if not payment_enabled:
await set_subscription_tier(user_id, tier)
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier}",
)
# No-op short-circuit: if the user is already on the requested paid tier,
# do NOT create a new Checkout Session. Without this guard, a duplicate
# request (double-click, retried POST, stale page) creates a second
# subscription for the same price; the user would be charged for both
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
# which only fires after the second charge has cleared.
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
return SubscriptionCheckoutResponse(url="")
# Paid upgrade → create Stripe Checkout Session.
# Paid→paid tier change: if the user already has a Stripe subscription,
# modify it in-place with proration instead of creating a new Checkout
# Session. This preserves remaining paid time and avoids double-charging.
# The customer.subscription.updated webhook fires and updates the DB tier.
current_tier = user.subscription_tier or SubscriptionTier.FREE
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return SubscriptionCheckoutResponse(url="")
# modify_stripe_subscription_for_tier returns False when no active
# Stripe subscription exists — i.e. the user has an admin-granted
# paid tier with no Stripe record. In that case, update the DB
# tier directly (same as the FREE-downgrade path for admin-granted
# users) rather than sending them through a new Checkout Session.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
# Paid upgrade from FREE → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
raise HTTPException(
status_code=422,
detail="success_url and cancel_url are required for paid tier upgrades",
)
# Open-redirect protection: both URLs must point to the configured frontend
# origin, otherwise an attacker could use our Stripe integration as a
# redirector to arbitrary phishing sites.
#
# Fail early with a clear 503 if the server is misconfigured (neither
# frontend_base_url nor platform_base_url set), so operators get an
# actionable error instead of the misleading "must match the platform
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
# produce when `allowed` is empty.
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
logger.error(
"update_subscription_tier: neither frontend_base_url nor "
"platform_base_url is configured; cannot validate checkout redirect URLs"
)
raise HTTPException(
status_code=503,
detail=(
"Payment redirect URLs cannot be validated: "
"frontend_base_url or platform_base_url must be set on the server."
),
)
if not _validate_checkout_redirect_url(
request.success_url
) or not _validate_checkout_redirect_url(request.cancel_url):
raise HTTPException(
status_code=422,
detail="success_url and cancel_url must match the platform frontend origin",
)
try:
url = await create_subscription_checkout(
user_id=user_id,
@@ -791,8 +964,19 @@ async def update_subscription_tier(
success_url=request.success_url,
cancel_url=request.cancel_url,
)
except (ValueError, stripe.StripeError) as e:
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error creating checkout session for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to start checkout right now. "
"Please try again or contact support."
),
)
return SubscriptionCheckoutResponse(url=url)
@@ -801,44 +985,78 @@ async def update_subscription_tier(
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
)
async def stripe_webhook(request: Request):
webhook_secret = settings.secrets.stripe_webhook_secret
if not webhook_secret:
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
# signature over the same empty key). Reject all webhook calls when unconfigured.
logger.error(
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
"rejecting request to prevent signature bypass"
)
raise HTTPException(status_code=503, detail="Webhook not configured")
# Get the raw request body
payload = await request.body()
# Get the signature header
sig_header = request.headers.get("stripe-signature")
try:
event = stripe.Webhook.construct_event(
payload, sig_header, settings.secrets.stripe_webhook_secret
)
except ValueError as e:
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
except ValueError:
# Invalid payload
raise HTTPException(
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
)
except stripe.SignatureVerificationError as e:
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.SignatureVerificationError:
# Invalid signature
raise HTTPException(
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
raise HTTPException(status_code=400, detail="Invalid signature")
# Defensive payload extraction. A malformed payload (missing/non-dict
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
# AFTER signature verification — which Stripe interprets as a delivery
# failure and retries forever, while spamming Sentry with no useful info.
# Acknowledge with 200 and a warning so Stripe stops retrying.
event_type = event.get("type", "")
event_data = event.get("data") or {}
data_object = event_data.get("object") if isinstance(event_data, dict) else None
if not isinstance(data_object, dict):
logger.warning(
"stripe_webhook: %s missing or non-dict data.object; ignoring",
event_type,
)
return Response(status_code=200)
if (
event["type"] == "checkout.session.completed"
or event["type"] == "checkout.session.async_payment_succeeded"
if event_type in (
"checkout.session.completed",
"checkout.session.async_payment_succeeded",
):
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
session_id = data_object.get("id")
if not session_id:
logger.warning(
"stripe_webhook: %s missing data.object.id; ignoring", event_type
)
return Response(status_code=200)
await UserCredit().fulfill_checkout(session_id=session_id)
if event["type"] in (
if event_type in (
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
):
await sync_subscription_from_stripe(event["data"]["object"])
await sync_subscription_from_stripe(data_object)
if event["type"] == "charge.dispute.created":
await UserCredit().handle_dispute(event["data"]["object"])
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
await UserCredit().deduct_credits(event["data"]["object"])
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
# to satisfy the type checker without changing runtime behaviour.
if event_type == "charge.dispute.created":
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
if event_type == "refund.created" or event_type == "charge.dispute.closed":
await UserCredit().deduct_credits(
cast("stripe.Refund | stripe.Dispute", data_object)
)
return Response(status_code=200)

View File

@@ -106,7 +106,6 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
@classmethod
def _missing_(cls, value: object) -> "LlmModel | None":
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
@@ -203,6 +202,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_4_20 = "x-ai/grok-4.20"
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
@@ -627,6 +628,18 @@ MODEL_METADATA = {
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_20: ModelMetadata(
"open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata(
"open_router",
2000000,
100000,
"Grok 4.20 Multi-Agent",
"OpenRouter",
"xAI",
3,
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
@@ -987,7 +1000,6 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "anthropic":
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
# Cache tool definitions alongside the system prompt.
# Placing cache_control on the last tool caches all tool schemas as a

View File

@@ -67,11 +67,15 @@ from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
TranscriptDownload,
detect_gap,
download_transcript,
extract_context_messages,
strip_for_upload,
upload_transcript,
validate_transcript,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util import json as util_json
from backend.util.exceptions import NotFoundError
from backend.util.prompt import (
compress_context,
@@ -699,81 +703,147 @@ async def _compress_session_messages(
return messages
def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool:
"""Return ``True`` when a download doesn't cover the current session.
A transcript is stale when it has a known ``message_count`` and that
count doesn't reach ``session_msg_count - 1`` (i.e. the session has
already advanced beyond what the stored transcript captures).
Loading a stale transcript would silently drop intermediate turns,
so callers should treat stale as "skip load, skip upload".
An unknown ``message_count`` (``0``) is treated as **not stale**
because older transcripts uploaded before msg_count tracking
existed must still be usable.
"""
if dl is None:
return False
if not dl.message_count:
return False
return dl.message_count < session_msg_count - 1
def should_upload_transcript(
user_id: str | None, transcript_covers_prefix: bool
) -> bool:
def should_upload_transcript(user_id: str | None, upload_safe: bool) -> bool:
"""Return ``True`` when the caller should upload the final transcript.
Uploads require a logged-in user (for the storage key) *and* a
transcript that covered the session prefix when loaded — otherwise
we'd be overwriting a more complete version in storage with a
partial one built from just the current turn.
Uploads require a logged-in user (for the storage key) *and* a safe
upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a
newer version that we'd be overwriting.
"""
return bool(user_id) and transcript_covers_prefix
return bool(user_id) and upload_safe
def _append_gap_to_builder(
gap: list[ChatMessage],
builder: TranscriptBuilder,
) -> None:
"""Append gap messages from chat-db into the TranscriptBuilder.
Converts ChatMessage (OpenAI format) to TranscriptBuilder entries
(Claude CLI JSONL format) so the uploaded transcript covers all turns.
Pre-condition: ``gap`` always starts at a user or assistant boundary
(never mid-turn at a ``tool`` role), because ``detect_gap`` enforces
``session_messages[wm-1].role == 'assistant'`` before returning a non-empty
gap. Any ``tool`` role messages within the gap always follow an assistant
entry that already exists in the builder or in the gap itself.
"""
for msg in gap:
if msg.role == "user":
builder.append_user(msg.content or "")
elif msg.role == "assistant":
content_blocks: list[dict] = []
if msg.content:
content_blocks.append({"type": "text", "text": msg.content})
if msg.tool_calls:
for tc in msg.tool_calls:
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
input_data = util_json.loads(fn.get("arguments", "{}"), fallback={})
content_blocks.append(
{
"type": "tool_use",
"id": tc.get("id", "") if isinstance(tc, dict) else "",
"name": fn.get("name", "unknown"),
"input": input_data,
}
)
if not content_blocks:
# Fallback: ensure every assistant gap message produces an entry
# so the builder's entry count matches the gap length.
content_blocks.append({"type": "text", "text": ""})
builder.append_assistant(content_blocks=content_blocks)
elif msg.role == "tool":
if msg.tool_call_id:
builder.append_tool_result(
tool_use_id=msg.tool_call_id,
content=msg.content or "",
)
else:
# Malformed tool message — no tool_call_id to link to an
# assistant tool_use block. Skip to avoid an unmatched
# tool_result entry in the builder (which would confuse --resume).
logger.warning(
"[Baseline] Skipping tool gap message with no tool_call_id"
)
async def _load_prior_transcript(
user_id: str,
session_id: str,
session_msg_count: int,
session_messages: list[ChatMessage],
transcript_builder: TranscriptBuilder,
) -> bool:
"""Download and load the prior transcript into ``transcript_builder``.
) -> tuple[bool, "TranscriptDownload | None"]:
"""Download and load the prior CLI session into ``transcript_builder``.
Returns ``True`` when the loaded transcript fully covers the session
prefix; ``False`` otherwise (stale, missing, invalid, or download
error). Callers should suppress uploads when this returns ``False``
to avoid overwriting a more complete version in storage.
Returns a tuple of (upload_safe, transcript_download):
- ``upload_safe`` is ``True`` when it is safe to upload at the end of this
turn. Upload is suppressed only for **download errors** (unknown GCS
state) — missing and invalid files return ``True`` because there is
nothing in GCS worth protecting against overwriting.
- ``transcript_download`` is a ``TranscriptDownload`` with str content
(pre-decoded and stripped) when available, or ``None`` when no valid
transcript could be loaded. Callers pass this to
``extract_context_messages`` to build the LLM context.
"""
try:
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
except Exception as e:
logger.warning("[Baseline] Transcript download failed: %s", e)
return False
if dl is None:
logger.debug("[Baseline] No transcript available")
return False
if not validate_transcript(dl.content):
logger.warning("[Baseline] Downloaded transcript but invalid")
return False
if is_transcript_stale(dl, session_msg_count):
logger.warning(
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
dl.message_count,
session_msg_count,
restore = await download_transcript(
user_id, session_id, log_prefix="[Baseline]"
)
return False
except Exception as e:
logger.warning("[Baseline] Session restore failed: %s", e)
# Unknown GCS state — be conservative, skip upload.
return False, None
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
if restore is None:
logger.debug("[Baseline] No CLI session available — will upload fresh")
# Nothing in GCS to protect; allow upload so the first baseline turn
# writes the initial transcript snapshot.
return True, None
content_bytes = restore.content
try:
raw_str = (
content_bytes.decode("utf-8")
if isinstance(content_bytes, bytes)
else content_bytes
)
except UnicodeDecodeError:
logger.warning("[Baseline] CLI session content is not valid UTF-8")
# Corrupt file in GCS; overwriting with a valid one is better.
return True, None
stripped = strip_for_upload(raw_str)
if not validate_transcript(stripped):
logger.warning("[Baseline] CLI session content invalid after strip")
# Corrupt file in GCS; overwriting with a valid one is better.
return True, None
transcript_builder.load_previous(stripped, log_prefix="[Baseline]")
logger.info(
"[Baseline] Loaded transcript: %dB, msg_count=%d",
len(dl.content),
dl.message_count,
"[Baseline] Loaded CLI session: %dB, msg_count=%d",
len(content_bytes) if isinstance(content_bytes, bytes) else len(raw_str),
restore.message_count,
)
return True
gap = detect_gap(restore, session_messages)
if gap:
_append_gap_to_builder(gap, transcript_builder)
logger.info(
"[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB",
restore.message_count,
len(gap),
)
# Return a str-content version so extract_context_messages receives a
# pre-decoded, stripped transcript (avoids redundant decode + strip).
# TranscriptDownload.content is typed as bytes | str; we pass str here
# to avoid a redundant encode + decode round-trip.
str_restore = TranscriptDownload(
content=stripped,
message_count=restore.message_count,
mode=restore.mode,
)
return True, str_restore
async def _upload_final_transcript(
@@ -807,10 +877,10 @@ async def _upload_final_transcript(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=content,
content=content.encode("utf-8"),
message_count=session_msg_count,
mode="baseline",
log_prefix="[Baseline]",
skip_strip=True,
)
)
_background_tasks.add(upload_task)
@@ -897,7 +967,7 @@ async def stream_chat_completion_baseline(
# --- Transcript support (feature parity with SDK path) ---
transcript_builder = TranscriptBuilder()
transcript_covers_prefix = True
transcript_upload_safe = True
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
@@ -914,15 +984,16 @@ async def stream_chat_completion_baseline(
# Run download + prompt build concurrently — both are independent I/O
# on the request critical path.
transcript_download: TranscriptDownload | None = None
if user_id and len(session.messages) > 1:
(
transcript_covers_prefix,
(transcript_upload_safe, transcript_download),
(base_system_prompt, understanding),
) = await asyncio.gather(
_load_prior_transcript(
user_id=user_id,
session_id=session_id,
session_msg_count=len(session.messages),
session_messages=session.messages,
transcript_builder=transcript_builder,
),
prompt_task,
@@ -962,9 +1033,14 @@ async def stream_chat_completion_baseline(
warm_ctx = await fetch_warm_context(user_id, message or "")
# Compress context if approaching the model's token limit
# Context path: transcript content (compacted, isCompactSummary preserved) +
# gap (DB messages after watermark) + current user turn.
# This avoids re-reading the full session history from DB on every turn.
# See extract_context_messages() in transcript.py for the shared primitive.
prior_context = extract_context_messages(transcript_download, session.messages)
messages_for_context = await _compress_session_messages(
session.messages, model=active_model
prior_context + ([session.messages[-1]] if session.messages else []),
model=active_model,
)
# Build OpenAI message list from session history.
@@ -1308,7 +1384,7 @@ async def stream_chat_completion_baseline(
stop_reason=STOP_REASON_END_TURN,
)
if user_id and should_upload_transcript(user_id, transcript_covers_prefix):
if user_id and should_upload_transcript(user_id, transcript_upload_safe):
await _upload_final_transcript(
user_id=user_id,
session_id=session_id,

View File

@@ -1,7 +1,7 @@
"""Integration tests for baseline transcript flow.
Exercises the real helpers in ``baseline/service.py`` that download,
validate, load, append to, backfill, and upload the transcript.
Exercises the real helpers in ``baseline/service.py`` that restore,
validate, load, append to, backfill, and upload the CLI session.
Storage is mocked via ``download_transcript`` / ``upload_transcript``
patches; no network access is required.
"""
@@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.baseline.service import (
_append_gap_to_builder,
_load_prior_transcript,
_record_turn_to_transcript,
_resolve_baseline_model,
_upload_final_transcript,
is_transcript_stale,
should_upload_transcript,
)
from backend.copilot.model import ChatMessage
from backend.copilot.service import config
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
@@ -54,6 +55,13 @@ def _make_transcript_content(*roles: str) -> str:
return "\n".join(lines) + "\n"
def _make_session_messages(*roles: str) -> list[ChatMessage]:
"""Build a list of ChatMessage objects matching the given roles."""
return [
ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles)
]
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
@@ -73,87 +81,102 @@ class TestResolveBaselineModel:
class TestLoadPriorTranscript:
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
"""``_load_prior_transcript`` wraps the CLI session restore + validate + load flow."""
@pytest.mark.asyncio
async def test_loads_fresh_transcript(self):
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=content, message_count=2)
restore = TranscriptDownload(
content=content.encode("utf-8"), message_count=2, mode="sdk"
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
assert dl is not None
assert dl.message_count == 2
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
@pytest.mark.asyncio
async def test_rejects_stale_transcript(self):
"""msg_count strictly less than session-1 is treated as stale."""
async def test_fills_gap_when_transcript_is_behind(self):
"""When transcript covers fewer messages than session, gap is filled from DB."""
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
# session has 6 messages, transcript only covers 2 → stale.
download = TranscriptDownload(content=content, message_count=2)
# transcript covers 2 messages, session has 4 (plus current user turn = 5)
restore = TranscriptDownload(
content=content.encode("utf-8"), message_count=2, mode="baseline"
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=6,
session_messages=_make_session_messages(
"user", "assistant", "user", "assistant", "user"
),
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
assert covers is True
assert dl is not None
# 2 from transcript + 2 gap messages (user+assistant at positions 2,3)
assert builder.entry_count == 4
@pytest.mark.asyncio
async def test_missing_transcript_returns_false(self):
async def test_missing_transcript_allows_upload(self):
"""Nothing in GCS → upload is safe; the turn writes the first snapshot."""
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
):
covers = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert covers is False
assert upload_safe is True
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
async def test_invalid_transcript_returns_false(self):
async def test_invalid_transcript_allows_upload(self):
"""Corrupt file in GCS → overwriting with a valid one is better."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content='{"type":"progress","uuid":"a"}\n',
restore = TranscriptDownload(
content=b'{"type":"progress","uuid":"a"}\n',
message_count=1,
mode="sdk",
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert covers is False
assert upload_safe is True
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
@@ -163,36 +186,39 @@ class TestLoadPriorTranscript:
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert covers is False
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
async def test_zero_message_count_not_stale(self):
"""When msg_count is 0 (unknown), staleness check is skipped."""
"""When msg_count is 0 (unknown), gap detection is skipped."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
restore = TranscriptDownload(
content=_make_transcript_content("user", "assistant").encode("utf-8"),
message_count=0,
mode="sdk",
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=20,
session_messages=_make_session_messages(*["user"] * 20),
transcript_builder=builder,
)
assert covers is True
assert dl is not None
assert builder.entry_count == 2
@@ -227,7 +253,7 @@ class TestUploadFinalTranscript:
assert call_kwargs["user_id"] == "user-1"
assert call_kwargs["session_id"] == "session-1"
assert call_kwargs["message_count"] == 2
assert "hello" in call_kwargs["content"]
assert b"hello" in call_kwargs["content"]
@pytest.mark.asyncio
async def test_skips_upload_when_builder_empty(self):
@@ -374,17 +400,19 @@ class TestRoundTrip:
@pytest.mark.asyncio
async def test_full_round_trip(self):
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
restore = TranscriptDownload(
content=prior.encode("utf-8"), message_count=2, mode="sdk"
)
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
@@ -424,11 +452,11 @@ class TestRoundTrip:
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "new question" in uploaded
assert "new answer" in uploaded
assert b"new question" in uploaded
assert b"new answer" in uploaded
# Original content preserved in the round trip.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
assert b"user message 0" in uploaded
assert b"assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_backfill_append_guard(self):
@@ -459,36 +487,6 @@ class TestRoundTrip:
assert builder.entry_count == initial_count
class TestIsTranscriptStale:
"""``is_transcript_stale`` gates prior-transcript loading."""
def test_none_download_is_not_stale(self):
assert is_transcript_stale(None, session_msg_count=5) is False
def test_zero_message_count_is_not_stale(self):
"""Legacy transcripts without msg_count tracking must remain usable."""
dl = TranscriptDownload(content="", message_count=0)
assert is_transcript_stale(dl, session_msg_count=20) is False
def test_stale_when_covers_less_than_prefix(self):
dl = TranscriptDownload(content="", message_count=2)
# session has 6 messages; transcript must cover at least 5 (6-1).
assert is_transcript_stale(dl, session_msg_count=6) is True
def test_fresh_when_covers_full_prefix(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_fresh_when_exceeds_prefix(self):
"""Race: transcript ahead of session count is still acceptable."""
dl = TranscriptDownload(content="", message_count=10)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_boundary_equal_to_prefix_minus_one(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
class TestShouldUploadTranscript:
"""``should_upload_transcript`` gates the final upload."""
@@ -510,7 +508,7 @@ class TestShouldUploadTranscript:
class TestTranscriptLifecycle:
"""End-to-end: download → validate → build → upload.
"""End-to-end: restore → validate → build → upload.
Simulates the full transcript lifecycle inside
``stream_chat_completion_baseline`` by mocking the storage layer and
@@ -519,27 +517,29 @@ class TestTranscriptLifecycle:
@pytest.mark.asyncio
async def test_full_lifecycle_happy_path(self):
"""Fresh download, append a turn, upload covers the session."""
"""Fresh restore, append a turn, upload covers the session."""
builder = TranscriptBuilder()
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
restore = TranscriptDownload(
content=prior.encode("utf-8"), message_count=2, mode="sdk"
)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
# --- 1. Download & load prior transcript ---
covers = await _load_prior_transcript(
# --- 1. Restore & load prior session ---
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
@@ -559,10 +559,7 @@ class TestTranscriptLifecycle:
# --- 3. Gate + upload ---
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is True
should_upload_transcript(user_id="user-1", upload_safe=covers) is True
)
await _upload_final_transcript(
user_id="user-1",
@@ -574,20 +571,21 @@ class TestTranscriptLifecycle:
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "follow-up question" in uploaded
assert "follow-up answer" in uploaded
assert b"follow-up question" in uploaded
assert b"follow-up answer" in uploaded
# Original prior-turn content preserved.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
assert b"user message 0" in uploaded
assert b"assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_lifecycle_stale_download_suppresses_upload(self):
"""Stale download → covers=False → upload must be skipped."""
async def test_lifecycle_stale_download_fills_gap(self):
"""When transcript covers fewer messages, gap is filled rather than rejected."""
builder = TranscriptBuilder()
# session has 10 msgs but stored transcript only covers 2 → stale.
# session has 5 msgs but stored transcript only covers 2 → gap filled.
stale = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
content=_make_transcript_content("user", "assistant").encode("utf-8"),
message_count=2,
mode="baseline",
)
upload_mock = AsyncMock(return_value=None)
@@ -601,20 +599,18 @@ class TestTranscriptLifecycle:
new=upload_mock,
),
):
covers = await _load_prior_transcript(
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
session_messages=_make_session_messages(
"user", "assistant", "user", "assistant", "user"
),
transcript_builder=builder,
)
assert covers is False
# The caller's gate mirrors the production path.
assert (
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
is False
)
upload_mock.assert_not_awaited()
assert covers is True
# Gap was filled: 2 from transcript + 2 gap messages
assert builder.entry_count == 4
@pytest.mark.asyncio
async def test_lifecycle_anonymous_user_skips_upload(self):
@@ -627,15 +623,11 @@ class TestTranscriptLifecycle:
stop_reason=STOP_REASON_END_TURN,
)
assert (
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
is False
)
assert should_upload_transcript(user_id=None, upload_safe=True) is False
@pytest.mark.asyncio
async def test_lifecycle_missing_download_still_uploads_new_content(self):
"""No prior transcript → covers defaults to True in the service,
new turn should upload cleanly."""
"""No prior session → upload is safe; the turn writes the first snapshot."""
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with (
@@ -648,20 +640,117 @@ class TestTranscriptLifecycle:
new=upload_mock,
),
):
covers = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=1,
session_messages=_make_session_messages("user"),
transcript_builder=builder,
)
# No download: covers is False, so the production path would
# skip upload. This protects against overwriting a future
# more-complete transcript with a single-turn snapshot.
assert covers is False
# Nothing in GCS → upload is safe so the first baseline turn
# can write the initial transcript snapshot.
assert upload_safe is True
assert dl is None
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is False
should_upload_transcript(user_id="user-1", upload_safe=upload_safe)
is True
)
upload_mock.assert_not_awaited()
# ---------------------------------------------------------------------------
# _append_gap_to_builder
# ---------------------------------------------------------------------------
class TestAppendGapToBuilder:
"""``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries."""
def test_user_message_appended(self):
builder = TranscriptBuilder()
msgs = [ChatMessage(role="user", content="hello")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
assert builder.last_entry_type == "user"
def test_assistant_text_message_appended(self):
builder = TranscriptBuilder()
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="answer"),
]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
assert "answer" in builder.to_jsonl()
def test_assistant_with_tool_calls_appended(self):
"""Assistant tool_calls are recorded as tool_use blocks in the transcript."""
builder = TranscriptBuilder()
tool_call = {
"id": "tc-1",
"type": "function",
"function": {"name": "my_tool", "arguments": '{"key":"val"}'},
}
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "tool_use" in jsonl
assert "my_tool" in jsonl
assert "tc-1" in jsonl
def test_assistant_invalid_json_args_uses_empty_dict(self):
"""Malformed JSON in tool_call arguments falls back to {}."""
builder = TranscriptBuilder()
tool_call = {
"id": "tc-bad",
"type": "function",
"function": {"name": "bad_tool", "arguments": "not-json"},
}
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert '"input":{}' in jsonl
def test_assistant_empty_content_and_no_tools_uses_fallback(self):
"""Assistant with no content and no tool_calls gets a fallback empty text block."""
builder = TranscriptBuilder()
msgs = [ChatMessage(role="assistant", content=None)]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "text" in jsonl
def test_tool_role_with_tool_call_id_appended(self):
"""Tool result messages are appended when tool_call_id is set."""
builder = TranscriptBuilder()
# Need a preceding assistant tool_use entry
builder.append_user("use tool")
builder.append_assistant(
content_blocks=[
{"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}}
]
)
msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 3
assert "tool_result" in builder.to_jsonl()
def test_tool_role_without_tool_call_id_skipped(self):
"""Tool messages without tool_call_id are silently skipped."""
builder = TranscriptBuilder()
msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 0
def test_tool_call_missing_function_key_uses_unknown_name(self):
"""A tool_call dict with no 'function' key uses 'unknown' as the tool name."""
builder = TranscriptBuilder()
# Tool call dict exists but 'function' sub-dict is missing entirely
msgs = [
ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}])
]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "unknown" in jsonl

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
# Allowed base directory for the Read tool. Public so service.py can use it
# for sweep operations without depending on a private implementation detail.
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
# _projects_base() function.
# projects_base() function.
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))

View File

@@ -10,9 +10,11 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatMessageWhereInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
FindManyChatMessageArgsFromChatSession,
)
from pydantic import BaseModel
@@ -30,6 +32,8 @@ from .model import get_chat_session as get_chat_session_cached
logger = logging.getLogger(__name__)
_BOUNDARY_SCAN_LIMIT = 10
class PaginatedMessages(BaseModel):
"""Result of a paginated message query."""
@@ -37,6 +41,7 @@ class PaginatedMessages(BaseModel):
messages: list[ChatMessage]
has_more: bool
oldest_sequence: int | None
newest_sequence: int | None
session: ChatSessionInfo
@@ -61,32 +66,48 @@ async def get_chat_messages_paginated(
session_id: str,
limit: int = 50,
before_sequence: int | None = None,
after_sequence: int | None = None,
from_start: bool = False,
user_id: str | None = None,
) -> PaginatedMessages | None:
"""Get paginated messages for a session, newest first.
"""Get paginated messages for a session.
Verifies session existence (and ownership when ``user_id`` is provided)
in parallel with the message query. Returns ``None`` when the session
is not found or does not belong to the user.
Three modes:
Args:
session_id: The chat session ID.
limit: Max messages to return.
before_sequence: Cursor — return messages with sequence < this value.
user_id: If provided, filters via ``Session.userId`` so only the
session owner's messages are returned (acts as an ownership guard).
- ``before_sequence`` set: backward pagination (DESC), returns messages
with sequence < ``before_sequence``. Used for active sessions or manual
backward navigation.
- ``from_start=True`` or ``after_sequence`` set: forward pagination (ASC).
Returns messages from sequence 0 (``from_start``) or after
``after_sequence``. Used on initial load of completed sessions and for
loading subsequent forward pages.
- Both cursors ``None`` and ``from_start=False``: newest-first (DESC
without filter). Used for active sessions on initial load.
Verifies session existence (and ownership when ``user_id`` is provided).
Returns ``None`` when the session is not found or does not belong to the
user.
"""
# Build session-existence / ownership check
session_where: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
session_where["userId"] = user_id
# Build message include — fetch paginated messages in the same query
msg_include: dict[str, Any] = {
"order_by": {"sequence": "desc"},
forward = from_start or after_sequence is not None
# Build message include — fetch paginated messages in the same query.
# Note: when both from_start=True and after_sequence is not None, the
# after_sequence filter takes precedence (the elif branch below is skipped).
# This combination is not reachable via the HTTP route (mutual exclusion is
# enforced there), so we rely on the documented priority here without an
# additional assertion.
msg_include: FindManyChatMessageArgsFromChatSession = {
"order_by": {"sequence": "asc" if forward else "desc"},
"take": limit + 1,
}
if before_sequence is not None:
if after_sequence is not None:
msg_include["where"] = {"sequence": {"gt": after_sequence}}
elif before_sequence is not None:
msg_include["where"] = {"sequence": {"lt": before_sequence}}
# Single query: session existence/ownership + paginated messages
@@ -104,57 +125,96 @@ async def get_chat_messages_paginated(
has_more = len(results) > limit
results = results[:limit]
# Reverse to ascending order
results.reverse()
if not forward:
# Backward mode: DB returned DESC; reverse to ascending order.
results.reverse()
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
_BOUNDARY_SCAN_LIMIT = 10
if results and results[0].role == "tool":
boundary_where: dict[str, Any] = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
if results and results[0].role == "tool":
boundary_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
else:
# Forward mode: DB returned ASC.
# Tool-call tail boundary fix: if the last message in this page is a
# tool message, the NEXT forward page would start after it and begin
# mid-tool-group — the owning assistant message is on this page but
# the following tool results are on the next page.
# Trim the current page so it ends on the owning assistant message,
# which keeps tool groups intact across page boundaries.
if results and results[-1].role == "tool":
# Walk backward through results to find the last non-tool message.
trim_idx = len(results) - 1
while trim_idx >= 0 and results[trim_idx].role == "tool":
trim_idx -= 1
if trim_idx >= 0:
# Trim results so the page ends at the owning assistant.
# Mark has_more=True so the client knows to fetch the rest.
results = results[: trim_idx + 1]
has_more = True
else:
# Entire page is tool messages with no visible owner — log and
# keep as-is so the caller is not stuck with an empty page.
logger.warning(
"Forward tail boundary: entire page is tool messages "
"for session=%s, no owning assistant found (%d msgs)",
session_id,
len(results),
)
messages = [ChatMessage.from_db(m) for m in results]
oldest_sequence = messages[0].sequence if messages else None
# oldest_sequence is only meaningful in backward mode (used as backward
# pagination cursor). In forward mode the page always starts near seq 0
# and clients should use newest_sequence as the forward cursor instead.
# Return None in forward mode so clients don't accidentally treat it as a
# backward cursor on a forward-paginated session.
oldest_sequence = messages[0].sequence if (messages and not forward) else None
# newest_sequence is only meaningful in forward mode; in backward mode it
# points to the last message of the page (not the session's newest message)
# which is not a valid forward cursor. Return None in backward mode so
# clients don't accidentally use it as one.
newest_sequence = messages[-1].sequence if (messages and forward) else None
return PaginatedMessages(
messages=messages,
has_more=has_more,
oldest_sequence=oldest_sequence,
newest_sequence=newest_sequence,
session=session_info,
)

View File

@@ -175,6 +175,187 @@ async def test_no_where_on_messages_without_before_sequence(
assert "where" not in include["Messages"]
# ---------- Forward pagination (from_start / after_sequence) ----------
@pytest.mark.asyncio
async def test_from_start_uses_asc_order_no_where(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""from_start=True queries messages in ASC order with no where filter."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["order_by"] == {"sequence": "asc"}
assert "where" not in include["Messages"]
@pytest.mark.asyncio
async def test_from_start_returns_messages_ascending(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""from_start=True returns messages in ascending sequence order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
assert page is not None
assert [m.sequence for m in page.messages] == [0, 1, 2]
assert (
page.oldest_sequence is None
) # None in forward mode — not a valid backward cursor
assert page.newest_sequence == 2
assert page.has_more is False
@pytest.mark.asyncio
async def test_from_start_has_more_when_results_exceed_limit(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""from_start=True sets has_more when DB returns more than limit items."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=2, from_start=True)
assert page is not None
assert page.has_more is True
assert [m.sequence for m in page.messages] == [0, 1]
assert page.newest_sequence == 1
@pytest.mark.asyncio
async def test_after_sequence_uses_gt_filter_asc_order(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""after_sequence adds a sequence > N where clause and uses ASC order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(11), _make_msg(12)],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["order_by"] == {"sequence": "asc"}
assert include["Messages"]["where"] == {"sequence": {"gt": 10}}
@pytest.mark.asyncio
async def test_after_sequence_returns_messages_in_order(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""after_sequence returns only messages with sequence > cursor, ascending."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(11), _make_msg(12), _make_msg(13)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
assert page is not None
assert [m.sequence for m in page.messages] == [11, 12, 13]
assert (
page.oldest_sequence is None
) # None in forward mode — not a valid backward cursor
assert page.newest_sequence == 13
assert page.has_more is False
@pytest.mark.asyncio
async def test_newest_sequence_none_for_backward_mode(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""newest_sequence is None in backward mode — it is not a valid forward cursor."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5), _make_msg(4), _make_msg(3)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.newest_sequence is None
assert page.oldest_sequence == 3
@pytest.mark.asyncio
async def test_forward_mode_no_boundary_expansion(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pagination never triggers backward boundary expansion."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0, role="tool"), _make_msg(1, role="tool")],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
assert find_many.call_count == 0
@pytest.mark.asyncio
async def test_forward_tail_boundary_trims_trailing_tool_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pages that end with tool messages are trimmed to the owning
assistant so the next after_sequence page doesn't start mid-tool-group."""
find_first, _ = mock_db
# DB returns 4 messages ASC: assistant at 0, tool at 1, tool at 2, tool at 3
find_first.return_value = _make_session(
messages=[
_make_msg(0, role="assistant"),
_make_msg(1, role="tool"),
_make_msg(2, role="tool"),
_make_msg(3, role="tool"),
],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
assert page is not None
# Page should be trimmed to end at the assistant message
assert [m.sequence for m in page.messages] == [0]
assert page.newest_sequence == 0
# has_more must be True so the client fetches the tool messages on next page
assert page.has_more is True
@pytest.mark.asyncio
async def test_forward_tail_boundary_no_trim_when_last_not_tool(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pages that end with a non-tool message are not trimmed."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[
_make_msg(0, role="user"),
_make_msg(1, role="assistant"),
_make_msg(2, role="tool"),
_make_msg(3, role="assistant"),
],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
assert page is not None
assert [m.sequence for m in page.messages] == [0, 1, 2, 3]
assert page.newest_sequence == 3
assert page.has_more is False
@pytest.mark.asyncio
async def test_user_id_filter_applied_to_session_where(
mock_db: tuple[AsyncMock, AsyncMock],

View File

@@ -1,71 +0,0 @@
"""Per-request idempotency lock for the /stream endpoint.
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
Lifecycle
---------
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
``None`` when the key already exists (duplicate request).
2. ``release()`` — deletes the key. Must be called on turn completion or turn
error so the next legitimate send is never blocked.
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
the backend turn is still running, and releasing would reopen the duplicate
window for infra-level retries. The 30 s TTL is the safety net.
"""
import hashlib
import logging
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_KEY_PREFIX = "chat:msg_dedup"
_TTL_SECONDS = 30
class _DedupLock:
def __init__(self, key: str, redis) -> None:
self._key = key
self._redis = redis
async def release(self) -> None:
"""Best-effort key deletion. The TTL handles failures silently."""
try:
await self._redis.delete(self._key)
except Exception:
pass
async def acquire_dedup_lock(
session_id: str,
message: str | None,
file_ids: list[str] | None,
) -> _DedupLock | None:
"""Acquire the idempotency lock for this (session, message, files) tuple.
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
Returns ``None`` when a duplicate is detected (lock already held).
Returns ``None`` when there is nothing to deduplicate (no message, no files).
"""
if not message and not file_ids:
return None
sorted_ids = ":".join(sorted(file_ids or []))
content_hash = hashlib.sha256(
f"{session_id}:{message or ''}:{sorted_ids}".encode()
).hexdigest()[:16]
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
redis = await get_redis_async()
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
if not acquired:
logger.warning(
f"[STREAM] Duplicate user message blocked for session {session_id}, "
f"hash={content_hash} — returning empty SSE",
)
return None
return _DedupLock(key, redis)

View File

@@ -1,94 +0,0 @@
"""Unit tests for backend.copilot.message_dedup."""
from unittest.mock import AsyncMock
import pytest
import pytest_mock
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
return mock_redis
@pytest.mark.asyncio
async def test_acquire_returns_none_when_no_message_no_files(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Nothing to deduplicate — no Redis call made, None returned."""
mock_redis = _patch_redis(mocker, set_returns=True)
result = await acquire_dedup_lock("sess-1", None, None)
assert result is None
mock_redis.set.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_returns_lock_on_first_request(
mocker: pytest_mock.MockerFixture,
) -> None:
"""First request acquires the lock and returns a _DedupLock."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
mock_redis.set.assert_called_once()
key_arg = mock_redis.set.call_args.args[0]
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
@pytest.mark.asyncio
async def test_acquire_returns_none_on_duplicate(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Duplicate request (NX fails) returns None to signal the caller."""
_patch_redis(mocker, set_returns=None)
result = await acquire_dedup_lock("sess-1", "hello", None)
assert result is None
@pytest.mark.asyncio
async def test_acquire_key_stable_across_file_order(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs are sorted before hashing so order doesn't affect the key."""
mock_redis_1 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
key_ab = mock_redis_1.set.call_args.args[0]
mock_redis_2 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
key_ba = mock_redis_2.set.call_args.args[0]
assert key_ab == key_ba
@pytest.mark.asyncio
async def test_release_deletes_key(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() calls Redis delete exactly once."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release()
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_release_swallows_redis_error(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() must not raise even when Redis delete fails."""
mock_redis = _patch_redis(mocker, set_returns=True)
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release() # must not raise
mock_redis.delete.assert_called_once()

View File

@@ -1,9 +1,8 @@
import asyncio
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any, Callable, Self, cast
from weakref import WeakValueDictionary
from typing import Any, AsyncIterator, Callable, Self, cast
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@@ -522,10 +521,7 @@ async def upsert_chat_session(
callers are aware of the persistence failure.
RedisError: If the cache write fails (after successful DB write).
"""
# Acquire session-specific lock to prevent concurrent upserts
lock = await _get_session_lock(session.session_id)
async with lock:
async with _get_session_lock(session.session_id) as _:
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
@@ -651,20 +647,50 @@ async def _save_session_to_db(
msg.sequence = existing_message_count + i
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
async def append_and_save_message(
session_id: str, message: ChatMessage
) -> ChatSession | None:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
Returns the updated session, or None if the message was detected as a
duplicate (idempotency guard). Callers must check for None and skip any
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
async with lock:
session = await get_chat_session(session_id)
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
The idempotency check below provides a last-resort guard when the lock degrades.
"""
async with _get_session_lock(session_id) as lock_acquired:
# When the lock degraded (Redis down or 2s timeout), bypass cache for
# the idempotency check. Stale cache could let two concurrent writers
# both see the old state, pass the check, and write the same message.
if lock_acquired:
session = await get_chat_session(session_id)
else:
session = await _get_session_from_db(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
# Idempotency: skip if the trailing block of same-role messages already
# contains this content. Uses is_message_duplicate which checks all
# consecutive trailing messages of the same role, not just [-1].
#
# This collapses infra/nginx retries whether they land on the same pod
# (serialised by the Redis lock) or a different pod.
#
# Legit same-text messages are distinguished by the assistant turn
# between them: if the user said "yes", got a response, and says
# "yes" again, session.messages[-1] is the assistant reply, so the
# role check fails and the second message goes through normally.
#
# Edge case: if a turn dies without writing any assistant message,
# the user's next send of the same text is blocked here permanently.
# The fix is to ensure failed turns always write an error/timeout
# assistant message so the session always ends on an assistant turn.
if message.content is not None and is_message_duplicate(
session.messages, message.role, message.content
):
return None # duplicate — caller should skip enqueue
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
@@ -679,6 +705,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
# Invalidate the stale entry so future reads fall back to DB,
# preventing a retry from bypassing the idempotency check above.
await invalidate_session_cache(session_id)
return session
@@ -699,9 +728,7 @@ async def append_message_if(
Returns the updated session on append, or ``None`` if the predicate
rejected, the session no longer exists, or the append failed.
"""
lock = await _get_session_lock(session_id)
async with lock:
async with _get_session_lock(session_id) as _lock_acquired:
# Read from DB directly — the Redis cache can be stale because the
# executor's upsert_chat_session overwrites it with in-memory copies
# during streaming, which may not include messages appended by the
@@ -815,10 +842,6 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
@@ -883,25 +906,38 @@ async def update_session_title(
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
@asynccontextmanager
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
"""Distributed Redis lock for a session, usable as an async context manager.
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Yields True if the lock was acquired, False if it timed out or Redis was
unavailable. Callers should treat False as a degraded mode and prefer fresh
DB reads over cache to avoid acting on stale state.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
_lock_key = f"copilot:session_lock:{session_id}"
lock = None
acquired = False
try:
_redis = await get_redis_async()
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
acquired = await lock.acquire(blocking=True)
if not acquired:
logger.warning(
"Could not acquire session lock for %s within 2s", session_id
)
except Exception as e:
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
try:
yield acquired
finally:
if acquired and lock is not None:
try:
await lock.release()
except Exception:
pass # TTL will expire the key

View File

@@ -11,11 +11,13 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from pytest_mock import MockerFixture
from .model import (
ChatMessage,
ChatSession,
Usage,
append_and_save_message,
get_chat_session,
is_message_duplicate,
maybe_append_user_message,
@@ -574,3 +576,345 @@ def test_maybe_append_assistant_skips_duplicate():
result = maybe_append_user_message(session, "dup", is_user_message=False)
assert result is False
assert len(session.messages) == 2
# --------------------------------------------------------------------------- #
# append_and_save_message #
# --------------------------------------------------------------------------- #
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
s = ChatSession.new(user_id="u1", dry_run=False)
s.messages = list(msgs)
return s
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_returns_none_for_duplicate(
mocker: MockerFixture,
) -> None:
"""append_and_save_message returns None when the trailing message is a duplicate."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="hello")
)
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_appends_new_message(
mocker: MockerFixture,
) -> None:
"""append_and_save_message appends a non-duplicate message and returns the session."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=2)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="second message")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None
assert result.messages[-1].content == "second message"
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_when_session_not_found(
mocker: MockerFixture,
) -> None:
"""append_and_save_message raises ValueError when the session does not exist."""
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=None,
)
with pytest.raises(ValueError, match="not found"):
await append_and_save_message(
"missing-session-id", ChatMessage(role="user", content="hi")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_lock_degraded(
mocker: MockerFixture,
) -> None:
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
# DB path was used (not cache-first)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_database_error_on_save_failure(
mocker: MockerFixture,
) -> None:
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
from backend.util.exceptions import DatabaseError
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("db down"),
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
with pytest.raises(DatabaseError):
await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
mocker: MockerFixture,
) -> None:
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("redis write failed"),
)
mock_invalidate = mocker.patch(
"backend.copilot.model.invalidate_session_cache",
new_callable=mocker.AsyncMock,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
# DB write succeeded, cache invalidation was called
mock_invalidate.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_redis_unavailable(
mocker: MockerFixture,
) -> None:
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
side_effect=ConnectionError("redis down"),
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_lock_release_failure_is_ignored(
mocker: MockerFixture,
) -> None:
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock(
side_effect=RuntimeError("release failed")
)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None

View File

@@ -174,6 +174,7 @@ sandbox so `bash_exec` can access it for further processing.
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
### GitHub CLI (`gh`) and git
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.

View File

@@ -8,7 +8,7 @@ Cross-mode transcript flow
==========================
Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking
mode) read and write the same JSONL transcript store via
mode) read and write the same CLI session store via
``backend.copilot.transcript.upload_transcript`` /
``download_transcript``.
@@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch:
@pytest.mark.asyncio
async def test_scenario_s_baseline_loads_sdk_transcript(self):
"""Scenario S: SDK-written transcript is accepted by baseline's load helper."""
"""Scenario S: SDK-written CLI session is accepted by baseline's load helper."""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.model import ChatMessage
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch:
sdk_transcript = builder_sdk.to_jsonl()
# Baseline session now has those 2 SDK messages + 1 new baseline message.
download = TranscriptDownload(content=sdk_transcript, message_count=2)
restore = TranscriptDownload(
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
)
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3, # 2 SDK + 1 new baseline
session_messages=[
ChatMessage(role="user", content="sdk-question"),
ChatMessage(role="assistant", content="sdk-answer"),
ChatMessage(role="user", content="baseline-question"),
],
transcript_builder=baseline_builder,
)
# Transcript is valid and covers the prefix.
# CLI session is valid and covers the prefix.
assert covers is True
assert dl is not None
assert baseline_builder.entry_count == 2
@pytest.mark.asyncio
async def test_scenario_s_stale_sdk_transcript_not_loaded(self):
"""Scenario S (stale): SDK transcript is stale — baseline does not load it.
"""Scenario S (stale): SDK CLI session is stale — baseline does not load it.
If SDK mode produced more turns than the transcript captured (e.g.
upload failed on one turn), the baseline rejects the stale transcript
If SDK mode produced more turns than the session captured (e.g.
upload failed on one turn), the baseline rejects the stale session
to avoid injecting an incomplete history.
"""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.model import ChatMessage
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch:
)
sdk_transcript = builder_sdk.to_jsonl()
# Transcript covers only 2 messages but session has 10 (many SDK turns).
download = TranscriptDownload(content=sdk_transcript, message_count=2)
# Session covers only 2 messages but session has 10 (many SDK turns).
# With watermark=2 and 10 total messages, detect_gap will fill the gap
# by appending messages 2..8 (positions 2 to total-2).
restore = TranscriptDownload(
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
)
# Build a session with 10 alternating user/assistant messages + current user
session_messages = [
ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}")
for i in range(10)
]
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
session_messages=session_messages,
transcript_builder=baseline_builder,
)
# Stale transcript must be rejected.
assert covers is False
assert baseline_builder.is_empty
# With gap filling, covers is True and gap messages are appended.
assert covers is True
assert dl is not None
# 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn)
assert baseline_builder.entry_count == 9

View File

@@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.transcript import (
TranscriptDownload,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
@@ -999,14 +1000,15 @@ def _make_sdk_patches(
f"{_SVC}.download_transcript",
dict(
new_callable=AsyncMock,
return_value=MagicMock(content=original_transcript, message_count=2),
return_value=TranscriptDownload(
content=original_transcript.encode("utf-8"),
message_count=2,
mode="sdk",
),
),
),
(
f"{_SVC}.restore_cli_session",
dict(new_callable=AsyncMock, return_value=True),
),
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
(f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.validate_transcript", dict(return_value=True)),
(
f"{_SVC}.compact_transcript",
@@ -1037,7 +1039,6 @@ def _make_sdk_patches(
claude_agent_fallback_model=None,
),
),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
]
@@ -1914,14 +1915,14 @@ class TestStreamChatCompletionRetryIntegration:
compacted_transcript=None,
client_side_effect=_client_factory,
)
# Override restore_cli_session to return False (CLI native session unavailable)
# Override download_transcript to return None (CLI native session unavailable)
patches = [
(
(
f"{_SVC}.restore_cli_session",
dict(new_callable=AsyncMock, return_value=False),
f"{_SVC}.download_transcript",
dict(new_callable=AsyncMock, return_value=None),
)
if p[0] == f"{_SVC}.restore_cli_session"
if p[0] == f"{_SVC}.download_transcript"
else p
)
for p in patches
@@ -1944,7 +1945,7 @@ class TestStreamChatCompletionRetryIntegration:
# captured_options holds {"options": ClaudeAgentOptions}, so check
# the attribute directly rather than dict keys.
assert not getattr(captured_options.get("options"), "resume", None), (
f"--resume was set even though restore_cli_session returned False: "
f"--resume was set even though download_transcript returned None: "
f"{captured_options}"
)
assert any(isinstance(e, StreamStart) for e in events)

View File

@@ -365,7 +365,7 @@ def create_security_hooks(
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
# Sanitize untrusted input: strip control chars for logging AND
# for the value passed downstream. read_compacted_entries()
# validates against _projects_base() as defence-in-depth, but
# validates against projects_base() as defence-in-depth, but
# sanitizing here prevents log injection and rejects obviously
# malformed paths early.
transcript_path = _sanitize(

View File

@@ -16,6 +16,7 @@ import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, NotRequired, cast
if TYPE_CHECKING:
@@ -92,12 +93,15 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from ..transcript import (
_run_compression,
TranscriptDownload,
cleanup_stale_project_dirs,
cli_session_path,
compact_transcript,
download_transcript,
extract_context_messages,
projects_base,
read_compacted_entries,
restore_cli_session,
upload_cli_session,
strip_for_upload,
upload_transcript,
validate_transcript,
)
@@ -121,7 +125,12 @@ config = ChatConfig()
class _SystemPromptPreset(SystemPromptPreset, total=False):
"""Extends SystemPromptPreset with fields added in claude-agent-sdk 0.1.59."""
"""Extends :class:`SystemPromptPreset` with ``exclude_dynamic_sections``.
The field was added to the upstream TypedDict in claude-agent-sdk 0.1.59.
Until the package is pinned to that version we declare it locally so Pyright
accepts the kwarg without a ``# type: ignore`` comment.
"""
exclude_dynamic_sections: NotRequired[bool]
@@ -849,6 +858,181 @@ def _make_sdk_cwd(session_id: str) -> str:
return cwd
def _write_cli_session_to_disk(
content: bytes,
sdk_cwd: str,
session_id: str,
log_prefix: str,
) -> bool:
"""Write downloaded CLI session bytes to disk so the CLI can --resume.
Returns True on success, False if the path is invalid or the write fails.
Path-traversal guard: rejects paths outside the CLI projects base.
"""
session_file = cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
_pbase = projects_base()
if not real_path.startswith(_pbase + os.sep):
logger.warning(
"%s CLI session restore path outside projects base: %s",
log_prefix,
os.path.basename(session_file),
)
return False
try:
os.makedirs(os.path.dirname(real_path), exist_ok=True)
Path(real_path).write_bytes(content)
logger.info(
"%s Wrote CLI session to disk (%dB) for --resume",
log_prefix,
len(content),
)
return True
except OSError as e:
logger.warning(
"%s Failed to write CLI session file %s: %s",
log_prefix,
os.path.basename(session_file),
e.strerror or str(e),
)
return False
def read_cli_session_from_disk(
sdk_cwd: str,
session_id: str,
log_prefix: str,
) -> bytes | None:
"""Read the CLI session JSONL file from disk after the SDK turn.
Returns the file bytes, or None if the file is missing, outside the
projects base, or unreadable.
Path-traversal guard: rejects paths outside the CLI projects base.
"""
session_file = cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
_pbase = projects_base()
if not real_path.startswith(_pbase + os.sep):
logger.warning(
"%s CLI session file outside projects base, skipping upload: %s",
log_prefix,
os.path.basename(real_path),
)
return None
try:
raw_bytes = Path(real_path).read_bytes()
except FileNotFoundError:
logger.debug(
"%s CLI session file not found, skipping upload: %s",
log_prefix,
os.path.basename(session_file),
)
return None
except OSError as e:
logger.warning(
"%s Failed to read CLI session file %s: %s",
log_prefix,
os.path.basename(session_file),
e.strerror or str(e),
)
return None
# Strip stale thinking blocks and metadata entries before uploading.
# Thinking blocks from non-last turns can be massive; keeping them causes
# the CLI to auto-compact its session when the context window fills up,
# silently losing conversation history.
try:
raw_text = raw_bytes.decode("utf-8")
stripped_text = strip_for_upload(raw_text)
stripped_bytes = stripped_text.encode("utf-8")
except UnicodeDecodeError:
logger.warning("%s CLI session is not valid UTF-8, uploading raw", log_prefix)
return raw_bytes
except (OSError, ValueError) as e:
# OSError: encode/decode I/O failure; ValueError: malformed JSONL in strip.
# Other unexpected exceptions are not silently swallowed here so they propagate
# to the outer OSError handler and are logged with exc_info.
logger.warning(
"%s Failed to strip CLI session, uploading raw: %s", log_prefix, e
)
return raw_bytes
if len(stripped_bytes) < len(raw_bytes):
# Write back locally so same-pod turns also benefit.
try:
Path(real_path).write_bytes(stripped_bytes)
logger.info(
"%s Stripped CLI session: %dB → %dB",
log_prefix,
len(raw_bytes),
len(stripped_bytes),
)
except OSError as e:
# write_bytes failed — stripped content is still valid for GCS upload even
# though the local write-back failed (same-pod optimization silently skipped).
logger.warning(
"%s Failed to write back stripped CLI session: %s",
log_prefix,
e.strerror or str(e),
)
return stripped_bytes
def process_cli_restore(
cli_restore: TranscriptDownload,
sdk_cwd: str,
session_id: str,
log_prefix: str,
) -> tuple[str, bool]:
"""Validate and write a restored CLI session to disk.
Decodes bytes → UTF-8, strips progress entries and stale thinking blocks,
validates the result, then writes the stripped content to disk so the CLI
can ``--resume`` from it.
Returns ``(stripped_content, success)`` where ``success=False`` means the
content was invalid or the disk write failed (caller should skip --resume).
"""
try:
raw_bytes = cli_restore.content
raw_str = (
raw_bytes.decode("utf-8") if isinstance(raw_bytes, bytes) else raw_bytes
)
except UnicodeDecodeError:
logger.warning(
"%s CLI session content is not valid UTF-8, skipping", log_prefix
)
return "", False
stripped = strip_for_upload(raw_str)
is_valid = validate_transcript(stripped)
# Use len(raw_str) rather than len(cli_restore.content) so the unit is always
# characters (raw_str is always str at this point regardless of input type).
# lines_stripped = original lines minus remaining lines after stripping.
_original_lines = len(raw_str.strip().split("\n")) if raw_str.strip() else 0
_remaining_lines = len(stripped.strip().split("\n")) if stripped.strip() else 0
logger.info(
"%s Restored CLI session: %dB raw, %d lines stripped, msg_count=%d, valid=%s",
log_prefix,
len(raw_str),
_original_lines - _remaining_lines,
cli_restore.message_count,
is_valid,
)
if not is_valid:
logger.warning(
"%s CLI session content invalid after strip — running without --resume",
log_prefix,
)
return "", False
stripped_bytes = stripped.encode("utf-8")
if not _write_cli_session_to_disk(stripped_bytes, sdk_cwd, session_id, log_prefix):
return "", False
return stripped, True
async def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK session artifacts for a specific working directory.
@@ -922,8 +1106,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
result.append(block)
else:
logger.warning(
f"[SDK] Unknown content block type: {type(block).__name__}. "
f"This may indicate a new SDK version with additional block types."
"[SDK] Unknown content block type: %s."
" This may indicate a new SDK version with additional block types.",
type(block).__name__,
)
return result
@@ -978,10 +1163,11 @@ async def _compress_messages(
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
"[SDK] Context compacted: %d -> %d tokens (%d summarized, %d dropped)",
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
# Convert compressed dicts back to ChatMessages
return [
@@ -1048,11 +1234,17 @@ def _session_messages_to_transcript(messages: list[ChatMessage]) -> str:
)
if blocks:
builder.append_assistant(blocks)
elif msg.role == "tool" and msg.tool_call_id:
builder.append_tool_result(
tool_use_id=msg.tool_call_id,
content=msg.content or "",
)
elif msg.role == "tool":
if msg.tool_call_id:
builder.append_tool_result(
tool_use_id=msg.tool_call_id,
content=msg.content or "",
)
else:
# Malformed tool message — no tool_call_id to link to an
# assistant tool_use block. Skip to avoid an unmatched
# tool_result entry in the builder (which would confuse --resume).
logger.warning("[SDK] Skipping tool gap message with no tool_call_id")
return builder.to_jsonl()
@@ -1098,6 +1290,7 @@ async def _build_query_message(
transcript_msg_count: int,
session_id: str,
target_tokens: int | None = None,
prior_messages: "list[ChatMessage] | None" = None,
) -> tuple[str, bool]:
"""Build the query message with appropriate context.
@@ -1203,15 +1396,16 @@ async def _build_query_message(
)
return current_message, False
source = prior_messages if prior_messages is not None else prior
logger.warning(
"[SDK] [%s] No --resume for %d-message session — compressing"
" full session history (pod affinity issue or first turn after"
" restore failure); target_tokens=%s",
"[SDK] [%s] No --resume for %d-message session — compressing context "
"(source=%s, target_tokens=%s)",
session_id[:8],
msg_count,
"transcript+gap" if prior_messages is not None else "full-db",
target_tokens,
)
compressed, was_compressed = await _compress_messages(prior, target_tokens)
compressed, was_compressed = await _compress_messages(source, target_tokens)
history_context = _format_conversation_context(compressed)
if history_context:
logger.info(
@@ -1228,7 +1422,7 @@ async def _build_query_message(
"[SDK] [%s] Fallback context empty after compression"
" (%d messages) — sending message without history",
session_id[:8],
len(prior),
len(source),
)
return current_message, False
@@ -2233,6 +2427,161 @@ async def _seed_transcript(
return _seeded, True, len(_prior)
@dataclass
class _RestoreResult:
"""Return value from ``_restore_cli_session_for_turn``."""
transcript_content: str = ""
transcript_covers_prefix: bool = True
use_resume: bool = False
resume_file: str | None = None
transcript_msg_count: int = 0
baseline_download: "TranscriptDownload | None" = None
context_messages: "list[ChatMessage] | None" = None
async def _restore_cli_session_for_turn(
user_id: str | None,
session_id: str,
session: "ChatSession",
sdk_cwd: str,
transcript_builder: "TranscriptBuilder",
log_prefix: str,
) -> _RestoreResult:
"""Download, validate and restore a CLI session for ``--resume`` on this turn.
Performs a single GCS round-trip to fetch the session bytes + message_count
watermark. Falls back to DB-message reconstruction when GCS has no session
(first turn or upload missed).
Returns a ``_RestoreResult`` with all transcript-related state ready for the
caller to merge into its local variables.
"""
result = _RestoreResult()
if not (config.claude_agent_use_resume and user_id and len(session.messages) > 1):
return result
try:
cli_restore = await download_transcript(
user_id, session_id, log_prefix=log_prefix
)
except Exception as restore_err:
logger.warning(
"%s CLI session restore failed, continuing without --resume: %s",
log_prefix,
restore_err,
)
cli_restore = None
# Only attempt --resume for SDK-written transcripts.
# Baseline-written transcripts use TranscriptBuilder format (synthetic IDs,
# stripped fields) that may not be valid for --resume.
if cli_restore is not None and cli_restore.mode != "sdk":
logger.info(
"%s Transcript written by mode=%r — skipping --resume, "
"will use transcript content + gap for context",
log_prefix,
cli_restore.mode,
)
result.baseline_download = cli_restore # keep for extract_context_messages
cli_restore = None
# Validate, strip, and write to disk — delegate to helper to reduce
# function complexity. Writing an invalid/corrupt file to disk then
# falling back to "no --resume" would cause the CLI to fail with
# "Session ID already in use" because the file exists at the expected
# session path, so we validate BEFORE any disk write.
stripped = ""
if cli_restore is not None and sdk_cwd:
stripped, ok = process_cli_restore(cli_restore, sdk_cwd, session_id, log_prefix)
if not ok:
result.transcript_covers_prefix = False
cli_restore = None
if cli_restore is None and sdk_cwd:
# Validation failed or GCS returned no session. Delete any
# existing local session file so the CLI doesn't reject the
# session_id with "Session ID already in use". T1 may have
# left a valid file at this path; we clear it so the fallback
# path (session_id= without --resume) can create a new session.
_stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id))
if Path(_stale_path).exists() and _stale_path.startswith(
projects_base() + os.sep
):
try:
Path(_stale_path).unlink()
logger.debug(
"%s Removed stale local CLI session file for clean fallback",
log_prefix,
)
except OSError as _unlink_err:
logger.debug(
"%s Failed to remove stale local session file: %s",
log_prefix,
_unlink_err,
)
if cli_restore is not None:
result.transcript_content = stripped
transcript_builder.load_previous(stripped, log_prefix=log_prefix)
result.use_resume = True
result.resume_file = session_id
result.transcript_msg_count = cli_restore.message_count
return result
# No valid --resume source (mode="baseline" or no GCS file).
# Build context from transcript content + gap, falling back to full DB.
# extract_context_messages handles both: non-None baseline_download uses
# the compacted transcript + gap; None falls back to all prior DB messages.
context_msgs = extract_context_messages(result.baseline_download, session.messages)
result.context_messages = context_msgs
result.transcript_msg_count = (
result.baseline_download.message_count
if result.baseline_download is not None
and result.baseline_download.message_count > 0
else len(session.messages) - 1
)
result.transcript_covers_prefix = True
logger.info(
"%s Context built from %s: %d messages (transcript watermark=%d, "
"will inject as <conversation_history>)",
log_prefix,
(
"baseline transcript + gap"
if result.baseline_download is not None
else "DB fallback"
),
len(context_msgs),
result.transcript_msg_count,
)
# Load baseline transcript content into builder so the upload path has accurate state.
# Also sets result.transcript_content so the _seed_transcript guard in the caller
# (``not transcript_content``) does not overwrite this builder state with a DB
# reconstruction — which would duplicate entries since load_previous appends.
if result.baseline_download is not None:
try:
raw_for_builder = result.baseline_download.content
if isinstance(raw_for_builder, bytes):
raw_for_builder = raw_for_builder.decode("utf-8")
stripped = strip_for_upload(raw_for_builder)
if validate_transcript(stripped):
transcript_builder.load_previous(stripped, log_prefix=log_prefix)
result.transcript_content = stripped
except (UnicodeDecodeError, ValueError, OSError) as _load_err:
# UnicodeDecodeError: non-UTF-8 content; ValueError: malformed JSONL in
# strip_for_upload; OSError: encode/decode I/O failure. Unexpected
# exceptions propagate so programming errors are not silently masked.
logger.debug(
"%s Could not load baseline transcript into builder: %s",
log_prefix,
_load_err,
)
return result
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -2427,28 +2776,9 @@ async def stream_chat_completion_sdk(
return sandbox
async def _fetch_transcript():
"""Download transcript for --resume if applicable."""
if not (
config.claude_agent_use_resume and user_id and len(session.messages) > 1
):
return None
try:
return await download_transcript(
user_id, session_id, log_prefix=log_prefix
)
except Exception as transcript_err:
logger.warning(
"%s Transcript download failed, continuing without --resume: %s",
log_prefix,
transcript_err,
)
return None
e2b_sandbox, (base_system_prompt, understanding), dl = await asyncio.gather(
e2b_sandbox, (base_system_prompt, understanding) = await asyncio.gather(
_setup_e2b(),
_build_system_prompt(user_id if not has_history else None),
_fetch_transcript(),
)
use_e2b = e2b_sandbox is not None
@@ -2473,95 +2803,17 @@ async def stream_chat_completion_sdk(
warm_ctx = await fetch_warm_context(user_id, message or "") or ""
# Process transcript download result and restore CLI native session.
# The CLI native session file (uploaded after each turn) is the
# source of truth for --resume. Our custom JSONL (TranscriptEntry)
# is loaded into the builder for future upload_transcript calls.
transcript_msg_count = 0
if dl:
is_valid = validate_transcript(dl.content)
dl_lines = dl.content.strip().split("\n") if dl.content else []
logger.info(
"%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s",
log_prefix,
len(dl.content),
len(dl_lines),
dl.message_count,
is_valid,
)
if is_valid:
# Load previous FULL context into builder for state tracking.
transcript_content = dl.content
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
# Restore CLI's native session file so --resume session_id works.
# Falls back gracefully if not available (first turn or upload missed).
# user_id is guaranteed non-None here: _fetch_transcript only sets dl
# when `config.claude_agent_use_resume and user_id` is truthy.
cli_restored = user_id is not None and await restore_cli_session(
user_id, session_id, sdk_cwd, log_prefix=log_prefix
)
if cli_restored:
use_resume = True
resume_file = session_id # CLI --resume expects UUID, not file path
transcript_msg_count = dl.message_count
logger.info(
"%s Using --resume %s (%dB transcript, msg_count=%d)",
log_prefix,
session_id[:8],
len(dl.content),
transcript_msg_count,
)
else:
# Builder loaded but CLI native session not available.
# --resume will not be used this turn; upload after turn
# will seed the native session for the next turn.
#
# Still record transcript_msg_count so _build_query_message
# can use the transcript-aware gap path (inject only new
# messages since the transcript end) instead of compressing
# the full DB history. This avoids prompt-too-long on
# large sessions where the CLI session is temporarily
# unavailable (e.g. mixed-version rolling deployment).
transcript_msg_count = dl.message_count
logger.info(
"%s CLI session not restored — running without"
" --resume this turn (transcript_msg_count=%d for"
" gap-aware fallback)",
log_prefix,
transcript_msg_count,
)
else:
logger.warning("%s Transcript downloaded but invalid", log_prefix)
transcript_covers_prefix = False
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
# No transcript in storage — reconstruct from DB messages as a
# last-resort fallback (e.g., first turn after a crash or transition).
# This path loses tool call IDs and structural fidelity but prevents
# a completely context-free response for established sessions.
prior = session.messages[:-1]
reconstructed = _session_messages_to_transcript(prior)
if reconstructed:
# Populate builder only; no --resume since there is no CLI
# native session to restore. The transcript builder state is
# still useful for the upload that seeds future native sessions.
transcript_content = reconstructed
transcript_builder.load_previous(reconstructed, log_prefix=log_prefix)
transcript_msg_count = len(prior)
transcript_covers_prefix = True
logger.info(
"%s Reconstructed transcript from %d session messages "
"(no CLI native session — running without --resume this turn)",
log_prefix,
len(prior),
)
else:
logger.warning(
"%s No transcript available and reconstruction produced empty"
" output (%d messages in session)",
log_prefix,
len(session.messages),
)
transcript_covers_prefix = False
# Restore CLI session — single GCS round-trip covers both --resume and builder state.
# message_count watermark lives in the companion .meta.json alongside the session file.
_restore = await _restore_cli_session_for_turn(
user_id, session_id, session, sdk_cwd, transcript_builder, log_prefix
)
transcript_content = _restore.transcript_content
transcript_covers_prefix = _restore.transcript_covers_prefix
use_resume = _restore.use_resume
resume_file = _restore.resume_file
transcript_msg_count = _restore.transcript_msg_count
restore_context_messages = _restore.context_messages
yield StreamStart(messageId=message_id, sessionId=session_id)
@@ -2680,14 +2932,14 @@ async def stream_chat_completion_sdk(
else:
# Set session_id whenever NOT resuming so the CLI writes the
# native session file to a predictable path for
# upload_cli_session() after the turn. This covers:
# upload_transcript() after the turn. This covers:
# • T1 fresh: no prior history, first SDK turn.
# • Mode-switch T1: has_history=True (prior baseline turns in
# DB) but no CLI session file was ever uploaded — the CLI has
# never been invoked with this session_id before.
# • T2+ without --resume (restore failed): no session file was
# restored to local storage (restore_cli_session returned
# False), so no conflict with an existing file.
# restored to local storage (download_transcript returned
# None), so no conflict with an existing file.
# When --resume is active the session_id is already implied by
# the resume file; passing it again would be rejected by the CLI.
sdk_options_kwargs["session_id"] = session_id
@@ -2780,6 +3032,7 @@ async def stream_chat_completion_sdk(
use_resume,
transcript_msg_count,
session_id,
prior_messages=restore_context_messages,
)
# If files are attached, prepare them: images become vision
# content blocks in the user message, other files go to sdk_cwd.
@@ -2909,7 +3162,7 @@ async def stream_chat_completion_sdk(
elif "session_id" in sdk_options_kwargs:
# Initial invocation used session_id (T1 or mode-switch
# T1): keep it so the CLI writes the session file to the
# predictable path for upload_cli_session(). Storage is
# predictable path for upload_transcript(). Storage is
# ephemeral per invocation, so no "Session ID already in
# use" conflict occurs — no prior file was restored.
sdk_options_kwargs_retry.pop("resume", None)
@@ -2932,6 +3185,10 @@ async def stream_chat_completion_sdk(
system_prompt, cross_user_cache=_cross_user_retry
)
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
# Retry intentionally omits prior_messages (transcript+gap context) and
# falls back to full session.messages[:-1] from DB — the authoritative
# source. transcript+gap is an optimisation for the first attempt only;
# on retry the extra overhead of full-DB context is acceptable.
state.query_message, state.was_compacted = await _build_query_message(
current_message,
session,
@@ -3367,86 +3624,23 @@ async def stream_chat_completion_sdk(
_background_tasks.add(_ingest_task)
_ingest_task.add_done_callback(_background_tasks.discard)
# --- Upload transcript for next-turn --resume ---
# TranscriptBuilder is the single source of truth. It mirrors the
# CLI's active context: on compaction, replace_entries() syncs it
# with the compacted session file. No CLI file read needed here.
if skip_transcript_upload:
logger.warning(
"%s Skipping transcript upload — transcript was dropped "
"during prompt-too-long recovery",
log_prefix,
)
elif (
config.claude_agent_use_resume
and user_id
and session is not None
and state is not None
):
try:
transcript_upload_content = state.transcript_builder.to_jsonl()
entry_count = state.transcript_builder.entry_count
if not transcript_upload_content:
logger.warning(
"%s No transcript to upload (builder empty)", log_prefix
)
elif not validate_transcript(transcript_upload_content):
logger.warning(
"%s Transcript invalid, skipping upload (entries=%d)",
log_prefix,
entry_count,
)
elif not transcript_covers_prefix:
logger.warning(
"%s Skipping transcript upload — builder does not "
"cover full session prefix (entries=%d, session=%d)",
log_prefix,
entry_count,
len(session.messages),
)
else:
logger.info(
"%s Uploading transcript (entries=%d, bytes=%d)",
log_prefix,
entry_count,
len(transcript_upload_content),
)
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=transcript_upload_content,
message_count=len(session.messages),
log_prefix=log_prefix,
)
)
except Exception as upload_err:
logger.error(
"%s Transcript upload failed in finally: %s",
log_prefix,
upload_err,
exc_info=True,
)
# --- Upload CLI native session file for cross-pod --resume ---
# The CLI writes its native session JSONL after each turn completes.
# Uploading it here enables --resume on any pod (no pod affinity needed).
# Runs after upload_transcript so both are available for the next turn.
# asyncio.shield: same pattern as upload_transcript above — if the
# outer finally-block coroutine is cancelled while awaiting shield,
# the CancelledError propagates (BaseException, not caught by
# `except Exception`) letting the caller handle cancellation, while
# the shielded inner coroutine continues running to completion so the
# upload is not lost. This is intentional and matches the pattern
# used for upload_transcript immediately above.
# The companion .meta.json carries the message_count watermark and mode
# so the next turn can restore both --resume context and gap-fill state
# in a single GCS round-trip via download_transcript().
# asyncio.shield: if the outer finally-block coroutine is cancelled
# while awaiting shield, the CancelledError propagates (BaseException,
# not caught by `except Exception`) letting the caller handle
# cancellation, while the shielded inner coroutine continues running
# to completion so the upload is not lost.
#
# NOTE: upload is attempted regardless of state.use_resume — even when
# this turn ran without --resume (restore failed or first T2+ on a new
# pod), the T1 session file at the expected path may still be present
# and should be re-uploaded so the next turn can resume from it.
# upload_cli_session silently skips when the file is absent, so this is
# always safe.
# read_cli_session_from_disk returns None when the file is absent, so
# this is always safe.
#
# Intentionally NOT gated on skip_transcript_upload: that flag is set
# when our custom JSONL transcript is dropped (transcript_lost=True on
@@ -3472,14 +3666,36 @@ async def stream_chat_completion_sdk(
skip_transcript_upload,
)
try:
await asyncio.shield(
upload_cli_session(
user_id=user_id,
session_id=session_id,
sdk_cwd=sdk_cwd,
log_prefix=log_prefix,
)
# Read the CLI's native session file from disk (written by the CLI
# after the turn), then upload the bytes to GCS.
_cli_content = read_cli_session_from_disk(
sdk_cwd, session_id, log_prefix
)
if _cli_content:
# Watermark = number of DB messages this transcript covers.
# len(session.messages) is accurate: the CLI session file
# was just written after the turn completed, so it covers
# all messages through this turn. Any gap from a prior
# missed upload was already detected by detect_gap and
# injected as context, so the model has the full history.
#
# Previously this used _final_tmsg_count + 2, which
# under-counted for tool-use turns (delta = 2 + 2*N_tool_calls),
# causing persistent spurious gap-fills on every subsequent turn.
# That concern was addressed by the inflated-watermark fix
# (using the GCS watermark as the anchor for gap detection),
# which makes len(session.messages) safe to use here.
_jsonl_covered = len(session.messages)
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=_cli_content,
message_count=_jsonl_covered,
mode="sdk",
log_prefix=log_prefix,
)
)
except Exception as cli_upload_err:
logger.warning(
"%s CLI session upload failed in finally: %s",

View File

@@ -22,6 +22,7 @@ from .service import (
_iter_sdk_messages,
_normalize_model_name,
_reduce_context,
_restore_cli_session_for_turn,
_TokenUsage,
)
@@ -615,3 +616,340 @@ class TestSdkSessionIdSelection:
)
assert retry.get("resume") == self.SESSION_ID
assert "session_id" not in retry
# ---------------------------------------------------------------------------
# _restore_cli_session_for_turn — mode check
# ---------------------------------------------------------------------------
class TestRestoreCliSessionModeCheck:
"""SDK skips --resume when the transcript was written by the baseline mode."""
@pytest.mark.asyncio
async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path):
"""A transcript with mode='baseline' must not be used as the --resume source.
The mode check discards the GCS baseline content and falls back to DB
reconstruction from session.messages instead.
"""
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="hello-unique-marker"),
ChatMessage(role="assistant", content="world-unique-marker"),
ChatMessage(role="user", content="follow up"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
# Baseline content with a sentinel that must NOT appear in the final transcript
baseline_restore = TranscriptDownload(
content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n',
message_count=1,
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
download_mock = AsyncMock(return_value=baseline_restore)
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=download_mock,
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
# download_transcript was called (attempted GCS restore)
download_mock.assert_awaited_once()
# use_resume must be False — baseline transcripts cannot be used with --resume
assert result.use_resume is False
# context_messages must be populated — new behaviour uses transcript content + gap
# instead of full DB reconstruction.
assert result.context_messages is not None
# The baseline transcript has 1 user message (BASELINE_SENTINEL).
# Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns [].
# Result: 1 message from transcript, no gap.
assert len(result.context_messages) == 1
assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "")
@pytest.mark.asyncio
async def test_sdk_mode_transcript_allows_resume(self, tmp_path):
"""A valid SDK-written transcript is accepted for --resume."""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "hi"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "hello"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="user", content="follow up"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
sdk_restore = TranscriptDownload(
content=content,
message_count=2,
mode="sdk",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=sdk_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is True
@pytest.mark.asyncio
async def test_baseline_mode_context_messages_from_transcript_content(
self, tmp_path
):
"""mode='baseline' → context_messages populated from transcript content + gap.
When a baseline-mode transcript exists, extract_context_messages converts
the JSONL content to ChatMessage objects and returns them in context_messages.
use_resume must remain False.
"""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
# Build a minimal valid JSONL transcript with 2 messages
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "TRANSCRIPT_USER"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="DB_USER"),
ChatMessage(role="assistant", content="DB_ASSISTANT"),
ChatMessage(role="user", content="current turn"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
baseline_restore = TranscriptDownload(
content=content,
message_count=2,
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=baseline_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is False
assert result.context_messages is not None
# Transcript content has 2 messages, no gap (watermark=2, session prior=2)
assert len(result.context_messages) == 2
assert result.context_messages[0].role == "user"
assert result.context_messages[1].role == "assistant"
assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "")
# transcript_content must be non-empty so the _seed_transcript guard in
# stream_chat_completion_sdk skips DB reconstruction (which would duplicate
# builder entries since load_previous appends).
assert result.transcript_content != ""
@pytest.mark.asyncio
async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path):
"""mode='baseline' + gap → context_messages includes transcript msgs and gap."""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
# Transcript covers only 2 messages; session has 4 prior + current turn
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "TRANSCRIPT_USER_0"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="DB_USER_0"),
ChatMessage(role="assistant", content="DB_ASSISTANT_1"),
ChatMessage(role="user", content="GAP_USER_2"),
ChatMessage(role="assistant", content="GAP_ASSISTANT_3"),
ChatMessage(role="user", content="current turn"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
baseline_restore = TranscriptDownload(
content=content,
message_count=2, # watermark=2; session has 4 prior → gap of 2
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=baseline_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is False
assert result.context_messages is not None
# 2 from transcript + 2 gap messages = 4 total
assert len(result.context_messages) == 4
roles = [m.role for m in result.context_messages]
assert roles == ["user", "assistant", "user", "assistant"]
# Gap messages come from DB (ChatMessage objects)
gap_user = result.context_messages[2]
gap_asst = result.context_messages[3]
assert gap_user.content == "GAP_USER_2"
assert gap_asst.content == "GAP_ASSISTANT_3"

View File

@@ -0,0 +1,95 @@
"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk.
The fix is at the upload step: when use_resume=True and transcript_msg_count>0
we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just
recorded) instead of len(session.messages). This prevents the "inflated
watermark" bug where a stale JSONL in GCS could hide missing context from
future gap-fill checks.
"""
from __future__ import annotations
def _compute_jsonl_covered(
use_resume: bool,
transcript_msg_count: int,
session_msg_count: int,
) -> int:
"""Mirror the watermark computation from ``stream_chat_completion_sdk``.
Extracted here so we can unit-test it independently without invoking the
full streaming stack.
"""
if use_resume and transcript_msg_count > 0:
return transcript_msg_count + 2
return session_msg_count
class TestWatermarkFix:
"""Watermark computation logic — mirrors the finally-block in SDK service."""
def test_inflated_watermark_triggers_gap_fill(self):
"""Stale JSONL (T12) with high watermark (46) → after fix, watermark=14.
Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1)
never fires because 46 >= 47-1=46, so context loss is silent.
After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and
the model receives the missing turns.
"""
# Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47
use_resume = True
transcript_msg_count = 12
session_msg_count = 47 # DB count (what old code used to set watermark)
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == 14 # 12 + 2, NOT 47
# Verify: the gap check would fire on next turn
# next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True
assert watermark < session_msg_count - 1
def test_no_false_positive_when_transcript_current(self):
"""Transcript current (watermark=46, DB=47) → gap stays 0.
When the JSONL actually covers T46 (the most recent assistant turn),
uploading watermark=46+2=48 means next turn's gap check sees
48 >= 48-1=47 → no gap. Correct.
"""
use_resume = True
transcript_msg_count = 46
session_msg_count = 47
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == 48 # 46 + 2
# Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap
next_turn_session = 48
assert watermark >= next_turn_session - 1
def test_fresh_session_falls_back_to_db_count(self):
"""use_resume=False → watermark = len(session.messages) (original behaviour)."""
use_resume = False
transcript_msg_count = 0
session_msg_count = 3
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == session_msg_count
def test_old_format_meta_zero_count_falls_back_to_db(self):
"""transcript_msg_count=0 (old-format meta with no count field) → DB fallback."""
use_resume = True
transcript_msg_count = 0 # old-format meta or not-yet-set
session_msg_count = 10
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == session_msg_count

View File

@@ -12,18 +12,20 @@ from backend.copilot.transcript import (
ENTRY_TYPE_MESSAGE,
STOP_REASON_END_TURN,
STRIPPABLE_TYPES,
TRANSCRIPT_STORAGE_PREFIX,
TranscriptDownload,
TranscriptMode,
cleanup_stale_project_dirs,
cli_session_path,
compact_transcript,
delete_transcript,
detect_gap,
download_transcript,
extract_context_messages,
projects_base,
read_compacted_entries,
restore_cli_session,
strip_for_upload,
strip_progress_entries,
strip_stale_thinking_blocks,
upload_cli_session,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -34,18 +36,20 @@ __all__ = [
"ENTRY_TYPE_MESSAGE",
"STOP_REASON_END_TURN",
"STRIPPABLE_TYPES",
"TRANSCRIPT_STORAGE_PREFIX",
"TranscriptDownload",
"TranscriptMode",
"cleanup_stale_project_dirs",
"cli_session_path",
"compact_transcript",
"delete_transcript",
"detect_gap",
"download_transcript",
"extract_context_messages",
"projects_base",
"read_compacted_entries",
"restore_cli_session",
"strip_for_upload",
"strip_progress_entries",
"strip_stale_thinking_blocks",
"upload_cli_session",
"upload_transcript",
"validate_transcript",
"write_transcript_to_tempfile",

View File

@@ -297,8 +297,8 @@ class TestStripProgressEntries:
class TestDeleteTranscript:
@pytest.mark.asyncio
async def test_deletes_both_jsonl_and_meta(self):
"""delete_transcript removes both the .jsonl and .meta.json files."""
async def test_deletes_cli_session_and_meta(self):
"""delete_transcript removes the CLI session .jsonl and .meta.json."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock()
@@ -309,7 +309,7 @@ class TestDeleteTranscript:
):
await delete_transcript("user-123", "session-456")
assert mock_storage.delete.call_count == 3
assert mock_storage.delete.call_count == 2
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
assert any(p.endswith(".jsonl") for p in paths)
assert any(p.endswith(".meta.json") for p in paths)
@@ -319,7 +319,7 @@ class TestDeleteTranscript:
"""If .jsonl delete fails, .meta.json delete is still attempted."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock(
side_effect=[Exception("jsonl delete failed"), None, None]
side_effect=[Exception("jsonl delete failed"), None]
)
with patch(
@@ -330,14 +330,14 @@ class TestDeleteTranscript:
# Should not raise
await delete_transcript("user-123", "session-456")
assert mock_storage.delete.call_count == 3
assert mock_storage.delete.call_count == 2
@pytest.mark.asyncio
async def test_handles_meta_delete_failure(self):
"""If .meta.json delete fails, no exception propagates."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock(
side_effect=[None, Exception("meta delete failed"), None]
side_effect=[None, Exception("meta delete failed")]
)
with patch(
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs:
nonexistent = str(tmp_path / "does-not-exist" / "projects")
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: nonexistent,
)
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks:
# Both entries of last turn (msg_last) preserved
assert lines[1]["message"]["content"][0]["type"] == "thinking"
assert lines[2]["message"]["content"][0]["type"] == "text"
class TestProcessCliRestore:
"""``process_cli_restore`` validates, strips, and writes CLI session to disk."""
def test_writes_stripped_bytes_not_raw(self, tmp_path):
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
import os
import re
from pathlib import Path
from unittest.mock import patch
from backend.copilot.sdk.service import process_cli_restore
from backend.copilot.transcript import TranscriptDownload
session_id = "12345678-0000-0000-0000-abcdef000001"
sdk_cwd = str(tmp_path)
projects_base_dir = str(tmp_path)
# Build raw content with a strippable progress entry + a valid user/assistant pair
raw_content = (
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
)
raw_bytes = raw_content.encode("utf-8")
restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk")
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
stripped_str, ok = process_cli_restore(
restore, sdk_cwd, session_id, "[Test]"
)
assert ok, "Expected successful restore"
# Find the written session file
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl"
assert session_file.exists(), "Session file should have been written"
written_bytes = session_file.read_bytes()
# The written bytes must be the stripped version (no progress entry)
assert (
b"progress" not in written_bytes
), "Raw bytes with progress entry should not have been written"
assert (
b"hello" in written_bytes
), "Stripped content should still contain assistant turn"
# Written bytes must equal the stripped string re-encoded
assert written_bytes == stripped_str.encode(
"utf-8"
), "Written bytes must equal stripped content"
def test_invalid_content_returns_false(self):
"""Content that fails validation after strip returns (empty, False)."""
from backend.copilot.sdk.service import process_cli_restore
from backend.copilot.transcript import TranscriptDownload
# A single progress-only entry — stripped result will be empty/invalid
raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
restore = TranscriptDownload(
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
)
stripped_str, ok = process_cli_restore(
restore,
"/tmp/nonexistent-sdk-cwd",
"12345678-0000-0000-0000-000000000099",
"[Test]",
)
assert not ok
assert stripped_str == ""
class TestReadCliSessionFromDisk:
"""``read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
def _build_session_file(self, tmp_path, session_id: str):
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
import os
import re
from pathlib import Path
sdk_cwd = str(tmp_path)
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_dir = Path(str(tmp_path)) / encoded_cwd
session_dir.mkdir(parents=True, exist_ok=True)
return sdk_cwd, session_dir / f"{session_id}.jsonl"
def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path):
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
from unittest.mock import patch
from backend.copilot.sdk.service import read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0001"
projects_base_dir = str(tmp_path)
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
# Write raw invalid UTF-8 bytes
session_file.write_bytes(b"\xff\xfe invalid utf-8\n")
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
assert result == b"\xff\xfe invalid utf-8\n"
def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path):
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
from unittest.mock import patch
from backend.copilot.sdk.service import read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0002"
projects_base_dir = str(tmp_path)
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
# Content with a strippable progress entry so stripped_bytes < raw_bytes
raw_content = (
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
)
session_file.write_bytes(raw_content.encode("utf-8"))
# Make the file read-only so write_bytes raises OSError on the write-back
session_file.chmod(0o444)
try:
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
finally:
session_file.chmod(0o644)
# Must return stripped bytes (not raw, not None) so GCS gets the clean version
assert result is not None
assert (
b"progress" not in result
), "Stripped bytes must not contain progress entry"
assert b"hello" in result, "Stripped bytes should contain assistant turn"

View File

@@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
# (CLI version, platform). When that happens, multi-turn still works
# via conversation compression (non-resume path), but we can't test
# the --resume round-trip.
transcript = None
cli_session = None
for _ in range(10):
await asyncio.sleep(0.5)
transcript = await download_transcript(test_user_id, session.session_id)
if transcript:
cli_session = await download_transcript(test_user_id, session.session_id)
# Wait until both the session bytes AND the message_count watermark are
# present — a session with message_count=0 means the .meta.json hasn't
# been uploaded yet, so --resume on the next turn would skip gap-fill.
if cli_session and cli_session.message_count > 0:
break
if not transcript:
if not cli_session:
return pytest.skip(
"CLI did not produce a usable transcript — "
"cannot test --resume round-trip in this environment"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
logger.info(
f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}"
)
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)

View File

@@ -423,20 +423,33 @@ async def subscribe_to_session(
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
# RACE CONDITION FIX: If session not found, retry once after small delay
# This handles the case where subscribe_to_session is called immediately
# after create_session but before Redis propagates the write
# RACE CONDITION FIX: If session not found, retry with backoff.
# Duplicate requests skip create_session and subscribe immediately; the
# original request's create_session (a Redis hset) may not have completed
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
# original request before the hset even starts.
if not meta:
logger.warning(
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
extra={"json_fields": {**log_meta}},
)
await asyncio.sleep(0.05) # 50ms
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
_max_retries = 3
_retry_delay = 0.1 # 100ms per attempt
for attempt in range(_max_retries):
logger.warning(
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
f"retrying after {int(_retry_delay * 1000)}ms",
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
)
await asyncio.sleep(_retry_delay)
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
logger.info(
f"[TIMING] Session found after {attempt + 1} retries",
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
)
break
else:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
f"({elapsed:.1f}ms total)",
extra={
"json_fields": {
**log_meta,
@@ -446,10 +459,6 @@ async def subscribe_to_session(
},
)
return None
logger.info(
"[TIMING] Session found after retry",
extra={"json_fields": {**log_meta}},
)
# Note: Redis client uses decode_responses=True, so keys are strings
session_status = meta.get("status", "")

View File

@@ -1,10 +1,10 @@
"""JSONL transcript management for stateless multi-turn resume.
The Claude Code CLI persists conversations as JSONL files (one JSON object per
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
(progress entries, metadata), and upload the result to bucket storage. On the
next turn we download the transcript, write it to a temp file, and pass
``--resume`` so the CLI can reconstruct the full conversation.
line). When the SDK's ``Stop`` hook fires the caller reads this file, strips
bloat (progress entries, metadata), and uploads the result to bucket storage.
On the next turn the caller downloads the bytes and writes them to disk before
passing ``--resume`` so the CLI can reconstruct the full conversation.
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
@@ -20,6 +20,7 @@ import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from uuid import uuid4
from backend.util import json
@@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client
from backend.util.prompt import CompressResult, compress_context
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
if TYPE_CHECKING:
from .model import ChatMessage
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
@@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset(
)
TranscriptMode = Literal["sdk", "baseline"]
@dataclass
class TranscriptDownload:
"""Result of downloading a transcript with its metadata."""
content: str
message_count: int = 0 # session.messages length when uploaded
uploaded_at: float = 0.0 # epoch timestamp of upload
content: bytes | str
message_count: int = 0
# "sdk" = Claude CLI native, "baseline" = TranscriptBuilder
mode: TranscriptMode = "sdk"
# Workspace storage constants — deterministic path from session_id.
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
# Storage prefix for the CLI's native session JSONL files (for cross-pod --resume).
_CLI_SESSION_STORAGE_PREFIX = "cli-sessions"
@@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _projects_base() -> str:
def projects_base() -> str:
"""Return the resolved path to the CLI's projects directory."""
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
return os.path.realpath(os.path.join(config_dir, "projects"))
@@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
Returns the number of directories removed.
"""
projects_base = _projects_base()
if not os.path.isdir(projects_base):
_pbase = projects_base()
if not os.path.isdir(_pbase):
return 0
now = time.time()
@@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
# Scoped mode: only clean up the one directory for the current session.
if encoded_cwd:
target = Path(projects_base) / encoded_cwd
target = Path(_pbase) / encoded_cwd
if not target.is_dir():
return 0
# Guard: only sweep copilot-generated dirs.
@@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
# Only safe for single-tenant deployments; callers should prefer the
# scoped variant by passing encoded_cwd.
try:
entries = Path(projects_base).iterdir()
entries = Path(_pbase).iterdir()
except OSError as e:
logger.warning("[Transcript] Failed to list projects dir: %s", e)
return 0
@@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
if not transcript_path:
return None
projects_base = _projects_base()
_pbase = projects_base()
real_path = os.path.realpath(transcript_path)
if not real_path.startswith(projects_base + os.sep):
if not real_path.startswith(_pbase + os.sep):
logger.warning(
"[Transcript] transcript_path outside projects base: %s", transcript_path
)
@@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool:
# ---------------------------------------------------------------------------
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript.
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
IDs are sanitized to hex+hyphen to prevent path traversal.
"""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.jsonl",
)
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.meta.json",
)
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
wid, fid, fname = parts
@@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
return f"local://{wid}/{fid}/{fname}"
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects."""
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path for the companion .meta.json file."""
return _build_path_from_parts(
_meta_storage_path_parts(user_id, session_id), backend
)
# ---------------------------------------------------------------------------
# CLI native session file — cross-pod --resume support
# ---------------------------------------------------------------------------
def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
def cli_session_path(sdk_cwd: str, session_id: str) -> str:
"""Expected path of the CLI's native session JSONL file.
The CLI resolves the working directory via ``os.path.realpath``, then
@@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
"""
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
safe_id = _sanitize_id(session_id)
return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl")
return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl")
def _cli_session_storage_path_parts(
@@ -689,235 +659,82 @@ def _cli_session_storage_path_parts(
)
async def upload_cli_session(
user_id: str,
session_id: str,
sdk_cwd: str,
log_prefix: str = "[Transcript]",
) -> None:
"""Upload the CLI's native session JSONL file to remote storage.
Called after each turn so the next turn can restore the file on any pod
(eliminating the pod-affinity requirement for --resume).
The CLI only writes the session file after the turn completes, so this
must run in the finally block, AFTER the SDK stream has finished.
"""
session_file = _cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
projects_base = _projects_base()
if not real_path.startswith(projects_base + os.sep):
logger.warning(
"%s CLI session file outside projects base, skipping upload: %s",
log_prefix,
os.path.basename(real_path),
)
return
try:
raw_bytes = Path(real_path).read_bytes()
except FileNotFoundError:
logger.debug(
"%s CLI session file not found, skipping upload: %s",
log_prefix,
session_file,
)
return
except OSError as e:
logger.warning("%s Failed to read CLI session file: %s", log_prefix, e)
return
# Strip stale thinking blocks and metadata entries (progress, file-history-snapshot,
# queue-operation) from the CLI session before writing it back locally and uploading
# to GCS. Thinking blocks from non-last assistant turns are not needed for --resume
# but can be massive (tens of thousands of tokens each), causing the CLI to auto-compact
# its session when the context window fills up. Stripping keeps the session well below
# the ~200K-token compaction threshold and prevents silent context loss.
try:
raw_text = raw_bytes.decode("utf-8")
stripped_text = strip_for_upload(raw_text)
stripped_bytes = stripped_text.encode("utf-8")
if len(stripped_bytes) < len(raw_bytes):
# Write the stripped version back locally so same-pod turns also benefit.
Path(real_path).write_bytes(stripped_bytes)
logger.info(
"%s Stripped CLI session file: %dB → %dB",
log_prefix,
len(raw_bytes),
len(stripped_bytes),
)
content = stripped_bytes
except Exception as e:
logger.warning(
"%s Failed to strip CLI session file, uploading raw: %s", log_prefix, e
)
content = raw_bytes
storage = await get_workspace_storage()
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
try:
await storage.store(
workspace_id=wid, file_id=fid, filename=fname, content=content
)
logger.info(
"%s Uploaded CLI session file (%dB) for cross-pod --resume",
log_prefix,
len(content),
)
except Exception as e:
logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e)
async def restore_cli_session(
user_id: str,
session_id: str,
sdk_cwd: str,
log_prefix: str = "[Transcript]",
) -> bool:
"""Download and restore the CLI's native session file for --resume.
Returns True if the file was successfully restored and --resume can be
used with the session UUID. Returns False if not available (first turn
or upload failed), in which case the caller should not set --resume.
"""
session_file = _cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
projects_base = _projects_base()
if not real_path.startswith(projects_base + os.sep):
logger.warning(
"%s CLI session restore path outside projects base: %s",
log_prefix,
os.path.basename(session_file),
)
return False
# If the session file already exists locally (same-pod reuse), use it directly.
# Downloading from storage could overwrite a newer local version when a previous
# turn's upload failed: stored content is stale while the local file already
# contains extended history from that turn.
if Path(real_path).exists():
logger.debug(
"%s CLI session file already exists locally — using it for --resume",
log_prefix,
)
return True
storage = await get_workspace_storage()
path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for the CLI session meta file."""
return (
_CLI_SESSION_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.meta.json",
)
try:
content = await storage.retrieve(path)
except FileNotFoundError:
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
return False
except Exception as e:
logger.warning("%s Failed to download CLI session: %s", log_prefix, e)
return False
try:
os.makedirs(os.path.dirname(real_path), exist_ok=True)
Path(real_path).write_bytes(content)
logger.info(
"%s Restored CLI session file (%dB) for --resume",
log_prefix,
len(content),
)
return True
except OSError as e:
logger.warning("%s Failed to write CLI session file: %s", log_prefix, e)
return False
async def upload_transcript(
user_id: str,
session_id: str,
content: str,
content: bytes,
message_count: int = 0,
mode: TranscriptMode = "sdk",
log_prefix: str = "[Transcript]",
skip_strip: bool = False,
) -> None:
"""Strip progress entries and stale thinking blocks, then upload transcript.
"""Upload CLI session content to GCS with companion meta.json.
The transcript represents the FULL active context (atomic).
Each upload REPLACES the previous transcript entirely.
Pure GCS operation — no disk I/O. The caller is responsible for reading
the session file from disk before calling this function.
The executor holds a cluster lock per session, so concurrent uploads for
the same session cannot happen.
Also uploads a companion .meta.json with the message_count watermark so
download_transcript can return it without a separate fetch.
Args:
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
skip_strip: When ``True``, skip the strip + re-validate pass.
Safe for builder-generated content (baseline path) which
never emits progress entries or stale thinking blocks.
Called after each turn so the next turn can restore the file on any pod
(eliminating the pod-affinity requirement for --resume).
"""
if skip_strip:
# Caller guarantees the content is already clean and valid.
stripped = content
else:
# Strip metadata entries and stale thinking blocks in a single parse.
# SDK-built transcripts may have progress entries; strip for safety.
stripped = strip_for_upload(content)
if not skip_strip and not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types = [
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
for line in stripped.strip().split("\n")
]
logger.warning(
"%s Skipping upload — stripped content not valid "
"(types=%s, stripped_len=%d, raw_len=%d)",
log_prefix,
entry_types,
len(stripped),
len(content),
)
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
return
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id)
meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()}
meta_encoded = json.dumps(meta).encode("utf-8")
# Transcript + metadata are independent objects at different keys, so
# write them concurrently. ``return_exceptions`` keeps a metadata
# failure from sinking the transcript write.
transcript_result, metadata_result = await asyncio.gather(
storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
),
storage.store(
workspace_id=mwid,
file_id=mfid,
filename=mfname,
content=meta_encoded,
),
return_exceptions=True,
)
if isinstance(transcript_result, BaseException):
raise transcript_result
if isinstance(metadata_result, BaseException):
# Metadata is best-effort — the gap-fill logic in
# _build_query_message tolerates a missing metadata file.
logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result)
# Write JSONL first, meta second — sequential so a crash between the two
# leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong
# watermark / mode paired with stale or absent content).
# On any failure we roll back the other file so the pair is always absent
# together; download_transcript returns None when either file is missing.
try:
await storage.store(
workspace_id=wid, file_id=fid, filename=fname, content=content
)
except Exception as session_err:
logger.warning(
"%s Failed to upload CLI session file: %s", log_prefix, session_err
)
return
try:
await storage.store(
workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded
)
except Exception as meta_err:
logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err)
# Roll back the JSONL so neither file exists — avoids orphaned JSONL being
# used with wrong mode/watermark defaults on the next restore.
try:
session_path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
)
await storage.delete(session_path)
except Exception as rollback_err:
logger.debug(
"%s Session rollback failed (harmless — download will return None): %s",
log_prefix,
rollback_err,
)
return
logger.info(
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
"%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)",
log_prefix,
len(encoded),
len(content),
message_count,
mode,
)
@@ -926,83 +743,173 @@ async def download_transcript(
session_id: str,
log_prefix: str = "[Transcript]",
) -> TranscriptDownload | None:
"""Download transcript and metadata from bucket storage.
"""Download CLI session from GCS. Returns content + message_count + mode, or None if not found.
Returns a ``TranscriptDownload`` with the JSONL content and the
``message_count`` watermark from the upload, or ``None`` if not found.
Pure GCS operation — no disk I/O. The caller is responsible for writing
content to disk if --resume is needed.
The content and metadata fetches run concurrently since they are
independent objects in the bucket.
Returns a TranscriptDownload with the raw content, message_count watermark,
and mode on success, or None if not available (first turn or upload failed).
"""
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
meta_path = _build_meta_storage_path(user_id, session_id, storage)
path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
)
meta_path = _build_path_from_parts(
_cli_session_meta_path_parts(user_id, session_id), storage
)
content_task = asyncio.create_task(storage.retrieve(path))
meta_task = asyncio.create_task(storage.retrieve(meta_path))
content_result, meta_result = await asyncio.gather(
content_task, meta_task, return_exceptions=True
storage.retrieve(path),
storage.retrieve(meta_path),
return_exceptions=True,
)
if isinstance(content_result, FileNotFoundError):
logger.debug("%s No transcript in storage", log_prefix)
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
return None
if isinstance(content_result, BaseException):
logger.warning(
"%s Failed to download transcript: %s", log_prefix, content_result
"%s Failed to download CLI session: %s", log_prefix, content_result
)
return None
content = content_result.decode("utf-8")
content: bytes = content_result
# Metadata is best-effort — old transcripts won't have it.
# Parse message_count and mode from companion meta best-effort, defaults.
message_count = 0
uploaded_at = 0.0
mode: TranscriptMode = "sdk"
if isinstance(meta_result, FileNotFoundError):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
pass # No meta — old upload; default to "sdk"
elif isinstance(meta_result, BaseException):
logger.debug(
"%s Failed to load transcript metadata: %s", log_prefix, meta_result
)
logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result)
else:
meta = json.loads(meta_result.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
try:
meta_str = meta_result.decode("utf-8")
except UnicodeDecodeError:
logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix)
meta_str = None
if meta_str is not None:
meta = json.loads(meta_str, fallback={})
if isinstance(meta, dict):
raw_count = meta.get("message_count", 0)
message_count = (
raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0
)
raw_mode = meta.get("mode", "sdk")
mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk"
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
return TranscriptDownload(
content=content,
message_count=message_count,
uploaded_at=uploaded_at,
"%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)",
log_prefix,
len(content),
message_count,
mode,
)
return TranscriptDownload(content=content, message_count=message_count, mode=mode)
def detect_gap(
download: TranscriptDownload,
session_messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Return chat-db messages after the transcript watermark (excluding current user turn).
Returns [] if transcript is current, watermark is zero, or the watermark
position doesn't end on an assistant turn (misaligned watermark).
"""
if download.message_count == 0:
return []
wm = download.message_count
total = len(session_messages)
if wm >= total - 1:
return []
# Sanity: position wm-1 should be an assistant turn; misaligned watermark
# means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context.
# In normal operation ``message_count`` is always written after a complete
# user→assistant exchange (never mid-turn), so the last covered position is
# always assistant. This guard fires only on data corruption or message deletion.
if session_messages[wm - 1].role != "assistant":
return []
return list(session_messages[wm : total - 1])
def extract_context_messages(
download: TranscriptDownload | None,
session_messages: "list[ChatMessage]",
) -> "list[ChatMessage]":
"""Return context messages for the current turn: transcript content + gap.
This is the shared context primitive used by both the SDK path
(``use_resume=False`` → ``<conversation_history>`` injection) and the
baseline path (OpenAI messages array).
How it works:
- When a transcript exists, ``TranscriptBuilder.load_previous`` preserves
``isCompactSummary=True`` compaction entries, so the returned messages
mirror the compacted context the CLI would see via ``--resume``.
- The gap (DB messages after the transcript watermark) is always small in
normal operation; it only grows during mode switches or when an upload
was missed.
- Falls back to full DB messages when no transcript exists (first turn,
upload failure, or GCS unavailable).
- Returns *prior* messages only (excluding the current user turn at
``session_messages[-1]``). Callers that need the current turn append
``session_messages[-1]`` themselves.
- **Tool calls from transcript entries are flattened to text**: assistant
messages derived from the JSONL use ``_flatten_assistant_content``, which
serialises ``tool_use`` blocks as human-readable text rather than
structured ``tool_calls``. Gap messages (from DB) preserve their
original ``tool_calls`` field. This is the same trade-off as the old
``_compress_session_messages(session.messages)`` approach — no regression.
Args:
download: The ``TranscriptDownload`` from GCS, or ``None`` when no
transcript is available. ``content`` may be either ``bytes`` or
``str`` (the baseline path decodes + strips before returning).
session_messages: All messages in the session, with the current user
turn as the last element.
Returns:
A list of ``ChatMessage`` objects covering the prior conversation
context, suitable for injection as conversation history.
"""
from .model import ChatMessage as _ChatMessage # runtime import
prior = session_messages[:-1]
if download is None:
return prior
raw_content = download.content
if not raw_content:
return prior
# Handle both bytes (raw GCS download) and str (pre-decoded baseline path).
if isinstance(raw_content, bytes):
try:
content_str: str = raw_content.decode("utf-8")
except UnicodeDecodeError:
return prior
else:
content_str = raw_content
raw = _transcript_to_messages(content_str)
if not raw:
return prior
transcript_msgs = [
_ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw
]
gap = detect_gap(download, session_messages)
return transcript_msgs + gap
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript and its metadata from bucket storage.
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
"""
"""Delete CLI session JSONL and its companion .meta.json from bucket storage."""
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
await storage.delete(path)
logger.info("[Transcript] Deleted transcript for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete transcript: %s", e)
# Also delete the companion .meta.json to avoid orphaned metadata.
try:
meta_path = _build_meta_storage_path(user_id, session_id, storage)
await storage.delete(meta_path)
logger.info("[Transcript] Deleted metadata for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete metadata: %s", e)
# Also delete the CLI native session file to prevent storage growth.
try:
cli_path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
@@ -1012,6 +919,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
except Exception as e:
logger.warning("[Transcript] Failed to delete CLI session: %s", e)
try:
cli_meta_path = _build_path_from_parts(
_cli_session_meta_path_parts(user_id, session_id), storage
)
await storage.delete(cli_meta_path)
logger.info("[Transcript] Deleted CLI session meta for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete CLI session meta: %s", e)
# ---------------------------------------------------------------------------
# Transcript compaction — LLM summarization for prompt-too-long recovery

File diff suppressed because it is too large Load Diff

View File

@@ -143,6 +143,8 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,
LlmModel.GROK_4_20: 5,
LlmModel.GROK_4_20_MULTI_AGENT: 5,
LlmModel.GROK_CODE_FAST_1: 1,
LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1,

View File

@@ -1,10 +1,13 @@
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast
import stripe
from fastapi.concurrency import run_in_threadpool
from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
@@ -31,6 +34,7 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.util.cache import cached
from backend.util.exceptions import InsufficientBalanceError
from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled
from backend.util.json import SafeJson, dumps
@@ -432,7 +436,7 @@ class UserCreditBase(ABC):
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
)
# Single unified atomic operation for all transaction types using UserBalance
@@ -571,7 +575,7 @@ class UserCreditBase(ABC):
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
user_id=user_id,
balance=current_balance,
amount=amount,
@@ -582,7 +586,6 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
@@ -734,7 +737,7 @@ class UserCredit(UserCreditBase):
)
if request.amount <= 0 or request.amount > transaction.amount:
raise AssertionError(
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
)
balance, _ = await self._add_transaction(
@@ -788,12 +791,12 @@ class UserCredit(UserCreditBase):
# If the user has enough balance, just let them win the dispute.
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
dispute.close()
return
logger.warning(
f"Adding extra info for dispute from {user_id} for ${amount/100}"
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
)
# Retrieve recent transaction history to support our evidence.
# This provides a concise timeline that shows service usage and proper credit application.
@@ -1237,14 +1240,23 @@ async def get_stripe_customer_id(user_id: str) -> str:
if user.stripe_customer_id:
return user.stripe_customer_id
customer = stripe.Customer.create(
# Race protection: two concurrent calls (e.g. user double-clicks "Upgrade",
# or any retried request) would each pass the check above and create their
# own Stripe Customer, leaving an orphaned billable customer in Stripe.
# Pass an idempotency_key so Stripe collapses concurrent + retried calls
# into the same Customer object server-side. The 24h Stripe idempotency
# window comfortably covers any realistic in-flight retry scenario.
customer = await run_in_threadpool(
stripe.Customer.create,
name=user.name or "",
email=user.email,
metadata={"user_id": user_id},
idempotency_key=f"customer-create-{user_id}",
)
await User.prisma().update(
where={"id": user_id}, data={"stripeCustomerId": customer.id}
)
get_user_by_id.cache_delete(user_id)
return customer.id
@@ -1263,23 +1275,203 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
data={"subscriptionTier": tier},
)
get_user_by_id.cache_delete(user_id)
# Also invalidate the rate-limit tier cache so CoPilot picks up the new
# tier immediately rather than waiting up to 5 minutes for the TTL to expire.
from backend.copilot.rate_limit import get_user_tier # local import avoids circular
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
async def cancel_stripe_subscription(user_id: str) -> None:
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
customer_id = await get_stripe_customer_id(user_id)
subscriptions = stripe.Subscription.list(
customer=customer_id, status="active", limit=10
)
for sub in subscriptions.auto_paging_iter():
try:
stripe.Subscription.cancel(sub["id"])
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: failed to cancel sub %s for user %s",
sub["id"],
user_id,
async def _cancel_customer_subscriptions(
customer_id: str,
exclude_sub_id: str | None = None,
at_period_end: bool = False,
) -> int:
"""Cancel all billable Stripe subscriptions for a customer, optionally excluding one.
Cancels both ``active`` and ``trialing`` subscriptions, since trialing subs will
start billing once the trial ends and must be cleaned up on downgrade/upgrade to
avoid double-charging or charging users who intended to cancel.
When ``at_period_end=True``, schedules cancellation at the end of the current
billing period instead of cancelling immediately — the user keeps their tier
until the period ends, then ``customer.subscription.deleted`` fires and the
webhook downgrades them to FREE.
Wraps every synchronous Stripe SDK call with run_in_threadpool so the async event
loop is never blocked. Raises stripe.StripeError on list/cancel failure so callers
that need strict consistency can react; cleanup callers can catch and log instead.
Returns the number of subscriptions cancelled/scheduled for cancellation.
"""
# Query active and trialing separately; Stripe's list API accepts a single status
# filter at a time (no OR), and we explicitly want to skip canceled/incomplete/
# past_due subs rather than filter them out client-side via status="all".
seen_ids: set[str] = set()
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status=status, limit=10
)
# Iterate only the first page (up to 10); avoid auto_paging_iter which would
# trigger additional sync HTTP calls inside the event loop.
if subscriptions.has_more:
logger.error(
"_cancel_customer_subscriptions: customer %s has more than 10 %s"
" subscriptions — only the first page was processed; remaining"
" subscriptions were NOT cancelled",
customer_id,
status,
)
for sub in subscriptions.data:
sub_id = sub["id"]
if exclude_sub_id and sub_id == exclude_sub_id:
continue
if sub_id in seen_ids:
continue
seen_ids.add(sub_id)
if at_period_end:
await run_in_threadpool(
stripe.Subscription.modify, sub_id, cancel_at_period_end=True
)
else:
await run_in_threadpool(stripe.Subscription.cancel, sub_id)
return len(seen_ids)
async def cancel_stripe_subscription(user_id: str) -> bool:
"""Schedule cancellation of all active/trialing Stripe subscriptions at period end.
The subscription stays active until the end of the billing period so the user
keeps their tier for the time they already paid for. The ``customer.subscription.deleted``
webhook fires at period end and downgrades the DB tier to FREE.
Returns True if at least one subscription was found and scheduled for cancellation,
False if the customer had no active/trialing subscriptions (e.g., admin-granted tier
with no associated Stripe subscription). When False, the caller should update the
DB tier directly since no webhook will fire to do it.
Raises stripe.StripeError if any modification fails, so the caller can avoid
updating the DB tier when Stripe is inconsistent.
"""
# Guard: only proceed if the user already has a Stripe customer ID. Calling
# get_stripe_customer_id for a user who has never had a paid subscription would
# create an orphaned, potentially-billable Stripe Customer object — we avoid that
# by returning False early so the caller can downgrade the DB tier directly.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
customer_id = user.stripe_customer_id
try:
cancelled_count = await _cancel_customer_subscriptions(
customer_id, at_period_end=True
)
return cancelled_count > 0
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: Stripe error while cancelling subs for user %s",
user_id,
)
raise
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
Fetches the user's active Stripe subscription to determine how many seconds
remain in the current billing period, then calculates the unused portion of
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
is found.
"""
if monthly_cost_cents <= 0:
return 0
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
# orphaned customer on every billing-page load for those users.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return 0
try:
customer_id = user.stripe_customer_id
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status="active", limit=1
)
if not subscriptions.data:
return 0
sub = subscriptions.data[0]
period_start: int = sub["current_period_start"]
period_end: int = sub["current_period_end"]
now = int(time.time())
total_seconds = period_end - period_start
remaining_seconds = max(period_end - now, 0)
if total_seconds <= 0:
return 0
return int(monthly_cost_cents * remaining_seconds / total_seconds)
except Exception:
logger.warning(
"get_proration_credit_cents: failed to compute proration for user %s",
user_id,
)
return 0
async def modify_stripe_subscription_for_tier(
user_id: str, tier: SubscriptionTier
) -> bool:
"""Modify an existing Stripe subscription to a new paid tier using proration.
For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing
subscription is preferable to cancelling + creating a new one via Checkout:
Stripe handles proration automatically, crediting unused time on the old plan
and charging the pro-rated amount for the new plan in the same billing cycle.
Returns:
True — a subscription was found and modified successfully.
False — no active/trialing subscription exists (e.g. admin-granted tier or
first-time paid signup); caller should fall back to Checkout.
Raises stripe.StripeError on API failures so callers can propagate a 502.
Raises ValueError when no Stripe price ID is configured for the tier.
"""
price_id = await get_subscription_price_id(tier)
if not price_id:
raise ValueError(f"No Stripe price ID configured for tier {tier}")
# Guard: only proceed if the user already has a Stripe customer ID. Calling
# get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier)
# would create an orphaned customer object if the subsequent Subscription.list call
# fails. Return False early so the API layer falls back to Checkout instead.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
customer_id = user.stripe_customer_id
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status=status, limit=1
)
if not subscriptions.data:
continue
sub = subscriptions.data[0]
sub_id = sub["id"]
items = sub.get("items", {}).get("data", [])
if not items:
continue
item_id = items[0]["id"]
await run_in_threadpool(
stripe.Subscription.modify,
sub_id,
items=[{"id": item_id, "price": price_id}],
proration_behavior="create_prorations",
)
logger.info(
"modify_stripe_subscription_for_tier: modified sub %s for user %s%s",
sub_id,
user_id,
tier,
)
return True
return False
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
@@ -1291,8 +1483,19 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
return AutoTopUpConfig.model_validate(user.top_up_config)
@cached(ttl_seconds=60, maxsize=8, cache_none=False)
async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
"""Return Stripe Price ID for a tier from LaunchDarkly. None = not configured."""
"""Return Stripe Price ID for a tier from LaunchDarkly, cached for 60 seconds.
Price IDs are LaunchDarkly flag values that change only at deploy time.
Caching for 60 seconds avoids hitting the LD SDK on every webhook delivery
and every GET /credits/subscription page load (called 2x per request).
``cache_none=False`` prevents a transient LD failure from caching ``None``
and blocking subscription upgrades for the full 60-second TTL window.
A tier with no configured flag (FREE, ENTERPRISE) returns ``None`` from an
O(1) dict lookup before hitting LD, so the extra LD call is never made.
"""
flag_map = {
SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO,
SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS,
@@ -1300,7 +1503,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
flag = flag_map.get(tier)
if flag is None:
return None
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
return price_id if isinstance(price_id, str) and price_id else None
@@ -1315,7 +1518,8 @@ async def create_subscription_checkout(
if not price_id:
raise ValueError(f"Subscription not available for tier {tier.value}")
customer_id = await get_stripe_customer_id(user_id)
session = stripe.checkout.Session.create(
session = await run_in_threadpool(
stripe.checkout.Session.create,
customer=customer_id,
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
@@ -1323,26 +1527,111 @@ async def create_subscription_checkout(
cancel_url=cancel_url,
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
)
return session.url or ""
if not session.url:
# An empty checkout URL for a paid upgrade is always an error; surfacing it
# as ValueError means the API handler returns 422 instead of silently
# redirecting the client to an empty URL.
raise ValueError("Stripe did not return a checkout session URL")
return session.url
async def _cleanup_stale_subscriptions(customer_id: str, new_sub_id: str) -> None:
"""Best-effort cancel of any active subs for the customer other than new_sub_id.
Called from the webhook handler after a new subscription becomes active. Failures
are logged but not raised so a transient Stripe error doesn't crash the webhook —
a periodic reconciliation job is the intended backstop for persistent drift.
NOTE: until that reconcile job lands, a failure here means the user is silently
billed for two simultaneous subscriptions. The error log below is intentionally
`logger.exception` so it surfaces in Sentry with the customer/sub IDs needed to
manually reconcile, and the metric `stripe_stale_subscription_cleanup_failed`
is bumped so on-call can alert on persistent drift.
TODO(#stripe-reconcile-job): replace this best-effort cleanup with a periodic
reconciliation job that queries Stripe for customers with >1 active sub.
"""
try:
await _cancel_customer_subscriptions(customer_id, exclude_sub_id=new_sub_id)
except stripe.StripeError:
# Use exception() (not warning) so this surfaces as an error in Sentry —
# any failure here means a paid-to-paid upgrade may have left the user
# with two simultaneous active subscriptions.
logger.exception(
"stripe_stale_subscription_cleanup_failed: customer=%s new_sub=%s"
" user may be billed for two simultaneous subscriptions; manual"
" reconciliation required",
customer_id,
new_sub_id,
)
async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
"""Update User.subscriptionTier from a Stripe subscription object."""
customer_id = stripe_subscription["customer"]
"""Update User.subscriptionTier from a Stripe subscription object.
Expected shape of stripe_subscription (subset of Stripe's Subscription object):
customer: str — Stripe customer ID
status: str — "active" | "trialing" | "canceled" | ...
id: str — Stripe subscription ID
items.data[].price.id: str — Stripe price ID identifying the tier
"""
customer_id = stripe_subscription.get("customer")
if not customer_id:
logger.warning(
"sync_subscription_from_stripe: missing 'customer' field in event, "
"skipping (keys: %s)",
list(stripe_subscription.keys()),
)
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"sync_subscription_from_stripe: no user for customer %s", customer_id
)
return
# Cross-check: if the subscription carries a metadata.user_id (set during
# Checkout Session creation), verify it matches the user we found via
# stripeCustomerId. A mismatch indicates a customer↔user mapping
# inconsistency — updating the wrong user's tier would be a data-corruption
# bug, so we log loudly and bail out. Absence of metadata.user_id (e.g.
# subscriptions created outside the Checkout flow) is not an error — we
# simply skip the check and proceed with the customer-ID-based lookup.
metadata = stripe_subscription.get("metadata") or {}
metadata_user_id = metadata.get("user_id") if isinstance(metadata, dict) else None
if metadata_user_id and metadata_user_id != user.id:
logger.error(
"sync_subscription_from_stripe: metadata.user_id=%s does not match"
" user.id=%s found via stripeCustomerId=%s — refusing to update tier"
" to avoid corrupting the wrong user's subscription state",
metadata_user_id,
user.id,
customer_id,
)
return
# ENTERPRISE tiers are admin-managed. Never let a Stripe webhook flip an
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
# a self-service Stripe sub, it's a data-consistency issue for an operator,
# not something the webhook should automatically "fix".
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
" for user %s (customer %s); event status=%s",
user.id,
customer_id,
stripe_subscription.get("status", ""),
)
return
status = stripe_subscription.get("status", "")
new_sub_id = stripe_subscription.get("id", "")
if status in ("active", "trialing"):
price_id = ""
items = stripe_subscription.get("items", {}).get("data", [])
if items:
price_id = items[0].get("price", {}).get("id", "")
pro_price = await get_subscription_price_id(SubscriptionTier.PRO)
biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS)
pro_price, biz_price = await asyncio.gather(
get_subscription_price_id(SubscriptionTier.PRO),
get_subscription_price_id(SubscriptionTier.BUSINESS),
)
if price_id and pro_price and price_id == pro_price:
tier = SubscriptionTier.PRO
elif price_id and biz_price and price_id == biz_price:
@@ -1359,10 +1648,206 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
)
return
else:
# A subscription was cancelled or ended. DO NOT unconditionally downgrade
# to FREE — Stripe does not guarantee webhook delivery order, so a
# `customer.subscription.deleted` for the OLD sub can arrive after we've
# already processed `customer.subscription.created` for a new paid sub.
# Ask Stripe whether any OTHER active/trialing subs exist for this
# customer; if they do, keep the user's current tier (the other sub's
# own event will/has already set the correct tier).
try:
other_subs_active, other_subs_trialing = await asyncio.gather(
run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status="active",
limit=10,
),
run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status="trialing",
limit=10,
),
)
except stripe.StripeError:
logger.warning(
"sync_subscription_from_stripe: could not verify other active"
" subs for customer %s on cancel event %s; preserving current"
" tier to avoid an unsafe downgrade",
customer_id,
new_sub_id,
)
return
# Filter out the cancelled subscription to check if other active subs
# exist. When new_sub_id is empty (malformed event with no 'id' field),
# we cannot safely exclude any sub — preserve current tier to avoid
# an unsafe downgrade on a malformed webhook payload.
if not new_sub_id:
logger.warning(
"sync_subscription_from_stripe: cancel event missing 'id' field"
" for customer %s; preserving current tier",
customer_id,
)
return
other_active_ids = {sub["id"] for sub in other_subs_active.data} - {new_sub_id}
other_trialing_ids = {sub["id"] for sub in other_subs_trialing.data} - {
new_sub_id
}
still_has_active_sub = bool(other_active_ids or other_trialing_ids)
if still_has_active_sub:
logger.info(
"sync_subscription_from_stripe: sub %s cancelled but customer %s"
" still has another active sub; keeping tier %s",
new_sub_id,
customer_id,
current_tier.value,
)
return
tier = SubscriptionTier.FREE
# Idempotency: Stripe retries webhooks on delivery failure, and several event
# types map to the same final tier. Skip the DB write + cache invalidation
# when the tier is already correct to avoid redundant writes on replay.
if current_tier == tier:
return
# When a new subscription becomes active (e.g. paid-to-paid tier upgrade
# via a fresh Checkout Session), cancel any OTHER active subscriptions for
# the same customer so the user isn't billed twice. We do this in the
# webhook rather than the API handler so that abandoning the checkout
# doesn't leave the user without a subscription.
# IMPORTANT: this runs AFTER the idempotency check above so that webhook
# replays for an already-applied event do NOT trigger another cleanup round
# (which could otherwise cancel a legitimately new subscription the user
# signed up for between the original event and its replay).
if status in ("active", "trialing") and new_sub_id:
# NOTE: paid-to-paid upgrade race (e.g. PRO → BUSINESS):
# _cleanup_stale_subscriptions cancels the old PRO sub before
# set_subscription_tier writes BUSINESS to the DB. If Stripe delivers
# the PRO `customer.subscription.deleted` event concurrently and it
# processes after the PRO cancel but before set_subscription_tier
# commits, the user could momentarily appear as FREE in the DB.
# This window is very short in practice (two sequential awaits),
# but is a known limitation of the current webhook-driven approach.
# A future improvement would be to write the new tier first, then
# cancel the old sub.
await _cleanup_stale_subscriptions(customer_id, new_sub_id)
await set_subscription_tier(user.id, tier)
async def handle_subscription_payment_failure(invoice: dict) -> None:
"""Handle a failed Stripe subscription payment.
Tries to cover the invoice amount from the user's credit balance.
- Balance sufficient → deduct from balance, then pay the Stripe invoice so
Stripe stops retrying it. The sub stays intact and the user keeps their tier.
- Balance insufficient → cancel Stripe sub immediately, downgrade to FREE.
Cancelling here avoids further Stripe retries on an invoice we cannot cover.
"""
customer_id = invoice.get("customer")
if not customer_id:
logger.warning(
"handle_subscription_payment_failure: missing customer in invoice; skipping"
)
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"handle_subscription_payment_failure: no user found for customer %s",
customer_id,
)
return
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
" (customer %s) — tier is admin-managed",
user.id,
customer_id,
)
return
amount_due: int = invoice.get("amount_due", 0)
sub_id: str = invoice.get("subscription", "")
invoice_id: str = invoice.get("id", "")
if amount_due <= 0:
logger.info(
"handle_subscription_payment_failure: amount_due=%d for user %s;"
" nothing to deduct",
amount_due,
user.id,
)
return
credit_model = UserCredit()
try:
await credit_model._add_transaction(
user_id=user.id,
amount=-amount_due,
transaction_type=CreditTransactionType.SUBSCRIPTION,
fail_insufficient_credits=True,
# Use invoice_id as the idempotency key so that Stripe webhook retries
# (e.g. on a transient stripe.Invoice.pay failure) do not double-charge.
transaction_key=invoice_id or None,
metadata=SafeJson(
{
"stripe_customer_id": customer_id,
"stripe_subscription_id": sub_id,
"reason": "subscription_payment_failure_covered_by_balance",
}
),
)
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
# system stops retrying it — without this call Stripe would retry automatically
# and re-trigger this webhook, causing double-deductions each retry cycle.
if invoice_id:
try:
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: balance deducted for user"
" %s but failed to mark invoice %s as paid; Stripe may retry",
user.id,
invoice_id,
)
logger.info(
"handle_subscription_payment_failure: deducted %d cents from balance"
" for user %s; Stripe invoice %s paid, sub %s intact, tier preserved",
amount_due,
user.id,
invoice_id,
sub_id,
)
except InsufficientBalanceError:
# Balance insufficient — cancel Stripe subscription first, then downgrade DB.
# Order matters: if we downgrade the DB first and the Stripe cancel fails, the
# user is permanently stuck on FREE while Stripe continues billing them.
# Cancelling Stripe first is safe: if the DB write then fails, the webhook
# customer.subscription.deleted will fire and correct the tier eventually.
logger.info(
"handle_subscription_payment_failure: insufficient balance for user %s;"
" cancelling Stripe sub %s then downgrading to FREE",
user.id,
sub_id,
)
try:
await _cancel_customer_subscriptions(customer_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
" for user %s (customer %s); skipping tier downgrade to avoid"
" inconsistency — Stripe may continue retrying the invoice",
sub_id,
user.id,
customer_id,
)
return
await set_subscription_tier(user.id, SubscriptionTier.FREE)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,

View File

@@ -73,6 +73,31 @@ def _get_redis() -> Redis:
return r
class _MissingType:
"""Singleton sentinel type — distinct from ``None`` (a valid cached value).
Using a dedicated class (instead of ``Any = object()``) lets mypy prove
that comparisons ``result is _MISSING`` narrow the type correctly and
prevents accidental use of the sentinel where a real value is expected.
"""
_instance: "_MissingType | None" = None
def __new__(cls) -> "_MissingType":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "<MISSING>"
# Sentinel returned by ``_get_from_memory`` / ``_get_from_redis`` to mean
# "no entry exists" — distinct from a cached ``None`` value, which is a
# valid result for callers that opt into caching it.
_MISSING = _MissingType()
@dataclass
class CachedValue:
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
@@ -160,6 +185,7 @@ def cached(
ttl_seconds: int,
shared_cache: bool = False,
refresh_ttl_on_get: bool = False,
cache_none: bool = True,
) -> Callable[[Callable[P, R]], CachedFunction[P, R]]:
"""
Thundering herd safe cache decorator for both sync and async functions.
@@ -172,6 +198,10 @@ def cached(
ttl_seconds: Time to live in seconds. Required - entries must expire.
shared_cache: If True, use Redis for cross-process caching
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
cache_none: If True (default) ``None`` is cached like any other value.
Set to ``False`` for functions that return ``None`` to signal a
transient error and should be re-tried on the next call without
poisoning the cache (e.g. external API calls that may fail).
Returns:
Decorated function with caching capabilities
@@ -184,6 +214,12 @@ def cached(
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
@cached(ttl_seconds=300, cache_none=False)
async def fetch_external(id: str) -> dict | None:
# Returns None on transient error — won't be stored,
# next call retries instead of returning the stale None.
...
"""
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
@@ -191,9 +227,14 @@ def cached(
cache_storage: dict[tuple, CachedValue] = {}
_event_loop_locks: dict[Any, asyncio.Lock] = {}
def _get_from_redis(redis_key: str) -> Any | None:
def _get_from_redis(redis_key: str) -> Any:
"""Get value from Redis, optionally refreshing TTL.
Returns the cached value (which may be ``None``) on a hit, or the
module-level ``_MISSING`` sentinel on a miss / corrupt entry.
Callers must compare with ``is _MISSING`` so cached ``None`` values
are not mistaken for misses.
Values are expected to carry an HMAC-SHA256 prefix for integrity
verification. Unsigned (legacy) or tampered entries are silently
discarded and treated as cache misses, so the caller recomputes and
@@ -213,11 +254,11 @@ def cached(
f"for {func_name}, discarding entry: "
"possible tampering or legacy unsigned value"
)
return None
return _MISSING
return pickle.loads(payload)
except Exception as e:
logger.error(f"Redis error during cache check for {func_name}: {e}")
return None
return _MISSING
def _set_to_redis(redis_key: str, value: Any) -> None:
"""Set HMAC-signed pickled value in Redis with TTL."""
@@ -227,8 +268,13 @@ def cached(
except Exception as e:
logger.error(f"Redis error storing cache for {func_name}: {e}")
def _get_from_memory(key: tuple) -> Any | None:
"""Get value from in-memory cache, checking TTL."""
def _get_from_memory(key: tuple) -> Any:
"""Get value from in-memory cache, checking TTL.
Returns the cached value (which may be ``None``) on a hit, or the
``_MISSING`` sentinel on a miss / TTL expiry. See
``_get_from_redis`` for the rationale.
"""
if key in cache_storage:
cached_data = cache_storage[key]
if time.time() - cached_data.timestamp < ttl_seconds:
@@ -236,7 +282,7 @@ def cached(
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
)
return cached_data.result
return None
return _MISSING
def _set_to_memory(key: tuple, value: Any) -> None:
"""Set value in in-memory cache with timestamp."""
@@ -270,11 +316,11 @@ def cached(
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Slow path: acquire lock for cache miss/expiry
@@ -282,22 +328,24 @@ def cached(
# Double-check: another coroutine might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {func_name}")
result = await target_func(*args, **kwargs)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
# Store result (skip ``None`` if the caller opted out of
# caching it — used for transient-error sentinels).
if cache_none or result is not None:
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result
@@ -315,11 +363,11 @@ def cached(
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Slow path: acquire lock for cache miss/expiry
@@ -327,22 +375,24 @@ def cached(
# Double-check: another thread might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {func_name}")
result = target_func(*args, **kwargs)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
# Store result (skip ``None`` if the caller opted out of
# caching it — used for transient-error sentinels).
if cache_none or result is not None:
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result

View File

@@ -1223,3 +1223,123 @@ class TestCacheHMAC:
assert call_count == 2
legacy_test_fn.cache_clear()
class TestCacheNoneHandling:
"""Tests for the ``cache_none`` parameter on the @cached decorator.
Sentry bug PRRT_kwDOJKSTjM56RTEu (HIGH): the cache previously could not
distinguish "no entry" from "entry is None", so any function returning
``None`` was effectively re-executed on every call. The fix is a
sentinel-based check inside the wrappers, plus an opt-out
``cache_none=False`` flag for callers that *want* errors to retry.
"""
@pytest.mark.asyncio
async def test_async_none_is_cached_by_default(self):
"""With ``cache_none=True`` (default), cached ``None`` is returned
from the cache instead of triggering re-execution."""
call_count = 0
@cached(ttl_seconds=300)
async def maybe_none(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
assert await maybe_none(1) is None
assert call_count == 1
# Second call should hit the cache, not re-execute.
assert await maybe_none(1) is None
assert call_count == 1
# Different argument is a different cache key — re-executes.
assert await maybe_none(2) is None
assert call_count == 2
def test_sync_none_is_cached_by_default(self):
call_count = 0
@cached(ttl_seconds=300)
def maybe_none(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
assert maybe_none(1) is None
assert maybe_none(1) is None
assert call_count == 1
@pytest.mark.asyncio
async def test_async_cache_none_false_skips_storing_none(self):
"""``cache_none=False`` skips storing ``None`` so transient errors
are retried on the next call instead of poisoning the cache."""
call_count = 0
results: list[int | None] = [None, None, 42]
@cached(ttl_seconds=300, cache_none=False)
async def maybe_none(x: int) -> int | None:
nonlocal call_count
result = results[call_count]
call_count += 1
return result
# First call: returns None, NOT stored.
assert await maybe_none(1) is None
assert call_count == 1
# Second call with same key: re-executes (None wasn't cached).
assert await maybe_none(1) is None
assert call_count == 2
# Third call: returns 42, this time it IS stored.
assert await maybe_none(1) == 42
assert call_count == 3
# Fourth call: cache hit on the stored 42.
assert await maybe_none(1) == 42
assert call_count == 3
def test_sync_cache_none_false_skips_storing_none(self):
call_count = 0
results: list[int | None] = [None, 99]
@cached(ttl_seconds=300, cache_none=False)
def maybe_none(x: int) -> int | None:
nonlocal call_count
result = results[call_count]
call_count += 1
return result
assert maybe_none(1) is None
assert call_count == 1
# None was not stored — re-executes.
assert maybe_none(1) == 99
assert call_count == 2
# 99 IS stored — no re-execution.
assert maybe_none(1) == 99
assert call_count == 2
@pytest.mark.asyncio
async def test_async_shared_cache_none_is_cached_by_default(self):
"""Shared (Redis) cache also properly returns cached ``None`` values."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
async def maybe_none_redis(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
maybe_none_redis.cache_clear()
assert await maybe_none_redis(1) is None
assert call_count == 1
assert await maybe_none_redis(1) is None
assert call_count == 1
maybe_none_redis.cache_clear()

View File

@@ -1,6 +1,7 @@
import contextlib
import logging
import os
import uuid
from enum import Enum
from functools import wraps
from typing import Any, Awaitable, Callable, TypeVar
@@ -101,6 +102,12 @@ async def _fetch_user_context_data(user_id: str) -> Context:
"""
builder = Context.builder(user_id).kind("user").anonymous(True)
try:
uuid.UUID(user_id)
except ValueError:
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
return builder.build()
try:
from backend.util.clients import get_supabase

View File

@@ -88,17 +88,19 @@ async def cmd_download(session_ids: list[str]) -> None:
print(f"[{sid[:12]}] Not found in GCS")
continue
content_str = (
dl.content.decode("utf-8") if isinstance(dl.content, bytes) else dl.content
)
out = _transcript_path(sid)
with open(out, "w") as f:
f.write(dl.content)
f.write(content_str)
lines = len(dl.content.strip().split("\n"))
lines = len(content_str.strip().split("\n"))
meta = {
"session_id": sid,
"user_id": user_id,
"message_count": dl.message_count,
"uploaded_at": dl.uploaded_at,
"transcript_bytes": len(dl.content),
"transcript_bytes": len(content_str),
"transcript_lines": lines,
}
with open(_meta_path(sid), "w") as f:
@@ -106,7 +108,7 @@ async def cmd_download(session_ids: list[str]) -> None:
print(
f"[{sid[:12]}] Saved: {lines} entries, "
f"{len(dl.content)} bytes, msg_count={dl.message_count}"
f"{len(content_str)} bytes, msg_count={dl.message_count}"
)
print("\nDone. Run 'load' command to import into local dev environment.")
@@ -227,7 +229,7 @@ async def cmd_load(session_ids: list[str]) -> None:
await upload_transcript(
user_id=user_id,
session_id=sid,
content=content,
content=content.encode("utf-8"),
message_count=msg_count,
)
print(f"[{sid[:12]}] Stored transcript in local workspace storage")

View File

@@ -0,0 +1,140 @@
"""Unit tests for the transcript watermark (message_count) fix.
The bug: upload used message_count=len(session.messages) (DB count). When a
prior turn's GCS upload failed silently, the JSONL on GCS was stale (e.g.
covered only T1-T12) but the meta.json watermark matched the full DB count
(e.g. 46). The next turn's gap-fill check (transcript_msg_count < msg_count-1)
never triggered, so the model silently lost context for the skipped turns.
The fix: watermark = previous_coverage + 2 (current user+asst pair) when
use_resume=True and transcript_msg_count > 0. This ensures the watermark
reflects the JSONL content, not the DB count.
These tests exercise _build_query_message directly to verify that gap-fill
triggers with the corrected watermark but NOT with the inflated (buggy) one.
"""
from unittest.mock import MagicMock
import pytest
from backend.copilot.sdk.service import _build_query_message
def _make_messages(n_pairs: int, *, current_user: str = "current") -> list[MagicMock]:
"""Build a flat list of n_pairs*2 alternating user/asst messages, plus
one trailing user message for the *current* turn."""
msgs: list[MagicMock] = []
for i in range(n_pairs):
u = MagicMock()
u.role = "user"
u.content = f"user message {i}"
a = MagicMock()
a.role = "assistant"
a.content = f"assistant response {i}"
msgs.extend([u, a])
# Current turn's user message
cur = MagicMock()
cur.role = "user"
cur.content = current_user
msgs.append(cur)
return msgs
def _make_session(messages: list[MagicMock]) -> MagicMock:
session = MagicMock()
session.messages = messages
return session
@pytest.mark.asyncio
async def test_gap_fill_triggers_for_stale_jsonl():
"""Scenario: T1-T12 in JSONL (watermark=24), DB has T1-T22+Test (46 msgs).
With the FIX: 'Test' uploaded watermark=26 (T12's 24 + 2 for 'Test').
Next turn (T24) downloads watermark=26, DB has 47.
Gap check: 26 < 47-1=46 → TRUE → gap fills T14-T23.
"""
# T23 turns in DB (46 messages) + T24 user = 47
msgs = _make_messages(23, current_user="memory test - recall all")
assert len(msgs) == 47
session = _make_session(msgs)
# Watermark as uploaded by the FIX: T12 covered 24, 'Test' +2 = 26
result_msg, _ = await _build_query_message(
current_message="memory test - recall all",
session=session,
use_resume=True,
transcript_msg_count=26,
session_id="test-session-id",
)
assert "<conversation_history>" in result_msg, (
"Expected gap-fill to inject <conversation_history> when "
"watermark=26 < msg_count-1=46"
)
@pytest.mark.asyncio
async def test_no_gap_fill_when_watermark_is_current():
"""When the JSONL is fully current (watermark = DB-1), no gap injected."""
# T23 turns in DB (46 messages) + T24 user = 47
msgs = _make_messages(23, current_user="next message")
session = _make_session(msgs)
result_msg, _ = await _build_query_message(
current_message="next message",
session=session,
use_resume=True,
transcript_msg_count=46, # current — no gap
session_id="test-session-id",
)
assert (
"<conversation_history>" not in result_msg
), "No gap-fill expected when watermark is current"
assert result_msg == "next message"
@pytest.mark.asyncio
async def test_inflated_watermark_suppresses_gap_fill():
"""Documents the original bug: inflated watermark suppresses gap-fill.
'Test' uploaded watermark=len(session.messages)=46 even though only 26
messages are in the JSONL. Next turn: 46 < 47-1=46 → FALSE → no gap fill.
"""
msgs = _make_messages(23, current_user="memory test")
session = _make_session(msgs)
# Buggy watermark: inflated to DB count
result_msg, _ = await _build_query_message(
current_message="memory test",
session=session,
use_resume=True,
transcript_msg_count=46, # inflated — suppresses gap fill
session_id="test-session-id",
)
assert (
"<conversation_history>" not in result_msg
), "With inflated watermark, gap-fill is suppressed — this documents the bug"
@pytest.mark.asyncio
async def test_fixed_watermark_fills_same_gap():
"""Same scenario but with the FIXED watermark triggers gap-fill."""
msgs = _make_messages(23, current_user="memory test")
session = _make_session(msgs)
result_msg, _ = await _build_query_message(
current_message="memory test",
session=session,
use_resume=True,
transcript_msg_count=26, # fixed watermark
session_id="test-session-id",
)
assert (
"<conversation_history>" in result_msg
), "With fixed watermark=26, gap-fill triggers and injects missing turns"

View File

@@ -155,6 +155,7 @@
"@types/twemoji": "13.1.2",
"@vitejs/plugin-react": "5.1.2",
"@vitest/coverage-v8": "4.0.17",
"agentation": "3.0.2",
"axe-playwright": "2.2.2",
"chromatic": "13.3.3",
"concurrently": "9.2.1",

View File

@@ -376,6 +376,9 @@ importers:
'@vitest/coverage-v8':
specifier: 4.0.17
version: 4.0.17(vitest@4.0.17(@opentelemetry/api@1.9.0)(@types/node@24.10.0)(happy-dom@20.3.4)(jiti@2.6.1)(jsdom@27.4.0)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(terser@5.44.1)(yaml@2.8.2))
agentation:
specifier: 3.0.2
version: 3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
axe-playwright:
specifier: 2.2.2
version: 2.2.2(playwright@1.56.1)
@@ -4119,6 +4122,17 @@ packages:
resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==}
engines: {node: '>= 14'}
agentation@3.0.2:
resolution: {integrity: sha512-iGzBxFVTuZEIKzLY6AExSLAQH6i6SwxV4pAu7v7m3X6bInZ7qlZXAwrEqyc4+EfP4gM7z2RXBF6SF4DeH0f2lA==}
peerDependencies:
react: '>=18.0.0'
react-dom: '>=18.0.0'
peerDependenciesMeta:
react:
optional: true
react-dom:
optional: true
ai@6.0.134:
resolution: {integrity: sha512-YalNEaavld/kE444gOcsMKXdVVRGEe0SK77fAFcWYcqLg+a7xKnEet8bdfrEAJTfnMjj01rhgrIL10903w1a5Q==}
engines: {node: '>=18'}
@@ -13119,6 +13133,11 @@ snapshots:
agent-base@7.1.4:
optional: true
agentation@3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
optionalDependencies:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
ai@6.0.134(zod@3.25.76):
dependencies:
'@ai-sdk/gateway': 3.0.77(zod@3.25.76)

View File

@@ -110,7 +110,7 @@ export const Flow = () => {
event.preventDefault();
}}
maxZoom={2}
minZoom={0.1}
minZoom={0.05}
onDragOver={onDragOver}
onDrop={onDrop}
nodesDraggable={!isLocked}

View File

@@ -93,6 +93,7 @@ export function CopilotPage() {
hasMoreMessages,
isLoadingMore,
loadMore,
forwardPaginated,
// Mobile drawer
isMobile,
isDrawerOpen,
@@ -217,6 +218,7 @@ export function CopilotPage() {
hasMoreMessages={hasMoreMessages}
isLoadingMore={isLoadingMore}
onLoadMore={loadMore}
forwardPaginated={forwardPaginated}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
historicalDurations={historicalDurations}

View File

@@ -0,0 +1,122 @@
import { renderHook } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useChatSession } from "../useChatSession";
const mockUseGetV2GetSession = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useGetV2GetSession: (...args: unknown[]) => mockUseGetV2GetSession(...args),
usePostV2CreateSession: () => ({ mutateAsync: vi.fn(), isPending: false }),
getGetV2GetSessionQueryKey: (id: string) => ["session", id],
getGetV2ListSessionsQueryKey: () => ["sessions"],
}));
vi.mock("@tanstack/react-query", () => ({
useQueryClient: () => ({
invalidateQueries: vi.fn(),
setQueryData: vi.fn(),
}),
}));
vi.mock("nuqs", () => ({
parseAsString: { withDefault: (v: unknown) => v },
useQueryState: () => ["sess-1", vi.fn()],
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
convertChatSessionMessagesToUiMessages: vi.fn(() => ({
messages: [],
historicalDurations: new Map(),
})),
}));
vi.mock("../helpers", () => ({
resolveSessionDryRun: vi.fn(() => false),
}));
vi.mock("@sentry/nextjs", () => ({
captureException: vi.fn(),
}));
function makeQueryResult(data: object | null) {
return {
data: data ? { status: 200, data } : undefined,
isLoading: false,
isError: false,
isFetching: false,
refetch: vi.fn(),
};
}
describe("useChatSession — newestSequence and forwardPaginated", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("returns null / false when no session data", () => {
mockUseGetV2GetSession.mockReturnValue(makeQueryResult(null));
const { result } = renderHook(() => useChatSession());
expect(result.current.newestSequence).toBeNull();
expect(result.current.forwardPaginated).toBe(false);
});
it("returns newestSequence from session data", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 0,
newest_sequence: 99,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.newestSequence).toBe(99);
});
it("returns null for newestSequence when field is missing", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: false,
oldest_sequence: 0,
newest_sequence: null,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.newestSequence).toBeNull();
});
it("returns forwardPaginated=true when session is forward-paginated", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 0,
newest_sequence: 49,
forward_paginated: true,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.forwardPaginated).toBe(true);
});
it("returns forwardPaginated=false when session is backward-paginated", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 50,
newest_sequence: 99,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.forwardPaginated).toBe(false);
});
});

View File

@@ -0,0 +1,202 @@
import { act, renderHook, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useCopilotPage } from "../useCopilotPage";
const mockUseChatSession = vi.fn();
const mockUseCopilotStream = vi.fn();
const mockUseLoadMoreMessages = vi.fn();
vi.mock("../useChatSession", () => ({
useChatSession: (...args: unknown[]) => mockUseChatSession(...args),
}));
vi.mock("../useCopilotStream", () => ({
useCopilotStream: (...args: unknown[]) => mockUseCopilotStream(...args),
}));
vi.mock("../useLoadMoreMessages", () => ({
useLoadMoreMessages: (...args: unknown[]) => mockUseLoadMoreMessages(...args),
}));
vi.mock("../useCopilotNotifications", () => ({
useCopilotNotifications: () => undefined,
}));
vi.mock("../useWorkflowImportAutoSubmit", () => ({
useWorkflowImportAutoSubmit: () => undefined,
}));
vi.mock("../store", () => ({
useCopilotUIStore: () => ({
sessionToDelete: null,
setSessionToDelete: vi.fn(),
isDrawerOpen: false,
setDrawerOpen: vi.fn(),
copilotChatMode: "chat",
copilotLlmModel: null,
isDryRun: false,
}),
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
concatWithAssistantMerge: (a: unknown[], b: unknown[]) => [...a, ...b],
}));
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useDeleteV2DeleteSession: () => ({ mutate: vi.fn(), isPending: false }),
useGetV2ListSessions: () => ({ data: undefined, isLoading: false }),
getGetV2ListSessionsQueryKey: () => ["sessions"],
}));
vi.mock("@/components/molecules/Toast/use-toast", () => ({
toast: vi.fn(),
}));
vi.mock("@/lib/direct-upload", () => ({
uploadFileDirect: vi.fn(),
}));
vi.mock("@/lib/hooks/useBreakpoint", () => ({
useBreakpoint: () => "lg",
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: () => ({ isUserLoading: false, isLoggedIn: true }),
}));
vi.mock("@tanstack/react-query", () => ({
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
}));
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
useGetFlag: () => false,
}));
function makeBaseChatSession(overrides: Record<string, unknown> = {}) {
return {
sessionId: "sess-1",
setSessionId: vi.fn(),
hydratedMessages: [],
rawSessionMessages: [],
historicalDurations: new Map(),
hasActiveStream: false,
hasMoreMessages: false,
oldestSequence: null,
newestSequence: null,
forwardPaginated: false,
isLoadingSession: false,
isSessionError: false,
createSession: vi.fn(),
isCreatingSession: false,
refetchSession: vi.fn(),
sessionDryRun: false,
...overrides,
};
}
function makeBaseCopilotStream(overrides: Record<string, unknown> = {}) {
return {
messages: [],
sendMessage: vi.fn(),
stop: vi.fn(),
status: "ready",
error: undefined,
isReconnecting: false,
isSyncing: false,
isUserStoppingRef: { current: false },
rateLimitMessage: null,
dismissRateLimit: vi.fn(),
...overrides,
};
}
function makeBaseLoadMore(overrides: Record<string, unknown> = {}) {
return {
pagedMessages: [],
hasMore: false,
isLoadingMore: false,
loadMore: vi.fn(),
resetPaged: vi.fn(),
...overrides,
};
}
describe("useCopilotPage — forwardPaginated message ordering", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("prepends pagedMessages before currentMessages when forwardPaginated=false", () => {
const pagedMsg = { id: "paged", role: "user" };
const currentMsg = { id: "current", role: "assistant" };
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: false }),
);
mockUseCopilotStream.mockReturnValue(
makeBaseCopilotStream({ messages: [currentMsg] }),
);
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
);
const { result } = renderHook(() => useCopilotPage());
// Backward: pagedMessages (older) come first
expect(result.current.messages[0]).toEqual(pagedMsg);
expect(result.current.messages[1]).toEqual(currentMsg);
});
it("appends pagedMessages after currentMessages when forwardPaginated=true", () => {
const pagedMsg = { id: "paged", role: "assistant" };
const currentMsg = { id: "current", role: "user" };
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: true }),
);
mockUseCopilotStream.mockReturnValue(
makeBaseCopilotStream({ messages: [currentMsg] }),
);
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
);
const { result } = renderHook(() => useCopilotPage());
// Forward: currentMessages (beginning of session) come first
expect(result.current.messages[0]).toEqual(currentMsg);
expect(result.current.messages[1]).toEqual(pagedMsg);
});
it("calls resetPaged when forwardPaginated transitions false→true with paged messages", async () => {
const mockResetPaged = vi.fn();
const pagedMsg = { id: "paged", role: "user" };
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: false }),
);
mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream());
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({
pagedMessages: [pagedMsg],
resetPaged: mockResetPaged,
}),
);
const { rerender } = renderHook(() => useCopilotPage());
// Simulate session completing — forwardPaginated flips to true
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: true }),
);
act(() => {
rerender();
});
await waitFor(() => {
expect(mockResetPaged).toHaveBeenCalled();
});
});
it("does not call resetPaged when forwardPaginated is already true on mount", () => {
const mockResetPaged = vi.fn();
mockUseChatSession.mockReturnValue(
makeBaseChatSession({ forwardPaginated: true }),
);
mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream());
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({ pagedMessages: [], resetPaged: mockResetPaged }),
);
renderHook(() => useCopilotPage());
expect(mockResetPaged).not.toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,568 @@
import { act, renderHook, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useLoadMoreMessages } from "../useLoadMoreMessages";
const mockGetV2GetSession = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
getV2GetSession: (...args: unknown[]) => mockGetV2GetSession(...args),
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
convertChatSessionMessagesToUiMessages: vi.fn(() => ({ messages: [] })),
extractToolOutputsFromRaw: vi.fn(() => []),
}));
const BASE_ARGS = {
sessionId: "sess-1",
initialOldestSequence: 0,
initialNewestSequence: 49,
initialHasMore: true,
forwardPaginated: true,
initialPageRawMessages: [],
};
function makeSuccessResponse(overrides: {
messages?: unknown[];
has_more_messages?: boolean;
oldest_sequence?: number;
newest_sequence?: number;
}) {
return {
status: 200,
data: {
messages: overrides.messages ?? [],
has_more_messages: overrides.has_more_messages ?? false,
oldest_sequence: overrides.oldest_sequence ?? 0,
newest_sequence: overrides.newest_sequence ?? 49,
},
};
}
describe("useLoadMoreMessages", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("initialises with empty pagedMessages and correct cursors", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("resetPaged clears paged state and sets hasMore=false during transition", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
act(() => {
result.current.resetPaged();
});
expect(result.current.pagedMessages).toHaveLength(0);
// hasMore must be false during transition to prevent forward loadMore
// from firing on the now-active session before forwardPaginated updates.
expect(result.current.hasMore).toBe(false);
expect(result.current.isLoadingMore).toBe(false);
});
it("resetPaged exposes a fresh loadMore via incremented epoch", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
// Just verify resetPaged is callable and doesn't throw.
expect(() => {
act(() => {
result.current.resetPaged();
});
}).not.toThrow();
});
it("resets all state on sessionId change", () => {
const { result, rerender } = renderHook(
(props) => useLoadMoreMessages(props),
{ initialProps: BASE_ARGS },
);
rerender({
...BASE_ARGS,
sessionId: "sess-2",
initialOldestSequence: 10,
initialNewestSequence: 59,
initialHasMore: false,
});
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.hasMore).toBe(false);
expect(result.current.isLoadingMore).toBe(false);
});
describe("loadMore — forward pagination", () => {
it("calls getV2GetSession with after_sequence and updates newestSequence", async () => {
const rawMsg = { role: "user", content: "hi", sequence: 50 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 49 }),
);
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("sets hasMore=false when response has no more messages", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false }),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(false);
});
it("is a no-op when hasMore is false", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
initialHasMore: false,
forwardPaginated: true,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
});
describe("loadMore — backward pagination", () => {
it("calls getV2GetSession with before_sequence", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [{ role: "user", content: "old", sequence: 0 }],
has_more_messages: false,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: 50,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenCalledWith(
"sess-1",
expect.objectContaining({ before_sequence: 50 }),
);
expect(result.current.hasMore).toBe(false);
});
});
describe("loadMore — error handling", () => {
it("does not set hasMore=false on first error", async () => {
mockGetV2GetSession.mockRejectedValueOnce(new Error("network error"));
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
// First error — hasMore still true
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) errors", async () => {
mockGetV2GetSession.mockRejectedValue(new Error("network error"));
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
for (let i = 0; i < 3; i++) {
await act(async () => {
await result.current.loadMore();
});
// Reset the in-flight guard between calls
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
}
expect(result.current.hasMore).toBe(false);
});
it("ignores non-200 response and increments error count", async () => {
mockGetV2GetSession.mockResolvedValueOnce({ status: 500, data: {} });
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
// One error, not yet at threshold — hasMore still true
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) non-200 responses", async () => {
mockGetV2GetSession.mockResolvedValue({ status: 503, data: {} });
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
for (let i = 0; i < 3; i++) {
await act(async () => {
await result.current.loadMore();
});
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
}
expect(result.current.hasMore).toBe(false);
});
it("discards in-flight error when epoch changes mid-flight (resetPaged called)", async () => {
let rejectRequest!: (e: Error) => void;
mockGetV2GetSession.mockReturnValueOnce(
new Promise((_, rej) => {
rejectRequest = rej;
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
act(() => {
result.current.loadMore();
});
// Reset epoch mid-flight
act(() => {
result.current.resetPaged();
});
// Reject the in-flight request — stale error should be discarded
await act(async () => {
rejectRequest(new Error("network error"));
});
// State unchanged: no hasMore=false, no errorCount, isLoadingMore cleared
expect(result.current.hasMore).toBe(false); // false from resetPaged
expect(result.current.isLoadingMore).toBe(false);
});
});
describe("loadMore — forward pagination cursor advancement", () => {
it("advances newestSequence after a successful forward load", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [{ role: "user", content: "hi", sequence: 50 }],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
// A second loadMore should use after_sequence: 99 (advanced cursor)
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 99 }),
);
});
it("does not regress newestSequence when parent refetches after pages loaded", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [{ role: "user", content: "msg", sequence: 50 }],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result, rerender } = renderHook(
(props) => useLoadMoreMessages(props),
{ initialProps: { ...BASE_ARGS, forwardPaginated: true } },
);
// Load one page — newestSequence advances to 99
await act(async () => {
await result.current.loadMore();
});
// Parent refetches with a lower newest_sequence (49) — should NOT regress cursor
rerender({
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: 49,
});
// Next loadMore should still use the advanced cursor (99)
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 99 }),
);
});
});
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
// Single load of 2001 messages exceeds the limit in one shot.
// This avoids relying on cross-render closure staleness: estimatedTotal =
// pagedRawMessages.length (0, fresh) + 2001 = 2001 >= 2000 → hasMore=false.
const args = { ...BASE_ARGS, forwardPaginated: false };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 2001 }, (_, i) => ({
role: "user",
content: `msg ${i}`,
sequence: i,
})),
has_more_messages: true,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() => useLoadMoreMessages(args));
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(false);
});
it("forward truncation keeps first MAX_OLDER_MESSAGES items (not last)", async () => {
// 1990 messages already paged; load 20 more forward — total 2010 > 2000.
// Forward truncation must keep slice(0, 2000), not slice(-2000),
// to preserve the beginning of the conversation.
const forwardNearLimitArgs = {
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: 49,
initialOldestSequence: 0,
initialHasMore: true,
};
const { result } = renderHook((props) => useLoadMoreMessages(props), {
initialProps: forwardNearLimitArgs,
});
// First load: 1990 messages — advances newestSequence to 2039
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 1990 }, (_, i) => ({
role: "assistant",
content: `msg ${i + 50}`,
sequence: i + 50,
})),
has_more_messages: true,
newest_sequence: 2039,
}),
);
await act(async () => {
await result.current.loadMore();
});
// Second load: 20 more messages pushes total to 2010 > 2000.
// Truncation keeps seq 50..2049 (2000 items); discards seq 2050..2059 (10 items).
// Even though the server says has_more_messages=false, hasMore stays true
// because there are discarded items that need to be re-fetched.
// The cursor (newestSequence) advances to 2049 — the last kept item's sequence.
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 20 }, (_, i) => ({
role: "assistant",
content: `msg ${i + 2040}`,
sequence: i + 2040,
})),
has_more_messages: false,
newest_sequence: 2059,
}),
);
await act(async () => {
await result.current.loadMore();
});
// Truncation occurred (2010 > 2000): hasMore=true so discarded items can be fetched.
// Cursor advances to last kept item (seq 2049), not the server's newest (2059).
await waitFor(() => expect(result.current.hasMore).toBe(true));
});
});
describe("loadMore — null cursor guard", () => {
it("is a no-op when newestSequence is null (forwardPaginated=true)", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: null,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
it("is a no-op when oldestSequence is null (forwardPaginated=false)", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: null,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
});
describe("pagedMessages — initialPageRawMessages extraToolOutputs", () => {
it("calls extractToolOutputsFromRaw for backward pagination with non-empty initialPageRawMessages", async () => {
const { extractToolOutputsFromRaw } = await import(
"../helpers/convertChatSessionToUiMessages"
);
const rawMsg = { role: "user", content: "old", sequence: 0 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: false,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: 50,
initialPageRawMessages: [{ role: "assistant", content: "response" }],
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(extractToolOutputsFromRaw).toHaveBeenCalled();
});
it("does NOT call extractToolOutputsFromRaw for forward pagination", async () => {
const { extractToolOutputsFromRaw } = await import(
"../helpers/convertChatSessionToUiMessages"
);
const rawMsg = { role: "assistant", content: "hi", sequence: 50 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: false,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: true,
initialPageRawMessages: [{ role: "user", content: "hello" }],
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(extractToolOutputsFromRaw).not.toHaveBeenCalled();
});
});
describe("loadMore — epoch / stale-response guard", () => {
it("discards response when epoch changes during flight (resetPaged called)", async () => {
let resolveRequest!: (v: unknown) => void;
mockGetV2GetSession.mockReturnValueOnce(
new Promise((res) => {
resolveRequest = res;
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
// Start the request without awaiting
act(() => {
result.current.loadMore();
});
// Reset epoch mid-flight
act(() => {
result.current.resetPaged();
});
// Now resolve the in-flight request
await act(async () => {
resolveRequest(
makeSuccessResponse({ messages: [{ role: "user", content: "hi" }] }),
);
});
// Response discarded — pagedMessages stays empty, isLoadingMore stays false
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.isLoadingMore).toBe(false);
});
});
});

View File

@@ -30,6 +30,7 @@ export interface ChatContainerProps {
hasMoreMessages?: boolean;
isLoadingMore?: boolean;
onLoadMore?: () => void;
forwardPaginated?: boolean;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -54,6 +55,7 @@ export const ChatContainer = ({
hasMoreMessages,
isLoadingMore,
onLoadMore,
forwardPaginated,
droppedFiles,
onDroppedFilesConsumed,
historicalDurations,
@@ -108,6 +110,7 @@ export const ChatContainer = ({
hasMoreMessages={hasMoreMessages}
isLoadingMore={isLoadingMore}
onLoadMore={onLoadMore}
forwardPaginated={forwardPaginated}
onRetry={handleRetry}
historicalDurations={historicalDurations}
/>

View File

@@ -86,11 +86,11 @@ export function ChatInput({
title:
next === "advanced"
? "Switched to Advanced model"
: "Switched to Standard model",
: "Switched to Balanced model",
description:
next === "advanced"
? "Using the highest-capability model."
: "Using the balanced standard model.",
: "Using the balanced default model.",
});
}

View File

@@ -162,10 +162,15 @@ describe("ChatInput mode toggle", () => {
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
});
it("hides toggle button when streaming", () => {
it("hides toggle buttons when streaming", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} isStreaming />);
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
expect(
screen.queryByLabelText(/switch to (fast|extended thinking) mode/i),
).toBeNull();
expect(
screen.queryByLabelText(/switch to (advanced|balanced|standard) model/i),
).toBeNull();
});
it("shows mode toggle when hasSession is true and not streaming", () => {
@@ -234,7 +239,7 @@ describe("ChatInput model toggle", () => {
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
});
@@ -288,10 +293,10 @@ describe("ChatInput model toggle", () => {
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to standard model/i),
title: expect.stringMatching(/switched to balanced model/i),
}),
);
});

View File

@@ -2,6 +2,11 @@
import { cn } from "@/lib/utils";
import { Flask } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
// This button is only rendered on NEW chats (no active session).
// Once a session exists, it is hidden — the session's dry_run flag is
@@ -14,27 +19,31 @@ interface Props {
export function DryRunToggleButton({ isDryRun, onToggle }: Props) {
return (
<button
type="button"
aria-pressed={isDryRun}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isDryRun
? "bg-amber-100 text-amber-900 hover:bg-amber-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
)}
aria-label={
isDryRun ? "Test mode active — click to disable" : "Enable Test mode"
}
title={
isDryRun
? "Test mode ON — new chats run agents as simulation (click to disable)"
: "Enable Test mode — new chats will run agents as simulation"
}
>
<Flask size={14} />
{isDryRun && "Test"}
</button>
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isDryRun}
onClick={onToggle}
className={cn(
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isDryRun
? "text-amber-900"
: "text-neutral-500 hover:text-neutral-700",
)}
aria-label={isDryRun ? "Test mode active" : "Enable Test mode"}
>
<Flask size={14} />
<span className="hidden sm:inline">
{isDryRun ? "Test mode enabled" : "Enable test mode"}
</span>
</button>
</TooltipTrigger>
<TooltipContent>
{isDryRun
? "Test mode on — new sessions run without performing real actions (click to turn off)."
: "Turn on test mode to try prompts without performing real actions."}
</TooltipContent>
</Tooltip>
);
}

View File

@@ -2,6 +2,11 @@
import { cn } from "@/lib/utils";
import { Brain, Lightning } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import type { CopilotMode } from "../../../store";
interface Props {
@@ -11,37 +16,42 @@ interface Props {
export function ModeToggleButton({ mode, onToggle }: Props) {
const isExtended = mode === "extended_thinking";
const tooltipText = isExtended
? "Extended Thinking — deeper reasoning (click to switch to Fast)"
: "Fast mode — quicker responses (click to switch to Thinking)";
return (
<button
type="button"
aria-pressed={isExtended}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isExtended
? "bg-purple-100 text-purple-900 hover:bg-purple-200"
: "bg-amber-100 text-amber-900 hover:bg-amber-200",
)}
aria-label={
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
}
title={
isExtended
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
}
>
{isExtended ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isExtended}
onClick={onToggle}
className={cn(
"ml-2 inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isExtended ? "text-purple-900" : "text-amber-900",
)}
aria-label={
isExtended
? "Switch to Fast mode"
: "Switch to Extended Thinking mode"
}
>
{isExtended ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
</TooltipTrigger>
<TooltipContent>{tooltipText}</TooltipContent>
</Tooltip>
);
}

View File

@@ -2,6 +2,11 @@
import { cn } from "@/lib/utils";
import { Cpu } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import type { CopilotLlmModel } from "../../../store";
interface Props {
@@ -12,27 +17,33 @@ interface Props {
export function ModelToggleButton({ model, onToggle }: Props) {
const isAdvanced = model === "advanced";
return (
<button
type="button"
aria-pressed={isAdvanced}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isAdvanced
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
)}
aria-label={
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
}
title={
isAdvanced
? "Advanced model — highest capability (click to switch to Standard)"
: "Standard model — click to switch to Advanced"
}
>
<Cpu size={14} />
{isAdvanced && "Advanced"}
</button>
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isAdvanced}
onClick={onToggle}
className={cn(
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isAdvanced
? "text-sky-900"
: "text-neutral-500 hover:text-neutral-700",
)}
aria-label={
isAdvanced ? "Switch to Balanced model" : "Switch to Advanced model"
}
>
<Cpu size={14} />
<span className="hidden sm:inline">
{isAdvanced ? "Advanced" : "Balanced"}
</span>
</button>
</TooltipTrigger>
<TooltipContent>
{isAdvanced
? "Using the highest-capability model (click to switch to Balanced)."
: "Using the balanced default model (click to switch to Advanced)."}
</TooltipContent>
</Tooltip>
);
}

View File

@@ -1,21 +1,32 @@
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import {
render as rtlRender,
screen,
fireEvent,
cleanup,
} from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import type { ReactElement } from "react";
import { TooltipProvider } from "@/components/ui/tooltip";
import { DryRunToggleButton } from "../DryRunToggleButton";
afterEach(cleanup);
function render(ui: ReactElement) {
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
}
// DryRunToggleButton only appears on new chats (no active session).
// It has no readOnly/isStreaming props — those scenarios are handled by hiding
// the button entirely at the ChatInput level when hasSession is true.
describe("DryRunToggleButton", () => {
it("shows Test label when isDryRun is true", () => {
it("shows enabled label when isDryRun is true", () => {
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
expect(screen.getByText("Test")).toBeTruthy();
expect(screen.getByText("Test mode enabled")).toBeTruthy();
});
it("shows no text label when isDryRun is false", () => {
it("shows enable label when isDryRun is false", () => {
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
expect(screen.queryByText("Test")).toBeNull();
expect(screen.getByText("Enable test mode")).toBeTruthy();
});
it("calls onToggle when clicked", () => {

View File

@@ -1,9 +1,20 @@
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import {
render as rtlRender,
screen,
fireEvent,
cleanup,
} from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import type { ReactElement } from "react";
import { TooltipProvider } from "@/components/ui/tooltip";
import { ModelToggleButton } from "../ModelToggleButton";
afterEach(cleanup);
function render(ui: ReactElement) {
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
}
describe("ModelToggleButton", () => {
it("shows no text label when model is standard", () => {
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
@@ -31,7 +42,7 @@ describe("ModelToggleButton", () => {
it("sets aria-pressed=true for advanced", () => {
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
const btn = screen.getByLabelText("Switch to Standard model");
const btn = screen.getByLabelText("Switch to Balanced model");
expect(btn.getAttribute("aria-pressed")).toBe("true");
});
});

View File

@@ -43,6 +43,10 @@ interface Props {
hasMoreMessages?: boolean;
isLoadingMore?: boolean;
onLoadMore?: () => void;
/** When true the load-more sentinel is placed at the bottom (forward
* pagination for completed sessions). When false it is at the top
* (backward pagination for active sessions). */
forwardPaginated?: boolean;
onRetry?: () => void;
historicalDurations?: Map<string, number>;
}
@@ -140,11 +144,25 @@ export function LoadMoreSentinel({
isLoading,
messageCount,
onLoadMore,
rootMargin = "200px 0px 0px 0px",
adjustScroll = true,
forwardPaginated = false,
}: {
hasMore: boolean;
isLoading: boolean;
messageCount: number;
onLoadMore: () => void;
/** IntersectionObserver rootMargin. Top sentinel uses "200px 0px 0px 0px"
* (pre-trigger when approaching from above); bottom sentinel should use
* "0px 0px 200px 0px" (pre-trigger when approaching from below). */
rootMargin?: string;
/** Whether to adjust scrollTop after load to preserve visual position.
* True for backward pagination (prepend above); false for forward
* pagination (append below) where no adjustment is needed. */
adjustScroll?: boolean;
/** When true the button reads "Load newer messages" (forward pagination).
* When false (default) it reads "Load older messages". */
forwardPaginated?: boolean;
}) {
const sentinelRef = useRef<HTMLDivElement>(null);
const onLoadMoreRef = useRef(onLoadMore);
@@ -189,11 +207,11 @@ export function LoadMoreSentinel({
if (autoFillRoundsRef.current >= MAX_AUTO_FILL_ROUNDS) return;
captureAndLoad(true);
},
{ rootMargin: "200px 0px 0px 0px" },
{ rootMargin },
);
observer.observe(sentinelRef.current);
return () => observer.disconnect();
}, [hasMore, isLoading, scrollRef]);
}, [hasMore, isLoading, rootMargin, scrollRef]);
// After React commits new DOM nodes (prepended messages), adjust
// scrollTop so the user stays at the same visual position.
@@ -206,7 +224,9 @@ export function LoadMoreSentinel({
scrollSnapshotRef.current;
if (!el || prevHeight === 0) return;
const delta = el.scrollHeight - prevHeight;
if (delta > 0) {
// Only restore scroll position for backward pagination (content prepended
// above). Forward pagination appends below — no adjustment needed.
if (adjustScroll && delta > 0) {
el.scrollTop = prevTop + delta;
}
// Reset the auto-fill backoff whenever the container becomes
@@ -220,7 +240,7 @@ export function LoadMoreSentinel({
}
scrollSnapshotRef.current = { scrollHeight: 0, scrollTop: 0 };
autoTriggeredRef.current = false;
}, [messageCount, scrollRef]);
}, [adjustScroll, messageCount, scrollRef]);
return (
<div
@@ -239,7 +259,7 @@ export function LoadMoreSentinel({
size="small"
onClick={() => captureAndLoad(false)}
>
Load older messages
{forwardPaginated ? "Load newer messages" : "Load older messages"}
</Button>
)
)}
@@ -256,6 +276,7 @@ export function ChatMessagesContainer({
hasMoreMessages,
isLoadingMore,
onLoadMore,
forwardPaginated,
onRetry,
historicalDurations,
}: Props) {
@@ -334,7 +355,7 @@ export function ChatMessagesContainer({
}
>
<ConversationContent className="flex min-h-full flex-1 flex-col gap-6 px-3 py-6">
{hasMoreMessages && onLoadMore && (
{hasMoreMessages && onLoadMore && !forwardPaginated && (
<LoadMoreSentinel
hasMore={hasMoreMessages}
isLoading={!!isLoadingMore}
@@ -497,6 +518,17 @@ export function ChatMessagesContainer({
</pre>
</details>
)}
{hasMoreMessages && onLoadMore && forwardPaginated && (
<LoadMoreSentinel
hasMore={hasMoreMessages}
isLoading={!!isLoadingMore}
messageCount={messages.length}
onLoadMore={onLoadMore}
rootMargin="0px 0px 200px 0px"
adjustScroll={false}
forwardPaginated
/>
)}
</ConversationContent>
<ConversationScrollButton />
</Conversation>

View File

@@ -0,0 +1,173 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { ChatMessagesContainer } from "../ChatMessagesContainer";
const mockScrollEl = {
scrollHeight: 100,
scrollTop: 0,
clientHeight: 500,
};
vi.mock("use-stick-to-bottom", () => ({
useStickToBottomContext: () => ({ scrollRef: { current: mockScrollEl } }),
Conversation: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationScrollButton: () => null,
}));
vi.mock("@/components/ai-elements/conversation", () => ({
Conversation: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationScrollButton: () => null,
}));
vi.mock("@/components/ai-elements/message", () => ({
Message: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
MessageContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
MessageActions: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
vi.mock("../components/AssistantMessageActions", () => ({
AssistantMessageActions: () => null,
}));
vi.mock("../components/CopyButton", () => ({ CopyButton: () => null }));
vi.mock("../components/CollapsedToolGroup", () => ({
CollapsedToolGroup: () => null,
}));
vi.mock("../components/MessageAttachments", () => ({
MessageAttachments: () => null,
}));
vi.mock("../components/MessagePartRenderer", () => ({
MessagePartRenderer: () => null,
}));
vi.mock("../components/ReasoningCollapse", () => ({
ReasoningCollapse: () => null,
}));
vi.mock("../components/ThinkingIndicator", () => ({
ThinkingIndicator: () => null,
}));
vi.mock("../../JobStatsBar/TurnStatsBar", () => ({
TurnStatsBar: () => null,
}));
vi.mock("../../JobStatsBar/useElapsedTimer", () => ({
useElapsedTimer: () => ({ elapsedSeconds: 0 }),
}));
vi.mock("../../CopilotPendingReviews/CopilotPendingReviews", () => ({
CopilotPendingReviews: () => null,
}));
vi.mock("../helpers", () => ({
buildRenderSegments: () => [],
getTurnMessages: () => [],
parseSpecialMarkers: () => ({ markerType: null }),
splitReasoningAndResponse: (parts: unknown[]) => ({
reasoningParts: [],
responseParts: parts,
}),
}));
type ObserverCallback = (entries: { isIntersecting: boolean }[]) => void;
class MockIntersectionObserver {
static lastCallback: ObserverCallback | null = null;
private callback: ObserverCallback;
constructor(cb: ObserverCallback) {
this.callback = cb;
MockIntersectionObserver.lastCallback = cb;
}
observe() {}
disconnect() {}
unobserve() {}
takeRecords() {
return [];
}
root = null;
rootMargin = "";
thresholds = [];
}
const BASE_PROPS = {
messages: [],
status: "ready" as const,
error: undefined,
isLoading: false,
sessionID: "sess-1",
hasMoreMessages: true,
isLoadingMore: false,
onLoadMore: vi.fn(),
onRetry: vi.fn(),
};
describe("ChatMessagesContainer", () => {
beforeEach(() => {
mockScrollEl.scrollHeight = 100;
mockScrollEl.scrollTop = 0;
mockScrollEl.clientHeight = 500;
MockIntersectionObserver.lastCallback = null;
vi.stubGlobal("IntersectionObserver", MockIntersectionObserver);
});
afterEach(() => {
cleanup();
vi.unstubAllGlobals();
});
it("renders top sentinel when forwardPaginated is false (backward pagination)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={false} />);
expect(
screen.getByRole("button", { name: /load older messages/i }),
).toBeDefined();
});
it("renders top sentinel when forwardPaginated is undefined (default, backward)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} />);
expect(
screen.getByRole("button", { name: /load older messages/i }),
).toBeDefined();
});
it("renders bottom sentinel when forwardPaginated is true (forward pagination)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={true} />);
expect(
screen.getByRole("button", { name: /load newer messages/i }),
).toBeDefined();
});
it("hides sentinel when hasMoreMessages is false", () => {
render(
<ChatMessagesContainer
{...BASE_PROPS}
hasMoreMessages={false}
forwardPaginated={true}
/>,
);
expect(
screen.queryByRole("button", { name: /load older messages/i }),
).toBeNull();
});
it("hides sentinel when onLoadMore is not provided", () => {
render(
<ChatMessagesContainer
{...BASE_PROPS}
onLoadMore={undefined}
forwardPaginated={true}
/>,
);
expect(
screen.queryByRole("button", { name: /load older messages/i }),
).toBeNull();
});
});

View File

@@ -172,6 +172,36 @@ describe("LoadMoreSentinel", () => {
expect(mockScrollEl.scrollTop).toBe(200);
});
it("does NOT adjust scroll when adjustScroll=false (forward pagination)", () => {
mockScrollEl.scrollHeight = 100;
mockScrollEl.scrollTop = 50;
const onLoadMore = vi.fn();
const { rerender } = render(
<LoadMoreSentinel
hasMore={true}
isLoading={false}
messageCount={5}
onLoadMore={onLoadMore}
adjustScroll={false}
/>,
);
// Fire observer to capture snapshot.
MockIntersectionObserver.lastCallback?.([{ isIntersecting: true }]);
// Simulate DOM growing from appended newer messages (forward load-more).
mockScrollEl.scrollHeight = 300;
rerender(
<LoadMoreSentinel
hasMore={true}
isLoading={false}
messageCount={10}
onLoadMore={onLoadMore}
adjustScroll={false}
/>,
);
// scrollTop should remain unchanged — no jump for forward pagination.
expect(mockScrollEl.scrollTop).toBe(50);
});
it("ignores same-frame duplicate triggers until isLoading transitions", () => {
const onLoadMore = vi.fn();
const { rerender } = render(

View File

@@ -13,6 +13,10 @@ import {
getSuggestionThemes,
} from "./helpers";
import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes";
import { PulseChips } from "../PulseChips/PulseChips";
import { usePulseChips } from "../PulseChips/usePulseChips";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { EditNameDialog } from "./components/EditNameDialog/EditNameDialog";
interface Props {
inputLayoutId: string;
@@ -34,6 +38,8 @@ export function EmptySession({
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const pulseChips = usePulseChips();
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
useGetV2GetSuggestedPrompts({
@@ -75,11 +81,16 @@ export function EmptySession({
<div className="mx-auto max-w-[52rem]">
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
Hey, <span className="text-violet-600">{greetingName}</span>
<EditNameDialog currentName={greetingName} />
</Text>
<Text variant="h3" className="mb-8 !font-normal">
Tell me about your work I&apos;ll find what to automate.
</Text>
{isAgentBriefingEnabled && (
<PulseChips chips={pulseChips} onChipClick={onSend} />
)}
<div className="mb-6">
<motion.div
layoutId={inputLayoutId}

View File

@@ -0,0 +1,107 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { PencilSimpleIcon } from "@phosphor-icons/react";
import { useState } from "react";
interface Props {
currentName: string;
}
export function EditNameDialog({ currentName }: Props) {
const [isOpen, setIsOpen] = useState(false);
const [name, setName] = useState(currentName);
const [isSaving, setIsSaving] = useState(false);
const { refreshSession } = useSupabase();
const { toast } = useToast();
function handleOpenChange(open: boolean) {
if (open) setName(currentName);
setIsOpen(open);
}
async function handleSave() {
const trimmed = name.trim();
if (!trimmed) return;
setIsSaving(true);
try {
const res = await fetch("/api/auth/user", {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ full_name: trimmed }),
});
if (!res.ok) {
const body = await res.json();
toast({
title: "Failed to update name",
description: body.error ?? "Unknown error",
variant: "destructive",
});
return;
}
const session = await refreshSession();
if (session?.error) {
toast({
title: "Name saved, but session refresh failed",
description: session.error,
variant: "destructive",
});
setIsOpen(false);
return;
}
setIsOpen(false);
toast({ title: "Name updated" });
} finally {
setIsSaving(false);
}
}
return (
<Dialog
title="Edit display name"
styling={{ maxWidth: "24rem" }}
controlled={{ isOpen, set: handleOpenChange }}
>
<Dialog.Trigger>
<button
type="button"
className="ml-1 inline-flex items-center text-violet-500 transition-colors hover:text-violet-700"
>
<PencilSimpleIcon size={16} />
</button>
</Dialog.Trigger>
<Dialog.Content>
<div className="flex flex-col gap-4 px-1">
<Input
id="display-name"
label="Display name"
placeholder="Your name"
value={name}
onChange={(e) => setName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
handleSave();
}
}}
/>
<Button
variant="primary"
onClick={handleSave}
disabled={!name.trim() || isSaving}
loading={isSaving}
>
Save
</Button>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,135 @@
import { beforeEach, describe, expect, test, vi } from "vitest";
import {
fireEvent,
render,
screen,
waitFor,
} from "@/tests/integrations/test-utils";
import { server } from "@/mocks/mock-server";
import { http, HttpResponse } from "msw";
import { EditNameDialog } from "../EditNameDialog";
const mockToast = vi.hoisted(() => vi.fn());
const mockRefreshSession = vi.hoisted(() => vi.fn());
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: mockToast }),
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: () => ({
refreshSession: mockRefreshSession,
}),
}));
function mockUpdateNameSuccess() {
server.use(
http.put("/api/auth/user", () => {
return HttpResponse.json({ user: { id: "u1" } });
}),
);
}
function mockUpdateNameError(message = "Network error") {
server.use(
http.put("/api/auth/user", () => {
return HttpResponse.json({ error: message }, { status: 400 });
}),
);
}
async function openDialogAndGetInput() {
const trigger = screen.getByRole("button");
fireEvent.click(trigger);
await screen.findAllByLabelText(/display name/i);
const inputs =
document.querySelectorAll<HTMLInputElement>("input#display-name");
return inputs[0];
}
function getSaveButton() {
const saves = screen.getAllByRole("button", { name: /save/i });
return saves[0] as HTMLButtonElement;
}
describe("EditNameDialog", () => {
beforeEach(() => {
mockToast.mockReset();
mockRefreshSession.mockReset();
mockRefreshSession.mockResolvedValue({ user: { id: "u1" } });
});
test("opens dialog with current name prefilled", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
expect(input.value).toBe("Alice");
});
test("saves name via API route and closes dialog", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockRefreshSession).toHaveBeenCalled();
});
expect(mockToast).toHaveBeenCalledWith({ title: "Name updated" });
});
test("shows error toast when API returns error", async () => {
mockUpdateNameError("Network error");
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Failed to update name",
description: "Network error",
variant: "destructive",
}),
);
});
expect(mockRefreshSession).not.toHaveBeenCalled();
});
test("shows warning toast when refreshSession returns an error", async () => {
mockUpdateNameSuccess();
mockRefreshSession.mockResolvedValue({ error: "refresh failed" });
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Name saved, but session refresh failed",
description: "refresh failed",
variant: "destructive",
}),
);
});
expect(mockToast).not.toHaveBeenCalledWith({ title: "Name updated" });
});
test("disables Save button while input is empty", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: " " } });
expect(getSaveButton().disabled).toBe(true);
});
});

View File

@@ -0,0 +1,93 @@
.glassPanel {
position: relative;
isolation: isolate;
}
.glassPanel::before {
content: "";
position: absolute;
inset: 0;
border-radius: inherit;
padding: 1px;
background: conic-gradient(
from var(--border-angle, 0deg),
rgba(129, 120, 228, 0.08),
rgba(129, 120, 228, 0.28),
rgba(168, 130, 255, 0.18),
rgba(129, 120, 228, 0.08),
rgba(99, 102, 241, 0.24),
rgba(129, 120, 228, 0.08)
);
-webkit-mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
-webkit-mask-composite: xor;
mask-composite: exclude;
animation: rotate-border 6s linear infinite;
pointer-events: none;
z-index: -1;
}
@property --border-angle {
syntax: "<angle>";
initial-value: 0deg;
inherits: false;
}
@keyframes rotate-border {
to {
--border-angle: 360deg;
}
}
.chip {
overflow: hidden;
}
@media (hover: hover) {
.chip {
padding-bottom: 0.9rem;
}
}
@media (hover: none) {
.chip {
padding-bottom: 2.25rem;
}
}
.chipActions {
position: absolute;
inset-inline: 0;
bottom: 0;
background: rgba(255, 255, 255, 0.95);
backdrop-filter: blur(4px);
-webkit-backdrop-filter: blur(4px);
}
@media (hover: hover) {
.chipActions {
opacity: 0;
transform: translateY(100%);
transition:
opacity 0.2s ease-out,
transform 0.2s ease-out;
}
.chip:hover .chipActions {
opacity: 1;
transform: translateY(0);
}
.chipContent {
transition: filter 0.2s ease-out;
}
.chip:hover .chipContent {
filter: blur(2px);
opacity: 0.5;
}
}

View File

@@ -0,0 +1,116 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import {
ArrowRightIcon,
EyeIcon,
ChatCircleDotsIcon,
} from "@phosphor-icons/react";
import NextLink from "next/link";
import { StatusBadge } from "@/app/(platform)/library/components/StatusBadge/StatusBadge";
import styles from "./PulseChips.module.css";
import type { PulseChipData } from "./types";
interface Props {
chips: PulseChipData[];
onChipClick?: (prompt: string) => void;
}
export function PulseChips({ chips, onChipClick }: Props) {
if (chips.length === 0) return null;
return (
<div
className={`${styles.glassPanel} mx-[0.6875rem] mb-5 rounded-large p-5`}
>
<div className="mb-3 flex items-center gap-3">
<Text variant="body-medium" className="text-zinc-600">
What&apos;s happening with your agents
</Text>
<NextLink
href="/library"
className="flex items-center gap-1 text-xs text-zinc-500 hover:text-zinc-700"
>
View all <ArrowRightIcon size={12} />
</NextLink>
</div>
<div className="flex gap-2 overflow-x-auto pb-1 scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300">
{chips.map((chip) => (
<PulseChip key={chip.id} chip={chip} onAsk={onChipClick} />
))}
</div>
</div>
);
}
interface ChipProps {
chip: PulseChipData;
onAsk?: (prompt: string) => void;
}
function PulseChip({ chip, onAsk }: ChipProps) {
function handleAsk() {
const prompt = buildChipPrompt(chip);
onAsk?.(prompt);
}
return (
<div
className={`${styles.chip} relative flex w-[15rem] shrink-0 flex-col items-start gap-2 rounded-medium border border-zinc-100 bg-white px-3 py-2`}
>
<div className={`${styles.chipContent} w-full text-left`}>
{chip.priority === "success" ? (
<span className="inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium text-emerald-600">
<span className="h-1.5 w-1.5 rounded-full bg-emerald-500" />
Completed
</span>
) : (
<StatusBadge status={chip.status} />
)}
<div className="mt-2 min-w-0">
<Text variant="small-medium" className="truncate text-zinc-900">
{chip.name}
</Text>
<Text variant="small" className="truncate text-zinc-500">
{chip.shortMessage}
</Text>
</div>
</div>
<div
className={`${styles.chipActions} flex items-center justify-center gap-1.5 rounded-b-medium px-3 py-1.5`}
>
<NextLink
href={`/library/agents/${chip.agentID}`}
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
>
<EyeIcon size={14} />
See
</NextLink>
<button
type="button"
onClick={handleAsk}
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
>
<ChatCircleDotsIcon size={14} />
Ask
</button>
</div>
</div>
);
}
function buildChipPrompt(chip: PulseChipData): string {
if (chip.priority === "success") {
return `${chip.name} just finished a run — can you summarize what it did?`;
}
switch (chip.status) {
case "error":
return `What happened with ${chip.name}? It has an error — can you check?`;
case "running":
return `Give me a status update on ${chip.name} — what has it done so far?`;
case "idle":
return `${chip.name} hasn't run recently. Should I keep it or update and re-run it?`;
default:
return `Tell me about ${chip.name} — what's its current status?`;
}
}

View File

@@ -0,0 +1,105 @@
import { describe, expect, test, vi } from "vitest";
import { render, screen, fireEvent } from "@/tests/integrations/test-utils";
import { PulseChips } from "../PulseChips";
import type { PulseChipData } from "../types";
function makeChip(overrides: Partial<PulseChipData> = {}): PulseChipData {
return {
id: "chip-1",
agentID: "agent-1",
name: "Test Agent",
status: "running",
priority: "running",
shortMessage: "Doing work…",
...overrides,
};
}
describe("PulseChips", () => {
test("renders nothing when chips array is empty", () => {
const { container } = render(<PulseChips chips={[]} />);
expect(container.innerHTML).toBe("");
});
test("renders chip names and messages", () => {
const chips = [
makeChip({ id: "1", name: "Alpha Bot", shortMessage: "Running task A" }),
makeChip({ id: "2", name: "Beta Bot", shortMessage: "Running task B" }),
];
render(<PulseChips chips={chips} />);
expect(screen.getByText("Alpha Bot")).toBeDefined();
expect(screen.getByText("Running task A")).toBeDefined();
expect(screen.getByText("Beta Bot")).toBeDefined();
expect(screen.getByText("Running task B")).toBeDefined();
});
test("renders section heading and View all link", () => {
render(<PulseChips chips={[makeChip()]} />);
expect(screen.getByText("What's happening with your agents")).toBeDefined();
expect(screen.getByText("View all")).toBeDefined();
});
test("shows Completed badge for success priority chips", () => {
render(
<PulseChips
chips={[makeChip({ priority: "success", status: "idle" })]}
/>,
);
expect(screen.getByText("Completed")).toBeDefined();
});
test("calls onChipClick with generated prompt when Ask is clicked", () => {
const onChipClick = vi.fn();
render(
<PulseChips
chips={[
makeChip({
name: "Error Agent",
status: "error",
priority: "error",
}),
]}
onChipClick={onChipClick}
/>,
);
fireEvent.click(screen.getByText("Ask"));
expect(onChipClick).toHaveBeenCalledWith(
"What happened with Error Agent? It has an error — can you check?",
);
});
test("generates success prompt for completed chips", () => {
const onChipClick = vi.fn();
render(
<PulseChips
chips={[
makeChip({
name: "Done Agent",
priority: "success",
status: "idle",
}),
]}
onChipClick={onChipClick}
/>,
);
fireEvent.click(screen.getByText("Ask"));
expect(onChipClick).toHaveBeenCalledWith(
"Done Agent just finished a run — can you summarize what it did?",
);
});
test("renders See link pointing to agent detail page", () => {
render(<PulseChips chips={[makeChip({ agentID: "agent-xyz" })]} />);
const seeLink = screen.getByText("See").closest("a");
expect(seeLink?.getAttribute("href")).toBe("/library/agents/agent-xyz");
});
});

View File

@@ -0,0 +1,13 @@
import type {
AgentStatus,
SitrepPriority,
} from "@/app/(platform)/library/types";
export interface PulseChipData {
id: string;
agentID: string;
name: string;
status: AgentStatus;
priority: SitrepPriority;
shortMessage: string;
}

View File

@@ -0,0 +1,23 @@
"use client";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
import { useSitrepItems } from "@/app/(platform)/library/components/SitrepItem/useSitrepItems";
import type { PulseChipData } from "./types";
import { useMemo } from "react";
export function usePulseChips(): PulseChipData[] {
const { agents } = useLibraryAgents();
const sitrepItems = useSitrepItems(agents, 5);
return useMemo(() => {
return sitrepItems.map((item) => ({
id: item.id,
agentID: item.agentID,
name: item.agentName,
status: item.status,
priority: item.priority,
shortMessage: item.message,
}));
}, [sitrepItems]);
}

View File

@@ -6,6 +6,9 @@ import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useRouter } from "next/navigation";
import { useEffect, useRef } from "react";
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
import { formatCents } from "../usageHelpers";
export { formatCents };
interface Props {
isOpen: boolean;
@@ -18,10 +21,6 @@ interface Props {
onCreditChange?: () => void;
}
export function formatCents(cents: number): string {
return `$${(cents / 100).toFixed(2)}`;
}
export function RateLimitResetDialog({
isOpen,
onClose,

View File

@@ -1,35 +1,10 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { Button } from "@/components/atoms/Button/Button";
import Link from "next/link";
import { formatCents } from "../RateLimitResetDialog/RateLimitResetDialog";
import { formatCents, formatResetTime } from "../usageHelpers";
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
export function formatResetTime(
resetsAt: Date | string,
now: Date = new Date(),
): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / (1000 * 60 * 60));
// Under 24h: show relative time ("in 4h 23m")
if (hours < 24) {
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
export { formatResetTime };
function UsageBar({
label,

View File

@@ -0,0 +1,28 @@
export function formatCents(cents: number): string {
return `$${(cents / 100).toFixed(2)}`;
}
export function formatResetTime(
resetsAt: Date | string,
now: Date = new Date(),
): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / (1000 * 60 * 60));
if (hours < 24) {
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}

View File

@@ -0,0 +1,59 @@
import { describe, expect, it } from "vitest";
import { convertChatSessionMessagesToUiMessages } from "../convertChatSessionToUiMessages";
const SESSION_ID = "sess-test";
describe("convertChatSessionMessagesToUiMessages", () => {
it("does not drop user messages with null content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: null, sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
it("does not drop user messages with empty string content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: "", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
it("still drops non-user messages with null content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "assistant", content: null, sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(0);
});
it("still drops non-user messages with empty string content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "assistant", content: "", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(0);
});
it("includes user message with normal content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: "hello", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
});

View File

@@ -253,6 +253,11 @@ export function convertChatSessionMessagesToUiMessages(
}
}
// User messages must always be rendered, even with empty content, so the
// initial prompt is visible when reloading a session.
if (parts.length === 0 && msg.role === "user") {
parts.push({ type: "text", text: "", state: "done" });
}
if (parts.length === 0) return;
// Merge consecutive assistant messages into a single UIMessage

View File

@@ -86,6 +86,16 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
return sessionQuery.data.data.oldest_sequence ?? null;
}, [sessionQuery.data]);
const newestSequence = useMemo(() => {
if (sessionQuery.data?.status !== 200) return null;
return sessionQuery.data.data.newest_sequence ?? null;
}, [sessionQuery.data]);
const forwardPaginated = useMemo(() => {
if (sessionQuery.data?.status !== 200) return false;
return !!sessionQuery.data.data.forward_paginated;
}, [sessionQuery.data]);
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
// array reference every render. Re-derives only when query data changes.
// When the session is complete (no active stream), mark dangling tool
@@ -185,6 +195,8 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
hasActiveStream,
hasMoreMessages,
oldestSequence,
newestSequence,
forwardPaginated,
isLoadingSession: sessionQuery.isLoading,
isSessionError: sessionQuery.isError,
createSession,

View File

@@ -56,6 +56,8 @@ export function useCopilotPage() {
hasActiveStream,
hasMoreMessages,
oldestSequence,
newestSequence,
forwardPaginated,
isLoadingSession,
isSessionError,
createSession,
@@ -84,18 +86,26 @@ export function useCopilotPage() {
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
});
const { olderMessages, hasMore, isLoadingMore, loadMore } =
const { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged } =
useLoadMoreMessages({
sessionId,
initialOldestSequence: oldestSequence,
initialNewestSequence: newestSequence,
initialHasMore: hasMoreMessages,
forwardPaginated,
initialPageRawMessages: rawSessionMessages,
});
// Combine older (paginated) messages with current page messages,
// merging consecutive assistant UIMessages at the page boundary so
// reasoning + response parts stay in a single bubble.
const messages = concatWithAssistantMerge(olderMessages, currentMessages);
// Combine paginated messages with current page messages, merging consecutive
// assistant UIMessages at the page boundary so reasoning + response parts
// stay in a single bubble.
// Forward pagination (completed sessions): current page is the beginning,
// paged messages are newer pages appended after.
// Backward pagination (active sessions): paged messages are older history
// prepended before the current page.
const messages = forwardPaginated
? concatWithAssistantMerge(currentMessages, pagedMessages)
: concatWithAssistantMerge(pagedMessages, currentMessages);
useCopilotNotifications(sessionId);
@@ -170,6 +180,23 @@ export function useCopilotPage() {
}
}, [sessionId, pendingMessage, sendMessage]);
// --- Clear backward-paginated messages when session completes ---
// When a session transitions from active (forwardPaginated=false) to complete
// (forwardPaginated=true), any backward-paginated older messages would be
// appended after currentMessages instead of before, causing chronological
// disorder. Reset paged state so the completed session renders cleanly.
const prevForwardPaginatedRef = useRef(forwardPaginated);
useEffect(() => {
if (
!prevForwardPaginatedRef.current &&
forwardPaginated &&
pagedMessages.length > 0
) {
resetPaged();
}
prevForwardPaginatedRef.current = forwardPaginated;
}, [forwardPaginated, pagedMessages.length, resetPaged]);
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
useWorkflowImportAutoSubmit({
createSession,
@@ -251,6 +278,15 @@ export function useCopilotPage() {
isUserStoppingRef.current = false;
if (sessionId) {
// When continuing a completed session that had forward-paginated history
// loaded, the paged messages would appear in wrong position relative to
// the new streaming turn (pagedMessages are newer pages, so they'd end
// up after the streaming turn). Reset paged state so ordering is correct
// during streaming; the user can reload history afterward if needed.
if (forwardPaginated && pagedMessages.length > 0) {
resetPaged();
}
if (files && files.length > 0) {
setIsUploadingFiles(true);
try {
@@ -397,6 +433,7 @@ export function useCopilotPage() {
hasMoreMessages: hasMore,
isLoadingMore,
loadMore,
forwardPaginated,
// Mobile drawer
isMobile,
isDrawerOpen,

View File

@@ -9,7 +9,11 @@ import {
interface UseLoadMoreMessagesArgs {
sessionId: string | null;
initialOldestSequence: number | null;
initialNewestSequence: number | null;
initialHasMore: boolean;
/** True when the initial page was loaded from sequence 0 forward (completed
* sessions). False when loaded newest-first (active sessions). */
forwardPaginated: boolean;
/** Raw messages from the initial page, used for cross-page tool output matching. */
initialPageRawMessages: unknown[];
}
@@ -20,16 +24,21 @@ const MAX_OLDER_MESSAGES = 2000;
export function useLoadMoreMessages({
sessionId,
initialOldestSequence,
initialNewestSequence,
initialHasMore,
forwardPaginated,
initialPageRawMessages,
}: UseLoadMoreMessagesArgs) {
// Store accumulated raw messages from all older pages (in ascending order).
// Accumulated raw messages from all extra pages (ascending order).
// Re-converting them all together ensures tool outputs are matched across
// inter-page boundaries.
const [olderRawMessages, setOlderRawMessages] = useState<unknown[]>([]);
const [pagedRawMessages, setPagedRawMessages] = useState<unknown[]>([]);
const [oldestSequence, setOldestSequence] = useState<number | null>(
initialOldestSequence,
);
const [newestSequence, setNewestSequence] = useState<number | null>(
initialNewestSequence,
);
const [hasMore, setHasMore] = useState(initialHasMore);
const [isLoadingMore, setIsLoadingMore] = useState(false);
const isLoadingMoreRef = useRef(false);
@@ -46,7 +55,7 @@ export function useLoadMoreMessages({
// The parent's `initialOldestSequence` drifts forward every time the
// session query refetches (e.g. after a stream completes — see
// `useCopilotStream` invalidation on `streaming → ready`). If we
// wiped `olderRawMessages` every time that happened, users who had
// wiped `pagedRawMessages` every time that happened, users who had
// scrolled back would lose their loaded history on each new turn and
// subsequent `loadMore` calls would fetch messages that overlap with
// the AI SDK's retained state in `currentMessages`, producing visible
@@ -63,8 +72,9 @@ export function useLoadMoreMessages({
// Session changed — full reset
prevSessionIdRef.current = sessionId;
prevInitialOldestRef.current = initialOldestSequence;
setOlderRawMessages([]);
setPagedRawMessages([]);
setOldestSequence(initialOldestSequence);
setNewestSequence(initialNewestSequence);
setHasMore(initialHasMore);
setIsLoadingMore(false);
isLoadingMoreRef.current = false;
@@ -75,49 +85,64 @@ export function useLoadMoreMessages({
prevInitialOldestRef.current = initialOldestSequence;
// If we haven't paged back yet, mirror the parent so the first
// If we haven't paged yet, mirror the parent so the first
// `loadMore` starts from the correct cursor.
if (olderRawMessages.length === 0) {
//
// When paged messages exist (pagedRawMessages.length > 0) we intentionally
// do NOT update `hasMore` or `newestSequence` from the parent. A parent
// refetch (e.g. after a new turn completes) may carry a fresh
// `initialHasMore=true` or a larger `initialNewestSequence`, but those
// reflect the *initial* page window, not the forward-paged window we have
// already advanced into. Overwriting the local cursor here would cause the
// next `loadMore` to re-fetch pages we already have. The local cursor is
// advanced correctly inside `loadMore` itself via `setNewestSequence`.
if (pagedRawMessages.length === 0) {
setOldestSequence(initialOldestSequence);
// Only regress the forward cursor if we haven't paged ahead yet —
// otherwise a parent refetch would reset a cursor we already advanced.
setNewestSequence((prev) =>
prev !== null && prev > (initialNewestSequence ?? -1)
? prev
: initialNewestSequence,
);
setHasMore(initialHasMore);
}
}, [sessionId, initialOldestSequence, initialHasMore]);
}, [sessionId, initialOldestSequence, initialNewestSequence, initialHasMore]);
// Convert all accumulated raw messages in one pass so tool outputs
// are matched across inter-page boundaries. Initial page tool outputs
// are included via extraToolOutputs to handle the boundary between
// the last older page and the initial/streaming page.
const olderMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
// are matched across inter-page boundaries.
// For backward pagination only: include initial page tool outputs so older
// paged pages can match tool calls whose outputs landed in the initial page.
// For forward pagination this is unnecessary — tool calls in newer paged
// pages cannot have their outputs in the older initial page.
const pagedMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
useMemo(() => {
if (!sessionId || olderRawMessages.length === 0) return [];
if (!sessionId || pagedRawMessages.length === 0) return [];
const extraToolOutputs =
initialPageRawMessages.length > 0
!forwardPaginated && initialPageRawMessages.length > 0
? extractToolOutputsFromRaw(initialPageRawMessages)
: undefined;
return convertChatSessionMessagesToUiMessages(
sessionId,
olderRawMessages,
pagedRawMessages,
{ isComplete: true, extraToolOutputs },
).messages;
}, [sessionId, olderRawMessages, initialPageRawMessages]);
}, [sessionId, pagedRawMessages, initialPageRawMessages, forwardPaginated]);
async function loadMore() {
if (
!sessionId ||
!hasMore ||
isLoadingMoreRef.current ||
oldestSequence === null
)
return;
if (!sessionId || !hasMore || isLoadingMoreRef.current) return;
const cursor = forwardPaginated ? newestSequence : oldestSequence;
if (cursor === null) return;
const requestEpoch = epochRef.current;
isLoadingMoreRef.current = true;
setIsLoadingMore(true);
try {
const response = await getV2GetSession(sessionId, {
limit: 50,
before_sequence: oldestSequence,
});
const params = forwardPaginated
? { limit: 50, after_sequence: cursor }
: { limit: 50, before_sequence: cursor };
const response = await getV2GetSession(sessionId, params);
// Discard response if session/pagination was reset while awaiting
if (epochRef.current !== requestEpoch) return;
@@ -136,18 +161,66 @@ export function useLoadMoreMessages({
consecutiveErrorsRef.current = 0;
const newRaw = (response.data.messages ?? []) as unknown[];
setOlderRawMessages((prev) => {
const merged = [...newRaw, ...prev];
// Estimate total after merge using the closure-captured pagedRawMessages.length.
// This is a safe approximation: worst case it's one page stale (one extra load
// allowed), but it avoids the React-18-batching pitfall where a functional
// updater's mutations are not visible until the next render.
const estimatedTotal = pagedRawMessages.length + newRaw.length;
setPagedRawMessages((prev) => {
// Forward: append to end. Backward: prepend to start.
const merged = forwardPaginated
? [...prev, ...newRaw]
: [...newRaw, ...prev];
if (merged.length > MAX_OLDER_MESSAGES) {
return merged.slice(merged.length - MAX_OLDER_MESSAGES);
// Backward: discard the oldest (front) items — user has scrolled far
// back and we shed the furthest history.
// Forward: discard the newest (tail) items — we only ever fetch
// forward, so the tail is the most recently appended page; shedding
// it means the sentinel stalls, which is safer than discarding the
// beginning of the conversation the user is here to read.
return forwardPaginated
? merged.slice(0, MAX_OLDER_MESSAGES)
: merged.slice(merged.length - MAX_OLDER_MESSAGES);
}
return merged;
});
setOldestSequence(response.data.oldest_sequence ?? null);
if (newRaw.length + olderRawMessages.length >= MAX_OLDER_MESSAGES) {
setHasMore(false);
if (forwardPaginated) {
const willTruncateForward = estimatedTotal > MAX_OLDER_MESSAGES;
if (willTruncateForward) {
// Truncation shed the newest tail. Advance the cursor to the last KEPT
// item's sequence so the sentinel re-fetches the discarded items next
// time rather than jumping past them.
// lastKeptIdx: index within newRaw of the last item that survives.
// prev contributes pagedRawMessages.length items; total kept = MAX.
const lastKeptIdx = MAX_OLDER_MESSAGES - 1 - pagedRawMessages.length;
if (lastKeptIdx >= 0 && lastKeptIdx < newRaw.length) {
const lastKeptMsg = newRaw[lastKeptIdx] as { sequence?: number };
if (typeof lastKeptMsg?.sequence === "number") {
setNewestSequence(lastKeptMsg.sequence);
setHasMore(true); // Discarded items still exist — keep sentinel active
} else {
// Sequence unavailable — fall back; truncated items will be lost
setNewestSequence(response.data.newest_sequence ?? null);
setHasMore(!!response.data.has_more_messages);
}
} else {
// All of newRaw was dropped (already at MAX_OLDER_MESSAGES cap).
// Stop to avoid an infinite re-fetch loop at the display cap.
setHasMore(false);
}
} else {
setNewestSequence(response.data.newest_sequence ?? null);
setHasMore(!!response.data.has_more_messages);
}
} else {
setHasMore(!!response.data.has_more_messages);
setOldestSequence(response.data.oldest_sequence ?? null);
if (estimatedTotal >= MAX_OLDER_MESSAGES) {
// Backward: accumulated MAX_OLDER_MESSAGES — stop to avoid unbounded memory.
setHasMore(false);
} else {
setHasMore(!!response.data.has_more_messages);
}
}
} catch (error) {
if (epochRef.current !== requestEpoch) return;
@@ -164,5 +237,22 @@ export function useLoadMoreMessages({
}
}
return { olderMessages, hasMore, isLoadingMore, loadMore };
function resetPaged() {
setPagedRawMessages([]);
setOldestSequence(initialOldestSequence);
setNewestSequence(initialNewestSequence);
// Set hasMore=false during the session-transition window so no loadMore
// fires with forward pagination (after_sequence) on the now-active session.
// The useEffect will restore hasMore from the parent after the refetch
// completes and forwardPaginated switches to false.
setHasMore(false);
// Clear the loading state so the spinner doesn't stay stuck if a loadMore
// was in flight when resetPaged was called.
setIsLoadingMore(false);
isLoadingMoreRef.current = false;
consecutiveErrorsRef.current = 0;
epochRef.current += 1;
}
return { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged };
}

View File

@@ -2,14 +2,17 @@ import { Navbar } from "@/components/layout/Navbar/Navbar";
import { NetworkStatusMonitor } from "@/services/network-status/NetworkStatusMonitor";
import { ReactNode } from "react";
import { AdminImpersonationBanner } from "./admin/components/AdminImpersonationBanner";
import { AutoPilotBridgeProvider } from "@/contexts/AutoPilotBridgeContext";
export default function PlatformLayout({ children }: { children: ReactNode }) {
return (
<main className="flex h-screen w-full flex-col">
<NetworkStatusMonitor />
<Navbar />
<AdminImpersonationBanner />
<section className="flex-1">{children}</section>
</main>
<AutoPilotBridgeProvider>
<main className="flex h-screen w-full flex-col">
<NetworkStatusMonitor />
<Navbar />
<AdminImpersonationBanner />
<section className="flex-1">{children}</section>
</main>
</AutoPilotBridgeProvider>
);
}

View File

@@ -137,8 +137,10 @@ describe("LibraryPage", () => {
user_id: "test-user",
name: "Work Agents",
agent_count: 3,
subfolder_count: 0,
color: null,
icon: null,
parent_id: null,
created_at: new Date(),
updated_at: new Date(),
},
@@ -147,8 +149,10 @@ describe("LibraryPage", () => {
user_id: "test-user",
name: "Personal",
agent_count: 1,
subfolder_count: 0,
color: null,
icon: null,
parent_id: null,
created_at: new Date(),
updated_at: new Date(),
},
@@ -158,12 +162,14 @@ describe("LibraryPage", () => {
render(<LibraryPage />);
await waitForAgentsToLoad();
expect(await screen.findByText("Work Agents")).toBeDefined();
expect(screen.getByText("Personal")).toBeDefined();
expect(screen.getAllByTestId("library-folder")).toHaveLength(2);
});
test("shows See runs link on agent card", async () => {
test("shows See tasks link on agent card", async () => {
setupHandlers({
agents: [makeAgent({ name: "Linked Agent", can_access_graph: true })],
});
@@ -172,7 +178,7 @@ describe("LibraryPage", () => {
await screen.findByText("Linked Agent");
const runLinks = screen.getAllByText("See runs");
const runLinks = screen.getAllByText("See tasks");
expect(runLinks.length).toBeGreaterThan(0);
});
@@ -190,7 +196,7 @@ describe("LibraryPage", () => {
expect(importButtons.length).toBeGreaterThan(0);
});
test("renders Jump Back In when there is an active execution", async () => {
test("renders running agent card when execution is active", async () => {
const agent = makeAgent({
id: "lib-1",
graph_id: "g-1",
@@ -218,6 +224,6 @@ describe("LibraryPage", () => {
render(<LibraryPage />);
expect(await screen.findByText("Jump Back In")).toBeDefined();
expect(await screen.findByText("Running Agent")).toBeDefined();
});
});

View File

@@ -0,0 +1,44 @@
.glassPanel {
position: relative;
isolation: isolate;
}
.glassPanel::before {
content: "";
position: absolute;
inset: 0;
border-radius: inherit;
padding: 1px;
background: conic-gradient(
from var(--border-angle, 0deg),
rgba(129, 120, 228, 0.04),
rgba(129, 120, 228, 0.14),
rgba(168, 130, 255, 0.09),
rgba(129, 120, 228, 0.04),
rgba(99, 102, 241, 0.12),
rgba(129, 120, 228, 0.04)
);
-webkit-mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
-webkit-mask-composite: xor;
mask-composite: exclude;
animation: rotate-border 6s linear infinite;
pointer-events: none;
z-index: -1;
}
@property --border-angle {
syntax: "<angle>";
initial-value: 0deg;
inherits: false;
}
@keyframes rotate-border {
to {
--border-angle: 360deg;
}
}

View File

@@ -0,0 +1,36 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useState } from "react";
import type { FleetSummary, AgentStatusFilter } from "../../types";
import { BriefingTabContent } from "./BriefingTabContent";
import { StatsGrid } from "./StatsGrid";
import styles from "./AgentBriefingPanel.module.css";
interface Props {
summary: FleetSummary;
agents: LibraryAgent[];
}
export function AgentBriefingPanel({ summary, agents }: Props) {
const [userTab, setUserTab] = useState<AgentStatusFilter | null>(null);
const activeTab: AgentStatusFilter =
userTab ?? (summary.running > 0 ? "running" : "all");
return (
<div
className={`${styles.glassPanel} min-h-[14.75rem] rounded-large bg-gradient-to-br from-indigo-50/30 via-white/90 to-purple-50/25 px-5 pb-5 pt-[1.125rem] shadow-sm backdrop-blur-md`}
>
<Text variant="h5">Agent Briefing</Text>
<div className="mt-4 space-y-5">
<StatsGrid
summary={summary}
activeTab={activeTab}
onTabChange={setUserTab}
/>
<BriefingTabContent activeTab={activeTab} agents={agents} />
</div>
</div>
);
}

View File

@@ -0,0 +1,347 @@
"use client";
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
import {
formatResetTime,
formatCents,
} from "@/app/(platform)/copilot/components/usageHelpers";
import { useResetRateLimit } from "@/app/(platform)/copilot/hooks/useResetRateLimit";
import { Button } from "@/components/atoms/Button/Button";
import { Badge } from "@/components/atoms/Badge/Badge";
import useCredits from "@/hooks/useCredits";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useSitrepItems } from "../SitrepItem/useSitrepItems";
import { SitrepItem } from "../SitrepItem/SitrepItem";
import { useAgentStatusMap } from "../../hooks/useAgentStatus";
import type { AgentStatusFilter } from "../../types";
import { Text } from "@/components/atoms/Text/Text";
import Link from "next/link";
import { useState } from "react";
interface Props {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}
export function BriefingTabContent({ activeTab, agents }: Props) {
if (activeTab === "all") {
return <UsageSection />;
}
if (
activeTab === "running" ||
activeTab === "attention" ||
activeTab === "completed"
) {
return <ExecutionListSection activeTab={activeTab} agents={agents} />;
}
return <AgentListSection activeTab={activeTab} agents={agents} />;
}
function UsageSection() {
const { data: usage } = useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
const isBillingEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
const { credits, fetchCredits } = useCredits({ fetchInitialCredits: true });
const resetCost = usage?.reset_cost;
const hasInsufficientCredits =
credits !== null && resetCost != null && credits < resetCost;
if (!usage?.daily || !usage?.weekly) return null;
return (
<div className="py-2">
<div className="flex items-center gap-2">
<Text variant="h5" className="text-neutral-800">
Usage limits
</Text>
{usage.tier && (
<Badge variant="info" size="small" className="bg-[rgb(224,237,255)]">
{usage.tier.charAt(0) + usage.tier.slice(1).toLowerCase()} plan
</Badge>
)}
<div className="flex-1" />
{isBillingEnabled && (
<Link
href="/profile/credits"
className="text-sm text-blue-600 hover:underline"
>
Manage billing
</Link>
)}
</div>
<div className="mt-4 grid grid-cols-1 gap-6 sm:grid-cols-2">
{usage.daily.limit > 0 && (
<UsageMeter
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{usage.weekly.limit > 0 && (
<UsageMeter
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
</div>
<UsageFooter
usage={usage}
hasInsufficientCredits={hasInsufficientCredits}
onCreditChange={fetchCredits}
/>
</div>
);
}
const MAX_VISIBLE = 6;
function ExecutionListSection({
activeTab,
agents,
}: {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}) {
const allItems = useSitrepItems(agents, 50);
const [showAll, setShowAll] = useState(false);
const filtered = allItems.filter((item) => {
if (activeTab === "running") return item.priority === "running";
if (activeTab === "attention") return item.priority === "error";
if (activeTab === "completed") return item.priority === "success";
return false;
});
if (filtered.length === 0) {
return <EmptyMessage tab={activeTab} />;
}
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
const hasMore = filtered.length > MAX_VISIBLE;
return (
<div>
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
{visible.map((item) => (
<SitrepItem key={item.id} item={item} />
))}
</div>
{hasMore && (
<div className="mt-3 flex justify-center">
<Button
variant="secondary"
size="small"
onClick={() => setShowAll(!showAll)}
>
{showAll ? "Collapse" : `Show all (${filtered.length})`}
</Button>
</div>
)}
</div>
);
}
const TAB_STATUS_LABEL: Record<string, string> = {
listening: "Waiting for trigger event",
scheduled: "Has a scheduled run",
idle: "No recent activity",
};
function AgentListSection({
activeTab,
agents,
}: {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}) {
const [showAll, setShowAll] = useState(false);
const statusMap = useAgentStatusMap(agents);
const filtered = agents.filter((agent) => {
const status = statusMap.get(agent.graph_id)?.status;
if (activeTab === "listening") return status === "listening";
if (activeTab === "scheduled") return status === "scheduled";
if (activeTab === "idle") return status === "idle";
return false;
});
if (filtered.length === 0) {
return <EmptyMessage tab={activeTab} />;
}
const status =
activeTab === "listening"
? ("listening" as const)
: activeTab === "scheduled"
? ("scheduled" as const)
: ("idle" as const);
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
const hasMore = filtered.length > MAX_VISIBLE;
return (
<div>
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
{visible.map((agent) => (
<SitrepItem
key={agent.id}
item={{
id: agent.id,
agentID: agent.id,
agentName: agent.name,
agentImageUrl: agent.image_url,
priority: status,
message: TAB_STATUS_LABEL[activeTab] ?? "",
status,
}}
/>
))}
</div>
{hasMore && (
<div className="mt-3 flex justify-center">
<Button
variant="secondary"
size="small"
onClick={() => setShowAll(!showAll)}
>
{showAll ? "Collapse" : `Show all (${filtered.length})`}
</Button>
</div>
)}
</div>
);
}
function UsageFooter({
usage,
hasInsufficientCredits,
onCreditChange,
}: {
usage: CoPilotUsageStatus;
hasInsufficientCredits: boolean;
onCreditChange?: () => void;
}) {
const isDailyExhausted =
usage.daily.limit > 0 && usage.daily.used >= usage.daily.limit;
const isWeeklyExhausted =
usage.weekly.limit > 0 && usage.weekly.used >= usage.weekly.limit;
const resetCost = usage.reset_cost ?? 0;
const { resetUsage, isPending } = useResetRateLimit({ onCreditChange });
const showReset =
isDailyExhausted &&
!isWeeklyExhausted &&
resetCost > 0 &&
!hasInsufficientCredits;
const showAddCredits =
isDailyExhausted && !isWeeklyExhausted && hasInsufficientCredits;
if (!showReset && !showAddCredits) return null;
return (
<div className="mt-4 flex items-center gap-3">
{showReset && (
<Button
variant="primary"
size="small"
onClick={() => resetUsage()}
loading={isPending}
>
{isPending
? "Resetting..."
: `Reset daily limit for ${formatCents(resetCost)}`}
</Button>
)}
{showAddCredits && (
<Link
href="/profile/credits"
className="inline-flex items-center justify-center rounded-md bg-primary px-3 py-1.5 text-sm font-medium text-primary-foreground hover:bg-primary/90"
>
Add credits to reset
</Link>
)}
</div>
);
}
function UsageMeter({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-2">
<div className="flex items-baseline justify-between">
<Text variant="body-medium" className="text-neutral-700">
{label}
</Text>
<Text variant="body" className="tabular-nums text-neutral-500">
{percentLabel}
</Text>
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
<div className="flex items-baseline justify-between">
<Text variant="small" className="tabular-nums text-neutral-500">
{used.toLocaleString()} / {limit.toLocaleString()}
</Text>
<Text variant="small" className="text-neutral-400">
Resets {formatResetTime(resetsAt)}
</Text>
</div>
</div>
);
}
const EMPTY_MESSAGES: Record<string, string> = {
running: "No agents running right now",
attention: "No agents that need attention",
completed: "No recently completed runs",
listening: "No agents listening for events",
scheduled: "No agents with scheduled runs",
idle: "No idle agents",
};
function EmptyMessage({ tab }: { tab: AgentStatusFilter }) {
return (
<div className="flex items-center justify-center pt-4">
<Text variant="body-medium" className="text-zinc-600">
{EMPTY_MESSAGES[tab] ?? "No agents in this category"}
</Text>
</div>
);
}

View File

@@ -0,0 +1,102 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { OverflowText } from "@/components/atoms/OverflowText/OverflowText";
import { Emoji } from "@/components/atoms/Emoji/Emoji";
import { cn } from "@/lib/utils";
import type { FleetSummary, AgentStatusFilter } from "../../types";
interface Props {
summary: FleetSummary;
activeTab: AgentStatusFilter;
onTabChange: (tab: AgentStatusFilter) => void;
}
const TILES: {
label: string;
key: keyof FleetSummary;
format?: (v: number) => string;
filter: AgentStatusFilter;
emoji: string;
color: string;
}[] = [
{
label: "Spent this month",
key: "monthlySpend",
format: (v) => `$${v.toLocaleString()}`,
filter: "all",
emoji: "💵",
color: "text-zinc-700",
},
{
label: "Running now",
key: "running",
filter: "running",
emoji: "🚩",
color: "text-blue-600",
},
{
label: "Recently completed",
key: "completed",
filter: "completed",
emoji: "🗃️",
color: "text-green-600",
},
{
label: "Needs attention",
key: "error",
filter: "attention",
emoji: "⚠️",
color: "text-red-500",
},
{
label: "Scheduled",
key: "scheduled",
filter: "scheduled",
emoji: "📅",
color: "text-yellow-600",
},
{
label: "Idle",
key: "idle",
filter: "idle",
emoji: "💤",
color: "text-zinc-400",
},
];
export function StatsGrid({ summary, activeTab, onTabChange }: Props) {
return (
<div className="grid grid-cols-1 gap-3 min-[450px]:grid-cols-2 sm:grid-cols-3 lg:grid-cols-6">
{TILES.map((tile) => {
const rawValue = summary[tile.key];
const value = tile.format ? tile.format(rawValue) : rawValue;
const isActive = activeTab === tile.filter;
return (
<button
key={tile.label}
type="button"
onClick={() => onTabChange(tile.filter)}
className={cn(
"flex min-w-0 flex-col gap-1 rounded-medium border p-3 text-left shadow-md transition-all hover:shadow-lg",
isActive
? "border-zinc-900 bg-zinc-50"
: "border-zinc-100 bg-white",
)}
>
<div className="flex min-w-0 items-center gap-1.5">
<Emoji text={tile.emoji} size={18} />
<OverflowText
value={tile.label}
variant="body"
className="text-zinc-800"
/>
</div>
<Text variant="h4">{value}</Text>
</button>
);
})}
</div>
);
}

View File

@@ -0,0 +1,52 @@
"use client";
import type { SelectOption } from "@/components/atoms/Select/Select";
import { Select } from "@/components/atoms/Select/Select";
import { FunnelIcon } from "@phosphor-icons/react";
import type { AgentStatusFilter, FleetSummary } from "../../types";
interface Props {
value: AgentStatusFilter;
onChange: (value: AgentStatusFilter) => void;
summary: FleetSummary;
}
function buildOptions(summary: FleetSummary): SelectOption[] {
return [
{ value: "all", label: "All Agents" },
{ value: "running", label: `Running (${summary.running})` },
{ value: "attention", label: `Needs Attention (${summary.error})` },
{ value: "listening", label: `Listening (${summary.listening})` },
{ value: "scheduled", label: `Scheduled (${summary.scheduled})` },
{ value: "idle", label: `Idle / Stale (${summary.idle})` },
{ value: "healthy", label: "Healthy" },
];
}
export function AgentFilterMenu({ value, onChange, summary }: Props) {
function handleChange(val: string) {
onChange(val as AgentStatusFilter);
}
const options = buildOptions(summary);
return (
<div className="flex items-center" data-testid="agent-filter-dropdown">
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
filter
</span>
<FunnelIcon className="ml-1 h-4 w-4 sm:hidden" />
<Select
id="agent-status-filter"
label="Filter agents"
hideLabel
value={value}
onValueChange={handleChange}
options={options}
size="small"
className="ml-1 w-fit border-none !bg-transparent text-sm underline underline-offset-4 shadow-none"
wrapperClassName="mb-0"
/>
</div>
);
}

View File

@@ -0,0 +1,68 @@
"use client";
import {
EyeIcon,
ArrowsClockwiseIcon,
MonitorPlayIcon,
PlayIcon,
ArrowCounterClockwiseIcon,
} from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
import { useRouter } from "next/navigation";
import type { AgentStatus } from "../../types";
interface Props {
status: AgentStatus;
agentID: string;
executionID?: string;
className?: string;
}
export function ContextualActionButton({
status,
agentID,
executionID,
className,
}: Props) {
const router = useRouter();
const config = ACTION_CONFIG[status];
if (!config) return null;
const Icon = config.icon;
function handleClick(e: React.MouseEvent) {
e.preventDefault();
e.stopPropagation();
const params = new URLSearchParams();
if (executionID) params.set("activeItem", executionID);
const query = params.toString();
router.push(`/library/agents/${agentID}${query ? `?${query}` : ""}`);
}
return (
<button
type="button"
onClick={handleClick}
className={cn(
"inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800",
className,
)}
>
<Icon size={12} className="shrink-0" />
{config.label}
</button>
);
}
const ACTION_CONFIG: Record<
AgentStatus,
{ label: string; icon: typeof EyeIcon }
> = {
error: { label: "View error", icon: EyeIcon },
listening: { label: "Reconnect", icon: ArrowsClockwiseIcon },
running: { label: "Watch live", icon: MonitorPlayIcon },
idle: { label: "Start", icon: PlayIcon },
scheduled: { label: "Start", icon: ArrowCounterClockwiseIcon },
};

View File

@@ -1,46 +0,0 @@
"use client";
import { ArrowRight, Lightning } from "@phosphor-icons/react";
import NextLink from "next/link";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { useJumpBackIn } from "./useJumpBackIn";
export function JumpBackIn() {
const { execution, isLoading } = useJumpBackIn();
if (isLoading || !execution) {
return null;
}
const href = execution.libraryAgentId
? `/library/agents/${execution.libraryAgentId}?activeTab=runs&activeItem=${execution.id}`
: "#";
return (
<div className="rounded-large bg-gradient-to-r from-zinc-200 via-zinc-200/60 to-indigo-200/50 p-[1px]">
<div className="flex items-center justify-between rounded-large bg-[#F6F7F8] px-5 py-4">
<div className="flex items-center gap-3">
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
<Lightning size={18} weight="fill" className="text-white" />
</div>
<div className="flex flex-col">
<Text variant="small" className="text-zinc-500">
{execution.statusLabel} · {execution.duration}
</Text>
<Text variant="body-medium" className="text-zinc-900">
{execution.agentName}
</Text>
</div>
</div>
<NextLink href={href}>
<Button variant="secondary" size="small" className="gap-1.5">
Jump Back In
<ArrowRight size={16} />
</Button>
</NextLink>
</div>
</div>
);
}

View File

@@ -1,82 +0,0 @@
"use client";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { okData } from "@/app/api/helpers";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
import { useMemo } from "react";
function isActive(status: AgentExecutionStatus) {
return (
status === AgentExecutionStatus.RUNNING ||
status === AgentExecutionStatus.QUEUED ||
status === AgentExecutionStatus.REVIEW
);
}
function formatDuration(startedAt: Date | string | null | undefined): string {
if (!startedAt) return "";
const start = new Date(startedAt);
if (isNaN(start.getTime())) return "";
const ms = Date.now() - start.getTime();
if (ms < 0) return "";
const sec = Math.floor(ms / 1000);
if (sec < 5) return "a few seconds";
if (sec < 60) return `${sec}s`;
const min = Math.floor(sec / 60);
if (min < 60) return `${min}m ${sec % 60}s`;
const hr = Math.floor(min / 60);
return `${hr}h ${min % 60}m`;
}
function getStatusLabel(status: AgentExecutionStatus) {
if (status === AgentExecutionStatus.RUNNING) return "Running";
if (status === AgentExecutionStatus.QUEUED) return "Queued";
if (status === AgentExecutionStatus.REVIEW) return "Awaiting approval";
return "";
}
export function useJumpBackIn() {
const { data: executions, isLoading: executionsLoading } =
useGetV1ListAllExecutions({
query: { select: okData },
});
const { agentInfoMap, isRefreshing: agentsLoading } = useLibraryAgents();
const activeExecution = useMemo(() => {
if (!executions) return null;
const active = executions
.filter((e) => isActive(e.status))
.sort((a, b) => {
const aTime = a.started_at ? new Date(a.started_at).getTime() : 0;
const bTime = b.started_at ? new Date(b.started_at).getTime() : 0;
return bTime - aTime;
});
return active[0] ?? null;
}, [executions]);
const enriched = useMemo(() => {
if (!activeExecution) return null;
const info = agentInfoMap.get(activeExecution.graph_id);
return {
id: activeExecution.id,
agentName: info?.name ?? "Unknown Agent",
libraryAgentId: info?.library_agent_id,
status: activeExecution.status,
statusLabel: getStatusLabel(activeExecution.status),
duration: formatDuration(activeExecution.started_at),
};
}, [activeExecution, agentInfoMap]);
return {
execution: enriched,
isLoading: executionsLoading || agentsLoading,
};
}

View File

@@ -8,7 +8,7 @@ interface Props {
export function LibraryActionHeader({ setSearchTerm }: Props) {
return (
<>
<div className="mb-[32px] hidden items-center justify-center gap-4 md:flex">
<div className="mb-7 hidden items-center justify-center gap-4 md:flex">
<LibrarySearchBar setSearchTerm={setSearchTerm} />
<LibraryImportDialog />
</div>

View File

@@ -1,29 +1,40 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { CaretCircleRightIcon } from "@phosphor-icons/react";
import { EyeIcon, ChatCircleDotsIcon } from "@phosphor-icons/react";
import Image from "next/image";
import NextLink from "next/link";
import { useRouter } from "next/navigation";
import { motion } from "framer-motion";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import Avatar, {
AvatarFallback,
AvatarImage,
} from "@/components/atoms/Avatar/Avatar";
import { Link } from "@/components/atoms/Link/Link";
import { cn } from "@/lib/utils";
import { AgentCardMenu } from "./components/AgentCardMenu";
import { FavoriteButton } from "./components/FavoriteButton";
import { useLibraryAgentCard } from "./useLibraryAgentCard";
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
import { StatusBadge } from "../StatusBadge/StatusBadge";
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
import type { AgentStatusInfo } from "../../types";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
interface Props {
agent: LibraryAgent;
statusInfo: AgentStatusInfo;
draggable?: boolean;
}
export function LibraryAgentCard({ agent, draggable = true }: Props) {
const { id, name, graph_id, can_access_graph, image_url } = agent;
export function LibraryAgentCard({
agent,
statusInfo,
draggable = true,
}: Props) {
const { id, name, image_url } = agent;
const router = useRouter();
const { triggerFavoriteAnimation } = useFavoriteAnimation();
function handleDragStart(e: React.DragEvent<HTMLDivElement>) {
@@ -31,18 +42,14 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
e.dataTransfer.effectAllowed = "move";
}
const {
isFromMarketplace,
isFavorite,
profile,
creator_image_url,
handleToggleFavorite,
} = useLibraryAgentCard({
const { isFavorite, handleToggleFavorite } = useLibraryAgentCard({
agent,
onFavoriteAdd: triggerFavoriteAnimation,
});
return (
const hasError = statusInfo.status === "error";
const card = (
<div
draggable={draggable}
onDragStart={handleDragStart}
@@ -52,7 +59,10 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
layoutId={`agent-card-${id}`}
data-testid="library-agent-card"
data-agent-id={id}
className="group relative inline-flex h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border border-zinc-100 bg-white hover:shadow-md"
className={cn(
"group relative inline-flex h-auto min-h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border bg-white hover:shadow-md",
hasError ? "border-red-400" : "border-zinc-100",
)}
transition={{
type: "spring",
damping: 25,
@@ -61,23 +71,10 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
style={{ willChange: "transform" }}
>
<NextLink href={`/library/agents/${id}`} className="flex-shrink-0">
<div className="relative flex items-center gap-2 px-4 pt-3">
<Avatar className="h-4 w-4 rounded-full">
<AvatarImage
src={
isFromMarketplace
? creator_image_url || "/avatar-placeholder.png"
: profile?.avatar_url || "/avatar-placeholder.png"
}
alt={`${name} creator avatar`}
/>
<AvatarFallback size={48}>{name.charAt(0)}</AvatarFallback>
</Avatar>
<Text
variant="small-medium"
className="uppercase tracking-wide text-zinc-400"
>
{isFromMarketplace ? "FROM MARKETPLACE" : "Built by you"}
<div className="relative flex items-center gap-3 pl-2 pr-4 pt-3">
<StatusBadge status={statusInfo.status} />
<Text variant="small" className="text-zinc-400">
{statusInfo.totalRuns} tasks
</Text>
</div>
</NextLink>
@@ -89,7 +86,7 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
<AgentCardMenu agent={agent} />
<div className="flex w-full flex-1 flex-col px-4 pb-2">
<Link
<NextLink
href={`/library/agents/${id}`}
className="flex w-full items-start justify-between gap-2 no-underline hover:no-underline focus:ring-0"
>
@@ -126,30 +123,52 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
className="flex-shrink-0 rounded-small object-cover"
/>
)}
</Link>
</NextLink>
<div className="mt-auto flex w-full justify-start gap-6 border-t border-zinc-100 pb-1 pt-3">
<Link
href={`/library/agents/${id}`}
<div className="mt-4 flex w-full items-center justify-end gap-1 border-t border-zinc-100 pb-0 pt-2">
<button
type="button"
onClick={() => router.push(`/library/agents/${id}`)}
data-testid="library-agent-card-see-runs-link"
className="flex items-center gap-1 text-[13px]"
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
See runs <CaretCircleRightIcon size={20} />
</Link>
{can_access_graph && (
<Link
href={`/build?flowID=${graph_id}`}
data-testid="library-agent-card-open-in-builder-link"
className="flex items-center gap-1 text-[13px]"
isExternal
>
Open in builder <CaretCircleRightIcon size={20} />
</Link>
)}
<EyeIcon size={14} className="shrink-0" />
See tasks
</button>
<ContextualActionButton
status={statusInfo.status}
agentID={id}
executionID={statusInfo.activeExecutionID ?? undefined}
/>
<button
type="button"
onClick={() => {
const prompt = encodeURIComponent(
`Tell me about ${name}, its current status, recent runs and how can I get the most out of it`,
);
router.push(`/copilot?autosubmit=true#prompt=${prompt}`);
}}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<ChatCircleDotsIcon size={14} className="shrink-0" />
Chat
</button>
</div>
</div>
</motion.div>
</div>
);
if (hasError && statusInfo.lastError) {
return (
<Tooltip>
<TooltipTrigger asChild>{card}</TooltipTrigger>
<TooltipContent className="max-w-xs text-red-600">
{statusInfo.lastError}
</TooltipContent>
</Tooltip>
);
}
return card;
}

View File

@@ -169,6 +169,7 @@ export function AgentCardMenu({ agent }: AgentCardMenuProps) {
href={`/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`}
target="_blank"
className="flex items-center gap-2"
data-testid="library-agent-card-open-in-builder-link"
onClick={(e) => e.stopPropagation()}
>
Edit agent

View File

@@ -1,6 +1,7 @@
"use client";
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
import { LibraryAgentCard } from "../LibraryAgentCard/LibraryAgentCard";
@@ -16,8 +17,11 @@ import {
} from "framer-motion";
import { LibraryFolderEditDialog } from "../LibraryFolderEditDialog/LibraryFolderEditDialog";
import { LibraryFolderDeleteDialog } from "../LibraryFolderDeleteDialog/LibraryFolderDeleteDialog";
import { LibraryTab } from "../../types";
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
import { useLibraryAgentList } from "./useLibraryAgentList";
import { AgentBriefingPanel } from "../AgentBriefingPanel/AgentBriefingPanel";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useAgentStatusMap, getAgentStatus } from "../../hooks/useAgentStatus";
// cancels the current spring and starts a new one from current state.
const containerVariants = {
@@ -70,6 +74,10 @@ interface Props {
tabs: LibraryTab[];
activeTab: string;
onTabChange: (tabId: string) => void;
statusFilter?: AgentStatusFilter;
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
fleetSummary?: FleetSummary;
briefingAgents?: LibraryAgent[];
}
export function LibraryAgentList({
@@ -81,7 +89,12 @@ export function LibraryAgentList({
tabs,
activeTab,
onTabChange,
statusFilter = "all",
onStatusFilterChange,
fleetSummary,
briefingAgents,
}: Props) {
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const shouldReduceMotion = useReducedMotion();
const activeContainerVariants = shouldReduceMotion
? reducedContainerVariants
@@ -95,7 +108,7 @@ export function LibraryAgentList({
const {
isFavoritesTab,
agentLoading,
allAgentsCount,
displayedCount,
favoritesCount,
agents,
hasNextPage,
@@ -116,18 +129,37 @@ export function LibraryAgentList({
selectedFolderId,
onFolderSelect,
activeTab,
statusFilter,
});
const agentStatusMap = useAgentStatusMap(agents);
return (
<>
{isAgentBriefingEnabled &&
!selectedFolderId &&
fleetSummary &&
briefingAgents &&
briefingAgents.length > 0 && (
<div className="mb-4">
<AgentBriefingPanel
summary={fleetSummary}
agents={briefingAgents}
/>
</div>
)}
{!selectedFolderId && (
<LibrarySubSection
tabs={tabs}
activeTab={activeTab}
onTabChange={onTabChange}
allCount={allAgentsCount}
allCount={displayedCount}
favoritesCount={favoritesCount}
setLibrarySort={setLibrarySort}
statusFilter={statusFilter}
onStatusFilterChange={onStatusFilterChange}
fleetSummary={fleetSummary}
/>
)}
@@ -219,7 +251,13 @@ export function LibraryAgentList({
0.04,
}}
>
<LibraryAgentCard agent={agent} />
<LibraryAgentCard
agent={agent}
statusInfo={getAgentStatus(
agentStatusMap,
agent.graph_id,
)}
/>
</motion.div>
))}
</motion.div>

View File

@@ -21,7 +21,12 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
import { useFavoriteAgents } from "../../hooks/useFavoriteAgents";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef, useState } from "react";
import { useEffect, useMemo, useRef, useState } from "react";
import type { AgentStatusFilter } from "../../types";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
const FILTER_EXHAUST_THRESHOLD = 3;
interface Props {
searchTerm: string;
@@ -29,6 +34,7 @@ interface Props {
selectedFolderId: string | null;
onFolderSelect: (folderId: string | null) => void;
activeTab: string;
statusFilter?: AgentStatusFilter;
}
export function useLibraryAgentList({
@@ -37,12 +43,16 @@ export function useLibraryAgentList({
selectedFolderId,
onFolderSelect,
activeTab,
statusFilter = "all",
}: Props) {
const isFavoritesTab = activeTab === "favorites";
const { toast } = useToast();
const stableQueryClient = getQueryClient();
const queryClient = useQueryClient();
const prevSortRef = useRef<LibraryAgentSort | null>(null);
const [consecutiveEmptyPages, setConsecutiveEmptyPages] = useState(0);
const prevFilteredLengthRef = useRef(0);
const prevAgentsLengthRef = useRef(0);
const [editingFolder, setEditingFolder] = useState<LibraryFolder | null>(
null,
@@ -199,6 +209,90 @@ export function useLibraryAgentList({
const showFolders = !isFavoritesTab;
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
const { activeGraphIds, errorGraphIds, completedGraphIds } = useMemo(() => {
const active = new Set<string>();
const errors = new Set<string>();
const completed = new Set<string>();
const cutoff = Date.now() - 72 * 60 * 60 * 1000;
for (const exec of executions ?? []) {
if (
exec.status === AgentExecutionStatus.RUNNING ||
exec.status === AgentExecutionStatus.QUEUED ||
exec.status === AgentExecutionStatus.REVIEW
) {
active.add(exec.graph_id);
}
const endedTs = exec.ended_at
? exec.ended_at instanceof Date
? exec.ended_at.getTime()
: new Date(String(exec.ended_at)).getTime()
: 0;
if (
(exec.status === AgentExecutionStatus.FAILED ||
exec.status === AgentExecutionStatus.TERMINATED) &&
endedTs > cutoff
) {
errors.add(exec.graph_id);
}
if (exec.status === AgentExecutionStatus.COMPLETED && endedTs > cutoff) {
completed.add(exec.graph_id);
}
}
return {
activeGraphIds: active,
errorGraphIds: errors,
completedGraphIds: completed,
};
}, [executions]);
const filteredAgents = filterAgentsByStatus(
agents,
statusFilter,
activeGraphIds,
errorGraphIds,
completedGraphIds,
);
useEffect(() => {
if (statusFilter === "all") {
setConsecutiveEmptyPages(0);
prevFilteredLengthRef.current = filteredAgents.length;
prevAgentsLengthRef.current = agents.length;
return;
}
if (agents.length > prevAgentsLengthRef.current) {
const newFilteredCount = filteredAgents.length;
const previousCount = prevFilteredLengthRef.current;
if (newFilteredCount > previousCount) {
setConsecutiveEmptyPages(0);
} else {
setConsecutiveEmptyPages((prev) => prev + 1);
}
}
prevAgentsLengthRef.current = agents.length;
prevFilteredLengthRef.current = filteredAgents.length;
}, [agents.length, filteredAgents.length, statusFilter]);
useEffect(() => {
setConsecutiveEmptyPages(0);
prevFilteredLengthRef.current = 0;
prevAgentsLengthRef.current = 0;
}, [statusFilter]);
const filteredExhausted =
statusFilter !== "all" && consecutiveEmptyPages >= FILTER_EXHAUST_THRESHOLD;
// When a filter is active, show the filtered count instead of the API total.
const displayedCount =
statusFilter === "all" ? allAgentsCount : filteredAgents.length;
function handleFolderDeleted() {
if (selectedFolderId === deletingFolder?.id) {
onFolderSelect(null);
@@ -210,9 +304,10 @@ export function useLibraryAgentList({
agentLoading,
agentCount,
allAgentsCount,
displayedCount,
favoritesCount: favoriteAgentsData.agentCount,
agents,
hasNextPage: agentsHasNextPage,
agents: filteredAgents,
hasNextPage: agentsHasNextPage && !filteredExhausted,
isFetchingNextPage: agentsIsFetchingNextPage,
fetchNextPage: agentsFetchNextPage,
foldersData,
@@ -226,3 +321,46 @@ export function useLibraryAgentList({
handleFolderDeleted,
};
}
function filterAgentsByStatus<
T extends {
graph_id: string;
has_external_trigger: boolean;
recommended_schedule_cron?: string | null;
},
>(
agents: T[],
statusFilter: AgentStatusFilter,
activeGraphIds: Set<string>,
errorGraphIds: Set<string>,
completedGraphIds: Set<string>,
): T[] {
if (statusFilter === "all") return agents;
return agents.filter((agent) => {
const isRunning = activeGraphIds.has(agent.graph_id);
const hasError = errorGraphIds.has(agent.graph_id);
if (statusFilter === "running") return isRunning;
if (statusFilter === "attention") return hasError && !isRunning;
if (statusFilter === "completed")
return completedGraphIds.has(agent.graph_id);
if (statusFilter === "listening")
return !isRunning && !hasError && agent.has_external_trigger;
if (statusFilter === "scheduled")
return (
!isRunning &&
!hasError &&
!agent.has_external_trigger &&
!!agent.recommended_schedule_cron
);
if (statusFilter === "idle")
return (
!isRunning &&
!hasError &&
!agent.has_external_trigger &&
!agent.recommended_schedule_cron
);
if (statusFilter === "healthy") return !hasError;
return true;
});
}

View File

@@ -2,14 +2,11 @@
import { Text } from "@/components/atoms/Text/Text";
import { Button } from "@/components/atoms/Button/Button";
import {
FolderIcon,
FolderColor,
folderCardStyles,
resolveColor,
} from "./FolderIcon";
import { FolderIcon, FolderColor } from "./FolderIcon";
import { useState } from "react";
import { PencilSimpleIcon, TrashIcon } from "@phosphor-icons/react";
import type { AgentStatus } from "../../types";
import { StatusBadge } from "../StatusBadge/StatusBadge";
interface Props {
id: string;
@@ -21,6 +18,8 @@ interface Props {
onDelete?: () => void;
onAgentDrop?: (agentId: string, folderId: string) => void;
onClick?: () => void;
/** Worst status among child agents (optional, for status aggregation). */
worstStatus?: AgentStatus;
}
export function LibraryFolder({
@@ -33,11 +32,10 @@ export function LibraryFolder({
onDelete,
onAgentDrop,
onClick,
worstStatus,
}: Props) {
const [isHovered, setIsHovered] = useState(false);
const [isDragOver, setIsDragOver] = useState(false);
const resolvedColor = resolveColor(color);
const cardStyle = folderCardStyles[resolvedColor];
function handleDragOver(e: React.DragEvent<HTMLDivElement>) {
if (e.dataTransfer.types.includes("application/agent-id")) {
@@ -64,10 +62,10 @@ export function LibraryFolder({
<div
data-testid="library-folder"
data-folder-id={id}
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 transition-all duration-200 hover:shadow-md ${
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 shadow-sm backdrop-blur-md transition-all duration-200 hover:shadow-md ${
isDragOver
? "border-blue-400 bg-blue-50 ring-2 ring-blue-200"
: `${cardStyle.border} ${cardStyle.bg}`
: "border-indigo-200/40 bg-gradient-to-br from-indigo-50/40 via-white/70 to-purple-50/30"
}`}
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
@@ -76,7 +74,7 @@ export function LibraryFolder({
onDrop={handleDrop}
onClick={onClick}
>
<div className="flex w-full items-start justify-between gap-4">
<div className="flex w-full items-center justify-between gap-4">
{/* Left side - Folder name and agent count */}
<div className="flex flex-1 flex-col gap-2">
<Text
@@ -86,17 +84,22 @@ export function LibraryFolder({
>
{name}
</Text>
<Text
variant="small"
className="text-zinc-500"
data-testid="library-folder-agent-count"
>
{agentCount} {agentCount === 1 ? "agent" : "agents"}
</Text>
<div className="flex items-center gap-2">
<Text
variant="small"
className="text-zinc-500"
data-testid="library-folder-agent-count"
>
{agentCount} {agentCount === 1 ? "agent" : "agents"}
</Text>
{worstStatus && worstStatus !== "idle" && (
<StatusBadge status={worstStatus} />
)}
</div>
</div>
{/* Right side - Custom folder icon */}
<div className="flex-shrink-0">
<div className="relative top-5 flex flex-shrink-0 items-center">
<FolderIcon isOpen={isHovered} color={color} icon={icon} />
</div>
</div>
@@ -114,7 +117,7 @@ export function LibraryFolder({
e.stopPropagation();
onEdit?.();
}}
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
>
<PencilSimpleIcon className="h-4 w-4" />
</Button>
@@ -126,7 +129,7 @@ export function LibraryFolder({
e.stopPropagation();
onDelete?.();
}}
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
>
<TrashIcon className="h-4 w-4" />
</Button>

View File

@@ -19,11 +19,11 @@ export function LibrarySortMenu({ setLibrarySort }: Props) {
const { handleSortChange } = useLibrarySortMenu({ setLibrarySort });
return (
<div className="flex items-center" data-testid="sort-by-dropdown">
<span className="hidden whitespace-nowrap text-sm sm:inline">
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
sort by
</span>
<Select onValueChange={handleSortChange}>
<SelectTrigger className="ml-1 w-fit space-x-1 border-none px-0 text-sm underline underline-offset-4 shadow-none">
<SelectTrigger className="!m-0 ml-1 w-fit space-x-1 border-none !bg-transparent px-[1rem] text-sm underline underline-offset-4 !shadow-none !ring-offset-transparent">
<ArrowDownNarrowWideIcon className="h-4 w-4 sm:hidden" />
<SelectValue placeholder="Last Modified" />
</SelectTrigger>

View File

@@ -6,9 +6,10 @@ import {
} from "@/components/molecules/TabsLine/TabsLine";
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
import { LibraryTab } from "../../types";
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
import LibraryFolderCreationDialog from "../LibraryFolderCreationDialog/LibraryFolderCreationDialog";
import { LibrarySortMenu } from "../LibrarySortMenu/LibrarySortMenu";
import { AgentFilterMenu } from "../AgentFilterMenu/AgentFilterMenu";
interface Props {
tabs: LibraryTab[];
@@ -17,6 +18,9 @@ interface Props {
allCount: number;
favoritesCount: number;
setLibrarySort: (value: LibraryAgentSort) => void;
statusFilter?: AgentStatusFilter;
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
fleetSummary?: FleetSummary;
}
export function LibrarySubSection({
@@ -26,6 +30,9 @@ export function LibrarySubSection({
allCount,
favoritesCount,
setLibrarySort,
statusFilter = "all",
onStatusFilterChange,
fleetSummary,
}: Props) {
const { registerFavoritesTabRef } = useFavoriteAnimation();
const favoritesRef = useRef<HTMLButtonElement>(null);
@@ -68,8 +75,15 @@ export function LibrarySubSection({
))}
</TabsLineList>
</TabsLine>
<div className="hidden items-center gap-6 md:flex">
<div className="relative top-1.5 hidden items-center gap-6 md:flex">
<LibraryFolderCreationDialog />
{fleetSummary && onStatusFilterChange && (
<AgentFilterMenu
value={statusFilter}
onChange={onStatusFilterChange}
summary={fleetSummary}
/>
)}
<LibrarySortMenu setLibrarySort={setLibrarySort} />
</div>
</div>

View File

@@ -0,0 +1,17 @@
.spinner {
aspect-ratio: 1;
border-radius: 50%;
background:
radial-gradient(farthest-side, currentColor 94%, #0000) top/3px 3px
no-repeat,
conic-gradient(#0000 30%, currentColor);
-webkit-mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
animation: spin 1s infinite linear;
}
@keyframes spin {
100% {
transform: rotate(1turn);
}
}

View File

@@ -0,0 +1,172 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import {
WarningCircleIcon,
ClockCountdownIcon,
CheckCircleIcon,
ChatCircleDotsIcon,
EarIcon,
CalendarDotsIcon,
MoonIcon,
EyeIcon,
} from "@phosphor-icons/react";
import NextLink from "next/link";
import { cn } from "@/lib/utils";
import { useRouter } from "next/navigation";
import type { SitrepItemData, SitrepPriority } from "../../types";
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
import styles from "./SitrepItem.module.css";
interface Props {
item: SitrepItemData;
}
const PRIORITY_CONFIG: Record<
SitrepPriority,
{
icon?: typeof WarningCircleIcon;
color: string;
bg: string;
cssSpinner?: boolean;
}
> = {
error: {
icon: WarningCircleIcon,
color: "text-red-500",
bg: "bg-red-50",
},
running: {
color: "text-zinc-800",
bg: "",
cssSpinner: true,
},
stale: {
icon: ClockCountdownIcon,
color: "text-yellow-600",
bg: "bg-yellow-50",
},
success: {
icon: CheckCircleIcon,
color: "text-green-600",
bg: "bg-green-50",
},
listening: {
icon: EarIcon,
color: "text-purple-500",
bg: "bg-purple-50",
},
scheduled: {
icon: CalendarDotsIcon,
color: "text-yellow-600",
bg: "bg-yellow-50",
},
idle: {
icon: MoonIcon,
color: "text-zinc-400",
bg: "bg-zinc-100",
},
};
export function SitrepItem({ item }: Props) {
const config = PRIORITY_CONFIG[item.priority];
const router = useRouter();
function handleAskAutoPilot() {
const prompt = buildAutoPilotPrompt(item);
const encoded = encodeURIComponent(prompt);
router.push(`/copilot?autosubmit=true#prompt=${encoded}`);
}
return (
<div
className={cn(
"flex flex-col gap-2 rounded-medium border border-zinc-200/50 bg-transparent p-2 sm:flex-row sm:items-center sm:gap-3",
)}
>
<div className="flex min-w-0 flex-1 items-center gap-3">
{item.agentImageUrl ? (
<img
src={item.agentImageUrl}
alt={item.agentName}
className="h-6 w-6 flex-shrink-0 rounded-full object-cover"
/>
) : (
<div
className={cn(
"flex h-6 w-6 flex-shrink-0 items-center justify-center rounded-full",
config.bg,
)}
>
{config.cssSpinner ? (
<div
className={cn(
styles.spinner,
"h-[21px] w-[21px] text-zinc-800",
)}
/>
) : (
config.icon && (
<config.icon size={14} className={config.color} weight="fill" />
)
)}
</div>
)}
<div className="min-w-0 flex-1">
<Text variant="body-medium" className="leading-tight text-zinc-900">
{item.agentName}
</Text>
<Text variant="small" className="leading-tight text-zinc-500">
{item.message}
</Text>
</div>
</div>
<div className="flex flex-shrink-0 flex-wrap items-center justify-center gap-1.5 sm:flex-nowrap sm:justify-end">
{item.priority === "success" ? (
<NextLink
href={`/library/agents/${item.agentID}${item.executionID ? `?activeItem=${item.executionID}` : ""}`}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<EyeIcon size={14} className="shrink-0" />
See task
</NextLink>
) : (
<ContextualActionButton
status={item.status}
agentID={item.agentID}
executionID={item.executionID}
/>
)}
<button
type="button"
onClick={handleAskAutoPilot}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<ChatCircleDotsIcon size={14} className="shrink-0" />
Ask AutoPilot
</button>
</div>
</div>
);
}
function buildAutoPilotPrompt(item: SitrepItemData): string {
switch (item.priority) {
case "error":
return `What happened with ${item.agentName}? It says "${item.message}" — can you check the logs and tell me what to fix?`;
case "running":
return `Give me a status update on the ${item.agentName} run — what has it found so far?`;
case "stale":
return `${item.agentName} hasn't run recently. Should I keep it or update and re-run it?`;
case "success":
return `Show me what ${item.agentName} found in its last run — summarize the results and any key takeaways.`;
case "listening":
return `What is ${item.agentName} listening for? Give me a summary of its trigger configuration.`;
case "scheduled":
return `When is ${item.agentName} scheduled to run next?`;
case "idle":
return `${item.agentName} has been idle. Should I keep it or update and re-run it?`;
}
}

View File

@@ -0,0 +1,34 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { ClockCounterClockwise } from "@phosphor-icons/react";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useSitrepItems } from "./useSitrepItems";
import { SitrepItem } from "./SitrepItem";
interface Props {
agents: LibraryAgent[];
maxItems?: number;
}
export function SitrepList({ agents, maxItems = 10 }: Props) {
const items = useSitrepItems(agents, maxItems);
if (items.length === 0) return null;
return (
<div>
<div className="mb-2 flex items-center gap-1.5">
<ClockCounterClockwise size={16} className="text-zinc-700" />
<Text variant="body-medium" className="text-zinc-700">
Recent tasks
</Text>
</div>
<div className="grid grid-cols-1 gap-1 lg:grid-cols-2">
{items.map((item) => (
<SitrepItem key={item.id} item={item} />
))}
</div>
</div>
);
}

View File

@@ -0,0 +1,133 @@
"use client";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { useMemo } from "react";
import type { SitrepItemData, SitrepPriority } from "../../types";
import {
isActive,
isFailed,
toEndTime,
endedAfter,
runningMessage,
SEVENTY_TWO_HOURS_MS,
} from "../../hooks/executionHelpers";
export function useSitrepItems(
agents: LibraryAgent[],
maxItems: number,
): SitrepItemData[] {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
if (!executions || agents.length === 0) return [];
const graphIdToAgent = new Map(agents.map((a) => [a.graph_id, a]));
const agentExecutions = groupByAgent(executions, graphIdToAgent);
const items: SitrepItemData[] = [];
for (const [agent, execs] of agentExecutions) {
const item = buildSitrepFromExecutions(agent, execs);
if (item) items.push(item);
}
const order: Record<SitrepPriority, number> = {
error: 0,
running: 1,
stale: 2,
success: 3,
listening: 4,
scheduled: 5,
idle: 6,
};
items.sort((a, b) => order[a.priority] - order[b.priority]);
return items.slice(0, maxItems);
}, [agents, executions, maxItems]);
}
function groupByAgent(
executions: GraphExecutionMeta[],
graphIdToAgent: Map<string, LibraryAgent>,
): Map<LibraryAgent, GraphExecutionMeta[]> {
const map = new Map<LibraryAgent, GraphExecutionMeta[]>();
for (const exec of executions) {
const agent = graphIdToAgent.get(exec.graph_id);
if (!agent) continue;
const list = map.get(agent);
if (list) {
list.push(exec);
} else {
map.set(agent, [exec]);
}
}
return map;
}
function buildSitrepFromExecutions(
agent: LibraryAgent,
executions: GraphExecutionMeta[],
): SitrepItemData | null {
const active = executions.find((e) => isActive(e.status));
if (active) {
return {
id: `${agent.id}-${active.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: active.id,
priority: "running",
message:
active.stats?.activity_status ??
runningMessage(active.status, active.started_at),
status: "running",
};
}
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
const recent = executions
.filter((e) => endedAfter(e, cutoff))
.sort((a, b) => toEndTime(b) - toEndTime(a));
const lastFailed = recent.find((e) => isFailed(e.status));
if (lastFailed) {
const errorMsg =
lastFailed.stats?.error ??
lastFailed.stats?.activity_status ??
"Execution failed";
return {
id: `${agent.id}-${lastFailed.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: lastFailed.id,
priority: "error",
message: typeof errorMsg === "string" ? errorMsg : "Execution failed",
status: "error",
};
}
const lastCompleted = recent.find(
(e) => e.status === AgentExecutionStatus.COMPLETED,
);
if (lastCompleted) {
const summary =
lastCompleted.stats?.activity_status ?? "Completed successfully";
return {
id: `${agent.id}-${lastCompleted.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: lastCompleted.id,
priority: "success",
message: typeof summary === "string" ? summary : "Completed successfully",
status: "idle",
};
}
return null;
}

View File

@@ -0,0 +1,84 @@
"use client";
import { cn } from "@/lib/utils";
import type { AgentStatus } from "../../types";
const STATUS_CONFIG: Record<
AgentStatus,
{ label: string; bg: string; text: string; pulse: boolean }
> = {
running: {
label: "Running",
bg: "",
text: "text-blue-600",
pulse: true,
},
error: {
label: "Error",
bg: "",
text: "text-red-500",
pulse: false,
},
listening: {
label: "Listening",
bg: "",
text: "text-purple-500",
pulse: true,
},
scheduled: {
label: "Scheduled",
bg: "",
text: "text-yellow-600",
pulse: false,
},
idle: {
label: "Idle",
bg: "",
text: "text-zinc-500",
pulse: false,
},
};
interface Props {
status: AgentStatus;
className?: string;
}
export function StatusBadge({ status, className }: Props) {
const config = STATUS_CONFIG[status];
return (
<span
className={cn(
"inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium",
config.bg,
config.text,
className,
)}
>
<span
className={cn(
"inline-block h-1.5 w-1.5 rounded-full",
config.pulse && "animate-pulse",
statusDotColor(status),
)}
/>
{config.label}
</span>
);
}
function statusDotColor(status: AgentStatus): string {
switch (status) {
case "running":
return "bg-blue-500";
case "error":
return "bg-red-500";
case "listening":
return "bg-purple-500";
case "scheduled":
return "bg-yellow-500";
case "idle":
return "bg-zinc-400";
}
}

View File

@@ -0,0 +1,59 @@
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
export const SEVENTY_TWO_HOURS_MS = 72 * 60 * 60 * 1000;
export function isActive(status: string): boolean {
return (
status === AgentExecutionStatus.RUNNING ||
status === AgentExecutionStatus.QUEUED ||
status === AgentExecutionStatus.REVIEW
);
}
export function isFailed(status: string): boolean {
return (
status === AgentExecutionStatus.FAILED ||
status === AgentExecutionStatus.TERMINATED
);
}
export function toEndTime(exec: GraphExecutionMeta): number {
if (!exec.ended_at) return 0;
return exec.ended_at instanceof Date
? exec.ended_at.getTime()
: new Date(exec.ended_at).getTime();
}
export function endedAfter(exec: GraphExecutionMeta, cutoff: number): boolean {
if (!exec.ended_at) return false;
return toEndTime(exec) > cutoff;
}
export function runningMessage(
status: string,
startedAt?: string | Date | null,
): string {
if (status === AgentExecutionStatus.QUEUED) return "Queued for execution";
if (status === AgentExecutionStatus.REVIEW) return "Awaiting review";
if (!startedAt) return "Currently executing";
const ms =
Date.now() -
(startedAt instanceof Date
? startedAt.getTime()
: new Date(startedAt).getTime());
return `Running for ${formatRelativeDuration(ms)}`;
}
export function formatRelativeDuration(ms: number): string {
const seconds = Math.floor(ms / 1000);
if (seconds < 60) return "a few seconds";
const minutes = Math.floor(seconds / 60);
if (minutes < 60) return `${minutes}m`;
const hours = Math.floor(minutes / 60);
const remainingMin = minutes % 60;
if (hours < 24)
return remainingMin > 0 ? `${hours}h ${remainingMin}m` : `${hours}h`;
const days = Math.floor(hours / 24);
return `${days}d ${hours % 24}h`;
}

View File

@@ -0,0 +1,213 @@
"use client";
import { useMemo } from "react";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import type {
AgentStatus,
AgentHealth,
AgentStatusInfo,
FleetSummary,
} from "../types";
import {
isActive,
isFailed,
toEndTime,
SEVENTY_TWO_HOURS_MS,
} from "./executionHelpers";
function deriveHealth(
status: AgentStatus,
lastRunAt: string | null,
): AgentHealth {
if (status === "error") return "attention";
if (status === "idle" && lastRunAt) {
const daysSince =
(Date.now() - new Date(lastRunAt).getTime()) / (1000 * 60 * 60 * 24);
if (daysSince > 14) return "stale";
}
return "good";
}
function computeAgentStatus(
agent: LibraryAgent,
agentExecutions: GraphExecutionMeta[],
): AgentStatusInfo {
const activeExec = agentExecutions.find((e) => isActive(e.status));
let status: AgentStatus;
let lastError: string | null = null;
let lastRunAt: string | null = null;
const activeExecutionID = activeExec?.id ?? null;
if (activeExec) {
status = "running";
} else {
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
const recentFailed = agentExecutions.find(
(e) =>
isFailed(e.status) &&
e.ended_at &&
new Date(
e.ended_at instanceof Date ? e.ended_at.getTime() : e.ended_at,
).getTime() > cutoff,
);
if (recentFailed) {
status = "error";
lastError =
(recentFailed.stats?.error as string) ??
(recentFailed.stats?.activity_status as string) ??
"Execution failed";
} else if (agent.has_external_trigger) {
status = "listening";
} else if (agent.recommended_schedule_cron) {
status = "scheduled";
} else {
status = "idle";
}
}
const completedExecs = agentExecutions.filter((e) => e.ended_at);
if (completedExecs.length > 0) {
const sorted = completedExecs.sort((a, b) => toEndTime(b) - toEndTime(a));
const endedAt = sorted[0].ended_at;
lastRunAt =
endedAt instanceof Date ? endedAt.toISOString() : String(endedAt);
}
const totalRuns = agent.execution_count ?? agentExecutions.length;
return {
status,
health: deriveHealth(status, lastRunAt),
progress: null,
totalRuns,
lastRunAt,
lastError,
activeExecutionID,
monthlySpend: 0,
nextScheduledRun: null,
triggerType: agent.has_external_trigger ? "webhook" : null,
};
}
export function useAgentStatusMap(
agents: LibraryAgent[],
): Map<string, AgentStatusInfo> {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
const map = new Map<string, AgentStatusInfo>();
const execsByGraph = new Map<string, GraphExecutionMeta[]>();
for (const exec of executions ?? []) {
const list = execsByGraph.get(exec.graph_id);
if (list) {
list.push(exec);
} else {
execsByGraph.set(exec.graph_id, [exec]);
}
}
for (const agent of agents) {
const agentExecs = execsByGraph.get(agent.graph_id) ?? [];
map.set(agent.graph_id, computeAgentStatus(agent, agentExecs));
}
return map;
}, [agents, executions]);
}
const DEFAULT_STATUS: AgentStatusInfo = {
status: "idle",
health: "good",
progress: null,
totalRuns: 0,
lastRunAt: null,
lastError: null,
activeExecutionID: null,
monthlySpend: 0,
nextScheduledRun: null,
triggerType: null,
};
export function getAgentStatus(
statusMap: Map<string, AgentStatusInfo>,
graphID: string,
): AgentStatusInfo {
return statusMap.get(graphID) ?? DEFAULT_STATUS;
}
export function useFleetSummary(agents: LibraryAgent[]): FleetSummary {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
const counts: FleetSummary = {
running: 0,
error: 0,
completed: 0,
listening: 0,
scheduled: 0,
idle: 0,
monthlySpend: 0,
};
const activeGraphIds = new Set<string>();
const errorGraphIds = new Set<string>();
const completedGraphIds = new Set<string>();
if (executions) {
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
for (const exec of executions) {
if (isActive(exec.status)) {
activeGraphIds.add(exec.graph_id);
}
const endedTs = exec.ended_at
? new Date(
exec.ended_at instanceof Date
? exec.ended_at.getTime()
: exec.ended_at,
).getTime()
: 0;
if (isFailed(exec.status) && endedTs > cutoff) {
errorGraphIds.add(exec.graph_id);
}
if (
exec.status === AgentExecutionStatus.COMPLETED &&
endedTs > cutoff
) {
completedGraphIds.add(exec.graph_id);
}
}
}
for (const agent of agents) {
if (activeGraphIds.has(agent.graph_id)) {
counts.running += 1;
} else if (errorGraphIds.has(agent.graph_id)) {
counts.error += 1;
} else if (agent.has_external_trigger) {
counts.listening += 1;
} else if (agent.recommended_schedule_cron) {
counts.scheduled += 1;
} else {
counts.idle += 1;
}
if (completedGraphIds.has(agent.graph_id)) {
counts.completed += 1;
}
}
return counts;
}, [agents, executions]);
}
export { deriveHealth };

View File

@@ -0,0 +1,116 @@
"use client";
import {
getGetV1ListAllExecutionsQueryKey,
useGetV1ListAllExecutions,
} from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { useExecutionEvents } from "@/hooks/useExecutionEvents";
import { useQueryClient } from "@tanstack/react-query";
import { useCallback, useMemo } from "react";
import type { FleetSummary } from "../types";
import { isActive, isFailed, SEVENTY_TWO_HOURS_MS } from "./executionHelpers";
function isRecentFailure(
status: string,
endedAt?: string | Date | null,
): boolean {
if (!isFailed(status)) return false;
if (!endedAt) return false;
const ts =
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
}
function isRecentCompletion(
status: string,
endedAt?: string | Date | null,
): boolean {
if (status !== AgentExecutionStatus.COMPLETED) return false;
if (!endedAt) return false;
const ts =
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
}
export function useLibraryFleetSummary(
agents: LibraryAgent[],
): FleetSummary | undefined {
const queryClient = useQueryClient();
const { data: executions, isSuccess } = useGetV1ListAllExecutions({
query: { select: okData },
});
const graphIDs = useMemo(() => agents.map((a) => a.graph_id), [agents]);
const handleExecutionUpdate = useCallback(() => {
queryClient.invalidateQueries({
queryKey: getGetV1ListAllExecutionsQueryKey(),
});
}, [queryClient]);
useExecutionEvents({
graphIds: graphIDs.length > 0 ? graphIDs : undefined,
enabled: graphIDs.length > 0,
onExecutionUpdate: handleExecutionUpdate,
});
return useMemo(() => {
if (!isSuccess || !executions) return undefined;
const agentsWithActiveExecution = new Set<string>();
const agentsWithRecentFailure = new Set<string>();
const agentsWithRecentCompletion = new Set<string>();
for (const exec of executions) {
if (isActive(exec.status)) {
agentsWithActiveExecution.add(exec.graph_id);
}
if (isRecentFailure(exec.status, exec.ended_at)) {
agentsWithRecentFailure.add(exec.graph_id);
}
if (isRecentCompletion(exec.status, exec.ended_at)) {
agentsWithRecentCompletion.add(exec.graph_id);
}
}
const summary: FleetSummary = {
running: 0,
error: 0,
completed: 0,
listening: 0,
scheduled: 0,
idle: 0,
monthlySpend: 0,
};
for (const agent of agents) {
if (agentsWithActiveExecution.has(agent.graph_id)) {
summary.running += 1;
} else if (agentsWithRecentFailure.has(agent.graph_id)) {
summary.error += 1;
} else if (agent.has_external_trigger) {
summary.listening += 1;
} else if (agent.recommended_schedule_cron) {
summary.scheduled += 1;
} else {
summary.idle += 1;
}
// Parallel counter: mutually exclusive with running/error (which match
// the sitrep priority order used by the "Recently completed" tab list)
// but orthogonal to listening/scheduled/idle.
if (
!agentsWithActiveExecution.has(agent.graph_id) &&
!agentsWithRecentFailure.has(agent.graph_id) &&
agentsWithRecentCompletion.has(agent.graph_id)
) {
summary.completed += 1;
}
}
return summary;
}, [agents, executions, isSuccess]);
}

View File

@@ -2,12 +2,14 @@
import { useEffect, useState, useCallback } from "react";
import { HeartIcon, ListIcon } from "@phosphor-icons/react";
import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn";
import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader";
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
import { useLibraryListPage } from "./components/useLibraryListPage";
import { FavoriteAnimationProvider } from "./context/FavoriteAnimationContext";
import { LibraryTab } from "./types";
import type { LibraryTab, AgentStatusFilter } from "./types";
import { useLibraryFleetSummary } from "./hooks/useLibraryFleetSummary";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
const LIBRARY_TABS: LibraryTab[] = [
{ id: "all", title: "All", icon: ListIcon },
@@ -19,6 +21,10 @@ export default function LibraryPage() {
useLibraryListPage();
const [selectedFolderId, setSelectedFolderId] = useState<string | null>(null);
const [activeTab, setActiveTab] = useState(LIBRARY_TABS[0].id);
const [statusFilter, setStatusFilter] = useState<AgentStatusFilter>("all");
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const { agents } = useLibraryAgents();
const fleetSummary = useLibraryFleetSummary(agents);
useEffect(() => {
document.title = "Library AutoGPT Platform";
@@ -40,7 +46,6 @@ export default function LibraryPage() {
>
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
<LibraryActionHeader setSearchTerm={setSearchTerm} />
<JumpBackIn />
<LibraryAgentList
searchTerm={searchTerm}
librarySort={librarySort}
@@ -50,6 +55,10 @@ export default function LibraryPage() {
tabs={LIBRARY_TABS}
activeTab={activeTab}
onTabChange={handleTabChange}
statusFilter={statusFilter}
onStatusFilterChange={setStatusFilter}
fleetSummary={isAgentBriefingEnabled ? fleetSummary : undefined}
briefingAgents={isAgentBriefingEnabled ? agents : undefined}
/>
</main>
</FavoriteAnimationProvider>

Some files were not shown because too many files have changed in this diff Show More