Compare commits

..

31 Commits

Author SHA1 Message Date
Zamil Majdy
ce0bad96a3 test: add E2E screenshots for PR #12796 2026-04-16 18:04:29 +07:00
Zamil Majdy
51286cc0a9 fix(backend/copilot): address review minor points
- db.py: add comment clarifying from_start + after_sequence precedence
  when both are set (only reachable via internal API, route enforces
  mutual exclusion)
- useLoadMoreMessages.ts: fix hasMore logic after forward truncation —
  forward mode should keep using the server's has_more_messages so the
  sentinel continues fetching (truncation sheds display items but cursor
  advances; backward mode still caps at MAX_OLDER_MESSAGES)
2026-04-16 17:57:10 +07:00
Zamil Majdy
b7d5a59f9d fix(backend/copilot): fix 3 review bugs in forward pagination
- db.py: trim forward page at tail if it ends on tool messages, so the
  next after_sequence page doesn't start mid-tool-group (forward analogue
  of the existing backward boundary scan)
- db.py: return newest_sequence=None in backward mode — it is not a valid
  forward cursor (was returning messages[-1].sequence, i.e. the oldest-of-
  page-in-DESC, not the session's newest)
- useLoadMoreMessages.ts: fix MAX_OLDER_MESSAGES truncation direction in
  forward mode — slice(0, MAX) keeps the beginning of the conversation;
  old code used slice(-MAX) which discarded the initial prompt
- Add db_test.py tests for forward tail boundary and None newest_sequence
- Add useLoadMoreMessages.test.ts test for forward truncation direction
2026-04-16 17:54:41 +07:00
Zamil Majdy
8c3bdb0315 fix(backend/copilot): update transcript_test to use strip_for_upload after upload_cli_session removal 2026-04-16 16:10:01 +07:00
Zamil Majdy
8524091a5f fix(backend/copilot): replace dict[str, Any] Prisma args with proper typed inputs
Replace msg_include and boundary_where with FindManyChatMessageArgsFromChatSession
and ChatMessageWhereInput respectively for proper Prisma type safety.
2026-04-16 15:47:51 +07:00
Zamil Majdy
22c5d6f86c Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into fix/copilot-pagination-initial-load-missing-messages 2026-04-16 15:42:52 +07:00
Zamil Majdy
2f6a02a7fa fix(frontend/copilot): skip extraToolOutputs extraction for forward pagination 2026-04-16 00:11:26 +07:00
Zamil Majdy
ddf8bb7d8b test(frontend/copilot): add coverage for null cursor guard and initialPageRawMessages path 2026-04-15 23:57:27 +07:00
Zamil Majdy
d635844412 fix(frontend/copilot): fix lint — remove unused rerender, format test file 2026-04-15 23:43:14 +07:00
Zamil Majdy
6f46bce634 test(frontend/copilot): add coverage for forward pagination cursor advancement and message truncation 2026-04-15 23:39:36 +07:00
Zamil Majdy
3056be165f fix(frontend/copilot): sync openapi.json with backend schema (query param descriptions) 2026-04-15 23:21:16 +07:00
Zamil Majdy
7e8a68a5c0 fix(frontend/copilot): update test assertion for forward-paginated load-more label 2026-04-15 23:10:09 +07:00
Zamil Majdy
54f585b8c4 fix(frontend/copilot): dynamic load-more label for forward vs backward pagination 2026-04-15 23:04:33 +07:00
Zamil Majdy
9ea7e61652 fix(backend/copilot): move Query param docs to description fields for OpenAPI schema 2026-04-15 21:40:25 +07:00
Zamil Majdy
6d1688b0f0 test(frontend): add useChatSession + ChatMessagesContainer coverage tests
- useChatSession.test.ts: 5 tests for newestSequence and forwardPaginated memos
- ChatMessagesContainer.test.tsx: 5 tests verifying top/bottom sentinel placement
  based on forwardPaginated prop
2026-04-15 21:36:44 +07:00
Zamil Majdy
ce22b21824 fix(backend/copilot): remove noisy GET_SESSION debug log 2026-04-15 21:27:34 +07:00
Zamil Majdy
c73c5b380c fix(frontend): fix stuck isLoadingMore on resetPaged + add loadMore coverage tests
- resetPaged() now calls setIsLoadingMore(false) so the spinner clears
  immediately if a loadMore flight is in progress when the session resets
- Extend useLoadMoreMessages.test.ts with 8 additional tests covering:
  forward/backward loadMore paths, error handling (1 error vs 3 errors),
  non-200 responses, and epoch/stale-response discard
2026-04-15 21:24:03 +07:00
Zamil Majdy
8f93942ee5 fix(frontend): prevent loadMore race in resetPaged + add coverage tests
Set hasMore=false in resetPaged() so the forward-pagination sentinel cannot
trigger a loadMore() with after_sequence on an active session during the
transition window before forwardPaginated updates.

Add tests:
- useLoadMoreMessages: resetPaged sets hasMore=false, session-change reset
- LoadMoreSentinel: adjustScroll=false leaves scrollTop unchanged (forward)
2026-04-15 21:07:57 +07:00
Zamil Majdy
60b1aba221 test(frontend): fix test title and add empty string case for non-user drop
Rename 'still drops non-user messages with empty content' to 'null content'
to match the fixture. Add a separate test for content: "" on non-user roles
to cover both null and empty-string branches.
2026-04-15 20:58:50 +07:00
Zamil Majdy
f9a33f2aa6 refactor(backend): move _BOUNDARY_SCAN_LIMIT to module level, logger.info -> debug
Move _BOUNDARY_SCAN_LIMIT from function body to module level to avoid
redefining it on every backward-pagination call. Downgrade GET_SESSION
per-request log from INFO to DEBUG to reduce noise at scale.
2026-04-15 20:57:46 +07:00
Zamil Majdy
df3c4b381c fix(frontend): pass extraToolOutputs for both forward and backward pagination
Forward pagination (completed sessions) can also have tool call/output pairs
that span the initial-page/paged-page boundary. Remove the !forwardPaginated
guard so initial page tool outputs are always available for cross-page matching.
2026-04-15 20:55:49 +07:00
Zamil Majdy
7bf5a8c226 fix(frontend): skip scroll adjustment for forward pagination sentinel
The useLayoutEffect scroll-position restore in LoadMoreSentinel is designed
for backward pagination (content prepended above current position). When
appending below for forward pagination, applying prevTop + delta pushes
the viewport past newly loaded content causing a jarring jump.

Add adjustScroll prop (default true for backward, false for forward) and
skip the scrollTop mutation when appending below.
2026-04-15 20:53:46 +07:00
Zamil Majdy
a0f149fcb2 fix(frontend): reset paged state on send to prevent misordering during streaming
When a completed session (forwardPaginated=true) has forward-loaded history
pages and the user sends a new message, pagedMessages would appear after the
new streaming turn instead of before it. Reset pagedRawMessages before calling
sendMessage so ordering is always correct during the active stream.

Expose resetPaged() from useLoadMoreMessages for this purpose.
2026-04-15 20:50:13 +07:00
Zamil Majdy
f35791170a fix(frontend): resolve merge conflict with dev in useLoadMoreMessages
Adopt dev's cleaner useEffect structure (early return after session-changed
reset, preserve paged state across refetches) while keeping forward pagination
additions (newestSequence guard, forwardPaginated support) from this PR.
2026-04-15 20:45:27 +07:00
Zamil Majdy
3771bfad9c fix(backend): close TOCTOU race in get_session pagination direction check
Re-verify active session after DB fetch so that if a session completes
between the pre-check and the query, the route re-fetches from seq 0
instead of returning messages in newest-first order.

Add test: test_get_session_toctou_refetch_when_session_completes_mid_request
2026-04-15 20:39:17 +07:00
Zamil Majdy
2e2f518c58 fix(frontend): regenerate openapi.json from updated backend docstring 2026-04-15 20:32:34 +07:00
Zamil Majdy
89f2dcc338 test(frontend/copilot): add unit tests for convertChatSessionToUiMessages null-content fix
Covers the user-message-with-null/empty-content fix (lines 258-260) which is the
primary user-visible fix in this PR. Tests verify that user messages with null or
empty content are preserved (not silently dropped), while non-user messages with
empty content continue to be filtered out.
2026-04-15 20:22:56 +07:00
Zamil Majdy
3d3aef58ac fix(frontend/copilot): add rootMargin prop to LoadMoreSentinel for bottom pagination
The bottom sentinel (forward pagination) needs rootMargin "0px 0px 200px 0px"
to pre-trigger loading 200px before the user hits the absolute bottom.
The top sentinel (backward pagination) keeps its "200px 0px 0px 0px" default.
2026-04-15 20:19:41 +07:00
Zamil Majdy
e85c042eb6 fix(backend/copilot): add route tests, mutual-exclusive cursor validation, newestSequence guard
- routes_test.py: 4 new tests for get_session — forward_paginated=True for
  completed sessions, False for active, after_sequence wiring, 400 for
  conflicting cursors
- routes.py: reject 400 when both before_sequence and after_sequence are sent
- useLoadMoreMessages.ts: guard newestSequence in else-branch so a parent
  refetch never reverts a cursor already advanced by paging
2026-04-15 20:13:39 +07:00
Zamil Majdy
e7b621f0b0 fix(backend/copilot): use %s logging, consistent forward_paginated, compact openapi.json
- routes.py: replace f-string log with %s lazy formatting
- routes.py: compute forward_paginated once and reuse across both return paths
- openapi.json: revert reformatting noise, keep only the 3 new field additions
  (newest_sequence, forward_paginated on SessionDetailResponse;
   after_sequence query param on GET /sessions/{id})
2026-04-15 19:55:48 +07:00
Zamil Majdy
e8c356a728 fix(backend/copilot): fix initial load missing messages + forward pagination for completed sessions
Completed copilot sessions with many messages were showing an empty view
because the backend returned only the newest 50 (all tool calls, no user
messages) and the frontend silently dropped messages with empty content.

Backend changes:
- get_chat_messages_paginated: add from_start (ASC) and after_sequence
  (forward cursor) modes alongside the existing before_sequence (DESC)
  backward mode
- PaginatedMessages: expose newest_sequence for forward-pagination cursors
- routes.py: detect completed sessions on initial load (no active stream)
  and use from_start=True; expose newest_sequence + forward_paginated in
  SessionDetailResponse; accept after_sequence query param
- openapi.json: add after_sequence param + newest_sequence / forward_paginated
  fields to SessionDetailResponse schema

Frontend changes:
- convertChatSessionToUiMessages: never drop user messages with empty content
- useLoadMoreMessages: support forward pagination via after_sequence cursor;
  append pages to end rather than prepending for completed sessions
- ChatMessagesContainer: move LoadMoreSentinel to bottom for forward pagination
- useChatSession / useCopilotPage / ChatContainer: wire up newestSequence and
  forwardPaginated props end-to-end

Tests: add 9 new unit tests for from_start and after_sequence pagination modes
2026-04-15 19:26:19 +07:00
141 changed files with 2567 additions and 8459 deletions

View File

@@ -18,6 +18,7 @@ from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.message_dedup import acquire_dedup_lock
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -190,6 +191,8 @@ class SessionDetailResponse(BaseModel):
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
has_more_messages: bool = False
oldest_sequence: int | None = None
newest_sequence: int | None = None
forward_paginated: bool = False
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
@@ -454,39 +457,113 @@ async def update_session_title_route(
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=200),
before_sequence: int | None = Query(default=None, ge=0),
limit: int = Query(
default=50,
ge=1,
le=200,
description="Maximum number of messages to return.",
),
before_sequence: int | None = Query(
default=None,
ge=0,
description=(
"Backward pagination cursor. Return messages with sequence number "
"strictly less than this value. Used by active-session load-more. "
"Mutually exclusive with after_sequence."
),
),
after_sequence: int | None = Query(
default=None,
ge=0,
description=(
"Forward pagination cursor. Return messages with sequence number "
"strictly greater than this value. Used by completed-session load-more. "
"Mutually exclusive with before_sequence."
),
),
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
Supports cursor-based pagination via ``limit``, ``before_sequence``, and
``after_sequence``. The two cursor parameters are mutually exclusive.
On the initial load (no cursor provided) of a completed session, messages
are returned in forward order starting from sequence 0 so the user always
sees their initial prompt. Active sessions use the legacy newest-first
order so streaming context is preserved.
"""
if before_sequence is not None and after_sequence is not None:
raise HTTPException(
status_code=400,
detail="before_sequence and after_sequence are mutually exclusive",
)
is_initial_load = before_sequence is None and after_sequence is None
# Check active stream before the DB query on initial loads so we can
# choose the correct pagination direction (forward for completed sessions,
# newest-first for active ones).
active_session = None
last_message_id = None
if is_initial_load:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
# Completed sessions on initial load start from sequence 0 so the user's
# initial prompt is always visible. Active sessions keep the legacy
# newest-first behavior to preserve streaming context.
from_start = is_initial_load and active_session is None
forward_paginated = from_start or after_sequence is not None
page = await get_chat_messages_paginated(
session_id, limit, before_sequence, user_id=user_id
session_id,
limit,
before_sequence=before_sequence,
after_sequence=after_sequence,
from_start=from_start,
user_id=user_id,
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
# Close the TOCTOU window: if the session was active at pre-check, re-verify
# after the DB fetch. The session may have completed between the two awaits,
# which would have caused messages to be fetched newest-first even though the
# session is now complete. Re-fetch from seq 0 so the initial prompt is
# always visible.
if is_initial_load and active_session is not None:
post_active, _ = await stream_registry.get_active_session(session_id, user_id)
if post_active is None:
active_session = None
last_message_id = None
from_start = True
forward_paginated = True
page = await get_chat_messages_paginated(
session_id,
limit,
before_sequence=None,
after_sequence=None,
from_start=True,
user_id=user_id,
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
messages = [
_strip_injected_context(message.model_dump()) for message in page.messages
]
# Only check active stream on initial load (not on "load more" requests)
active_stream_info = None
if before_sequence is None:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
if active_session and last_message_id is not None:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
# Skip session metadata on "load more" — frontend only needs messages
if before_sequence is not None:
if not is_initial_load:
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
@@ -496,6 +573,8 @@ async def get_session(
active_stream=None,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
newest_sequence=page.newest_sequence,
forward_paginated=forward_paginated,
total_prompt_tokens=0,
total_completion_tokens=0,
)
@@ -512,6 +591,8 @@ async def get_session(
active_stream=active_stream_info,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
newest_sequence=page.newest_sequence,
forward_paginated=forward_paginated,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=page.session.metadata,
@@ -832,6 +913,9 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
# Capture the original message text BEFORE any mutation (attachment enrichment)
# so the idempotency hash is stable across retries.
original_message = request.message
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
@@ -860,36 +944,58 @@ async def stream_chat_post(
)
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message returns None when a duplicate is
# detected — in that case skip enqueue to avoid processing the message twice.
is_duplicate_message = False
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
# ── Idempotency guard ────────────────────────────────────────────────────
# Blocks duplicate executor tasks from concurrent/retried POSTs.
# See backend/copilot/message_dedup.py for the full lifecycle description.
dedup_lock = None
if request.is_user_message:
dedup_lock = await acquire_dedup_lock(
session_id, original_message, sanitized_file_ids
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
is_duplicate_message = (
await append_and_save_message(session_id, message)
) is None
logger.info(f"[STREAM] User message saved for session {session_id}")
if not is_duplicate_message and request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
if dedup_lock is None and (original_message or sanitized_file_ids):
async def _empty_sse() -> AsyncGenerator[str, None]:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return StreamingResponse(
_empty_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
},
)
# Create a task in the stream registry for reconnection support.
# For duplicate messages, skip create_session entirely so the infra-retry
# client subscribes to the *existing* turn's Redis stream and receives the
# in-progress executor output rather than an empty stream.
turn_id = ""
if not is_duplicate_message:
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
#
# If any of these operations raises, release the dedup lock before propagating
# so subsequent retries are not blocked for 30 s.
try:
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
@@ -907,6 +1013,7 @@ async def stream_chat_post(
}
},
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
@@ -918,10 +1025,10 @@ async def stream_chat_post(
mode=request.mode,
model=request.model,
)
else:
logger.info(
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
)
except Exception:
if dedup_lock:
await dedup_lock.release()
raise
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
@@ -945,6 +1052,12 @@ async def stream_chat_post(
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
# True for every exit path except GeneratorExit (client disconnect).
# On disconnect the backend turn is still running — releasing the lock
# there would reopen the infra-retry duplicate window. The 30 s TTL
# is the fallback. All other exits (normal finish, early return, error)
# should release so the user can re-send the same message.
release_dedup_lock_on_exit = True
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
@@ -956,7 +1069,7 @@ async def stream_chat_post(
if subscriber_queue is None:
yield StreamFinish().to_sse()
return
return # finally releases dedup_lock
# Read from the subscriber queue and yield to SSE
logger.info(
@@ -998,7 +1111,7 @@ async def stream_chat_post(
}
},
)
break
break # finally releases dedup_lock
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
@@ -1014,6 +1127,7 @@ async def stream_chat_post(
}
},
)
release_dedup_lock_on_exit = False
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
@@ -1028,7 +1142,10 @@ async def stream_chat_post(
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
# finally releases dedup_lock
finally:
if dedup_lock and release_dedup_lock_on_exit:
await dedup_lock.release()
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:

View File

@@ -133,12 +133,21 @@ def test_stream_chat_rejects_too_many_file_ids():
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
def _mock_stream_internals(
mocker: pytest_mock.MockerFixture,
*,
redis_set_returns: object = True,
):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing RabbitMQ.
validation and enrichment logic without needing Redis/RabbitMQ.
Args:
redis_set_returns: Value returned by the mocked Redis ``set`` call.
``True`` (default) simulates a fresh key (new message);
``None`` simulates a collision (duplicate blocked).
Returns:
A namespace with ``save`` and ``enqueue`` mock objects so
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
callers can make additional assertions about side-effects.
"""
import types
@@ -149,7 +158,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
)
mock_save = mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
@@ -165,9 +174,15 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
return types.SimpleNamespace(
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=redis_set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
return ns
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
@@ -196,29 +211,6 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
assert response.status_code == 200
# ─── Duplicate message dedup ──────────────────────────────────────────
def test_stream_chat_skips_enqueue_for_duplicate_message(
mocker: pytest_mock.MockerFixture,
):
"""When append_and_save_message returns None (duplicate detected),
enqueue_copilot_turn and stream_registry.create_session must NOT be called
to avoid double-processing and to prevent overwriting the active stream's
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
mocks = _mock_stream_internals(mocker)
# Override save to return None — signalling a duplicate
mocks.save.return_value = None
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 200
mocks.enqueue.assert_not_called()
mocks.registry.create_session.assert_not_called()
# ─── UUID format filtering ─────────────────────────────────────────────
@@ -714,6 +706,237 @@ class TestStripInjectedContext:
assert result["content"] == "hello"
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
mocker: pytest_mock.MockerFixture,
) -> None:
"""A second POST with the same message within the 30-s window must return
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
turn complete without creating a ghost response."""
# redis_set_returns=None simulates a collision: the NX key already exists.
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-dup/stream",
json={"message": "duplicate message", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
assert '"finish"' in body
assert "[DONE]" in body
# The empty SSE response must include the AI SDK protocol header so the
# frontend treats it as a valid stream and marks the turn complete.
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
# The duplicate guard must prevent save/enqueue side effects.
ns.save.assert_not_called()
ns.enqueue.assert_not_called()
def test_stream_chat_first_post_proceeds_normally(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The first POST (Redis NX key set successfully) must proceed through the
normal streaming path — no early return."""
ns = _mock_stream_internals(mocker, redis_set_returns=True)
response = client.post(
"/sessions/sess-new/stream",
json={"message": "first message", "is_user_message": True},
)
assert response.status_code == 200
# Redis set must have been called once with the NX flag.
ns.redis.set.assert_called_once()
call_kwargs = ns.redis.set.call_args
assert call_kwargs.kwargs.get("nx") is True
def test_stream_chat_dedup_skipped_for_non_user_messages(
mocker: pytest_mock.MockerFixture,
) -> None:
"""System/assistant messages (is_user_message=False) bypass the dedup
guard — they are injected programmatically and must always be processed."""
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-sys/stream",
json={"message": "system context", "is_user_message": False},
)
# Even though redis_set_returns=None (would block a user message),
# the endpoint must proceed because is_user_message=False.
assert response.status_code == 200
ns.redis.set.assert_not_called()
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup hash must be computed from the original request message,
not the mutated version that has the [Attached files] block appended.
A file_id is sent so the route actually appends the [Attached files] block,
exercising the mutation path — the hash must still match the original text."""
import hashlib
ns = _mock_stream_internals(mocker, redis_set_returns=True)
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
# Mock workspace + prisma so the attachment block is actually appended.
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
fake_file = type(
"F",
(),
{
"id": file_id,
"name": "doc.pdf",
"mimeType": "application/pdf",
"sizeBytes": 1024,
},
)()
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-hash/stream",
json={
"message": "plain message",
"is_user_message": True,
"file_ids": [file_id],
},
)
assert response.status_code == 200
ns.redis.set.assert_called_once()
call_args = ns.redis.set.call_args
dedup_key = call_args.args[0]
# Hash must use the original message + sorted file IDs, not the mutated text.
expected_hash = hashlib.sha256(
f"sess-hash:plain message:{file_id}".encode()
).hexdigest()[:16]
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
assert dedup_key == expected_key, (
f"Dedup key {dedup_key!r} does not match expected {expected_key!r}"
"hash may be using mutated message or wrong inputs"
)
def test_stream_chat_dedup_key_released_after_stream_finish(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup Redis key must be deleted after the turn completes (when
subscriber_queue is None the route yields StreamFinish immediately and
should release the key so the user can re-send the same message)."""
from unittest.mock import AsyncMock as _AsyncMock
# Set up all internals manually so we can control subscribe_to_session.
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
response = client.post(
"/sessions/sess-finish/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
assert '"finish"' in body
# The dedup key must be released so intentional re-sends are allowed.
mock_redis.delete.assert_called_once()
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The route must not crash when the dedup Redis delete fails on the
subscriber_queue-is-None early-finish path (except Exception: pass)."""
from unittest.mock import AsyncMock as _AsyncMock
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
# Make the delete raise so the except-pass branch is exercised.
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
# Should not raise even though delete fails.
response = client.post(
"/sessions/sess-finish-err/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
assert '"finish"' in response.text
# delete must have been attempted — the except-pass branch silenced the error.
mock_redis.delete.assert_called_once()
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
@@ -759,7 +982,7 @@ def test_disconnect_stream_returns_404_when_session_missing(
mock_disconnect.assert_not_awaited()
# ─── GET /sessions/{session_id} — backward pagination ─────────────────────────
# ─── GET /sessions/{session_id} — forward/backward pagination ──────────────────
def _make_paginated_messages(
@@ -784,6 +1007,7 @@ def _make_paginated_messages(
messages=[ChatMessage(role="user", content="hello", sequence=0)],
has_more=has_more,
oldest_sequence=0,
newest_sequence=0,
session=session_info,
)
mock_paginate = mocker.patch(
@@ -794,11 +1018,11 @@ def _make_paginated_messages(
return page, mock_paginate
def test_get_session_returns_backward_paginated(
def test_get_session_completed_returns_forward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""All sessions use backward (newest-first) pagination."""
"""Completed sessions (no active stream) return forward_paginated=True."""
_make_paginated_messages(mocker)
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
@@ -810,6 +1034,92 @@ def test_get_session_returns_backward_paginated(
assert response.status_code == 200
data = response.json()
assert data["oldest_sequence"] == 0
assert "forward_paginated" not in data
assert "newest_sequence" not in data
assert data["forward_paginated"] is True
assert data["newest_sequence"] == 0
def test_get_session_active_returns_backward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Active sessions (with running stream) return forward_paginated=False."""
from backend.copilot.stream_registry import ActiveSession
_make_paginated_messages(mocker)
active = MagicMock(spec=ActiveSession)
active.turn_id = "turn-1"
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(active, "msg-1"),
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
assert data["forward_paginated"] is False
assert data["active_stream"] is not None
assert data["active_stream"]["turn_id"] == "turn-1"
def test_get_session_after_sequence_returns_forward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""after_sequence param returns forward_paginated=True; no stream check needed."""
_, mock_paginate = _make_paginated_messages(mocker)
response = client.get("/sessions/sess-1?after_sequence=10")
assert response.status_code == 200
data = response.json()
assert data["forward_paginated"] is True
call_kwargs = mock_paginate.call_args
assert call_kwargs.kwargs.get("after_sequence") == 10
assert call_kwargs.kwargs.get("before_sequence") is None
def test_get_session_both_cursors_returns_400(
test_user_id: str,
) -> None:
"""Sending both before_sequence and after_sequence returns 400."""
response = client.get("/sessions/sess-1?before_sequence=5&after_sequence=10")
assert response.status_code == 400
def test_get_session_toctou_refetch_when_session_completes_mid_request(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Race condition: session was active at pre-check but completes before DB fetch.
The route should detect the race via a post-fetch re-check, then re-fetch
from seq 0 so the initial prompt is always visible.
"""
from backend.copilot.stream_registry import ActiveSession
page, mock_paginate = _make_paginated_messages(mocker)
active = MagicMock(spec=ActiveSession)
active.turn_id = "turn-1"
# First call: session appears active. Second call: session has completed.
mock_get_active = mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
side_effect=[(active, "msg-1"), (None, None)],
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
# Post-race: session is now completed → forward_paginated=True, no stream
assert data["forward_paginated"] is True
assert data["active_stream"] is None
# The DB was queried twice: once newest-first, once from_start=True
assert mock_paginate.call_count == 2
assert mock_get_active.call_count == 2
second_call = mock_paginate.call_args_list[1]
assert second_call.kwargs.get("from_start") is True

View File

@@ -12,7 +12,6 @@ import prisma.models
import backend.api.features.library.model as library_model
import backend.data.graph as graph_db
from backend.api.features.library.db import _fetch_schedule_info
from backend.data.graph import GraphModel, GraphSettings
from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
@@ -118,5 +117,4 @@ async def add_graph_to_library(
f"for store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)
return library_model.LibraryAgent.from_db(added_agent)

View File

@@ -21,17 +21,13 @@ async def test_add_graph_to_library_create_new_agent() -> None:
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
patch(
"backend.api.features.library._add_to_library._fetch_schedule_info",
new=AsyncMock(return_value={}),
),
):
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
mock_from_db.assert_called_once_with(created_agent)
# Verify create was called with correct data
create_call = mock_prisma.return_value.create.call_args
create_data = create_call.kwargs["data"]
@@ -58,10 +54,6 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
patch(
"backend.api.features.library._add_to_library._fetch_schedule_info",
new=AsyncMock(return_value={}),
),
):
mock_prisma.return_value.create = AsyncMock(
side_effect=prisma.errors.UniqueViolationError(
@@ -73,7 +65,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
mock_from_db.assert_called_once_with(updated_agent)
# Verify update was called with correct where and data
update_call = mock_prisma.return_value.update.call_args
assert update_call.kwargs["where"] == {

View File

@@ -1,7 +1,6 @@
import asyncio
import itertools
import logging
from datetime import datetime, timezone
from typing import Literal, Optional
import fastapi
@@ -44,65 +43,6 @@ config = Config()
integration_creds_manager = IntegrationCredentialsManager()
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
"""Fetch execution counts per graph in a single batched query."""
if not graph_ids:
return {}
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
by=["agentGraphId"],
where={
"userId": user_id,
"agentGraphId": {"in": graph_ids},
"isDeleted": False,
},
count=True,
)
return {
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
for row in rows
}
async def _fetch_schedule_info(
user_id: str, graph_id: Optional[str] = None
) -> dict[str, str]:
"""Fetch a map of graph_id → earliest next_run_time ISO string.
When `graph_id` is provided, the scheduler query is narrowed to that graph,
which is cheaper for single-agent lookups (detail page, post-update, etc.).
"""
try:
scheduler_client = get_scheduler_client()
schedules = await scheduler_client.get_execution_schedules(
graph_id=graph_id,
user_id=user_id,
)
earliest: dict[str, tuple[datetime, str]] = {}
for s in schedules:
parsed = _parse_iso_datetime(s.next_run_time)
if parsed is None:
continue
current = earliest.get(s.graph_id)
if current is None or parsed < current[0]:
earliest[s.graph_id] = (parsed, s.next_run_time)
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
except Exception:
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
return {}
def _parse_iso_datetime(value: str) -> Optional[datetime]:
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
try:
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
logger.warning("Failed to parse schedule next_run_time: %s", value)
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed
async def list_library_agents(
user_id: str,
search_term: Optional[str] = None,
@@ -197,22 +137,12 @@ async def list_library_agents(
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts, schedule_info = await asyncio.gather(
_fetch_execution_counts(user_id, graph_ids),
_fetch_schedule_info(user_id),
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
schedule_info=schedule_info,
)
library_agent = library_model.LibraryAgent.from_db(agent)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -284,22 +214,12 @@ async def list_favorite_library_agents(
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
)
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts, schedule_info = await asyncio.gather(
_fetch_execution_counts(user_id, graph_ids),
_fetch_schedule_info(user_id),
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
schedule_info=schedule_info,
)
library_agent = library_model.LibraryAgent.from_db(agent)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -365,12 +285,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
where={"userId": store_listing.owningUserId}
)
schedule_info = (
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
if library_agent.AgentGraph
else {}
)
return library_model.LibraryAgent.from_db(
library_agent,
sub_graphs=(
@@ -380,7 +294,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
),
store_listing=store_listing,
profile=profile,
schedule_info=schedule_info,
)
@@ -416,10 +329,7 @@ async def get_library_agent_by_store_version_id(
},
include=library_agent_include(user_id),
)
if not agent:
return None
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
return library_model.LibraryAgent.from_db(agent) if agent else None
async def get_library_agent_by_graph_id(
@@ -448,10 +358,7 @@ async def get_library_agent_by_graph_id(
assert agent.AgentGraph # make type checker happy
# Include sub-graphs so we can make a full credentials input schema
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
return library_model.LibraryAgent.from_db(
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
)
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
async def add_generated_agent_image(
@@ -593,11 +500,7 @@ async def create_library_agent(
for agent, graph in zip(library_agents, graph_entries):
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
schedule_info = await _fetch_schedule_info(user_id)
return [
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
for agent in library_agents
]
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
async def update_agent_version_in_library(
@@ -659,8 +562,7 @@ async def update_agent_version_in_library(
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
)
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
return library_model.LibraryAgent.from_db(lib)
async def create_graph_in_library(
@@ -1565,11 +1467,7 @@ async def bulk_move_agents_to_folder(
),
)
schedule_info = await _fetch_schedule_info(user_id)
return [
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
for agent in agents
]
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
def collect_tree_ids(

View File

@@ -65,11 +65,6 @@ async def test_get_library_agents(mocker):
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
# Call function
result = await db.list_library_agents("test-user")
@@ -358,136 +353,3 @@ async def test_create_library_agent_uses_upsert():
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False
@pytest.mark.asyncio
async def test_list_favorite_library_agents(mocker):
mock_library_agents = [
prisma.models.LibraryAgent(
id="fav1",
userId="test-user",
agentGraphId="agent-fav",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=True,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-fav",
version=1,
name="Favorite Agent",
description="My Favorite",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
)
result = await db.list_favorite_library_agents("test-user")
assert len(result.agents) == 1
assert result.agents[0].id == "fav1"
assert result.agents[0].name == "Favorite Agent"
assert result.agents[0].graph_id == "agent-fav"
assert result.pagination.total_items == 1
assert result.pagination.total_pages == 1
assert result.pagination.current_page == 1
assert result.pagination.page_size == 50
@pytest.mark.asyncio
async def test_list_library_agents_skips_failed_agent(mocker):
"""Agents that fail parsing should be skipped — covers the except branch."""
mock_library_agents = [
prisma.models.LibraryAgent(
id="ua-bad",
userId="test-user",
agentGraphId="agent-bad",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-bad",
version=1,
name="Bad Agent",
description="",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
mocker.patch(
"backend.api.features.library.model.LibraryAgent.from_db",
side_effect=Exception("parse error"),
)
result = await db.list_library_agents("test-user")
assert len(result.agents) == 0
assert result.pagination.total_items == 1
@pytest.mark.asyncio
async def test_fetch_execution_counts_empty_graph_ids():
result = await db._fetch_execution_counts("user-1", [])
assert result == {}
@pytest.mark.asyncio
async def test_fetch_execution_counts_uses_group_by(mocker):
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
mock_prisma.return_value.group_by = mocker.AsyncMock(
return_value=[
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
]
)
result = await db._fetch_execution_counts(
"user-1", ["graph-1", "graph-2", "graph-3"]
)
assert result == {"graph-1": 5, "graph-2": 2}
mock_prisma.return_value.group_by.assert_called_once_with(
by=["agentGraphId"],
where={
"userId": "user-1",
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
"isDeleted": False,
},
count=True,
)

View File

@@ -214,14 +214,6 @@ class LibraryAgent(pydantic.BaseModel):
folder_name: str | None = None # Denormalized for display
recommended_schedule_cron: str | None = None
is_scheduled: bool = pydantic.Field(
default=False,
description="Whether this agent has active execution schedules",
)
next_scheduled_run: str | None = pydantic.Field(
default=None,
description="ISO 8601 timestamp of the next scheduled run, if any",
)
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
marketplace_listing: Optional["MarketplaceListing"] = None
@@ -231,8 +223,6 @@ class LibraryAgent(pydantic.BaseModel):
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
store_listing: Optional[prisma.models.StoreListing] = None,
profile: Optional[prisma.models.Profile] = None,
execution_count_override: Optional[int] = None,
schedule_info: Optional[dict[str, str]] = None,
) -> "LibraryAgent":
"""
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
@@ -268,14 +258,10 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = (
execution_count_override
if execution_count_override is not None
else len(executions)
)
execution_count = len(executions)
success_rate: float | None = None
avg_correctness_score: float | None = None
if executions and execution_count > 0:
if execution_count > 0:
success_count = sum(
1
for e in executions
@@ -368,10 +354,6 @@ class LibraryAgent(pydantic.BaseModel):
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
next_scheduled_run=(
schedule_info.get(agent.agentGraphId) if schedule_info else None
),
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
)

View File

@@ -1,66 +1,11 @@
import datetime
import prisma.enums
import prisma.models
import pytest
from . import model as library_model
def _make_library_agent(
*,
graph_id: str = "g1",
executions: list | None = None,
) -> prisma.models.LibraryAgent:
return prisma.models.LibraryAgent(
id="la1",
userId="u1",
agentGraphId=graph_id,
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=True,
isDeleted=False,
isArchived=False,
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id=graph_id,
version=1,
name="Agent",
description="Desc",
userId="u1",
isActive=True,
createdAt=datetime.datetime.now(),
Executions=executions,
),
)
def test_from_db_execution_count_override_covers_success_rate():
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
now = datetime.datetime.now(datetime.timezone.utc)
exec1 = prisma.models.AgentGraphExecution(
id="exec-1",
agentGraphId="g1",
agentGraphVersion=1,
userId="u1",
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
createdAt=now,
updatedAt=now,
isDeleted=False,
isShared=False,
)
agent = _make_library_agent(executions=[exec1])
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
assert result.execution_count == 1
assert result.success_rate is not None
assert result.success_rate == 100.0
@pytest.mark.asyncio
async def test_agent_preset_from_db(test_user_id: str):
# Create mock DB agent

View File

@@ -5,8 +5,7 @@ import time
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Any, Literal, Sequence, cast, get_args
from urllib.parse import urlparse
from typing import Annotated, Any, Literal, Sequence, get_args
import pydantic
import stripe
@@ -55,11 +54,8 @@ from backend.data.credit import (
cancel_stripe_subscription,
create_subscription_checkout,
get_auto_top_up,
get_proration_credit_cents,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
modify_stripe_subscription_for_tier,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
@@ -703,72 +699,9 @@ class SubscriptionCheckoutResponse(BaseModel):
class SubscriptionStatusResponse(BaseModel):
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
monthly_cost: int # amount in cents (Stripe convention)
tier_costs: dict[str, int] # tier name -> amount in cents
proration_credit_cents: int # unused portion of current sub to convert on upgrade
def _validate_checkout_redirect_url(url: str) -> bool:
"""Return True if `url` matches the configured frontend origin.
Prevents open-redirect: attackers must not be able to supply arbitrary
success_url/cancel_url that Stripe will redirect users to after checkout.
Pre-parse rejection rules (applied before urlparse):
- Backslashes (``\\``) are normalised differently across parsers/browsers.
- Control characters (U+0000U+001F) are not valid in URLs and may confuse
some URL-parsing implementations.
"""
# Reject characters that can confuse URL parsers before any parsing.
if "\\" in url:
return False
if any(ord(c) < 0x20 for c in url):
return False
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
if not allowed:
# No configured origin — refuse to validate rather than allow arbitrary URLs.
return False
try:
parsed = urlparse(url)
allowed_parsed = urlparse(allowed)
except ValueError:
return False
if parsed.scheme not in ("http", "https"):
return False
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
# can trick browsers into connecting to a different host than displayed.
# ``@`` in query/fragment is harmless and must be allowed.
if "@" in parsed.netloc:
return False
return (
parsed.scheme == allowed_parsed.scheme
and parsed.netloc == allowed_parsed.netloc
)
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
async def _get_stripe_price_amount(price_id: str) -> int | None:
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
of caching the ``None`` sentinel so the next request retries Stripe instead
of being served a stale "no price" for the rest of the TTL window. Callers
should treat ``None`` as an unknown price and fall back to 0.
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
every GET /credits/subscription page load and reduces quota consumption.
"""
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
return price.unit_amount or 0
except stripe.StripeError:
logger.warning(
"Failed to retrieve Stripe price %s — returning None (not cached)",
price_id,
)
return None
tier: str
monthly_cost: int
tier_costs: dict[str, int]
@v1_router.get(
@@ -789,26 +722,21 @@ async def get_subscription_status(
*[get_subscription_price_id(t) for t in paid_tiers]
)
tier_costs: dict[str, int] = {
SubscriptionTier.FREE.value: 0,
SubscriptionTier.ENTERPRISE.value: 0,
}
async def _cost(pid: str | None) -> int:
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
for t, cost in zip(paid_tiers, costs):
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
for t, price_id in zip(paid_tiers, price_ids):
cost = 0
if price_id:
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
cost = price.unit_amount or 0
except stripe.StripeError:
pass
tier_costs[t.value] = cost
current_monthly_cost = tier_costs.get(tier.value, 0)
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
return SubscriptionStatusResponse(
tier=tier.value,
monthly_cost=current_monthly_cost,
monthly_cost=tier_costs.get(tier.value, 0),
tier_costs=tier_costs,
proration_credit_cents=proration_credit,
)
@@ -838,125 +766,24 @@ async def update_subscription_tier(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
# keeps their tier for the time they already paid for. The DB tier is NOT
# updated here when a subscription exists — the customer.subscription.deleted
# webhook fires at period end and downgrades to FREE then.
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
# tier), cancel_stripe_subscription returns False and we update the DB tier
# immediately since no webhook will ever fire.
# When payment is disabled entirely, update the DB tier directly.
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
if tier == SubscriptionTier.FREE:
if payment_enabled:
try:
had_subscription = await cancel_stripe_subscription(user_id)
except stripe.StripeError as e:
# Log full Stripe error server-side but return a generic message
# to the client — raw Stripe errors can leak customer/sub IDs and
# infrastructure config details.
logger.exception(
"Stripe error cancelling subscription for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=502,
detail=(
"Unable to cancel your subscription right now. "
"Please try again or contact support."
),
)
if not had_subscription:
# No active Stripe subscription found — the user was on an
# admin-granted tier. Update DB immediately since the
# subscription.deleted webhook will never fire.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
await cancel_stripe_subscription(user_id)
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
# Paid tier changes require payment to be enabled — block self-service upgrades
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
# Beta users (payment not enabled) → update tier directly without Stripe.
if not payment_enabled:
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier}",
)
# No-op short-circuit: if the user is already on the requested paid tier,
# do NOT create a new Checkout Session. Without this guard, a duplicate
# request (double-click, retried POST, stale page) creates a second
# subscription for the same price; the user would be charged for both
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
# which only fires after the second charge has cleared.
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
# Paid→paid tier change: if the user already has a Stripe subscription,
# modify it in-place with proration instead of creating a new Checkout
# Session. This preserves remaining paid time and avoids double-charging.
# The customer.subscription.updated webhook fires and updates the DB tier.
current_tier = user.subscription_tier or SubscriptionTier.FREE
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return SubscriptionCheckoutResponse(url="")
# modify_stripe_subscription_for_tier returns False when no active
# Stripe subscription exists — i.e. the user has an admin-granted
# paid tier with no Stripe record. In that case, update the DB
# tier directly (same as the FREE-downgrade path for admin-granted
# users) rather than sending them through a new Checkout Session.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
# Paid upgrade from FREE → create Stripe Checkout Session.
# Paid upgrade → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
raise HTTPException(
status_code=422,
detail="success_url and cancel_url are required for paid tier upgrades",
)
# Open-redirect protection: both URLs must point to the configured frontend
# origin, otherwise an attacker could use our Stripe integration as a
# redirector to arbitrary phishing sites.
#
# Fail early with a clear 503 if the server is misconfigured (neither
# frontend_base_url nor platform_base_url set), so operators get an
# actionable error instead of the misleading "must match the platform
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
# produce when `allowed` is empty.
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
logger.error(
"update_subscription_tier: neither frontend_base_url nor "
"platform_base_url is configured; cannot validate checkout redirect URLs"
)
raise HTTPException(
status_code=503,
detail=(
"Payment redirect URLs cannot be validated: "
"frontend_base_url or platform_base_url must be set on the server."
),
)
if not _validate_checkout_redirect_url(
request.success_url
) or not _validate_checkout_redirect_url(request.cancel_url):
raise HTTPException(
status_code=422,
detail="success_url and cancel_url must match the platform frontend origin",
)
try:
url = await create_subscription_checkout(
user_id=user_id,
@@ -964,19 +791,8 @@ async def update_subscription_tier(
success_url=request.success_url,
cancel_url=request.cancel_url,
)
except ValueError as e:
except (ValueError, stripe.StripeError) as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error creating checkout session for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to start checkout right now. "
"Please try again or contact support."
),
)
return SubscriptionCheckoutResponse(url=url)
@@ -985,78 +801,44 @@ async def update_subscription_tier(
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
)
async def stripe_webhook(request: Request):
webhook_secret = settings.secrets.stripe_webhook_secret
if not webhook_secret:
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
# signature over the same empty key). Reject all webhook calls when unconfigured.
logger.error(
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
"rejecting request to prevent signature bypass"
)
raise HTTPException(status_code=503, detail="Webhook not configured")
# Get the raw request body
payload = await request.body()
# Get the signature header
sig_header = request.headers.get("stripe-signature")
try:
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
except ValueError:
# Invalid payload
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.SignatureVerificationError:
# Invalid signature
raise HTTPException(status_code=400, detail="Invalid signature")
# Defensive payload extraction. A malformed payload (missing/non-dict
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
# AFTER signature verification — which Stripe interprets as a delivery
# failure and retries forever, while spamming Sentry with no useful info.
# Acknowledge with 200 and a warning so Stripe stops retrying.
event_type = event.get("type", "")
event_data = event.get("data") or {}
data_object = event_data.get("object") if isinstance(event_data, dict) else None
if not isinstance(data_object, dict):
logger.warning(
"stripe_webhook: %s missing or non-dict data.object; ignoring",
event_type,
event = stripe.Webhook.construct_event(
payload, sig_header, settings.secrets.stripe_webhook_secret
)
except ValueError as e:
# Invalid payload
raise HTTPException(
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
)
except stripe.SignatureVerificationError as e:
# Invalid signature
raise HTTPException(
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
)
return Response(status_code=200)
if event_type in (
"checkout.session.completed",
"checkout.session.async_payment_succeeded",
if (
event["type"] == "checkout.session.completed"
or event["type"] == "checkout.session.async_payment_succeeded"
):
session_id = data_object.get("id")
if not session_id:
logger.warning(
"stripe_webhook: %s missing data.object.id; ignoring", event_type
)
return Response(status_code=200)
await UserCredit().fulfill_checkout(session_id=session_id)
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
if event_type in (
if event["type"] in (
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
):
await sync_subscription_from_stripe(data_object)
await sync_subscription_from_stripe(event["data"]["object"])
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)
if event["type"] == "charge.dispute.created":
await UserCredit().handle_dispute(event["data"]["object"])
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
# to satisfy the type checker without changing runtime behaviour.
if event_type == "charge.dispute.created":
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
if event_type == "refund.created" or event_type == "charge.dispute.closed":
await UserCredit().deduct_credits(
cast("stripe.Refund | stripe.Dispute", data_object)
)
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
await UserCredit().deduct_credits(event["data"]["object"])
return Response(status_code=200)

View File

@@ -106,6 +106,7 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
@classmethod
def _missing_(cls, value: object) -> "LlmModel | None":
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
@@ -202,8 +203,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_4_20 = "x-ai/grok-4.20"
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
@@ -628,18 +627,6 @@ MODEL_METADATA = {
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_20: ModelMetadata(
"open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata(
"open_router",
2000000,
100000,
"Grok 4.20 Multi-Agent",
"OpenRouter",
"xAI",
3,
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
@@ -1000,6 +987,7 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "anthropic":
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
# Cache tool definitions alongside the system prompt.
# Placing cache_control on the last tool caches all tool schemas as a

View File

@@ -41,6 +41,7 @@ class PaginatedMessages(BaseModel):
messages: list[ChatMessage]
has_more: bool
oldest_sequence: int | None
newest_sequence: int | None
session: ChatSessionInfo
@@ -65,30 +66,48 @@ async def get_chat_messages_paginated(
session_id: str,
limit: int = 50,
before_sequence: int | None = None,
after_sequence: int | None = None,
from_start: bool = False,
user_id: str | None = None,
) -> PaginatedMessages | None:
"""Get paginated messages for a session, newest first.
"""Get paginated messages for a session.
Verifies session existence (and ownership when ``user_id`` is provided)
in parallel with the message query. Returns ``None`` when the session
is not found or does not belong to the user.
Three modes:
After fetching, a visibility guarantee ensures the page contains at least
one user or assistant message. If the entire page is tool messages (which
are hidden in the UI), it expands backward until a visible message is found
so the chat never appears blank.
- ``before_sequence`` set: backward pagination (DESC), returns messages
with sequence < ``before_sequence``. Used for active sessions or manual
backward navigation.
- ``from_start=True`` or ``after_sequence`` set: forward pagination (ASC).
Returns messages from sequence 0 (``from_start``) or after
``after_sequence``. Used on initial load of completed sessions and for
loading subsequent forward pages.
- Both cursors ``None`` and ``from_start=False``: newest-first (DESC
without filter). Used for active sessions on initial load.
Verifies session existence (and ownership when ``user_id`` is provided).
Returns ``None`` when the session is not found or does not belong to the
user.
"""
# Build session-existence / ownership check
session_where: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
session_where["userId"] = user_id
# Build message include — fetch paginated messages in the same query
forward = from_start or after_sequence is not None
# Build message include — fetch paginated messages in the same query.
# Note: when both from_start=True and after_sequence is not None, the
# after_sequence filter takes precedence (the elif branch below is skipped).
# This combination is not reachable via the HTTP route (mutual exclusion is
# enforced there), so we rely on the documented priority here without an
# additional assertion.
msg_include: FindManyChatMessageArgsFromChatSession = {
"order_by": {"sequence": "desc"},
"order_by": {"sequence": "asc" if forward else "desc"},
"take": limit + 1,
}
if before_sequence is not None:
if after_sequence is not None:
msg_include["where"] = {"sequence": {"gt": after_sequence}}
elif before_sequence is not None:
msg_include["where"] = {"sequence": {"lt": before_sequence}}
# Single query: session existence/ownership + paginated messages
@@ -106,129 +125,95 @@ async def get_chat_messages_paginated(
has_more = len(results) > limit
results = results[:limit]
# Reverse to ascending order
results.reverse()
if not forward:
# Backward mode: DB returned DESC; reverse to ascending order.
results.reverse()
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
if results and results[0].role == "tool":
results, has_more = await _expand_tool_boundary(
session_id, results, has_more, user_id
)
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
if results and results[0].role == "tool":
boundary_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
else:
# Forward mode: DB returned ASC.
# Tool-call tail boundary fix: if the last message in this page is a
# tool message, the NEXT forward page would start after it and begin
# mid-tool-group — the owning assistant message is on this page but
# the following tool results are on the next page.
# Trim the current page so it ends on the owning assistant message,
# which keeps tool groups intact across page boundaries.
if results and results[-1].role == "tool":
# Walk backward through results to find the last non-tool message.
trim_idx = len(results) - 1
while trim_idx >= 0 and results[trim_idx].role == "tool":
trim_idx -= 1
# Visibility guarantee: if the entire page has no user/assistant messages
# (all tool messages), the chat would appear blank. Expand backward
# until we find at least one visible message.
if results and not any(m.role in ("user", "assistant") for m in results):
results, has_more = await _expand_for_visibility(
session_id, results, has_more, user_id
)
if trim_idx >= 0:
# Trim results so the page ends at the owning assistant.
# Mark has_more=True so the client knows to fetch the rest.
results = results[: trim_idx + 1]
has_more = True
else:
# Entire page is tool messages with no visible owner — log and
# keep as-is so the caller is not stuck with an empty page.
logger.warning(
"Forward tail boundary: entire page is tool messages "
"for session=%s, no owning assistant found (%d msgs)",
session_id,
len(results),
)
messages = [ChatMessage.from_db(m) for m in results]
oldest_sequence = messages[0].sequence if messages else None
# newest_sequence is only meaningful in forward mode; in backward mode it
# points to the last message of the page (not the session's newest message)
# which is not a valid forward cursor. Return None in backward mode so
# clients don't accidentally use it as one.
newest_sequence = messages[-1].sequence if (messages and forward) else None
return PaginatedMessages(
messages=messages,
has_more=has_more,
oldest_sequence=oldest_sequence,
newest_sequence=newest_sequence,
session=session_info,
)
async def _expand_tool_boundary(
session_id: str,
results: list[Any],
has_more: bool,
user_id: str | None,
) -> tuple[list[Any], bool]:
"""Expand backward from the oldest message to include the owning assistant
message when the page starts mid-tool-group."""
boundary_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
has_more = boundary_msgs[0].sequence > 0
return results, has_more
_VISIBILITY_EXPAND_LIMIT = 200
async def _expand_for_visibility(
session_id: str,
results: list[Any],
has_more: bool,
user_id: str | None,
) -> tuple[list[Any], bool]:
"""Expand backward until the page contains at least one user or assistant
message, so the chat is never blank."""
expand_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
expand_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=expand_where,
order={"sequence": "desc"},
take=_VISIBILITY_EXPAND_LIMIT,
)
if not extra:
return results, has_more
# Collect messages until we find a visible one (user/assistant)
prepend = []
found_visible = False
for msg in extra:
prepend.append(msg)
if msg.role in ("user", "assistant"):
found_visible = True
break
if not found_visible:
logger.warning(
"Visibility expansion did not find any user/assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
prepend.reverse()
if prepend:
results = prepend + results
has_more = prepend[0].sequence > 0
return results, has_more
async def create_chat_session(
session_id: str,
user_id: str,

View File

@@ -175,136 +175,181 @@ async def test_no_where_on_messages_without_before_sequence(
assert "where" not in include["Messages"]
# ---------- Visibility guarantee ----------
# ---------- Forward pagination (from_start / after_sequence) ----------
@pytest.mark.asyncio
async def test_visibility_expands_when_all_tool_messages(
async def test_from_start_uses_asc_order_no_where(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When the entire page is tool messages, expand backward to find
at least one visible (user/assistant) message so the chat isn't blank."""
find_first, find_many = mock_db
# Newest 3 messages are all tool messages (DESC → reversed to ASC)
"""from_start=True queries messages in ASC order with no where filter."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[
_make_msg(12, role="tool"),
_make_msg(11, role="tool"),
_make_msg(10, role="tool"),
],
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
# Boundary expansion finds the owning assistant first (boundary fix),
# then visibility expansion finds a user message further back
find_many.side_effect = [
# First call: boundary fix (oldest msg is tool → find owner)
[_make_msg(9, role="tool"), _make_msg(8, role="tool")],
# Second call: visibility expansion (still all tool → find visible)
[_make_msg(7, role="tool"), _make_msg(6, role="assistant")],
]
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
assert page is not None
# Should include the expanded messages + original tool messages
roles = [m.role for m in page.messages]
assert "assistant" in roles or "user" in roles
assert page.has_more is True
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["order_by"] == {"sequence": "asc"}
assert "where" not in include["Messages"]
@pytest.mark.asyncio
async def test_no_visibility_expansion_when_visible_messages_present(
async def test_from_start_returns_messages_ascending(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""No visibility expansion needed when page already has visible messages."""
find_first, find_many = mock_db
# Page has an assistant message among tool messages
"""from_start=True returns messages in ascending sequence order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[
_make_msg(5, role="tool"),
_make_msg(4, role="assistant"),
_make_msg(3, role="user"),
],
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
page = await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
assert page is not None
# Boundary expansion might fire (oldest is tool), but NOT visibility
assert [m.sequence for m in page.messages][0] <= 3
@pytest.mark.asyncio
async def test_visibility_no_expansion_when_no_earlier_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When the page is all tool messages but there are no earlier messages
in the DB, visibility expansion returns early without changes."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(1, role="tool"), _make_msg(0, role="tool")],
)
# Boundary expansion: no earlier messages
# Visibility expansion: no earlier messages
find_many.side_effect = [[], []]
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert all(m.role == "tool" for m in page.messages)
@pytest.mark.asyncio
async def test_visibility_expansion_reaches_seq_zero(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When visibility expansion finds a visible message at sequence 0,
has_more should be False."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
)
find_many.side_effect = [
# Boundary expansion
[_make_msg(3, role="tool")],
# Visibility expansion — finds user at seq 0
[
_make_msg(2, role="tool"),
_make_msg(1, role="tool"),
_make_msg(0, role="user"),
],
]
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert page.messages[0].role == "user"
assert page.messages[0].sequence == 0
assert [m.sequence for m in page.messages] == [0, 1, 2]
assert page.oldest_sequence == 0
assert page.newest_sequence == 2
assert page.has_more is False
@pytest.mark.asyncio
async def test_visibility_expansion_with_user_id(
async def test_from_start_has_more_when_results_exceed_limit(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Visibility expansion passes user_id filter to the boundary query."""
"""from_start=True sets has_more when DB returns more than limit items."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=2, from_start=True)
assert page is not None
assert page.has_more is True
assert [m.sequence for m in page.messages] == [0, 1]
assert page.newest_sequence == 1
@pytest.mark.asyncio
async def test_after_sequence_uses_gt_filter_asc_order(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""after_sequence adds a sequence > N where clause and uses ASC order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(11), _make_msg(12)],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["order_by"] == {"sequence": "asc"}
assert include["Messages"]["where"] == {"sequence": {"gt": 10}}
@pytest.mark.asyncio
async def test_after_sequence_returns_messages_in_order(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""after_sequence returns only messages with sequence > cursor, ascending."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(11), _make_msg(12), _make_msg(13)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
assert page is not None
assert [m.sequence for m in page.messages] == [11, 12, 13]
assert page.oldest_sequence == 11
assert page.newest_sequence == 13
assert page.has_more is False
@pytest.mark.asyncio
async def test_newest_sequence_none_for_backward_mode(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""newest_sequence is None in backward mode — it is not a valid forward cursor."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5), _make_msg(4), _make_msg(3)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.newest_sequence is None
assert page.oldest_sequence == 3
@pytest.mark.asyncio
async def test_forward_mode_no_boundary_expansion(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pagination never triggers backward boundary expansion."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(10, role="tool")],
messages=[_make_msg(0, role="tool"), _make_msg(1, role="tool")],
)
find_many.side_effect = [
# Boundary expansion
[_make_msg(9, role="tool")],
# Visibility expansion
[_make_msg(8, role="assistant")],
]
await get_chat_messages_paginated(SESSION_ID, limit=1, user_id="user-abc")
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
# Both find_many calls should include the user_id session filter
for call in find_many.call_args_list:
where = call.kwargs.get("where") or call[1].get("where")
assert "Session" in where
assert where["Session"] == {"is": {"userId": "user-abc"}}
assert find_many.call_count == 0
@pytest.mark.asyncio
async def test_forward_tail_boundary_trims_trailing_tool_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pages that end with tool messages are trimmed to the owning
assistant so the next after_sequence page doesn't start mid-tool-group."""
find_first, _ = mock_db
# DB returns 4 messages ASC: assistant at 0, tool at 1, tool at 2, tool at 3
find_first.return_value = _make_session(
messages=[
_make_msg(0, role="assistant"),
_make_msg(1, role="tool"),
_make_msg(2, role="tool"),
_make_msg(3, role="tool"),
],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
assert page is not None
# Page should be trimmed to end at the assistant message
assert [m.sequence for m in page.messages] == [0]
assert page.newest_sequence == 0
# has_more must be True so the client fetches the tool messages on next page
assert page.has_more is True
@pytest.mark.asyncio
async def test_forward_tail_boundary_no_trim_when_last_not_tool(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Forward pages that end with a non-tool message are not trimmed."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[
_make_msg(0, role="user"),
_make_msg(1, role="assistant"),
_make_msg(2, role="tool"),
_make_msg(3, role="assistant"),
],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
assert page is not None
assert [m.sequence for m in page.messages] == [0, 1, 2, 3]
assert page.newest_sequence == 3
assert page.has_more is False
@pytest.mark.asyncio
@@ -461,8 +506,7 @@ async def test_boundary_expansion_warns_when_no_owner_found(
with patch("backend.copilot.db.logger") as mock_logger:
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
# Two warnings: boundary expansion + visibility expansion (all tool msgs)
assert mock_logger.warning.call_count == 2
mock_logger.warning.assert_called_once()
assert page is not None
assert page.messages[0].role == "tool"

View File

@@ -0,0 +1,71 @@
"""Per-request idempotency lock for the /stream endpoint.
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
Lifecycle
---------
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
``None`` when the key already exists (duplicate request).
2. ``release()`` — deletes the key. Must be called on turn completion or turn
error so the next legitimate send is never blocked.
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
the backend turn is still running, and releasing would reopen the duplicate
window for infra-level retries. The 30 s TTL is the safety net.
"""
import hashlib
import logging
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_KEY_PREFIX = "chat:msg_dedup"
_TTL_SECONDS = 30
class _DedupLock:
def __init__(self, key: str, redis) -> None:
self._key = key
self._redis = redis
async def release(self) -> None:
"""Best-effort key deletion. The TTL handles failures silently."""
try:
await self._redis.delete(self._key)
except Exception:
pass
async def acquire_dedup_lock(
session_id: str,
message: str | None,
file_ids: list[str] | None,
) -> _DedupLock | None:
"""Acquire the idempotency lock for this (session, message, files) tuple.
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
Returns ``None`` when a duplicate is detected (lock already held).
Returns ``None`` when there is nothing to deduplicate (no message, no files).
"""
if not message and not file_ids:
return None
sorted_ids = ":".join(sorted(file_ids or []))
content_hash = hashlib.sha256(
f"{session_id}:{message or ''}:{sorted_ids}".encode()
).hexdigest()[:16]
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
redis = await get_redis_async()
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
if not acquired:
logger.warning(
f"[STREAM] Duplicate user message blocked for session {session_id}, "
f"hash={content_hash} — returning empty SSE",
)
return None
return _DedupLock(key, redis)

View File

@@ -0,0 +1,94 @@
"""Unit tests for backend.copilot.message_dedup."""
from unittest.mock import AsyncMock
import pytest
import pytest_mock
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
return mock_redis
@pytest.mark.asyncio
async def test_acquire_returns_none_when_no_message_no_files(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Nothing to deduplicate — no Redis call made, None returned."""
mock_redis = _patch_redis(mocker, set_returns=True)
result = await acquire_dedup_lock("sess-1", None, None)
assert result is None
mock_redis.set.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_returns_lock_on_first_request(
mocker: pytest_mock.MockerFixture,
) -> None:
"""First request acquires the lock and returns a _DedupLock."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
mock_redis.set.assert_called_once()
key_arg = mock_redis.set.call_args.args[0]
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
@pytest.mark.asyncio
async def test_acquire_returns_none_on_duplicate(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Duplicate request (NX fails) returns None to signal the caller."""
_patch_redis(mocker, set_returns=None)
result = await acquire_dedup_lock("sess-1", "hello", None)
assert result is None
@pytest.mark.asyncio
async def test_acquire_key_stable_across_file_order(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs are sorted before hashing so order doesn't affect the key."""
mock_redis_1 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
key_ab = mock_redis_1.set.call_args.args[0]
mock_redis_2 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
key_ba = mock_redis_2.set.call_args.args[0]
assert key_ab == key_ba
@pytest.mark.asyncio
async def test_release_deletes_key(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() calls Redis delete exactly once."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release()
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_release_swallows_redis_error(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() must not raise even when Redis delete fails."""
mock_redis = _patch_redis(mocker, set_returns=True)
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release() # must not raise
mock_redis.delete.assert_called_once()

View File

@@ -1,8 +1,9 @@
import asyncio
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any, AsyncIterator, Self, cast
from typing import Any, Self, cast
from weakref import WeakValueDictionary
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@@ -521,7 +522,10 @@ async def upsert_chat_session(
callers are aware of the persistence failure.
RedisError: If the cache write fails (after successful DB write).
"""
async with _get_session_lock(session.session_id) as _:
# Acquire session-specific lock to prevent concurrent upserts
lock = await _get_session_lock(session.session_id)
async with lock:
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
@@ -647,50 +651,20 @@ async def _save_session_to_db(
msg.sequence = existing_message_count + i
async def append_and_save_message(
session_id: str, message: ChatMessage
) -> ChatSession | None:
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Returns the updated session, or None if the message was detected as a
duplicate (idempotency guard). Callers must check for None and skip any
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
The idempotency check below provides a last-resort guard when the lock degrades.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
async with _get_session_lock(session_id) as lock_acquired:
# When the lock degraded (Redis down or 2s timeout), bypass cache for
# the idempotency check. Stale cache could let two concurrent writers
# both see the old state, pass the check, and write the same message.
if lock_acquired:
session = await get_chat_session(session_id)
else:
session = await _get_session_from_db(session_id)
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
# Idempotency: skip if the trailing block of same-role messages already
# contains this content. Uses is_message_duplicate which checks all
# consecutive trailing messages of the same role, not just [-1].
#
# This collapses infra/nginx retries whether they land on the same pod
# (serialised by the Redis lock) or a different pod.
#
# Legit same-text messages are distinguished by the assistant turn
# between them: if the user said "yes", got a response, and says
# "yes" again, session.messages[-1] is the assistant reply, so the
# role check fails and the second message goes through normally.
#
# Edge case: if a turn dies without writing any assistant message,
# the user's next send of the same text is blocked here permanently.
# The fix is to ensure failed turns always write an error/timeout
# assistant message so the session always ends on an assistant turn.
if message.content is not None and is_message_duplicate(
session.messages, message.role, message.content
):
return None # duplicate — caller should skip enqueue
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
@@ -705,9 +679,6 @@ async def append_and_save_message(
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
# Invalidate the stale entry so future reads fall back to DB,
# preventing a retry from bypassing the idempotency check above.
await invalidate_session_cache(session_id)
return session
@@ -793,6 +764,10 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
@@ -857,38 +832,25 @@ async def update_session_title(
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
@asynccontextmanager
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
"""Distributed Redis lock for a session, usable as an async context manager.
Yields True if the lock was acquired, False if it timed out or Redis was
unavailable. Callers should treat False as a degraded mode and prefer fresh
DB reads over cache to avoid acting on stale state.
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
"""
_lock_key = f"copilot:session_lock:{session_id}"
lock = None
acquired = False
try:
_redis = await get_redis_async()
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
acquired = await lock.acquire(blocking=True)
if not acquired:
logger.warning(
"Could not acquire session lock for %s within 2s", session_id
)
except Exception as e:
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
try:
yield acquired
finally:
if acquired and lock is not None:
try:
await lock.release()
except Exception:
pass # TTL will expire the key
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock

View File

@@ -11,13 +11,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from pytest_mock import MockerFixture
from .model import (
ChatMessage,
ChatSession,
Usage,
append_and_save_message,
get_chat_session,
is_message_duplicate,
maybe_append_user_message,
@@ -576,345 +574,3 @@ def test_maybe_append_assistant_skips_duplicate():
result = maybe_append_user_message(session, "dup", is_user_message=False)
assert result is False
assert len(session.messages) == 2
# --------------------------------------------------------------------------- #
# append_and_save_message #
# --------------------------------------------------------------------------- #
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
s = ChatSession.new(user_id="u1", dry_run=False)
s.messages = list(msgs)
return s
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_returns_none_for_duplicate(
mocker: MockerFixture,
) -> None:
"""append_and_save_message returns None when the trailing message is a duplicate."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="hello")
)
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_appends_new_message(
mocker: MockerFixture,
) -> None:
"""append_and_save_message appends a non-duplicate message and returns the session."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=2)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="second message")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None
assert result.messages[-1].content == "second message"
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_when_session_not_found(
mocker: MockerFixture,
) -> None:
"""append_and_save_message raises ValueError when the session does not exist."""
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=None,
)
with pytest.raises(ValueError, match="not found"):
await append_and_save_message(
"missing-session-id", ChatMessage(role="user", content="hi")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_lock_degraded(
mocker: MockerFixture,
) -> None:
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
# DB path was used (not cache-first)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_database_error_on_save_failure(
mocker: MockerFixture,
) -> None:
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
from backend.util.exceptions import DatabaseError
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("db down"),
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
with pytest.raises(DatabaseError):
await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
mocker: MockerFixture,
) -> None:
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("redis write failed"),
)
mock_invalidate = mocker.patch(
"backend.copilot.model.invalidate_session_cache",
new_callable=mocker.AsyncMock,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
# DB write succeeded, cache invalidation was called
mock_invalidate.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_redis_unavailable(
mocker: MockerFixture,
) -> None:
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
side_effect=ConnectionError("redis down"),
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_lock_release_failure_is_ignored(
mocker: MockerFixture,
) -> None:
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock(
side_effect=RuntimeError("release failed")
)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None

View File

@@ -125,12 +125,7 @@ config = ChatConfig()
class _SystemPromptPreset(SystemPromptPreset, total=False):
"""Extends :class:`SystemPromptPreset` with ``exclude_dynamic_sections``.
The field was added to the upstream TypedDict in claude-agent-sdk 0.1.59.
Until the package is pinned to that version we declare it locally so Pyright
accepts the kwarg without a ``# type: ignore`` comment.
"""
"""Extends SystemPromptPreset with fields added in claude-agent-sdk 0.1.59."""
exclude_dynamic_sections: NotRequired[bool]
@@ -898,7 +893,7 @@ def _write_cli_session_to_disk(
return False
def read_cli_session_from_disk(
def _read_cli_session_from_disk(
sdk_cwd: str,
session_id: str,
log_prefix: str,
@@ -978,7 +973,7 @@ def read_cli_session_from_disk(
return stripped_bytes
def process_cli_restore(
def _process_cli_restore(
cli_restore: TranscriptDownload,
sdk_cwd: str,
session_id: str,
@@ -2494,7 +2489,9 @@ async def _restore_cli_session_for_turn(
# session path, so we validate BEFORE any disk write.
stripped = ""
if cli_restore is not None and sdk_cwd:
stripped, ok = process_cli_restore(cli_restore, sdk_cwd, session_id, log_prefix)
stripped, ok = _process_cli_restore(
cli_restore, sdk_cwd, session_id, log_prefix
)
if not ok:
result.transcript_covers_prefix = False
cli_restore = None
@@ -3639,7 +3636,7 @@ async def stream_chat_completion_sdk(
# this turn ran without --resume (restore failed or first T2+ on a new
# pod), the T1 session file at the expected path may still be present
# and should be re-uploaded so the next turn can resume from it.
# read_cli_session_from_disk returns None when the file is absent, so
# _read_cli_session_from_disk returns None when the file is absent, so
# this is always safe.
#
# Intentionally NOT gated on skip_transcript_upload: that flag is set
@@ -3668,7 +3665,7 @@ async def stream_chat_completion_sdk(
try:
# Read the CLI's native session file from disk (written by the CLI
# after the turn), then upload the bytes to GCS.
_cli_content = read_cli_session_from_disk(
_cli_content = _read_cli_session_from_disk(
sdk_cwd, session_id, log_prefix
)
if _cli_content:

View File

@@ -1371,7 +1371,7 @@ class TestStripStaleThinkingBlocks:
class TestProcessCliRestore:
"""``process_cli_restore`` validates, strips, and writes CLI session to disk."""
"""``_process_cli_restore`` validates, strips, and writes CLI session to disk."""
def test_writes_stripped_bytes_not_raw(self, tmp_path):
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
@@ -1380,7 +1380,7 @@ class TestProcessCliRestore:
from pathlib import Path
from unittest.mock import patch
from backend.copilot.sdk.service import process_cli_restore
from backend.copilot.sdk.service import _process_cli_restore
from backend.copilot.transcript import TranscriptDownload
session_id = "12345678-0000-0000-0000-abcdef000001"
@@ -1406,7 +1406,7 @@ class TestProcessCliRestore:
return_value=projects_base_dir,
),
):
stripped_str, ok = process_cli_restore(
stripped_str, ok = _process_cli_restore(
restore, sdk_cwd, session_id, "[Test]"
)
@@ -1433,7 +1433,7 @@ class TestProcessCliRestore:
def test_invalid_content_returns_false(self):
"""Content that fails validation after strip returns (empty, False)."""
from backend.copilot.sdk.service import process_cli_restore
from backend.copilot.sdk.service import _process_cli_restore
from backend.copilot.transcript import TranscriptDownload
# A single progress-only entry — stripped result will be empty/invalid
@@ -1442,7 +1442,7 @@ class TestProcessCliRestore:
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
)
stripped_str, ok = process_cli_restore(
stripped_str, ok = _process_cli_restore(
restore,
"/tmp/nonexistent-sdk-cwd",
"12345678-0000-0000-0000-000000000099",
@@ -1454,7 +1454,7 @@ class TestProcessCliRestore:
class TestReadCliSessionFromDisk:
"""``read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
"""``_read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
def _build_session_file(self, tmp_path, session_id: str):
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
@@ -1472,7 +1472,7 @@ class TestReadCliSessionFromDisk:
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
from unittest.mock import patch
from backend.copilot.sdk.service import read_cli_session_from_disk
from backend.copilot.sdk.service import _read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0001"
projects_base_dir = str(tmp_path)
@@ -1491,7 +1491,7 @@ class TestReadCliSessionFromDisk:
return_value=projects_base_dir,
),
):
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
assert result == b"\xff\xfe invalid utf-8\n"
@@ -1500,7 +1500,7 @@ class TestReadCliSessionFromDisk:
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
from unittest.mock import patch
from backend.copilot.sdk.service import read_cli_session_from_disk
from backend.copilot.sdk.service import _read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0002"
projects_base_dir = str(tmp_path)
@@ -1527,7 +1527,7 @@ class TestReadCliSessionFromDisk:
return_value=projects_base_dir,
),
):
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
finally:
session_file.chmod(0o644)

View File

@@ -423,33 +423,20 @@ async def subscribe_to_session(
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
# RACE CONDITION FIX: If session not found, retry with backoff.
# Duplicate requests skip create_session and subscribe immediately; the
# original request's create_session (a Redis hset) may not have completed
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
# original request before the hset even starts.
# RACE CONDITION FIX: If session not found, retry once after small delay
# This handles the case where subscribe_to_session is called immediately
# after create_session but before Redis propagates the write
if not meta:
_max_retries = 3
_retry_delay = 0.1 # 100ms per attempt
for attempt in range(_max_retries):
logger.warning(
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
f"retrying after {int(_retry_delay * 1000)}ms",
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
)
await asyncio.sleep(_retry_delay)
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
logger.info(
f"[TIMING] Session found after {attempt + 1} retries",
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
)
break
else:
logger.warning(
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
extra={"json_fields": {**log_meta}},
)
await asyncio.sleep(0.05) # 50ms
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
f"({elapsed:.1f}ms total)",
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
extra={
"json_fields": {
**log_meta,
@@ -459,6 +446,10 @@ async def subscribe_to_session(
},
)
return None
logger.info(
"[TIMING] Session found after retry",
extra={"json_fields": {**log_meta}},
)
# Note: Redis client uses decode_responses=True, so keys are strings
session_status = meta.get("status", "")

View File

@@ -143,8 +143,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,
LlmModel.GROK_4_20: 5,
LlmModel.GROK_4_20_MULTI_AGENT: 5,
LlmModel.GROK_CODE_FAST_1: 1,
LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1,

View File

@@ -1,13 +1,10 @@
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast
import stripe
from fastapi.concurrency import run_in_threadpool
from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
@@ -34,7 +31,6 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.util.cache import cached
from backend.util.exceptions import InsufficientBalanceError
from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled
from backend.util.json import SafeJson, dumps
@@ -436,7 +432,7 @@ class UserCreditBase(ABC):
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
)
# Single unified atomic operation for all transaction types using UserBalance
@@ -575,7 +571,7 @@ class UserCreditBase(ABC):
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=current_balance,
amount=amount,
@@ -586,6 +582,7 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
@@ -737,7 +734,7 @@ class UserCredit(UserCreditBase):
)
if request.amount <= 0 or request.amount > transaction.amount:
raise AssertionError(
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
)
balance, _ = await self._add_transaction(
@@ -791,12 +788,12 @@ class UserCredit(UserCreditBase):
# If the user has enough balance, just let them win the dispute.
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
dispute.close()
return
logger.warning(
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
f"Adding extra info for dispute from {user_id} for ${amount/100}"
)
# Retrieve recent transaction history to support our evidence.
# This provides a concise timeline that shows service usage and proper credit application.
@@ -1240,23 +1237,14 @@ async def get_stripe_customer_id(user_id: str) -> str:
if user.stripe_customer_id:
return user.stripe_customer_id
# Race protection: two concurrent calls (e.g. user double-clicks "Upgrade",
# or any retried request) would each pass the check above and create their
# own Stripe Customer, leaving an orphaned billable customer in Stripe.
# Pass an idempotency_key so Stripe collapses concurrent + retried calls
# into the same Customer object server-side. The 24h Stripe idempotency
# window comfortably covers any realistic in-flight retry scenario.
customer = await run_in_threadpool(
stripe.Customer.create,
customer = stripe.Customer.create(
name=user.name or "",
email=user.email,
metadata={"user_id": user_id},
idempotency_key=f"customer-create-{user_id}",
)
await User.prisma().update(
where={"id": user_id}, data={"stripeCustomerId": customer.id}
)
get_user_by_id.cache_delete(user_id)
return customer.id
@@ -1275,203 +1263,23 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
data={"subscriptionTier": tier},
)
get_user_by_id.cache_delete(user_id)
# Also invalidate the rate-limit tier cache so CoPilot picks up the new
# tier immediately rather than waiting up to 5 minutes for the TTL to expire.
from backend.copilot.rate_limit import get_user_tier # local import avoids circular
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
async def _cancel_customer_subscriptions(
customer_id: str,
exclude_sub_id: str | None = None,
at_period_end: bool = False,
) -> int:
"""Cancel all billable Stripe subscriptions for a customer, optionally excluding one.
Cancels both ``active`` and ``trialing`` subscriptions, since trialing subs will
start billing once the trial ends and must be cleaned up on downgrade/upgrade to
avoid double-charging or charging users who intended to cancel.
When ``at_period_end=True``, schedules cancellation at the end of the current
billing period instead of cancelling immediately — the user keeps their tier
until the period ends, then ``customer.subscription.deleted`` fires and the
webhook downgrades them to FREE.
Wraps every synchronous Stripe SDK call with run_in_threadpool so the async event
loop is never blocked. Raises stripe.StripeError on list/cancel failure so callers
that need strict consistency can react; cleanup callers can catch and log instead.
Returns the number of subscriptions cancelled/scheduled for cancellation.
"""
# Query active and trialing separately; Stripe's list API accepts a single status
# filter at a time (no OR), and we explicitly want to skip canceled/incomplete/
# past_due subs rather than filter them out client-side via status="all".
seen_ids: set[str] = set()
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status=status, limit=10
)
# Iterate only the first page (up to 10); avoid auto_paging_iter which would
# trigger additional sync HTTP calls inside the event loop.
if subscriptions.has_more:
logger.error(
"_cancel_customer_subscriptions: customer %s has more than 10 %s"
" subscriptions — only the first page was processed; remaining"
" subscriptions were NOT cancelled",
customer_id,
status,
async def cancel_stripe_subscription(user_id: str) -> None:
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
customer_id = await get_stripe_customer_id(user_id)
subscriptions = stripe.Subscription.list(
customer=customer_id, status="active", limit=10
)
for sub in subscriptions.auto_paging_iter():
try:
stripe.Subscription.cancel(sub["id"])
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: failed to cancel sub %s for user %s",
sub["id"],
user_id,
)
for sub in subscriptions.data:
sub_id = sub["id"]
if exclude_sub_id and sub_id == exclude_sub_id:
continue
if sub_id in seen_ids:
continue
seen_ids.add(sub_id)
if at_period_end:
await run_in_threadpool(
stripe.Subscription.modify, sub_id, cancel_at_period_end=True
)
else:
await run_in_threadpool(stripe.Subscription.cancel, sub_id)
return len(seen_ids)
async def cancel_stripe_subscription(user_id: str) -> bool:
"""Schedule cancellation of all active/trialing Stripe subscriptions at period end.
The subscription stays active until the end of the billing period so the user
keeps their tier for the time they already paid for. The ``customer.subscription.deleted``
webhook fires at period end and downgrades the DB tier to FREE.
Returns True if at least one subscription was found and scheduled for cancellation,
False if the customer had no active/trialing subscriptions (e.g., admin-granted tier
with no associated Stripe subscription). When False, the caller should update the
DB tier directly since no webhook will fire to do it.
Raises stripe.StripeError if any modification fails, so the caller can avoid
updating the DB tier when Stripe is inconsistent.
"""
# Guard: only proceed if the user already has a Stripe customer ID. Calling
# get_stripe_customer_id for a user who has never had a paid subscription would
# create an orphaned, potentially-billable Stripe Customer object — we avoid that
# by returning False early so the caller can downgrade the DB tier directly.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
customer_id = user.stripe_customer_id
try:
cancelled_count = await _cancel_customer_subscriptions(
customer_id, at_period_end=True
)
return cancelled_count > 0
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: Stripe error while cancelling subs for user %s",
user_id,
)
raise
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
Fetches the user's active Stripe subscription to determine how many seconds
remain in the current billing period, then calculates the unused portion of
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
is found.
"""
if monthly_cost_cents <= 0:
return 0
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
# orphaned customer on every billing-page load for those users.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return 0
try:
customer_id = user.stripe_customer_id
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status="active", limit=1
)
if not subscriptions.data:
return 0
sub = subscriptions.data[0]
period_start: int = sub["current_period_start"]
period_end: int = sub["current_period_end"]
now = int(time.time())
total_seconds = period_end - period_start
remaining_seconds = max(period_end - now, 0)
if total_seconds <= 0:
return 0
return int(monthly_cost_cents * remaining_seconds / total_seconds)
except Exception:
logger.warning(
"get_proration_credit_cents: failed to compute proration for user %s",
user_id,
)
return 0
async def modify_stripe_subscription_for_tier(
user_id: str, tier: SubscriptionTier
) -> bool:
"""Modify an existing Stripe subscription to a new paid tier using proration.
For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing
subscription is preferable to cancelling + creating a new one via Checkout:
Stripe handles proration automatically, crediting unused time on the old plan
and charging the pro-rated amount for the new plan in the same billing cycle.
Returns:
True — a subscription was found and modified successfully.
False — no active/trialing subscription exists (e.g. admin-granted tier or
first-time paid signup); caller should fall back to Checkout.
Raises stripe.StripeError on API failures so callers can propagate a 502.
Raises ValueError when no Stripe price ID is configured for the tier.
"""
price_id = await get_subscription_price_id(tier)
if not price_id:
raise ValueError(f"No Stripe price ID configured for tier {tier}")
# Guard: only proceed if the user already has a Stripe customer ID. Calling
# get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier)
# would create an orphaned customer object if the subsequent Subscription.list call
# fails. Return False early so the API layer falls back to Checkout instead.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
customer_id = user.stripe_customer_id
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status=status, limit=1
)
if not subscriptions.data:
continue
sub = subscriptions.data[0]
sub_id = sub["id"]
items = sub.get("items", {}).get("data", [])
if not items:
continue
item_id = items[0]["id"]
await run_in_threadpool(
stripe.Subscription.modify,
sub_id,
items=[{"id": item_id, "price": price_id}],
proration_behavior="create_prorations",
)
logger.info(
"modify_stripe_subscription_for_tier: modified sub %s for user %s%s",
sub_id,
user_id,
tier,
)
return True
return False
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
@@ -1483,19 +1291,8 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
return AutoTopUpConfig.model_validate(user.top_up_config)
@cached(ttl_seconds=60, maxsize=8, cache_none=False)
async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
"""Return Stripe Price ID for a tier from LaunchDarkly, cached for 60 seconds.
Price IDs are LaunchDarkly flag values that change only at deploy time.
Caching for 60 seconds avoids hitting the LD SDK on every webhook delivery
and every GET /credits/subscription page load (called 2x per request).
``cache_none=False`` prevents a transient LD failure from caching ``None``
and blocking subscription upgrades for the full 60-second TTL window.
A tier with no configured flag (FREE, ENTERPRISE) returns ``None`` from an
O(1) dict lookup before hitting LD, so the extra LD call is never made.
"""
"""Return Stripe Price ID for a tier from LaunchDarkly. None = not configured."""
flag_map = {
SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO,
SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS,
@@ -1503,7 +1300,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
flag = flag_map.get(tier)
if flag is None:
return None
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
return price_id if isinstance(price_id, str) and price_id else None
@@ -1518,8 +1315,7 @@ async def create_subscription_checkout(
if not price_id:
raise ValueError(f"Subscription not available for tier {tier.value}")
customer_id = await get_stripe_customer_id(user_id)
session = await run_in_threadpool(
stripe.checkout.Session.create,
session = stripe.checkout.Session.create(
customer=customer_id,
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
@@ -1527,111 +1323,26 @@ async def create_subscription_checkout(
cancel_url=cancel_url,
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
)
if not session.url:
# An empty checkout URL for a paid upgrade is always an error; surfacing it
# as ValueError means the API handler returns 422 instead of silently
# redirecting the client to an empty URL.
raise ValueError("Stripe did not return a checkout session URL")
return session.url
async def _cleanup_stale_subscriptions(customer_id: str, new_sub_id: str) -> None:
"""Best-effort cancel of any active subs for the customer other than new_sub_id.
Called from the webhook handler after a new subscription becomes active. Failures
are logged but not raised so a transient Stripe error doesn't crash the webhook —
a periodic reconciliation job is the intended backstop for persistent drift.
NOTE: until that reconcile job lands, a failure here means the user is silently
billed for two simultaneous subscriptions. The error log below is intentionally
`logger.exception` so it surfaces in Sentry with the customer/sub IDs needed to
manually reconcile, and the metric `stripe_stale_subscription_cleanup_failed`
is bumped so on-call can alert on persistent drift.
TODO(#stripe-reconcile-job): replace this best-effort cleanup with a periodic
reconciliation job that queries Stripe for customers with >1 active sub.
"""
try:
await _cancel_customer_subscriptions(customer_id, exclude_sub_id=new_sub_id)
except stripe.StripeError:
# Use exception() (not warning) so this surfaces as an error in Sentry —
# any failure here means a paid-to-paid upgrade may have left the user
# with two simultaneous active subscriptions.
logger.exception(
"stripe_stale_subscription_cleanup_failed: customer=%s new_sub=%s"
" user may be billed for two simultaneous subscriptions; manual"
" reconciliation required",
customer_id,
new_sub_id,
)
return session.url or ""
async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
"""Update User.subscriptionTier from a Stripe subscription object.
Expected shape of stripe_subscription (subset of Stripe's Subscription object):
customer: str — Stripe customer ID
status: str — "active" | "trialing" | "canceled" | ...
id: str — Stripe subscription ID
items.data[].price.id: str — Stripe price ID identifying the tier
"""
customer_id = stripe_subscription.get("customer")
if not customer_id:
logger.warning(
"sync_subscription_from_stripe: missing 'customer' field in event, "
"skipping (keys: %s)",
list(stripe_subscription.keys()),
)
return
"""Update User.subscriptionTier from a Stripe subscription object."""
customer_id = stripe_subscription["customer"]
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"sync_subscription_from_stripe: no user for customer %s", customer_id
)
return
# Cross-check: if the subscription carries a metadata.user_id (set during
# Checkout Session creation), verify it matches the user we found via
# stripeCustomerId. A mismatch indicates a customer↔user mapping
# inconsistency — updating the wrong user's tier would be a data-corruption
# bug, so we log loudly and bail out. Absence of metadata.user_id (e.g.
# subscriptions created outside the Checkout flow) is not an error — we
# simply skip the check and proceed with the customer-ID-based lookup.
metadata = stripe_subscription.get("metadata") or {}
metadata_user_id = metadata.get("user_id") if isinstance(metadata, dict) else None
if metadata_user_id and metadata_user_id != user.id:
logger.error(
"sync_subscription_from_stripe: metadata.user_id=%s does not match"
" user.id=%s found via stripeCustomerId=%s — refusing to update tier"
" to avoid corrupting the wrong user's subscription state",
metadata_user_id,
user.id,
customer_id,
)
return
# ENTERPRISE tiers are admin-managed. Never let a Stripe webhook flip an
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
# a self-service Stripe sub, it's a data-consistency issue for an operator,
# not something the webhook should automatically "fix".
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
" for user %s (customer %s); event status=%s",
user.id,
customer_id,
stripe_subscription.get("status", ""),
)
return
status = stripe_subscription.get("status", "")
new_sub_id = stripe_subscription.get("id", "")
if status in ("active", "trialing"):
price_id = ""
items = stripe_subscription.get("items", {}).get("data", [])
if items:
price_id = items[0].get("price", {}).get("id", "")
pro_price, biz_price = await asyncio.gather(
get_subscription_price_id(SubscriptionTier.PRO),
get_subscription_price_id(SubscriptionTier.BUSINESS),
)
pro_price = await get_subscription_price_id(SubscriptionTier.PRO)
biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS)
if price_id and pro_price and price_id == pro_price:
tier = SubscriptionTier.PRO
elif price_id and biz_price and price_id == biz_price:
@@ -1648,206 +1359,10 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
)
return
else:
# A subscription was cancelled or ended. DO NOT unconditionally downgrade
# to FREE — Stripe does not guarantee webhook delivery order, so a
# `customer.subscription.deleted` for the OLD sub can arrive after we've
# already processed `customer.subscription.created` for a new paid sub.
# Ask Stripe whether any OTHER active/trialing subs exist for this
# customer; if they do, keep the user's current tier (the other sub's
# own event will/has already set the correct tier).
try:
other_subs_active, other_subs_trialing = await asyncio.gather(
run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status="active",
limit=10,
),
run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status="trialing",
limit=10,
),
)
except stripe.StripeError:
logger.warning(
"sync_subscription_from_stripe: could not verify other active"
" subs for customer %s on cancel event %s; preserving current"
" tier to avoid an unsafe downgrade",
customer_id,
new_sub_id,
)
return
# Filter out the cancelled subscription to check if other active subs
# exist. When new_sub_id is empty (malformed event with no 'id' field),
# we cannot safely exclude any sub — preserve current tier to avoid
# an unsafe downgrade on a malformed webhook payload.
if not new_sub_id:
logger.warning(
"sync_subscription_from_stripe: cancel event missing 'id' field"
" for customer %s; preserving current tier",
customer_id,
)
return
other_active_ids = {sub["id"] for sub in other_subs_active.data} - {new_sub_id}
other_trialing_ids = {sub["id"] for sub in other_subs_trialing.data} - {
new_sub_id
}
still_has_active_sub = bool(other_active_ids or other_trialing_ids)
if still_has_active_sub:
logger.info(
"sync_subscription_from_stripe: sub %s cancelled but customer %s"
" still has another active sub; keeping tier %s",
new_sub_id,
customer_id,
current_tier.value,
)
return
tier = SubscriptionTier.FREE
# Idempotency: Stripe retries webhooks on delivery failure, and several event
# types map to the same final tier. Skip the DB write + cache invalidation
# when the tier is already correct to avoid redundant writes on replay.
if current_tier == tier:
return
# When a new subscription becomes active (e.g. paid-to-paid tier upgrade
# via a fresh Checkout Session), cancel any OTHER active subscriptions for
# the same customer so the user isn't billed twice. We do this in the
# webhook rather than the API handler so that abandoning the checkout
# doesn't leave the user without a subscription.
# IMPORTANT: this runs AFTER the idempotency check above so that webhook
# replays for an already-applied event do NOT trigger another cleanup round
# (which could otherwise cancel a legitimately new subscription the user
# signed up for between the original event and its replay).
if status in ("active", "trialing") and new_sub_id:
# NOTE: paid-to-paid upgrade race (e.g. PRO → BUSINESS):
# _cleanup_stale_subscriptions cancels the old PRO sub before
# set_subscription_tier writes BUSINESS to the DB. If Stripe delivers
# the PRO `customer.subscription.deleted` event concurrently and it
# processes after the PRO cancel but before set_subscription_tier
# commits, the user could momentarily appear as FREE in the DB.
# This window is very short in practice (two sequential awaits),
# but is a known limitation of the current webhook-driven approach.
# A future improvement would be to write the new tier first, then
# cancel the old sub.
await _cleanup_stale_subscriptions(customer_id, new_sub_id)
await set_subscription_tier(user.id, tier)
async def handle_subscription_payment_failure(invoice: dict) -> None:
"""Handle a failed Stripe subscription payment.
Tries to cover the invoice amount from the user's credit balance.
- Balance sufficient → deduct from balance, then pay the Stripe invoice so
Stripe stops retrying it. The sub stays intact and the user keeps their tier.
- Balance insufficient → cancel Stripe sub immediately, downgrade to FREE.
Cancelling here avoids further Stripe retries on an invoice we cannot cover.
"""
customer_id = invoice.get("customer")
if not customer_id:
logger.warning(
"handle_subscription_payment_failure: missing customer in invoice; skipping"
)
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"handle_subscription_payment_failure: no user found for customer %s",
customer_id,
)
return
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
" (customer %s) — tier is admin-managed",
user.id,
customer_id,
)
return
amount_due: int = invoice.get("amount_due", 0)
sub_id: str = invoice.get("subscription", "")
invoice_id: str = invoice.get("id", "")
if amount_due <= 0:
logger.info(
"handle_subscription_payment_failure: amount_due=%d for user %s;"
" nothing to deduct",
amount_due,
user.id,
)
return
credit_model = UserCredit()
try:
await credit_model._add_transaction(
user_id=user.id,
amount=-amount_due,
transaction_type=CreditTransactionType.SUBSCRIPTION,
fail_insufficient_credits=True,
# Use invoice_id as the idempotency key so that Stripe webhook retries
# (e.g. on a transient stripe.Invoice.pay failure) do not double-charge.
transaction_key=invoice_id or None,
metadata=SafeJson(
{
"stripe_customer_id": customer_id,
"stripe_subscription_id": sub_id,
"reason": "subscription_payment_failure_covered_by_balance",
}
),
)
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
# system stops retrying it — without this call Stripe would retry automatically
# and re-trigger this webhook, causing double-deductions each retry cycle.
if invoice_id:
try:
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: balance deducted for user"
" %s but failed to mark invoice %s as paid; Stripe may retry",
user.id,
invoice_id,
)
logger.info(
"handle_subscription_payment_failure: deducted %d cents from balance"
" for user %s; Stripe invoice %s paid, sub %s intact, tier preserved",
amount_due,
user.id,
invoice_id,
sub_id,
)
except InsufficientBalanceError:
# Balance insufficient — cancel Stripe subscription first, then downgrade DB.
# Order matters: if we downgrade the DB first and the Stripe cancel fails, the
# user is permanently stuck on FREE while Stripe continues billing them.
# Cancelling Stripe first is safe: if the DB write then fails, the webhook
# customer.subscription.deleted will fire and correct the tier eventually.
logger.info(
"handle_subscription_payment_failure: insufficient balance for user %s;"
" cancelling Stripe sub %s then downgrading to FREE",
user.id,
sub_id,
)
try:
await _cancel_customer_subscriptions(customer_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
" for user %s (customer %s); skipping tier downgrade to avoid"
" inconsistency — Stripe may continue retrying the invoice",
sub_id,
user.id,
customer_id,
)
return
await set_subscription_tier(user.id, SubscriptionTier.FREE)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,

View File

@@ -73,31 +73,6 @@ def _get_redis() -> Redis:
return r
class _MissingType:
"""Singleton sentinel type — distinct from ``None`` (a valid cached value).
Using a dedicated class (instead of ``Any = object()``) lets mypy prove
that comparisons ``result is _MISSING`` narrow the type correctly and
prevents accidental use of the sentinel where a real value is expected.
"""
_instance: "_MissingType | None" = None
def __new__(cls) -> "_MissingType":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "<MISSING>"
# Sentinel returned by ``_get_from_memory`` / ``_get_from_redis`` to mean
# "no entry exists" — distinct from a cached ``None`` value, which is a
# valid result for callers that opt into caching it.
_MISSING = _MissingType()
@dataclass
class CachedValue:
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
@@ -185,7 +160,6 @@ def cached(
ttl_seconds: int,
shared_cache: bool = False,
refresh_ttl_on_get: bool = False,
cache_none: bool = True,
) -> Callable[[Callable[P, R]], CachedFunction[P, R]]:
"""
Thundering herd safe cache decorator for both sync and async functions.
@@ -198,10 +172,6 @@ def cached(
ttl_seconds: Time to live in seconds. Required - entries must expire.
shared_cache: If True, use Redis for cross-process caching
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
cache_none: If True (default) ``None`` is cached like any other value.
Set to ``False`` for functions that return ``None`` to signal a
transient error and should be re-tried on the next call without
poisoning the cache (e.g. external API calls that may fail).
Returns:
Decorated function with caching capabilities
@@ -214,12 +184,6 @@ def cached(
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
@cached(ttl_seconds=300, cache_none=False)
async def fetch_external(id: str) -> dict | None:
# Returns None on transient error — won't be stored,
# next call retries instead of returning the stale None.
...
"""
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
@@ -227,14 +191,9 @@ def cached(
cache_storage: dict[tuple, CachedValue] = {}
_event_loop_locks: dict[Any, asyncio.Lock] = {}
def _get_from_redis(redis_key: str) -> Any:
def _get_from_redis(redis_key: str) -> Any | None:
"""Get value from Redis, optionally refreshing TTL.
Returns the cached value (which may be ``None``) on a hit, or the
module-level ``_MISSING`` sentinel on a miss / corrupt entry.
Callers must compare with ``is _MISSING`` so cached ``None`` values
are not mistaken for misses.
Values are expected to carry an HMAC-SHA256 prefix for integrity
verification. Unsigned (legacy) or tampered entries are silently
discarded and treated as cache misses, so the caller recomputes and
@@ -254,11 +213,11 @@ def cached(
f"for {func_name}, discarding entry: "
"possible tampering or legacy unsigned value"
)
return _MISSING
return None
return pickle.loads(payload)
except Exception as e:
logger.error(f"Redis error during cache check for {func_name}: {e}")
return _MISSING
return None
def _set_to_redis(redis_key: str, value: Any) -> None:
"""Set HMAC-signed pickled value in Redis with TTL."""
@@ -268,13 +227,8 @@ def cached(
except Exception as e:
logger.error(f"Redis error storing cache for {func_name}: {e}")
def _get_from_memory(key: tuple) -> Any:
"""Get value from in-memory cache, checking TTL.
Returns the cached value (which may be ``None``) on a hit, or the
``_MISSING`` sentinel on a miss / TTL expiry. See
``_get_from_redis`` for the rationale.
"""
def _get_from_memory(key: tuple) -> Any | None:
"""Get value from in-memory cache, checking TTL."""
if key in cache_storage:
cached_data = cache_storage[key]
if time.time() - cached_data.timestamp < ttl_seconds:
@@ -282,7 +236,7 @@ def cached(
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
)
return cached_data.result
return _MISSING
return None
def _set_to_memory(key: tuple, value: Any) -> None:
"""Set value in in-memory cache with timestamp."""
@@ -316,11 +270,11 @@ def cached(
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not _MISSING:
if result is not None:
return result
else:
result = _get_from_memory(key)
if result is not _MISSING:
if result is not None:
return result
# Slow path: acquire lock for cache miss/expiry
@@ -328,24 +282,22 @@ def cached(
# Double-check: another coroutine might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not _MISSING:
if result is not None:
return result
else:
result = _get_from_memory(key)
if result is not _MISSING:
if result is not None:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {func_name}")
result = await target_func(*args, **kwargs)
# Store result (skip ``None`` if the caller opted out of
# caching it — used for transient-error sentinels).
if cache_none or result is not None:
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result
@@ -363,11 +315,11 @@ def cached(
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not _MISSING:
if result is not None:
return result
else:
result = _get_from_memory(key)
if result is not _MISSING:
if result is not None:
return result
# Slow path: acquire lock for cache miss/expiry
@@ -375,24 +327,22 @@ def cached(
# Double-check: another thread might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not _MISSING:
if result is not None:
return result
else:
result = _get_from_memory(key)
if result is not _MISSING:
if result is not None:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {func_name}")
result = target_func(*args, **kwargs)
# Store result (skip ``None`` if the caller opted out of
# caching it — used for transient-error sentinels).
if cache_none or result is not None:
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result

View File

@@ -1223,123 +1223,3 @@ class TestCacheHMAC:
assert call_count == 2
legacy_test_fn.cache_clear()
class TestCacheNoneHandling:
"""Tests for the ``cache_none`` parameter on the @cached decorator.
Sentry bug PRRT_kwDOJKSTjM56RTEu (HIGH): the cache previously could not
distinguish "no entry" from "entry is None", so any function returning
``None`` was effectively re-executed on every call. The fix is a
sentinel-based check inside the wrappers, plus an opt-out
``cache_none=False`` flag for callers that *want* errors to retry.
"""
@pytest.mark.asyncio
async def test_async_none_is_cached_by_default(self):
"""With ``cache_none=True`` (default), cached ``None`` is returned
from the cache instead of triggering re-execution."""
call_count = 0
@cached(ttl_seconds=300)
async def maybe_none(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
assert await maybe_none(1) is None
assert call_count == 1
# Second call should hit the cache, not re-execute.
assert await maybe_none(1) is None
assert call_count == 1
# Different argument is a different cache key — re-executes.
assert await maybe_none(2) is None
assert call_count == 2
def test_sync_none_is_cached_by_default(self):
call_count = 0
@cached(ttl_seconds=300)
def maybe_none(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
assert maybe_none(1) is None
assert maybe_none(1) is None
assert call_count == 1
@pytest.mark.asyncio
async def test_async_cache_none_false_skips_storing_none(self):
"""``cache_none=False`` skips storing ``None`` so transient errors
are retried on the next call instead of poisoning the cache."""
call_count = 0
results: list[int | None] = [None, None, 42]
@cached(ttl_seconds=300, cache_none=False)
async def maybe_none(x: int) -> int | None:
nonlocal call_count
result = results[call_count]
call_count += 1
return result
# First call: returns None, NOT stored.
assert await maybe_none(1) is None
assert call_count == 1
# Second call with same key: re-executes (None wasn't cached).
assert await maybe_none(1) is None
assert call_count == 2
# Third call: returns 42, this time it IS stored.
assert await maybe_none(1) == 42
assert call_count == 3
# Fourth call: cache hit on the stored 42.
assert await maybe_none(1) == 42
assert call_count == 3
def test_sync_cache_none_false_skips_storing_none(self):
call_count = 0
results: list[int | None] = [None, 99]
@cached(ttl_seconds=300, cache_none=False)
def maybe_none(x: int) -> int | None:
nonlocal call_count
result = results[call_count]
call_count += 1
return result
assert maybe_none(1) is None
assert call_count == 1
# None was not stored — re-executes.
assert maybe_none(1) == 99
assert call_count == 2
# 99 IS stored — no re-execution.
assert maybe_none(1) == 99
assert call_count == 2
@pytest.mark.asyncio
async def test_async_shared_cache_none_is_cached_by_default(self):
"""Shared (Redis) cache also properly returns cached ``None`` values."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
async def maybe_none_redis(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
maybe_none_redis.cache_clear()
assert await maybe_none_redis(1) is None
assert call_count == 1
assert await maybe_none_redis(1) is None
assert call_count == 1
maybe_none_redis.cache_clear()

View File

@@ -1,7 +1,6 @@
import contextlib
import logging
import os
import uuid
from enum import Enum
from functools import wraps
from typing import Any, Awaitable, Callable, TypeVar
@@ -102,12 +101,6 @@ async def _fetch_user_context_data(user_id: str) -> Context:
"""
builder = Context.builder(user_id).kind("user").anonymous(True)
try:
uuid.UUID(user_id)
except ValueError:
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
return builder.build()
try:
from backend.util.clients import get_supabase

View File

@@ -40,8 +40,6 @@
"folder_id": null,
"folder_name": null,
"recommended_schedule_cron": null,
"is_scheduled": false,
"next_scheduled_run": null,
"settings": {
"human_in_the_loop_safe_mode": true,
"sensitive_action_safe_mode": false
@@ -88,8 +86,6 @@
"folder_id": null,
"folder_name": null,
"recommended_schedule_cron": null,
"is_scheduled": false,
"next_scheduled_run": null,
"settings": {
"human_in_the_loop_safe_mode": true,
"sensitive_action_safe_mode": false

View File

@@ -155,7 +155,6 @@
"@types/twemoji": "13.1.2",
"@vitejs/plugin-react": "5.1.2",
"@vitest/coverage-v8": "4.0.17",
"agentation": "3.0.2",
"axe-playwright": "2.2.2",
"chromatic": "13.3.3",
"concurrently": "9.2.1",

View File

@@ -376,9 +376,6 @@ importers:
'@vitest/coverage-v8':
specifier: 4.0.17
version: 4.0.17(vitest@4.0.17(@opentelemetry/api@1.9.0)(@types/node@24.10.0)(happy-dom@20.3.4)(jiti@2.6.1)(jsdom@27.4.0)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(terser@5.44.1)(yaml@2.8.2))
agentation:
specifier: 3.0.2
version: 3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
axe-playwright:
specifier: 2.2.2
version: 2.2.2(playwright@1.56.1)
@@ -4122,17 +4119,6 @@ packages:
resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==}
engines: {node: '>= 14'}
agentation@3.0.2:
resolution: {integrity: sha512-iGzBxFVTuZEIKzLY6AExSLAQH6i6SwxV4pAu7v7m3X6bInZ7qlZXAwrEqyc4+EfP4gM7z2RXBF6SF4DeH0f2lA==}
peerDependencies:
react: '>=18.0.0'
react-dom: '>=18.0.0'
peerDependenciesMeta:
react:
optional: true
react-dom:
optional: true
ai@6.0.134:
resolution: {integrity: sha512-YalNEaavld/kE444gOcsMKXdVVRGEe0SK77fAFcWYcqLg+a7xKnEet8bdfrEAJTfnMjj01rhgrIL10903w1a5Q==}
engines: {node: '>=18'}
@@ -13133,11 +13119,6 @@ snapshots:
agent-base@7.1.4:
optional: true
agentation@3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
optionalDependencies:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
ai@6.0.134(zod@3.25.76):
dependencies:
'@ai-sdk/gateway': 3.0.77(zod@3.25.76)

View File

@@ -1,5 +1,5 @@
import { describe, expect, it } from "vitest";
import { getNodeDisplayName, serializeGraphForChat } from "../helpers";
import { serializeGraphForChat } from "../helpers";
import type { CustomNode } from "../../FlowEditor/nodes/CustomNode/CustomNode";
describe("serializeGraphForChat XML injection prevention", () => {
@@ -53,53 +53,3 @@ describe("serializeGraphForChat XML injection prevention", () => {
expect(result).toContain("&lt;injection&gt;");
});
});
function makeNode(overrides: Partial<CustomNode["data"]> = {}): CustomNode {
return {
id: "node-1",
data: {
title: "AgentExecutorBlock",
description: "",
hardcodedValues: {},
inputSchema: {},
outputSchema: {},
uiType: "agent",
block_id: "b1",
costs: [],
categories: [],
...overrides,
},
type: "custom" as const,
position: { x: 0, y: 0 },
} as unknown as CustomNode;
}
describe("getNodeDisplayName", () => {
it("returns fallback when node is undefined", () => {
expect(getNodeDisplayName(undefined, "fallback-id")).toBe("fallback-id");
});
it("returns customized_name when set", () => {
const node = makeNode({
metadata: { customized_name: "My Agent" } as any,
});
expect(getNodeDisplayName(node, "fallback")).toBe("My Agent");
});
it("returns agent_name with version via getNodeDisplayTitle delegation", () => {
const node = makeNode({
hardcodedValues: { agent_name: "Researcher", graph_version: 3 },
});
expect(getNodeDisplayName(node, "fallback")).toBe("Researcher v3");
});
it("returns block title when no custom or agent name", () => {
const node = makeNode({ title: "SomeBlock" });
expect(getNodeDisplayName(node, "fallback")).toBe("SomeBlock");
});
it("returns fallback when title is empty", () => {
const node = makeNode({ title: "" });
expect(getNodeDisplayName(node, "fallback")).toBe("fallback");
});
});

View File

@@ -1,6 +1,5 @@
import type { CustomNode } from "../FlowEditor/nodes/CustomNode/CustomNode";
import type { CustomEdge } from "../FlowEditor/edges/CustomEdge";
import { getNodeDisplayTitle } from "../FlowEditor/nodes/CustomNode/helpers";
/** Maximum nodes serialized into the AI context to prevent token overruns. */
const MAX_NODES = 100;
@@ -145,16 +144,18 @@ export function getActionKey(action: GraphAction): string {
/**
* Resolves the display name for a node: prefers the user-customized name,
* then agent name from hardcodedValues, then block title, then fallback ID.
* Delegates to `getNodeDisplayTitle` for the 3-tier resolution logic.
* falls back to the block title, then to the raw ID.
* Shared between `serializeGraphForChat` and `ActionItem` to avoid duplication.
*/
export function getNodeDisplayName(
node: CustomNode | undefined,
fallback: string,
): string {
if (!node) return fallback;
return getNodeDisplayTitle(node.data) || fallback;
return (
(node?.data.metadata?.customized_name as string | undefined) ||
node?.data.title ||
fallback
);
}
/**

View File

@@ -1,92 +0,0 @@
import { describe, it, expect } from "vitest";
import { getNodeDisplayTitle, formatNodeDisplayTitle } from "../helpers";
import { CustomNodeData } from "../CustomNode";
function makeNodeData(overrides: Partial<CustomNodeData> = {}): CustomNodeData {
return {
title: "AgentExecutorBlock",
description: "",
hardcodedValues: {},
inputSchema: {},
outputSchema: {},
uiType: "agent",
block_id: "block-1",
costs: [],
categories: [],
...overrides,
} as CustomNodeData;
}
describe("getNodeDisplayTitle", () => {
it("returns customized_name when set (tier 1)", () => {
const data = makeNodeData({
metadata: { customized_name: "My Custom Agent" } as any,
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
});
expect(getNodeDisplayTitle(data)).toBe("My Custom Agent");
});
it("returns agent_name with version when no customized_name (tier 2)", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
});
expect(getNodeDisplayTitle(data)).toBe("Researcher v2");
});
it("returns agent_name without version when graph_version is undefined (tier 2)", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Researcher" },
});
expect(getNodeDisplayTitle(data)).toBe("Researcher");
});
it("returns agent_name with version 0 (tier 2)", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Researcher", graph_version: 0 },
});
expect(getNodeDisplayTitle(data)).toBe("Researcher v0");
});
it("returns generic block title when no custom or agent name (tier 3)", () => {
const data = makeNodeData({ title: "AgentExecutorBlock" });
expect(getNodeDisplayTitle(data)).toBe("AgentExecutorBlock");
});
it("prioritizes customized_name over agent_name", () => {
const data = makeNodeData({
metadata: { customized_name: "Renamed" } as any,
hardcodedValues: { agent_name: "Original Agent", graph_version: 1 },
});
expect(getNodeDisplayTitle(data)).toBe("Renamed");
});
});
describe("formatNodeDisplayTitle", () => {
it("returns custom name as-is without beautifying", () => {
const data = makeNodeData({
metadata: { customized_name: "my_custom_name" } as any,
});
expect(formatNodeDisplayTitle(data)).toBe("my_custom_name");
});
it("returns agent name as-is without beautifying", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Blockchain Agent", graph_version: 1 },
});
expect(formatNodeDisplayTitle(data)).toBe("Blockchain Agent v1");
});
it("beautifies generic block title and strips Block suffix", () => {
const data = makeNodeData({ title: "AgentExecutorBlock" });
const result = formatNodeDisplayTitle(data);
expect(result).not.toContain("Block");
expect(result).toBe("Agent Executor");
});
it("does not corrupt agent names containing 'Block'", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Blockchain Agent", graph_version: 2 },
});
expect(formatNodeDisplayTitle(data)).toBe("Blockchain Agent v2");
});
});

View File

@@ -6,10 +6,9 @@ import {
TooltipProvider,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { cn } from "@/lib/utils";
import { useEffect, useState } from "react";
import { beautifyString, cn } from "@/lib/utils";
import { useState } from "react";
import { CustomNodeData } from "../CustomNode";
import { formatNodeDisplayTitle, getNodeDisplayTitle } from "../helpers";
import { NodeBadges } from "./NodeBadges";
import { NodeContextMenu } from "./NodeContextMenu";
import { NodeCost } from "./NodeCost";
@@ -22,24 +21,15 @@ type Props = {
export const NodeHeader = ({ data, nodeId }: Props) => {
const updateNodeData = useNodeStore((state) => state.updateNodeData);
const title = getNodeDisplayTitle(data);
const displayTitle = formatNodeDisplayTitle(data);
const title = (data.metadata?.customized_name as string) || data.title;
const [isEditingTitle, setIsEditingTitle] = useState(false);
const [editedTitle, setEditedTitle] = useState(title);
useEffect(() => {
if (!isEditingTitle) {
setEditedTitle(title);
}
}, [title, isEditingTitle]);
const handleTitleEdit = () => {
if (editedTitle !== title) {
updateNodeData(nodeId, {
metadata: { ...data.metadata, customized_name: editedTitle },
});
}
updateNodeData(nodeId, {
metadata: { ...data.metadata, customized_name: editedTitle },
});
setIsEditingTitle(false);
};
@@ -82,12 +72,12 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
variant="large-semibold"
className="line-clamp-1 hover:cursor-text"
>
{displayTitle}
{beautifyString(title).replace("Block", "").trim()}
</Text>
</div>
</TooltipTrigger>
<TooltipContent>
<p>{displayTitle}</p>
<p>{beautifyString(title).replace("Block", "").trim()}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>

View File

@@ -1,121 +0,0 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, fireEvent } from "@/tests/integrations/test-utils";
import { NodeHeader } from "../NodeHeader";
import { CustomNodeData } from "../../CustomNode";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
vi.mock("../NodeCost", () => ({
NodeCost: () => <div data-testid="node-cost" />,
}));
vi.mock("../NodeContextMenu", () => ({
NodeContextMenu: () => <div data-testid="node-context-menu" />,
}));
vi.mock("../NodeBadges", () => ({
NodeBadges: () => <div data-testid="node-badges" />,
}));
function makeData(overrides: Partial<CustomNodeData> = {}): CustomNodeData {
return {
title: "AgentExecutorBlock",
description: "",
hardcodedValues: {},
inputSchema: {},
outputSchema: {},
uiType: "agent",
block_id: "block-1",
costs: [],
categories: [],
...overrides,
} as CustomNodeData;
}
describe("NodeHeader", () => {
const mockUpdateNodeData = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
useNodeStore.setState({ updateNodeData: mockUpdateNodeData } as any);
});
it("renders beautified generic block title", () => {
render(<NodeHeader data={makeData()} nodeId="abc-123" />);
expect(screen.getByText("Agent Executor")).toBeTruthy();
});
it("renders agent name with version from hardcodedValues", () => {
const data = makeData({
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
});
render(<NodeHeader data={data} nodeId="abc-123" />);
expect(screen.getByText("Researcher v2")).toBeTruthy();
});
it("renders customized_name over agent name", () => {
const data = makeData({
metadata: { customized_name: "My Custom Node" } as any,
hardcodedValues: { agent_name: "Researcher", graph_version: 1 },
});
render(<NodeHeader data={data} nodeId="abc-123" />);
expect(screen.getByText("My Custom Node")).toBeTruthy();
});
it("shows node ID prefix", () => {
render(<NodeHeader data={makeData()} nodeId="abc-123" />);
expect(screen.getByText("#abc")).toBeTruthy();
});
it("enters edit mode on double-click and saves on blur", () => {
render(<NodeHeader data={makeData()} nodeId="node-1" />);
const titleEl = screen.getByText("Agent Executor");
fireEvent.doubleClick(titleEl);
const input = screen.getByDisplayValue("AgentExecutorBlock");
fireEvent.change(input, { target: { value: "New Name" } });
fireEvent.blur(input);
expect(mockUpdateNodeData).toHaveBeenCalledWith("node-1", {
metadata: { customized_name: "New Name" },
});
});
it("does not save when title is unchanged on blur", () => {
const data = makeData({
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
});
render(<NodeHeader data={data} nodeId="node-1" />);
const titleEl = screen.getByText("Researcher v2");
fireEvent.doubleClick(titleEl);
const input = screen.getByDisplayValue("Researcher v2");
fireEvent.blur(input);
expect(mockUpdateNodeData).not.toHaveBeenCalled();
});
it("saves on Enter key", () => {
render(<NodeHeader data={makeData()} nodeId="node-1" />);
fireEvent.doubleClick(screen.getByText("Agent Executor"));
const input = screen.getByDisplayValue("AgentExecutorBlock");
fireEvent.change(input, { target: { value: "Renamed" } });
fireEvent.keyDown(input, { key: "Enter" });
expect(mockUpdateNodeData).toHaveBeenCalledWith("node-1", {
metadata: { customized_name: "Renamed" },
});
});
it("cancels edit on Escape key", () => {
render(<NodeHeader data={makeData()} nodeId="node-1" />);
fireEvent.doubleClick(screen.getByText("Agent Executor"));
const input = screen.getByDisplayValue("AgentExecutorBlock");
fireEvent.change(input, { target: { value: "Changed" } });
fireEvent.keyDown(input, { key: "Escape" });
expect(mockUpdateNodeData).not.toHaveBeenCalled();
expect(screen.getByText("Agent Executor")).toBeTruthy();
});
});

View File

@@ -1,55 +1,6 @@
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { NodeResolutionData } from "@/app/(platform)/build/stores/types";
import { beautifyString } from "@/lib/utils";
import { RJSFSchema } from "@rjsf/utils";
import { CustomNodeData } from "./CustomNode";
/**
* Resolves the display title for a node using a 3-tier fallback:
*
* 1. `customized_name` — the user's manual rename (highest priority)
* 2. `agent_name` (+ version) from `hardcodedValues` — the selected agent's
* display name, persisted by blocks like AgentExecutorBlock
* 3. `data.title` — the generic block name (e.g. "Agent Executor")
*
* `customized_name` is the user's explicit rename via double-click; it lives in
* node metadata. `agent_name` is the programmatic name of the agent graph
* selected in the block's input form; it lives in `hardcodedValues` alongside
* `graph_version`. These are distinct sources of truth — customized_name always
* wins because it reflects deliberate user intent.
*/
export function getNodeDisplayTitle(data: CustomNodeData): string {
if (data.metadata?.customized_name) {
return data.metadata.customized_name as string;
}
const agentName = data.hardcodedValues?.agent_name as string | undefined;
const graphVersion = data.hardcodedValues?.graph_version as
| number
| undefined;
if (agentName) {
return graphVersion != null ? `${agentName} v${graphVersion}` : agentName;
}
return data.title;
}
/**
* Returns the formatted display title for rendering.
* Agent names and custom names are shown as-is; generic block names get
* beautified and have the trailing " Block" suffix stripped.
*/
export function formatNodeDisplayTitle(data: CustomNodeData): string {
const title = getNodeDisplayTitle(data);
const isAgentOrCustom = !!(
data.metadata?.customized_name || data.hardcodedValues?.agent_name
);
return isAgentOrCustom
? title
: beautifyString(title)
.replace(/ Block$/, "")
.trim();
}
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
INCOMPLETE: "ring-slate-300 bg-slate-300",

View File

@@ -1,4 +1,3 @@
import { formatNodeDisplayTitle } from "@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/helpers";
import { Separator } from "@/components/ui/separator";
import { ScrollArea } from "@/components/ui/scroll-area";
import { beautifyString, cn } from "@/lib/utils";
@@ -59,7 +58,9 @@ export function GraphSearchContent({
filteredNodes.map((node, index) => {
if (!node?.data) return null;
const nodeTitle = formatNodeDisplayTitle(node.data);
const nodeTitle =
(node.data.metadata?.customized_name as string) ||
beautifyString(node.data.title || "").replace(/ Block$/, "");
const nodeType = beautifyString(node.data.title || "").replace(
/ Block$/,
"",
@@ -69,10 +70,7 @@ export function GraphSearchContent({
node.data.description ||
"";
const hasCustomName = !!(
node.data.metadata?.customized_name ||
node.data.hardcodedValues?.agent_name
);
const hasCustomName = !!node.data.metadata?.customized_name;
return (
<div

View File

@@ -69,9 +69,6 @@ function calculateNodeScore(
const customizedName = String(
node.data?.metadata?.customized_name || "",
).toLowerCase();
const agentName = String(
node.data?.hardcodedValues?.agent_name || "",
).toLowerCase();
// Get input and output names with defensive checks
const inputNames = Object.keys(node.data?.inputSchema?.properties || {}).map(
@@ -84,7 +81,6 @@ function calculateNodeScore(
// 1. Check exact match in customized name, title (includes ID), node ID, or block type (highest priority)
if (
customizedName.includes(query) ||
agentName.includes(query) ||
nodeTitle.includes(query) ||
nodeID.includes(query) ||
blockType.includes(query) ||
@@ -99,7 +95,6 @@ function calculateNodeScore(
queryWords.every(
(word) =>
customizedName.includes(word) ||
agentName.includes(word) ||
nodeTitle.includes(word) ||
beautifiedBlockType.includes(word),
)

View File

@@ -93,6 +93,7 @@ export function CopilotPage() {
hasMoreMessages,
isLoadingMore,
loadMore,
forwardPaginated,
// Mobile drawer
isMobile,
isDrawerOpen,
@@ -217,6 +218,7 @@ export function CopilotPage() {
hasMoreMessages={hasMoreMessages}
isLoadingMore={isLoadingMore}
onLoadMore={loadMore}
forwardPaginated={forwardPaginated}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
historicalDurations={historicalDurations}

View File

@@ -48,40 +48,75 @@ function makeQueryResult(data: object | null) {
};
}
describe("useChatSession — pagination metadata", () => {
describe("useChatSession — newestSequence and forwardPaginated", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("returns null for oldestSequence when no session data", () => {
it("returns null / false when no session data", () => {
mockUseGetV2GetSession.mockReturnValue(makeQueryResult(null));
const { result } = renderHook(() => useChatSession());
expect(result.current.oldestSequence).toBeNull();
expect(result.current.newestSequence).toBeNull();
expect(result.current.forwardPaginated).toBe(false);
});
it("returns oldestSequence from session data", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 50,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.oldestSequence).toBe(50);
});
it("returns hasMoreMessages from session data", () => {
it("returns newestSequence from session data", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 0,
newest_sequence: 99,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.hasMoreMessages).toBe(true);
expect(result.current.newestSequence).toBe(99);
});
it("returns null for newestSequence when field is missing", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: false,
oldest_sequence: 0,
newest_sequence: null,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.newestSequence).toBeNull();
});
it("returns forwardPaginated=true when session is forward-paginated", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 0,
newest_sequence: 49,
forward_paginated: true,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.forwardPaginated).toBe(true);
});
it("returns forwardPaginated=false when session is backward-paginated", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 50,
newest_sequence: 99,
forward_paginated: false,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.forwardPaginated).toBe(false);
});
});

View File

@@ -1,131 +0,0 @@
import { renderHook } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useCopilotPage } from "../useCopilotPage";
const mockUseChatSession = vi.fn();
const mockUseCopilotStream = vi.fn();
const mockUseLoadMoreMessages = vi.fn();
vi.mock("../useChatSession", () => ({
useChatSession: (...args: unknown[]) => mockUseChatSession(...args),
}));
vi.mock("../useCopilotStream", () => ({
useCopilotStream: (...args: unknown[]) => mockUseCopilotStream(...args),
}));
vi.mock("../useLoadMoreMessages", () => ({
useLoadMoreMessages: (...args: unknown[]) => mockUseLoadMoreMessages(...args),
}));
vi.mock("../useCopilotNotifications", () => ({
useCopilotNotifications: () => undefined,
}));
vi.mock("../useWorkflowImportAutoSubmit", () => ({
useWorkflowImportAutoSubmit: () => undefined,
}));
vi.mock("../store", () => ({
useCopilotUIStore: () => ({
sessionToDelete: null,
setSessionToDelete: vi.fn(),
isDrawerOpen: false,
setDrawerOpen: vi.fn(),
copilotChatMode: "chat",
copilotLlmModel: null,
isDryRun: false,
}),
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
concatWithAssistantMerge: (a: unknown[], b: unknown[]) => [...a, ...b],
}));
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useDeleteV2DeleteSession: () => ({ mutate: vi.fn(), isPending: false }),
useGetV2ListSessions: () => ({ data: undefined, isLoading: false }),
getGetV2ListSessionsQueryKey: () => ["sessions"],
}));
vi.mock("@/components/molecules/Toast/use-toast", () => ({
toast: vi.fn(),
}));
vi.mock("@/lib/direct-upload", () => ({
uploadFileDirect: vi.fn(),
}));
vi.mock("@/lib/hooks/useBreakpoint", () => ({
useBreakpoint: () => "lg",
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: () => ({ isUserLoading: false, isLoggedIn: true }),
}));
vi.mock("@tanstack/react-query", () => ({
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
}));
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
useGetFlag: () => false,
}));
function makeBaseChatSession(overrides: Record<string, unknown> = {}) {
return {
sessionId: "sess-1",
setSessionId: vi.fn(),
hydratedMessages: [],
rawSessionMessages: [],
historicalDurations: new Map(),
hasActiveStream: false,
hasMoreMessages: false,
oldestSequence: null,
isLoadingSession: false,
isSessionError: false,
createSession: vi.fn(),
isCreatingSession: false,
refetchSession: vi.fn(),
sessionDryRun: false,
...overrides,
};
}
function makeBaseCopilotStream(overrides: Record<string, unknown> = {}) {
return {
messages: [],
sendMessage: vi.fn(),
stop: vi.fn(),
status: "ready",
error: undefined,
isReconnecting: false,
isSyncing: false,
isUserStoppingRef: { current: false },
rateLimitMessage: null,
dismissRateLimit: vi.fn(),
...overrides,
};
}
function makeBaseLoadMore(overrides: Record<string, unknown> = {}) {
return {
pagedMessages: [],
hasMore: false,
isLoadingMore: false,
loadMore: vi.fn(),
...overrides,
};
}
describe("useCopilotPage — backward pagination message ordering", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("prepends pagedMessages before currentMessages", () => {
const pagedMsg = { id: "paged", role: "user" };
const currentMsg = { id: "current", role: "assistant" };
mockUseChatSession.mockReturnValue(makeBaseChatSession());
mockUseCopilotStream.mockReturnValue(
makeBaseCopilotStream({ messages: [currentMsg] }),
);
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
);
const { result } = renderHook(() => useCopilotPage());
// Backward: pagedMessages (older) come first
expect(result.current.messages[0]).toEqual(pagedMsg);
expect(result.current.messages[1]).toEqual(currentMsg);
});
});

View File

@@ -15,8 +15,10 @@ vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
const BASE_ARGS = {
sessionId: "sess-1",
initialOldestSequence: 50,
initialOldestSequence: 0,
initialNewestSequence: 49,
initialHasMore: true,
forwardPaginated: true,
initialPageRawMessages: [],
};
@@ -24,6 +26,7 @@ function makeSuccessResponse(overrides: {
messages?: unknown[];
has_more_messages?: boolean;
oldest_sequence?: number;
newest_sequence?: number;
}) {
return {
status: 200,
@@ -31,6 +34,7 @@ function makeSuccessResponse(overrides: {
messages: overrides.messages ?? [],
has_more_messages: overrides.has_more_messages ?? false,
oldest_sequence: overrides.oldest_sequence ?? 0,
newest_sequence: overrides.newest_sequence ?? 49,
},
};
}
@@ -47,6 +51,30 @@ describe("useLoadMoreMessages", () => {
expect(result.current.isLoadingMore).toBe(false);
});
it("resetPaged clears paged state and sets hasMore=false during transition", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
act(() => {
result.current.resetPaged();
});
expect(result.current.pagedMessages).toHaveLength(0);
// hasMore must be false during transition to prevent forward loadMore
// from firing on the now-active session before forwardPaginated updates.
expect(result.current.hasMore).toBe(false);
expect(result.current.isLoadingMore).toBe(false);
});
it("resetPaged exposes a fresh loadMore via incremented epoch", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
// Just verify resetPaged is callable and doesn't throw.
expect(() => {
act(() => {
result.current.resetPaged();
});
}).not.toThrow();
});
it("resets all state on sessionId change", () => {
const { result, rerender } = renderHook(
(props) => useLoadMoreMessages(props),
@@ -57,6 +85,7 @@ describe("useLoadMoreMessages", () => {
...BASE_ARGS,
sessionId: "sess-2",
initialOldestSequence: 10,
initialNewestSequence: 59,
initialHasMore: false,
});
@@ -65,6 +94,66 @@ describe("useLoadMoreMessages", () => {
expect(result.current.isLoadingMore).toBe(false);
});
describe("loadMore — forward pagination", () => {
it("calls getV2GetSession with after_sequence and updates newestSequence", async () => {
const rawMsg = { role: "user", content: "hi", sequence: 50 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 49 }),
);
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("sets hasMore=false when response has no more messages", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false }),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(false);
});
it("is a no-op when hasMore is false", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
initialHasMore: false,
forwardPaginated: true,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
});
describe("loadMore — backward pagination", () => {
it("calls getV2GetSession with before_sequence", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
@@ -75,7 +164,13 @@ describe("useLoadMoreMessages", () => {
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: 50,
}),
);
await act(async () => {
await result.current.loadMore();
@@ -87,30 +182,6 @@ describe("useLoadMoreMessages", () => {
);
expect(result.current.hasMore).toBe(false);
});
it("is a no-op when hasMore is false", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, initialHasMore: false }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
it("is a no-op when oldestSequence is null", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, initialOldestSequence: null }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
});
describe("loadMore — error handling", () => {
@@ -123,6 +194,7 @@ describe("useLoadMoreMessages", () => {
await result.current.loadMore();
});
// First error — hasMore still true
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
@@ -136,6 +208,7 @@ describe("useLoadMoreMessages", () => {
await act(async () => {
await result.current.loadMore();
});
// Reset the in-flight guard between calls
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
}
@@ -151,18 +224,122 @@ describe("useLoadMoreMessages", () => {
await result.current.loadMore();
});
// One error, not yet at threshold — hasMore still true
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
});
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
describe("loadMore — forward pagination cursor advancement", () => {
it("advances newestSequence after a successful forward load", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 2001 }, (_, i) => ({
messages: [{ role: "user", content: "hi", sequence: 50 }],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
);
await act(async () => {
await result.current.loadMore();
});
// A second loadMore should use after_sequence: 99 (advanced cursor)
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 99 }),
);
});
it("does not regress newestSequence when parent refetches after pages loaded", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [{ role: "user", content: "msg", sequence: 50 }],
has_more_messages: true,
newest_sequence: 99,
}),
);
const { result, rerender } = renderHook(
(props) => useLoadMoreMessages(props),
{ initialProps: { ...BASE_ARGS, forwardPaginated: true } },
);
// Load one page — newestSequence advances to 99
await act(async () => {
await result.current.loadMore();
});
// Parent refetches with a lower newest_sequence (49) — should NOT regress cursor
rerender({
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: 49,
});
// Next loadMore should still use the advanced cursor (99)
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
"sess-1",
expect.objectContaining({ after_sequence: 99 }),
);
});
});
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
// Simulate being near the limit — 1990 existing paged messages
const nearLimitArgs = {
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: 1990,
};
// Return 20 messages to push total past 2000
const newMessages = Array.from({ length: 20 }, (_, i) => ({
role: "user",
content: `msg ${i}`,
sequence: i,
}));
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: newMessages,
has_more_messages: true,
oldest_sequence: 0,
}),
);
const { result } = renderHook((props) => useLoadMoreMessages(props), {
initialProps: nearLimitArgs,
});
// Pre-fill pagedRawMessages to near limit by doing a successful load first
// then checking hasMore is set to false when limit reached
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 1990 }, (_, i) => ({
role: "user",
content: `msg ${i}`,
content: `old ${i}`,
sequence: i,
})),
has_more_messages: true,
@@ -170,18 +347,109 @@ describe("useLoadMoreMessages", () => {
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
// Now add 20 more to exceed 2000 — hasMore should be forced false
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(false);
});
it("forward truncation keeps first MAX_OLDER_MESSAGES items (not last)", async () => {
// 1990 messages already paged; load 20 more forward — total 2010 > 2000.
// Forward truncation must keep slice(0, 2000), not slice(-2000),
// to preserve the beginning of the conversation.
const forwardNearLimitArgs = {
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: 49,
initialOldestSequence: 0,
initialHasMore: true,
};
const { result } = renderHook((props) => useLoadMoreMessages(props), {
initialProps: forwardNearLimitArgs,
});
// First load: 1990 messages — advances newestSequence to 2039
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 1990 }, (_, i) => ({
role: "assistant",
content: `msg ${i + 50}`,
sequence: i + 50,
})),
has_more_messages: true,
newest_sequence: 2039,
}),
);
await act(async () => {
await result.current.loadMore();
});
// Second load: 20 more messages pushes total to 2010 > 2000 — hasMore=false
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 20 }, (_, i) => ({
role: "assistant",
content: `msg ${i + 2040}`,
sequence: i + 2040,
})),
has_more_messages: false,
newest_sequence: 2059,
}),
);
await act(async () => {
await result.current.loadMore();
});
// After truncation, hasMore is forced false (total ≥ MAX_OLDER_MESSAGES).
expect(result.current.hasMore).toBe(false);
});
});
describe("loadMore — null cursor guard", () => {
it("is a no-op when newestSequence is null (forwardPaginated=true)", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: true,
initialNewestSequence: null,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
it("is a no-op when oldestSequence is null (forwardPaginated=false)", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: null,
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
});
describe("pagedMessages — initialPageRawMessages extraToolOutputs", () => {
it("calls extractToolOutputsFromRaw with non-empty initialPageRawMessages", async () => {
it("calls extractToolOutputsFromRaw for backward pagination with non-empty initialPageRawMessages", async () => {
const { extractToolOutputsFromRaw } = await import(
"../helpers/convertChatSessionToUiMessages"
);
@@ -198,6 +466,8 @@ describe("useLoadMoreMessages", () => {
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: false,
initialOldestSequence: 50,
initialPageRawMessages: [{ role: "assistant", content: "response" }],
}),
);
@@ -208,5 +478,68 @@ describe("useLoadMoreMessages", () => {
expect(extractToolOutputsFromRaw).toHaveBeenCalled();
});
it("does NOT call extractToolOutputsFromRaw for forward pagination", async () => {
const { extractToolOutputsFromRaw } = await import(
"../helpers/convertChatSessionToUiMessages"
);
const rawMsg = { role: "assistant", content: "hi", sequence: 50 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: false,
newest_sequence: 99,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
forwardPaginated: true,
initialPageRawMessages: [{ role: "user", content: "hello" }],
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(extractToolOutputsFromRaw).not.toHaveBeenCalled();
});
});
describe("loadMore — epoch / stale-response guard", () => {
it("discards response when epoch changes during flight (resetPaged called)", async () => {
let resolveRequest!: (v: unknown) => void;
mockGetV2GetSession.mockReturnValueOnce(
new Promise((res) => {
resolveRequest = res;
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
// Start the request without awaiting
act(() => {
result.current.loadMore();
});
// Reset epoch mid-flight
act(() => {
result.current.resetPaged();
});
// Now resolve the in-flight request
await act(async () => {
resolveRequest(
makeSuccessResponse({ messages: [{ role: "user", content: "hi" }] }),
);
});
// Response discarded — pagedMessages stays empty, isLoadingMore stays false
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.isLoadingMore).toBe(false);
});
});
});

View File

@@ -6,11 +6,9 @@ import { Suspense, useState } from "react";
import { Skeleton } from "@/components/ui/skeleton";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import { ArtifactErrorBoundary } from "./ArtifactErrorBoundary";
import { ArtifactReactPreview } from "./ArtifactReactPreview";
import { ArtifactSkeleton } from "./ArtifactSkeleton";
import {
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
TAILWIND_CDN_URL,
wrapWithHeadInjection,
} from "@/lib/iframe-sandbox-csp";
@@ -55,18 +53,13 @@ function ArtifactContentLoader({
return (
<div ref={scrollRef} className="flex-1 overflow-y-auto">
<ArtifactErrorBoundary
artifactTitle={artifact.title}
artifactType={classification.type}
>
<ArtifactRenderer
artifact={artifact}
content={content}
pdfUrl={pdfUrl}
isSourceView={isSourceView}
classification={classification}
/>
</ArtifactErrorBoundary>
<ArtifactRenderer
artifact={artifact}
content={content}
pdfUrl={pdfUrl}
isSourceView={isSourceView}
classification={classification}
/>
</div>
);
}
@@ -207,10 +200,7 @@ function ArtifactRenderer({
if (classification.type === "html") {
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
const wrapped = wrapWithHeadInjection(
content,
tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
);
const wrapped = wrapWithHeadInjection(content, tailwindScript);
return (
<iframe
sandbox="allow-scripts"

View File

@@ -1,96 +0,0 @@
"use client";
import * as Sentry from "@sentry/nextjs";
import { Component, type ErrorInfo, type ReactNode } from "react";
interface Props {
children: ReactNode;
artifactTitle: string;
artifactType: string;
}
interface State {
error: Error | null;
}
export class ArtifactErrorBoundary extends Component<Props, State> {
state: State = { error: null };
static getDerivedStateFromError(error: Error): State {
return { error };
}
componentDidCatch(error: Error, errorInfo: ErrorInfo) {
Sentry.captureException(error, {
contexts: {
react: { componentStack: errorInfo.componentStack },
},
tags: { errorBoundary: "true", context: "copilot-artifact" },
extra: {
artifactTitle: this.props.artifactTitle,
artifactType: this.props.artifactType,
},
});
}
componentDidUpdate(prevProps: Props) {
if (
this.state.error &&
(prevProps.artifactTitle !== this.props.artifactTitle ||
prevProps.artifactType !== this.props.artifactType)
) {
this.setState({ error: null });
}
}
handleCopy = () => {
const { error } = this.state;
if (!error) return;
const details = [
`Artifact: ${this.props.artifactTitle}`,
`Type: ${this.props.artifactType}`,
`Error: ${error.message}`,
error.stack ? `Stack:\n${error.stack}` : "",
]
.filter(Boolean)
.join("\n");
navigator.clipboard?.writeText(details).catch(() => {});
};
render() {
const { error } = this.state;
if (!error) return this.props.children;
const message = error.message || "Unknown rendering error";
return (
<div
role="alert"
className="flex h-full flex-col items-center justify-center gap-3 p-8 text-center"
>
<p className="text-sm font-medium text-zinc-700">
This artifact couldn&apos;t be rendered
</p>
<p className="max-w-md break-words text-xs text-zinc-500">
Something in{" "}
<span className="font-mono">{this.props.artifactTitle}</span> threw an
error while rendering. The chat and sidebar are still working.
</p>
<pre className="max-h-32 max-w-md overflow-auto whitespace-pre-wrap break-words rounded-md bg-zinc-100 px-3 py-2 text-left text-xs text-zinc-700">
{message}
</pre>
<button
type="button"
onClick={this.handleCopy}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
Copy error details
</button>
<p className="max-w-md text-xs text-zinc-400">
Paste this into the chat so the agent can regenerate a working
version.
</p>
</div>
);
}
}

View File

@@ -412,41 +412,6 @@ describe("ArtifactContent", () => {
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
});
it("injects the fragment-link interceptor into HTML artifact iframes (regression)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () =>
Promise.resolve(
'<html><head></head><body><a href="#x">x</a><div id="x">x</div></body></html>',
),
}),
);
const { container } = render(
<ArtifactContent
artifact={makeArtifact({
id: "html-frag",
title: "page.html",
mimeType: "text/html",
})}
isSourceView={false}
classification={makeClassification({ type: "html" })}
/>,
);
await screen.findByTitle("page.html");
const srcdoc = container.querySelector("iframe")?.getAttribute("srcdoc");
expect(srcdoc).toBeTruthy();
// Markers unique to FRAGMENT_LINK_INTERCEPTOR_SCRIPT — if any of these
// disappear, the interceptor is no longer being injected and fragment
// links will navigate the parent URL again.
expect(srcdoc).toContain("__fragmentLinkInterceptor");
expect(srcdoc).toContain('a[href^="#"]');
expect(srcdoc).toContain("scrollIntoView");
});
// ── Source view ───────────────────────────────────────────────────
it("renders source view as pre tag", async () => {
@@ -958,164 +923,6 @@ describe("ArtifactContent", () => {
},
);
// ── Error boundary ────────────────────────────────────────────────
it("shows a visible error instead of crashing when the renderer throws", async () => {
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
const originalImpl = vi
.mocked(ArtifactReactPreview)
.getMockImplementation();
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
throw new Error("boom in renderer");
});
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source"),
}),
);
const artifact = makeArtifact({
id: "crash-001",
title: "broken.tsx",
mimeType: "text/tsx",
});
const classification = makeClassification({ type: "react" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
expect(
await screen.findByText(/This artifact couldn't be rendered/i),
).toBeTruthy();
expect(screen.getByText(/boom in renderer/)).toBeTruthy();
expect(
screen.getByRole("button", { name: /copy error details/i }),
).toBeTruthy();
if (originalImpl) {
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
}
consoleErr.mockRestore();
});
it("copies artifact title, type, and error to the clipboard", async () => {
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
const writeText = vi.fn().mockResolvedValue(undefined);
Object.defineProperty(navigator, "clipboard", {
value: { writeText },
writable: true,
configurable: true,
});
const originalImpl = vi
.mocked(ArtifactReactPreview)
.getMockImplementation();
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
throw new Error("jsx parse failed at line 42");
});
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source"),
}),
);
render(
<ArtifactContent
artifact={makeArtifact({
id: "crash-002",
title: "report.tsx",
mimeType: "text/tsx",
})}
isSourceView={false}
classification={makeClassification({ type: "react" })}
/>,
);
fireEvent.click(
await screen.findByRole("button", { name: /copy error details/i }),
);
await waitFor(() => {
expect(writeText).toHaveBeenCalled();
});
const payload = writeText.mock.calls[0]![0] as string;
expect(payload).toContain("report.tsx");
expect(payload).toContain("react");
expect(payload).toContain("jsx parse failed at line 42");
if (originalImpl) {
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
}
consoleErr.mockRestore();
});
it("renders the user-reported plotly HTML artifact into a sandboxed iframe", async () => {
const html = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>AutoGPT Beta Launch Interactive Report</title>
<script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
<style>
:root { --bg: #f8f9fa; --primary: #6c5ce7; }
* { margin: 0; padding: 0; box-sizing: border-box; }
body { font-family: 'Segoe UI', system-ui, sans-serif; }
</style>
</head>
<body>
<header><h1>\u{1F4CA} AutoGPT Beta Launch Interactive Report</h1></header>
<div class="chart-container" id="globalActivationChart"></div>
<script>
function showTab(tabId, groupId) {
const group = document.getElementById(groupId);
group.querySelectorAll('.tab-content').forEach(t => t.classList.remove('active'));
document.getElementById(tabId).classList.add('active');
}
Plotly.newPlot('globalActivationChart', [{ type: 'pie', values: [1, 2] }], {});
</script>
</body>
</html>`;
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(html),
}),
);
const artifact = makeArtifact({
id: "html-big-report",
title: "report.html",
mimeType: "text/html",
});
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={makeClassification({ type: "html" })}
/>,
);
await screen.findByTitle("report.html");
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
expect(screen.queryByText(/couldn't be rendered/i)).toBeNull();
});
it("falls back to pre tag when no renderer matches", async () => {
const { globalRegistry } = await import(
"@/components/contextual/OutputRenderers"

View File

@@ -116,11 +116,4 @@ describe("buildReactArtifactSrcDoc", () => {
expect(doc).toContain("/^[A-Z]/.test(name)");
expect(doc).toContain("wrapWithProviders");
});
it("injects the fragment-link interceptor so #anchor clicks stay inside the iframe (regression)", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("__fragmentLinkInterceptor");
expect(doc).toContain('a[href^="#"]');
expect(doc).toContain("scrollIntoView");
});
});

View File

@@ -19,10 +19,7 @@
* React is loaded from unpkg with pinned version and SRI integrity hashes.
*/
import {
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
TAILWIND_CDN_URL,
} from "@/lib/iframe-sandbox-csp";
import { TAILWIND_CDN_URL } from "@/lib/iframe-sandbox-csp";
export { transpileReactArtifactSource } from "./transpileReactArtifact";
@@ -98,7 +95,6 @@ export function buildReactArtifactSrcDoc(
}
</style>
<script src="${TAILWIND_CDN_URL}"></script>
${FRAGMENT_LINK_INTERCEPTOR_SCRIPT}
<script crossorigin="anonymous" src="https://unpkg.com/react@18.3.1/umd/react.production.min.js" integrity="sha384-DGyLxAyjq0f9SPpVevD6IgztCFlnMF6oW/XQGmfe+IsZ8TqEiDrcHkMLKI6fiB/Z"></script><!-- pragma: allowlist secret -->
<script crossorigin="anonymous" src="https://unpkg.com/react-dom@18.3.1/umd/react-dom.production.min.js" integrity="sha384-gTGxhz21lVGYNMcdJOyq01Edg0jhn/c22nsx0kyqP0TxaV5WVdsSH1fSDUf5YJj1"></script><!-- pragma: allowlist secret -->
</head>

View File

@@ -30,6 +30,7 @@ export interface ChatContainerProps {
hasMoreMessages?: boolean;
isLoadingMore?: boolean;
onLoadMore?: () => void;
forwardPaginated?: boolean;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -54,6 +55,7 @@ export const ChatContainer = ({
hasMoreMessages,
isLoadingMore,
onLoadMore,
forwardPaginated,
droppedFiles,
onDroppedFilesConsumed,
historicalDurations,
@@ -108,6 +110,7 @@ export const ChatContainer = ({
hasMoreMessages={hasMoreMessages}
isLoadingMore={isLoadingMore}
onLoadMore={onLoadMore}
forwardPaginated={forwardPaginated}
onRetry={handleRetry}
historicalDurations={historicalDurations}
/>

View File

@@ -86,11 +86,11 @@ export function ChatInput({
title:
next === "advanced"
? "Switched to Advanced model"
: "Switched to Balanced model",
: "Switched to Standard model",
description:
next === "advanced"
? "Using the highest-capability model."
: "Using the balanced default model.",
: "Using the balanced standard model.",
});
}

View File

@@ -162,15 +162,10 @@ describe("ChatInput mode toggle", () => {
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
});
it("hides toggle buttons when streaming", () => {
it("hides toggle button when streaming", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} isStreaming />);
expect(
screen.queryByLabelText(/switch to (fast|extended thinking) mode/i),
).toBeNull();
expect(
screen.queryByLabelText(/switch to (advanced|balanced|standard) model/i),
).toBeNull();
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
});
it("shows mode toggle when hasSession is true and not streaming", () => {
@@ -239,7 +234,7 @@ describe("ChatInput model toggle", () => {
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
});
@@ -293,10 +288,10 @@ describe("ChatInput model toggle", () => {
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to balanced model/i),
title: expect.stringMatching(/switched to standard model/i),
}),
);
});

View File

@@ -2,11 +2,6 @@
import { cn } from "@/lib/utils";
import { Flask } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
// This button is only rendered on NEW chats (no active session).
// Once a session exists, it is hidden — the session's dry_run flag is
@@ -19,31 +14,27 @@ interface Props {
export function DryRunToggleButton({ isDryRun, onToggle }: Props) {
return (
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isDryRun}
onClick={onToggle}
className={cn(
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isDryRun
? "text-amber-900"
: "text-neutral-500 hover:text-neutral-700",
)}
aria-label={isDryRun ? "Test mode active" : "Enable Test mode"}
>
<Flask size={14} />
<span className="hidden sm:inline">
{isDryRun ? "Test mode enabled" : "Enable test mode"}
</span>
</button>
</TooltipTrigger>
<TooltipContent>
{isDryRun
? "Test mode on — new sessions run without performing real actions (click to turn off)."
: "Turn on test mode to try prompts without performing real actions."}
</TooltipContent>
</Tooltip>
<button
type="button"
aria-pressed={isDryRun}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isDryRun
? "bg-amber-100 text-amber-900 hover:bg-amber-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
)}
aria-label={
isDryRun ? "Test mode active — click to disable" : "Enable Test mode"
}
title={
isDryRun
? "Test mode ON — new chats run agents as simulation (click to disable)"
: "Enable Test mode — new chats will run agents as simulation"
}
>
<Flask size={14} />
{isDryRun && "Test"}
</button>
);
}

View File

@@ -2,11 +2,6 @@
import { cn } from "@/lib/utils";
import { Brain, Lightning } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import type { CopilotMode } from "../../../store";
interface Props {
@@ -16,42 +11,37 @@ interface Props {
export function ModeToggleButton({ mode, onToggle }: Props) {
const isExtended = mode === "extended_thinking";
const tooltipText = isExtended
? "Extended Thinking — deeper reasoning (click to switch to Fast)"
: "Fast mode — quicker responses (click to switch to Thinking)";
return (
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isExtended}
onClick={onToggle}
className={cn(
"ml-2 inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isExtended ? "text-purple-900" : "text-amber-900",
)}
aria-label={
isExtended
? "Switch to Fast mode"
: "Switch to Extended Thinking mode"
}
>
{isExtended ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
</TooltipTrigger>
<TooltipContent>{tooltipText}</TooltipContent>
</Tooltip>
<button
type="button"
aria-pressed={isExtended}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isExtended
? "bg-purple-100 text-purple-900 hover:bg-purple-200"
: "bg-amber-100 text-amber-900 hover:bg-amber-200",
)}
aria-label={
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
}
title={
isExtended
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
}
>
{isExtended ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
);
}

View File

@@ -2,11 +2,6 @@
import { cn } from "@/lib/utils";
import { Cpu } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import type { CopilotLlmModel } from "../../../store";
interface Props {
@@ -17,33 +12,27 @@ interface Props {
export function ModelToggleButton({ model, onToggle }: Props) {
const isAdvanced = model === "advanced";
return (
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isAdvanced}
onClick={onToggle}
className={cn(
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isAdvanced
? "text-sky-900"
: "text-neutral-500 hover:text-neutral-700",
)}
aria-label={
isAdvanced ? "Switch to Balanced model" : "Switch to Advanced model"
}
>
<Cpu size={14} />
<span className="hidden sm:inline">
{isAdvanced ? "Advanced" : "Balanced"}
</span>
</button>
</TooltipTrigger>
<TooltipContent>
{isAdvanced
? "Using the highest-capability model (click to switch to Balanced)."
: "Using the balanced default model (click to switch to Advanced)."}
</TooltipContent>
</Tooltip>
<button
type="button"
aria-pressed={isAdvanced}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isAdvanced
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
)}
aria-label={
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
}
title={
isAdvanced
? "Advanced model — highest capability (click to switch to Standard)"
: "Standard model — click to switch to Advanced"
}
>
<Cpu size={14} />
{isAdvanced && "Advanced"}
</button>
);
}

View File

@@ -1,32 +1,21 @@
import {
render as rtlRender,
screen,
fireEvent,
cleanup,
} from "@testing-library/react";
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import type { ReactElement } from "react";
import { TooltipProvider } from "@/components/ui/tooltip";
import { DryRunToggleButton } from "../DryRunToggleButton";
afterEach(cleanup);
function render(ui: ReactElement) {
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
}
// DryRunToggleButton only appears on new chats (no active session).
// It has no readOnly/isStreaming props — those scenarios are handled by hiding
// the button entirely at the ChatInput level when hasSession is true.
describe("DryRunToggleButton", () => {
it("shows enabled label when isDryRun is true", () => {
it("shows Test label when isDryRun is true", () => {
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
expect(screen.getByText("Test mode enabled")).toBeTruthy();
expect(screen.getByText("Test")).toBeTruthy();
});
it("shows enable label when isDryRun is false", () => {
it("shows no text label when isDryRun is false", () => {
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
expect(screen.getByText("Enable test mode")).toBeTruthy();
expect(screen.queryByText("Test")).toBeNull();
});
it("calls onToggle when clicked", () => {

View File

@@ -1,20 +1,9 @@
import {
render as rtlRender,
screen,
fireEvent,
cleanup,
} from "@testing-library/react";
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import type { ReactElement } from "react";
import { TooltipProvider } from "@/components/ui/tooltip";
import { ModelToggleButton } from "../ModelToggleButton";
afterEach(cleanup);
function render(ui: ReactElement) {
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
}
describe("ModelToggleButton", () => {
it("shows no text label when model is standard", () => {
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
@@ -42,7 +31,7 @@ describe("ModelToggleButton", () => {
it("sets aria-pressed=true for advanced", () => {
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
const btn = screen.getByLabelText("Switch to Balanced model");
const btn = screen.getByLabelText("Switch to Standard model");
expect(btn.getAttribute("aria-pressed")).toBe("true");
});
});

View File

@@ -43,6 +43,10 @@ interface Props {
hasMoreMessages?: boolean;
isLoadingMore?: boolean;
onLoadMore?: () => void;
/** When true the load-more sentinel is placed at the bottom (forward
* pagination for completed sessions). When false it is at the top
* (backward pagination for active sessions). */
forwardPaginated?: boolean;
onRetry?: () => void;
historicalDurations?: Map<string, number>;
}
@@ -136,11 +140,25 @@ export function LoadMoreSentinel({
isLoading,
messageCount,
onLoadMore,
rootMargin = "200px 0px 0px 0px",
adjustScroll = true,
forwardPaginated = false,
}: {
hasMore: boolean;
isLoading: boolean;
messageCount: number;
onLoadMore: () => void;
/** IntersectionObserver rootMargin. Top sentinel uses "200px 0px 0px 0px"
* (pre-trigger when approaching from above); bottom sentinel should use
* "0px 0px 200px 0px" (pre-trigger when approaching from below). */
rootMargin?: string;
/** Whether to adjust scrollTop after load to preserve visual position.
* True for backward pagination (prepend above); false for forward
* pagination (append below) where no adjustment is needed. */
adjustScroll?: boolean;
/** When true the button reads "Load newer messages" (forward pagination).
* When false (default) it reads "Load older messages". */
forwardPaginated?: boolean;
}) {
const sentinelRef = useRef<HTMLDivElement>(null);
const onLoadMoreRef = useRef(onLoadMore);
@@ -185,11 +203,11 @@ export function LoadMoreSentinel({
if (autoFillRoundsRef.current >= MAX_AUTO_FILL_ROUNDS) return;
captureAndLoad(true);
},
{ rootMargin: "200px 0px 0px 0px" },
{ rootMargin },
);
observer.observe(sentinelRef.current);
return () => observer.disconnect();
}, [hasMore, isLoading, scrollRef]);
}, [hasMore, isLoading, rootMargin, scrollRef]);
// After React commits new DOM nodes (prepended messages), adjust
// scrollTop so the user stays at the same visual position.
@@ -202,7 +220,9 @@ export function LoadMoreSentinel({
scrollSnapshotRef.current;
if (!el || prevHeight === 0) return;
const delta = el.scrollHeight - prevHeight;
if (delta > 0) {
// Only restore scroll position for backward pagination (content prepended
// above). Forward pagination appends below — no adjustment needed.
if (adjustScroll && delta > 0) {
el.scrollTop = prevTop + delta;
}
// Reset the auto-fill backoff whenever the container becomes
@@ -216,7 +236,7 @@ export function LoadMoreSentinel({
}
scrollSnapshotRef.current = { scrollHeight: 0, scrollTop: 0 };
autoTriggeredRef.current = false;
}, [messageCount, scrollRef]);
}, [adjustScroll, messageCount, scrollRef]);
return (
<div
@@ -235,7 +255,7 @@ export function LoadMoreSentinel({
size="small"
onClick={() => captureAndLoad(false)}
>
Load older messages
{forwardPaginated ? "Load newer messages" : "Load older messages"}
</Button>
)
)}
@@ -252,6 +272,7 @@ export function ChatMessagesContainer({
hasMoreMessages,
isLoadingMore,
onLoadMore,
forwardPaginated,
onRetry,
historicalDurations,
}: Props) {
@@ -330,7 +351,7 @@ export function ChatMessagesContainer({
}
>
<ConversationContent className="flex min-h-full flex-1 flex-col gap-6 px-3 py-6">
{hasMoreMessages && onLoadMore && (
{hasMoreMessages && onLoadMore && !forwardPaginated && (
<LoadMoreSentinel
hasMore={hasMoreMessages}
isLoading={!!isLoadingMore}
@@ -489,6 +510,17 @@ export function ChatMessagesContainer({
</pre>
</details>
)}
{hasMoreMessages && onLoadMore && forwardPaginated && (
<LoadMoreSentinel
hasMore={hasMoreMessages}
isLoading={!!isLoadingMore}
messageCount={messages.length}
onLoadMore={onLoadMore}
rootMargin="0px 0px 200px 0px"
adjustScroll={false}
forwardPaginated
/>
)}
</ConversationContent>
<ConversationScrollButton />
</Conversation>

View File

@@ -124,22 +124,48 @@ describe("ChatMessagesContainer", () => {
vi.unstubAllGlobals();
});
it("renders top sentinel for backward pagination", () => {
it("renders top sentinel when forwardPaginated is false (backward pagination)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={false} />);
expect(
screen.getByRole("button", { name: /load older messages/i }),
).toBeDefined();
});
it("renders top sentinel when forwardPaginated is undefined (default, backward)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} />);
expect(
screen.getByRole("button", { name: /load older messages/i }),
).toBeDefined();
});
it("renders bottom sentinel when forwardPaginated is true (forward pagination)", () => {
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={true} />);
expect(
screen.getByRole("button", { name: /load newer messages/i }),
).toBeDefined();
});
it("hides sentinel when hasMoreMessages is false", () => {
render(<ChatMessagesContainer {...BASE_PROPS} hasMoreMessages={false} />);
render(
<ChatMessagesContainer
{...BASE_PROPS}
hasMoreMessages={false}
forwardPaginated={true}
/>,
);
expect(
screen.queryByRole("button", { name: /load older messages/i }),
).toBeNull();
});
it("hides sentinel when onLoadMore is not provided", () => {
render(<ChatMessagesContainer {...BASE_PROPS} onLoadMore={undefined} />);
render(
<ChatMessagesContainer
{...BASE_PROPS}
onLoadMore={undefined}
forwardPaginated={true}
/>,
);
expect(
screen.queryByRole("button", { name: /load older messages/i }),
).toBeNull();

View File

@@ -172,6 +172,36 @@ describe("LoadMoreSentinel", () => {
expect(mockScrollEl.scrollTop).toBe(200);
});
it("does NOT adjust scroll when adjustScroll=false (forward pagination)", () => {
mockScrollEl.scrollHeight = 100;
mockScrollEl.scrollTop = 50;
const onLoadMore = vi.fn();
const { rerender } = render(
<LoadMoreSentinel
hasMore={true}
isLoading={false}
messageCount={5}
onLoadMore={onLoadMore}
adjustScroll={false}
/>,
);
// Fire observer to capture snapshot.
MockIntersectionObserver.lastCallback?.([{ isIntersecting: true }]);
// Simulate DOM growing from appended newer messages (forward load-more).
mockScrollEl.scrollHeight = 300;
rerender(
<LoadMoreSentinel
hasMore={true}
isLoading={false}
messageCount={10}
onLoadMore={onLoadMore}
adjustScroll={false}
/>,
);
// scrollTop should remain unchanged — no jump for forward pagination.
expect(mockScrollEl.scrollTop).toBe(50);
});
it("ignores same-frame duplicate triggers until isLoading transitions", () => {
const onLoadMore = vi.fn();
const { rerender } = render(

View File

@@ -246,7 +246,7 @@ export function ChatSidebar() {
</SidebarHeader>
)}
{!isCollapsed && (
<SidebarHeader className="shrink-0 px-4 pb-3 pt-3 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
<SidebarHeader className="shrink-0 px-4 pb-4 pt-4 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}

View File

@@ -13,10 +13,6 @@ import {
getSuggestionThemes,
} from "./helpers";
import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes";
import { PulseChips } from "../PulseChips/PulseChips";
import { usePulseChips } from "../PulseChips/usePulseChips";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { EditNameDialog } from "./components/EditNameDialog/EditNameDialog";
interface Props {
inputLayoutId: string;
@@ -38,8 +34,6 @@ export function EmptySession({
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const pulseChips = usePulseChips();
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
useGetV2GetSuggestedPrompts({
@@ -81,16 +75,11 @@ export function EmptySession({
<div className="mx-auto max-w-[52rem]">
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
Hey, <span className="text-violet-600">{greetingName}</span>
<EditNameDialog currentName={greetingName} />
</Text>
<Text variant="h3" className="mb-8 !font-normal">
Tell me about your work I&apos;ll find what to automate.
</Text>
{isAgentBriefingEnabled && (
<PulseChips chips={pulseChips} onChipClick={onSend} />
)}
<div className="mb-6">
<motion.div
layoutId={inputLayoutId}

View File

@@ -1,107 +0,0 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { PencilSimpleIcon } from "@phosphor-icons/react";
import { useState } from "react";
interface Props {
currentName: string;
}
export function EditNameDialog({ currentName }: Props) {
const [isOpen, setIsOpen] = useState(false);
const [name, setName] = useState(currentName);
const [isSaving, setIsSaving] = useState(false);
const { refreshSession } = useSupabase();
const { toast } = useToast();
function handleOpenChange(open: boolean) {
if (open) setName(currentName);
setIsOpen(open);
}
async function handleSave() {
const trimmed = name.trim();
if (!trimmed) return;
setIsSaving(true);
try {
const res = await fetch("/api/auth/user", {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ full_name: trimmed }),
});
if (!res.ok) {
const body = await res.json();
toast({
title: "Failed to update name",
description: body.error ?? "Unknown error",
variant: "destructive",
});
return;
}
const session = await refreshSession();
if (session?.error) {
toast({
title: "Name saved, but session refresh failed",
description: session.error,
variant: "destructive",
});
setIsOpen(false);
return;
}
setIsOpen(false);
toast({ title: "Name updated" });
} finally {
setIsSaving(false);
}
}
return (
<Dialog
title="Edit display name"
styling={{ maxWidth: "24rem" }}
controlled={{ isOpen, set: handleOpenChange }}
>
<Dialog.Trigger>
<button
type="button"
className="ml-1 inline-flex items-center text-violet-500 transition-colors hover:text-violet-700"
>
<PencilSimpleIcon size={16} />
</button>
</Dialog.Trigger>
<Dialog.Content>
<div className="flex flex-col gap-4 px-1">
<Input
id="display-name"
label="Display name"
placeholder="Your name"
value={name}
onChange={(e) => setName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
handleSave();
}
}}
/>
<Button
variant="primary"
onClick={handleSave}
disabled={!name.trim() || isSaving}
loading={isSaving}
>
Save
</Button>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -1,135 +0,0 @@
import { beforeEach, describe, expect, test, vi } from "vitest";
import {
fireEvent,
render,
screen,
waitFor,
} from "@/tests/integrations/test-utils";
import { server } from "@/mocks/mock-server";
import { http, HttpResponse } from "msw";
import { EditNameDialog } from "../EditNameDialog";
const mockToast = vi.hoisted(() => vi.fn());
const mockRefreshSession = vi.hoisted(() => vi.fn());
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: mockToast }),
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: () => ({
refreshSession: mockRefreshSession,
}),
}));
function mockUpdateNameSuccess() {
server.use(
http.put("/api/auth/user", () => {
return HttpResponse.json({ user: { id: "u1" } });
}),
);
}
function mockUpdateNameError(message = "Network error") {
server.use(
http.put("/api/auth/user", () => {
return HttpResponse.json({ error: message }, { status: 400 });
}),
);
}
async function openDialogAndGetInput() {
const trigger = screen.getByRole("button");
fireEvent.click(trigger);
await screen.findAllByLabelText(/display name/i);
const inputs =
document.querySelectorAll<HTMLInputElement>("input#display-name");
return inputs[0];
}
function getSaveButton() {
const saves = screen.getAllByRole("button", { name: /save/i });
return saves[0] as HTMLButtonElement;
}
describe("EditNameDialog", () => {
beforeEach(() => {
mockToast.mockReset();
mockRefreshSession.mockReset();
mockRefreshSession.mockResolvedValue({ user: { id: "u1" } });
});
test("opens dialog with current name prefilled", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
expect(input.value).toBe("Alice");
});
test("saves name via API route and closes dialog", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockRefreshSession).toHaveBeenCalled();
});
expect(mockToast).toHaveBeenCalledWith({ title: "Name updated" });
});
test("shows error toast when API returns error", async () => {
mockUpdateNameError("Network error");
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Failed to update name",
description: "Network error",
variant: "destructive",
}),
);
});
expect(mockRefreshSession).not.toHaveBeenCalled();
});
test("shows warning toast when refreshSession returns an error", async () => {
mockUpdateNameSuccess();
mockRefreshSession.mockResolvedValue({ error: "refresh failed" });
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Name saved, but session refresh failed",
description: "refresh failed",
variant: "destructive",
}),
);
});
expect(mockToast).not.toHaveBeenCalledWith({ title: "Name updated" });
});
test("disables Save button while input is empty", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: " " } });
expect(getSaveButton().disabled).toBe(true);
});
});

View File

@@ -1,93 +0,0 @@
.glassPanel {
position: relative;
isolation: isolate;
}
.glassPanel::before {
content: "";
position: absolute;
inset: 0;
border-radius: inherit;
padding: 1px;
background: conic-gradient(
from var(--border-angle, 0deg),
rgba(129, 120, 228, 0.08),
rgba(129, 120, 228, 0.28),
rgba(168, 130, 255, 0.18),
rgba(129, 120, 228, 0.08),
rgba(99, 102, 241, 0.24),
rgba(129, 120, 228, 0.08)
);
-webkit-mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
-webkit-mask-composite: xor;
mask-composite: exclude;
animation: rotate-border 6s linear infinite;
pointer-events: none;
z-index: -1;
}
@property --border-angle {
syntax: "<angle>";
initial-value: 0deg;
inherits: false;
}
@keyframes rotate-border {
to {
--border-angle: 360deg;
}
}
.chip {
overflow: hidden;
}
@media (hover: hover) {
.chip {
padding-bottom: 0.9rem;
}
}
@media (hover: none) {
.chip {
padding-bottom: 2.25rem;
}
}
.chipActions {
position: absolute;
inset-inline: 0;
bottom: 0;
background: rgba(255, 255, 255, 0.95);
backdrop-filter: blur(4px);
-webkit-backdrop-filter: blur(4px);
}
@media (hover: hover) {
.chipActions {
opacity: 0;
transform: translateY(100%);
transition:
opacity 0.2s ease-out,
transform 0.2s ease-out;
}
.chip:hover .chipActions {
opacity: 1;
transform: translateY(0);
}
.chipContent {
transition: filter 0.2s ease-out;
}
.chip:hover .chipContent {
filter: blur(2px);
opacity: 0.5;
}
}

View File

@@ -1,116 +0,0 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import {
ArrowRightIcon,
EyeIcon,
ChatCircleDotsIcon,
} from "@phosphor-icons/react";
import NextLink from "next/link";
import { StatusBadge } from "@/app/(platform)/library/components/StatusBadge/StatusBadge";
import styles from "./PulseChips.module.css";
import type { PulseChipData } from "./types";
interface Props {
chips: PulseChipData[];
onChipClick?: (prompt: string) => void;
}
export function PulseChips({ chips, onChipClick }: Props) {
if (chips.length === 0) return null;
return (
<div
className={`${styles.glassPanel} mx-[0.6875rem] mb-5 rounded-large p-5`}
>
<div className="mb-3 flex items-center gap-3">
<Text variant="body-medium" className="text-zinc-600">
What&apos;s happening with your agents
</Text>
<NextLink
href="/library"
className="flex items-center gap-1 text-xs text-zinc-500 hover:text-zinc-700"
>
View all <ArrowRightIcon size={12} />
</NextLink>
</div>
<div className="flex gap-2 overflow-x-auto pb-1 scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300">
{chips.map((chip) => (
<PulseChip key={chip.id} chip={chip} onAsk={onChipClick} />
))}
</div>
</div>
);
}
interface ChipProps {
chip: PulseChipData;
onAsk?: (prompt: string) => void;
}
function PulseChip({ chip, onAsk }: ChipProps) {
function handleAsk() {
const prompt = buildChipPrompt(chip);
onAsk?.(prompt);
}
return (
<div
className={`${styles.chip} relative flex w-[15rem] shrink-0 flex-col items-start gap-2 rounded-medium border border-zinc-100 bg-white px-3 py-2`}
>
<div className={`${styles.chipContent} w-full text-left`}>
{chip.priority === "success" ? (
<span className="inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium text-emerald-600">
<span className="h-1.5 w-1.5 rounded-full bg-emerald-500" />
Completed
</span>
) : (
<StatusBadge status={chip.status} />
)}
<div className="mt-2 min-w-0">
<Text variant="small-medium" className="truncate text-zinc-900">
{chip.name}
</Text>
<Text variant="small" className="truncate text-zinc-500">
{chip.shortMessage}
</Text>
</div>
</div>
<div
className={`${styles.chipActions} flex items-center justify-center gap-1.5 rounded-b-medium px-3 py-1.5`}
>
<NextLink
href={`/library/agents/${chip.agentID}`}
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
>
<EyeIcon size={14} />
See
</NextLink>
<button
type="button"
onClick={handleAsk}
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
>
<ChatCircleDotsIcon size={14} />
Ask
</button>
</div>
</div>
);
}
function buildChipPrompt(chip: PulseChipData): string {
if (chip.priority === "success") {
return `${chip.name} just finished a run — can you summarize what it did?`;
}
switch (chip.status) {
case "error":
return `What happened with ${chip.name}? It has an error — can you check?`;
case "running":
return `Give me a status update on ${chip.name} — what has it done so far?`;
case "idle":
return `${chip.name} hasn't run recently. Should I keep it or update and re-run it?`;
default:
return `Tell me about ${chip.name} — what's its current status?`;
}
}

View File

@@ -1,105 +0,0 @@
import { describe, expect, test, vi } from "vitest";
import { render, screen, fireEvent } from "@/tests/integrations/test-utils";
import { PulseChips } from "../PulseChips";
import type { PulseChipData } from "../types";
function makeChip(overrides: Partial<PulseChipData> = {}): PulseChipData {
return {
id: "chip-1",
agentID: "agent-1",
name: "Test Agent",
status: "running",
priority: "running",
shortMessage: "Doing work…",
...overrides,
};
}
describe("PulseChips", () => {
test("renders nothing when chips array is empty", () => {
const { container } = render(<PulseChips chips={[]} />);
expect(container.innerHTML).toBe("");
});
test("renders chip names and messages", () => {
const chips = [
makeChip({ id: "1", name: "Alpha Bot", shortMessage: "Running task A" }),
makeChip({ id: "2", name: "Beta Bot", shortMessage: "Running task B" }),
];
render(<PulseChips chips={chips} />);
expect(screen.getByText("Alpha Bot")).toBeDefined();
expect(screen.getByText("Running task A")).toBeDefined();
expect(screen.getByText("Beta Bot")).toBeDefined();
expect(screen.getByText("Running task B")).toBeDefined();
});
test("renders section heading and View all link", () => {
render(<PulseChips chips={[makeChip()]} />);
expect(screen.getByText("What's happening with your agents")).toBeDefined();
expect(screen.getByText("View all")).toBeDefined();
});
test("shows Completed badge for success priority chips", () => {
render(
<PulseChips
chips={[makeChip({ priority: "success", status: "idle" })]}
/>,
);
expect(screen.getByText("Completed")).toBeDefined();
});
test("calls onChipClick with generated prompt when Ask is clicked", () => {
const onChipClick = vi.fn();
render(
<PulseChips
chips={[
makeChip({
name: "Error Agent",
status: "error",
priority: "error",
}),
]}
onChipClick={onChipClick}
/>,
);
fireEvent.click(screen.getByText("Ask"));
expect(onChipClick).toHaveBeenCalledWith(
"What happened with Error Agent? It has an error — can you check?",
);
});
test("generates success prompt for completed chips", () => {
const onChipClick = vi.fn();
render(
<PulseChips
chips={[
makeChip({
name: "Done Agent",
priority: "success",
status: "idle",
}),
]}
onChipClick={onChipClick}
/>,
);
fireEvent.click(screen.getByText("Ask"));
expect(onChipClick).toHaveBeenCalledWith(
"Done Agent just finished a run — can you summarize what it did?",
);
});
test("renders See link pointing to agent detail page", () => {
render(<PulseChips chips={[makeChip({ agentID: "agent-xyz" })]} />);
const seeLink = screen.getByText("See").closest("a");
expect(seeLink?.getAttribute("href")).toBe("/library/agents/agent-xyz");
});
});

View File

@@ -1,13 +0,0 @@
import type {
AgentStatus,
SitrepPriority,
} from "@/app/(platform)/library/types";
export interface PulseChipData {
id: string;
agentID: string;
name: string;
status: AgentStatus;
priority: SitrepPriority;
shortMessage: string;
}

View File

@@ -1,25 +0,0 @@
"use client";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
import { useSitrepItems } from "@/app/(platform)/library/components/SitrepItem/useSitrepItems";
import type { PulseChipData } from "./types";
import { useMemo } from "react";
const THREE_DAYS_MS = 3 * 24 * 60 * 60 * 1000;
export function usePulseChips(): PulseChipData[] {
const { agents } = useLibraryAgents();
const sitrepItems = useSitrepItems(agents, 5, THREE_DAYS_MS);
return useMemo(() => {
return sitrepItems.map((item) => ({
id: item.id,
agentID: item.agentID,
name: item.agentName,
status: item.status,
priority: item.priority,
shortMessage: item.message,
}));
}, [sitrepItems]);
}

View File

@@ -6,9 +6,6 @@ import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useRouter } from "next/navigation";
import { useEffect, useRef } from "react";
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
import { formatCents } from "../usageHelpers";
export { formatCents };
interface Props {
isOpen: boolean;
@@ -21,6 +18,10 @@ interface Props {
onCreditChange?: () => void;
}
export function formatCents(cents: number): string {
return `$${(cents / 100).toFixed(2)}`;
}
export function RateLimitResetDialog({
isOpen,
onClose,

View File

@@ -1,10 +1,35 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { Button } from "@/components/atoms/Button/Button";
import Link from "next/link";
import { formatCents, formatResetTime } from "../usageHelpers";
import { formatCents } from "../RateLimitResetDialog/RateLimitResetDialog";
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
export { formatResetTime };
export function formatResetTime(
resetsAt: Date | string,
now: Date = new Date(),
): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / (1000 * 60 * 60));
// Under 24h: show relative time ("in 4h 23m")
if (hours < 24) {
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
function UsageBar({
label,

View File

@@ -1,28 +0,0 @@
export function formatCents(cents: number): string {
return `$${(cents / 100).toFixed(2)}`;
}
export function formatResetTime(
resetsAt: Date | string,
now: Date = new Date(),
): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / (1000 * 60 * 60));
if (hours < 24) {
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}

View File

@@ -86,6 +86,16 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
return sessionQuery.data.data.oldest_sequence ?? null;
}, [sessionQuery.data]);
const newestSequence = useMemo(() => {
if (sessionQuery.data?.status !== 200) return null;
return sessionQuery.data.data.newest_sequence ?? null;
}, [sessionQuery.data]);
const forwardPaginated = useMemo(() => {
if (sessionQuery.data?.status !== 200) return false;
return !!sessionQuery.data.data.forward_paginated;
}, [sessionQuery.data]);
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
// array reference every render. Re-derives only when query data changes.
// When the session is complete (no active stream), mark dangling tool
@@ -185,6 +195,8 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
hasActiveStream,
hasMoreMessages,
oldestSequence,
newestSequence,
forwardPaginated,
isLoadingSession: sessionQuery.isLoading,
isSessionError: sessionQuery.isError,
createSession,

View File

@@ -56,6 +56,8 @@ export function useCopilotPage() {
hasActiveStream,
hasMoreMessages,
oldestSequence,
newestSequence,
forwardPaginated,
isLoadingSession,
isSessionError,
createSession,
@@ -84,19 +86,26 @@ export function useCopilotPage() {
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
});
const { pagedMessages, hasMore, isLoadingMore, loadMore } =
const { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged } =
useLoadMoreMessages({
sessionId,
initialOldestSequence: oldestSequence,
initialNewestSequence: newestSequence,
initialHasMore: hasMoreMessages,
forwardPaginated,
initialPageRawMessages: rawSessionMessages,
});
// Combine paginated messages with current page messages, merging consecutive
// assistant UIMessages at the page boundary so reasoning + response parts
// stay in a single bubble. Paged messages are older history prepended before
// the current page.
const messages = concatWithAssistantMerge(pagedMessages, currentMessages);
// stay in a single bubble.
// Forward pagination (completed sessions): current page is the beginning,
// paged messages are newer pages appended after.
// Backward pagination (active sessions): paged messages are older history
// prepended before the current page.
const messages = forwardPaginated
? concatWithAssistantMerge(currentMessages, pagedMessages)
: concatWithAssistantMerge(pagedMessages, currentMessages);
useCopilotNotifications(sessionId);
@@ -252,6 +261,15 @@ export function useCopilotPage() {
isUserStoppingRef.current = false;
if (sessionId) {
// When continuing a completed session that had forward-paginated history
// loaded, the paged messages would appear in wrong position relative to
// the new streaming turn (pagedMessages are newer pages, so they'd end
// up after the streaming turn). Reset paged state so ordering is correct
// during streaming; the user can reload history afterward if needed.
if (forwardPaginated && pagedMessages.length > 0) {
resetPaged();
}
if (files && files.length > 0) {
setIsUploadingFiles(true);
try {
@@ -398,6 +416,7 @@ export function useCopilotPage() {
hasMoreMessages: hasMore,
isLoadingMore,
loadMore,
forwardPaginated,
// Mobile drawer
isMobile,
isDrawerOpen,

View File

@@ -9,7 +9,11 @@ import {
interface UseLoadMoreMessagesArgs {
sessionId: string | null;
initialOldestSequence: number | null;
initialNewestSequence: number | null;
initialHasMore: boolean;
/** True when the initial page was loaded from sequence 0 forward (completed
* sessions). False when loaded newest-first (active sessions). */
forwardPaginated: boolean;
/** Raw messages from the initial page, used for cross-page tool output matching. */
initialPageRawMessages: unknown[];
}
@@ -20,7 +24,9 @@ const MAX_OLDER_MESSAGES = 2000;
export function useLoadMoreMessages({
sessionId,
initialOldestSequence,
initialNewestSequence,
initialHasMore,
forwardPaginated,
initialPageRawMessages,
}: UseLoadMoreMessagesArgs) {
// Accumulated raw messages from all extra pages (ascending order).
@@ -30,6 +36,9 @@ export function useLoadMoreMessages({
const [oldestSequence, setOldestSequence] = useState<number | null>(
initialOldestSequence,
);
const [newestSequence, setNewestSequence] = useState<number | null>(
initialNewestSequence,
);
const [hasMore, setHasMore] = useState(initialHasMore);
const [isLoadingMore, setIsLoadingMore] = useState(false);
const isLoadingMoreRef = useRef(false);
@@ -37,7 +46,9 @@ export function useLoadMoreMessages({
// Epoch counter to discard stale loadMore responses after a reset
const epochRef = useRef(0);
// Track the sessionId and initial cursor to reset state on change
const prevSessionIdRef = useRef(sessionId);
const prevInitialOldestRef = useRef(initialOldestSequence);
// Sync initial values from parent when they change.
//
@@ -60,8 +71,10 @@ export function useLoadMoreMessages({
if (prevSessionIdRef.current !== sessionId) {
// Session changed — full reset
prevSessionIdRef.current = sessionId;
prevInitialOldestRef.current = initialOldestSequence;
setPagedRawMessages([]);
setOldestSequence(initialOldestSequence);
setNewestSequence(initialNewestSequence);
setHasMore(initialHasMore);
setIsLoadingMore(false);
isLoadingMoreRef.current = false;
@@ -70,23 +83,34 @@ export function useLoadMoreMessages({
return;
}
prevInitialOldestRef.current = initialOldestSequence;
// If we haven't paged yet, mirror the parent so the first
// `loadMore` starts from the correct cursor.
if (pagedRawMessages.length === 0) {
setOldestSequence(initialOldestSequence);
// Only regress the forward cursor if we haven't paged ahead yet —
// otherwise a parent refetch would reset a cursor we already advanced.
setNewestSequence((prev) =>
prev !== null && prev > (initialNewestSequence ?? -1)
? prev
: initialNewestSequence,
);
setHasMore(initialHasMore);
}
}, [sessionId, initialOldestSequence, initialHasMore]);
}, [sessionId, initialOldestSequence, initialNewestSequence, initialHasMore]);
// Convert all accumulated raw messages in one pass so tool outputs
// are matched across inter-page boundaries.
// Include initial page tool outputs so older paged pages can match
// tool calls whose outputs landed in the initial page.
// For backward pagination only: include initial page tool outputs so older
// paged pages can match tool calls whose outputs landed in the initial page.
// For forward pagination this is unnecessary — tool calls in newer paged
// pages cannot have their outputs in the older initial page.
const pagedMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
useMemo(() => {
if (!sessionId || pagedRawMessages.length === 0) return [];
const extraToolOutputs =
initialPageRawMessages.length > 0
!forwardPaginated && initialPageRawMessages.length > 0
? extractToolOutputsFromRaw(initialPageRawMessages)
: undefined;
return convertChatSessionMessagesToUiMessages(
@@ -94,20 +118,22 @@ export function useLoadMoreMessages({
pagedRawMessages,
{ isComplete: true, extraToolOutputs },
).messages;
}, [sessionId, pagedRawMessages, initialPageRawMessages]);
}, [sessionId, pagedRawMessages, initialPageRawMessages, forwardPaginated]);
async function loadMore() {
if (!sessionId || !hasMore || isLoadingMoreRef.current) return;
if (oldestSequence === null) return;
const cursor = forwardPaginated ? newestSequence : oldestSequence;
if (cursor === null) return;
const requestEpoch = epochRef.current;
isLoadingMoreRef.current = true;
setIsLoadingMore(true);
try {
const response = await getV2GetSession(sessionId, {
limit: 50,
before_sequence: oldestSequence,
});
const params = forwardPaginated
? { limit: 50, after_sequence: cursor }
: { limit: 50, before_sequence: cursor };
const response = await getV2GetSession(sessionId, params);
// Discard response if session/pagination was reset while awaiting
if (epochRef.current !== requestEpoch) return;
@@ -126,20 +152,40 @@ export function useLoadMoreMessages({
consecutiveErrorsRef.current = 0;
const newRaw = (response.data.messages ?? []) as unknown[];
const estimatedTotal = pagedRawMessages.length + newRaw.length;
setPagedRawMessages((prev) => {
const merged = [...newRaw, ...prev];
// Forward: append to end. Backward: prepend to start.
const merged = forwardPaginated
? [...prev, ...newRaw]
: [...newRaw, ...prev];
if (merged.length > MAX_OLDER_MESSAGES) {
return merged.slice(merged.length - MAX_OLDER_MESSAGES);
// Backward: discard the oldest (front) items — user has scrolled far
// back and we shed the furthest history.
// Forward: discard the newest (tail) items — we only ever fetch
// forward, so the tail is the most recently appended page; shedding
// it means the sentinel stalls, which is safer than discarding the
// beginning of the conversation the user is here to read.
return forwardPaginated
? merged.slice(0, MAX_OLDER_MESSAGES)
: merged.slice(merged.length - MAX_OLDER_MESSAGES);
}
return merged;
});
// Note: after truncation, oldest_sequence may reference a dropped
// message. This is safe because we also set hasMore=false below,
// preventing further loads with the stale cursor.
setOldestSequence(response.data.oldest_sequence ?? null);
if (estimatedTotal >= MAX_OLDER_MESSAGES) {
if (forwardPaginated) {
setNewestSequence(response.data.newest_sequence ?? null);
} else {
setOldestSequence(response.data.oldest_sequence ?? null);
}
const totalAfterMerge = newRaw.length + pagedRawMessages.length;
if (forwardPaginated) {
// Forward: truncation sheds the newest tail but the cursor
// (newestSequence) still advances, so the sentinel can keep
// fetching. Only stop when the server reports no more messages.
setHasMore(!!response.data.has_more_messages);
} else if (totalAfterMerge >= MAX_OLDER_MESSAGES) {
// Backward: we've accumulated MAX_OLDER_MESSAGES of history —
// stop to avoid unbounded memory growth.
setHasMore(false);
} else {
setHasMore(!!response.data.has_more_messages);
@@ -159,5 +205,22 @@ export function useLoadMoreMessages({
}
}
return { pagedMessages, hasMore, isLoadingMore, loadMore };
function resetPaged() {
setPagedRawMessages([]);
setOldestSequence(initialOldestSequence);
setNewestSequence(initialNewestSequence);
// Set hasMore=false during the session-transition window so no loadMore
// fires with forward pagination (after_sequence) on the now-active session.
// The useEffect will restore hasMore from the parent after the refetch
// completes and forwardPaginated switches to false.
setHasMore(false);
// Clear the loading state so the spinner doesn't stay stuck if a loadMore
// was in flight when resetPaged was called.
setIsLoadingMore(false);
isLoadingMoreRef.current = false;
consecutiveErrorsRef.current = 0;
epochRef.current += 1;
}
return { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged };
}

View File

@@ -2,17 +2,14 @@ import { Navbar } from "@/components/layout/Navbar/Navbar";
import { NetworkStatusMonitor } from "@/services/network-status/NetworkStatusMonitor";
import { ReactNode } from "react";
import { AdminImpersonationBanner } from "./admin/components/AdminImpersonationBanner";
import { AutoPilotBridgeProvider } from "@/contexts/AutoPilotBridgeContext";
export default function PlatformLayout({ children }: { children: ReactNode }) {
return (
<AutoPilotBridgeProvider>
<main className="flex h-screen w-full flex-col">
<NetworkStatusMonitor />
<Navbar />
<AdminImpersonationBanner />
<section className="flex-1">{children}</section>
</main>
</AutoPilotBridgeProvider>
<main className="flex h-screen w-full flex-col">
<NetworkStatusMonitor />
<Navbar />
<AdminImpersonationBanner />
<section className="flex-1">{children}</section>
</main>
);
}

View File

@@ -137,10 +137,8 @@ describe("LibraryPage", () => {
user_id: "test-user",
name: "Work Agents",
agent_count: 3,
subfolder_count: 0,
color: null,
icon: null,
parent_id: null,
created_at: new Date(),
updated_at: new Date(),
},
@@ -149,10 +147,8 @@ describe("LibraryPage", () => {
user_id: "test-user",
name: "Personal",
agent_count: 1,
subfolder_count: 0,
color: null,
icon: null,
parent_id: null,
created_at: new Date(),
updated_at: new Date(),
},
@@ -162,14 +158,12 @@ describe("LibraryPage", () => {
render(<LibraryPage />);
await waitForAgentsToLoad();
expect(await screen.findByText("Work Agents")).toBeDefined();
expect(screen.getByText("Personal")).toBeDefined();
expect(screen.getAllByTestId("library-folder")).toHaveLength(2);
});
test("shows See tasks link on agent card", async () => {
test("shows See runs link on agent card", async () => {
setupHandlers({
agents: [makeAgent({ name: "Linked Agent", can_access_graph: true })],
});
@@ -178,7 +172,7 @@ describe("LibraryPage", () => {
await screen.findByText("Linked Agent");
const runLinks = screen.getAllByText("See tasks");
const runLinks = screen.getAllByText("See runs");
expect(runLinks.length).toBeGreaterThan(0);
});
@@ -196,7 +190,7 @@ describe("LibraryPage", () => {
expect(importButtons.length).toBeGreaterThan(0);
});
test("renders running agent card when execution is active", async () => {
test("renders Jump Back In when there is an active execution", async () => {
const agent = makeAgent({
id: "lib-1",
graph_id: "g-1",
@@ -224,6 +218,6 @@ describe("LibraryPage", () => {
render(<LibraryPage />);
expect(await screen.findByText("Running Agent")).toBeDefined();
expect(await screen.findByText("Jump Back In")).toBeDefined();
});
});

View File

@@ -1,44 +0,0 @@
.glassPanel {
position: relative;
isolation: isolate;
}
.glassPanel::before {
content: "";
position: absolute;
inset: 0;
border-radius: inherit;
padding: 1px;
background: conic-gradient(
from var(--border-angle, 0deg),
rgba(129, 120, 228, 0.04),
rgba(129, 120, 228, 0.14),
rgba(168, 130, 255, 0.09),
rgba(129, 120, 228, 0.04),
rgba(99, 102, 241, 0.12),
rgba(129, 120, 228, 0.04)
);
-webkit-mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
-webkit-mask-composite: xor;
mask-composite: exclude;
animation: rotate-border 6s linear infinite;
pointer-events: none;
z-index: -1;
}
@property --border-angle {
syntax: "<angle>";
initial-value: 0deg;
inherits: false;
}
@keyframes rotate-border {
to {
--border-angle: 360deg;
}
}

View File

@@ -1,36 +0,0 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useState } from "react";
import type { FleetSummary, AgentStatusFilter } from "../../types";
import { BriefingTabContent } from "./BriefingTabContent";
import { StatsGrid } from "./StatsGrid";
import styles from "./AgentBriefingPanel.module.css";
interface Props {
summary: FleetSummary;
agents: LibraryAgent[];
}
export function AgentBriefingPanel({ summary, agents }: Props) {
const [userTab, setUserTab] = useState<AgentStatusFilter | null>(null);
const activeTab: AgentStatusFilter =
userTab ?? (summary.running > 0 ? "running" : "all");
return (
<div
className={`${styles.glassPanel} min-h-[14.75rem] rounded-large bg-gradient-to-br from-indigo-50/30 via-white/90 to-purple-50/25 px-5 pb-5 pt-[1.125rem] shadow-sm backdrop-blur-md`}
>
<Text variant="h5">Agent Briefing</Text>
<div className="mt-4 space-y-5">
<StatsGrid
summary={summary}
activeTab={activeTab}
onTabChange={setUserTab}
/>
<BriefingTabContent activeTab={activeTab} agents={agents} />
</div>
</div>
);
}

View File

@@ -1,361 +0,0 @@
"use client";
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
import {
formatResetTime,
formatCents,
} from "@/app/(platform)/copilot/components/usageHelpers";
import { useResetRateLimit } from "@/app/(platform)/copilot/hooks/useResetRateLimit";
import { Button } from "@/components/atoms/Button/Button";
import { Badge } from "@/components/atoms/Badge/Badge";
import useCredits from "@/hooks/useCredits";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useSitrepItems } from "../SitrepItem/useSitrepItems";
import { SitrepItem } from "../SitrepItem/SitrepItem";
import { useAgentStatusMap } from "../../hooks/useAgentStatus";
import type { AgentStatusFilter } from "../../types";
import { Text } from "@/components/atoms/Text/Text";
import Link from "next/link";
import { useState } from "react";
interface Props {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}
export function BriefingTabContent({ activeTab, agents }: Props) {
if (activeTab === "all") {
return <UsageSection />;
}
if (
activeTab === "running" ||
activeTab === "attention" ||
activeTab === "completed"
) {
return <ExecutionListSection activeTab={activeTab} agents={agents} />;
}
return <AgentListSection activeTab={activeTab} agents={agents} />;
}
function UsageSection() {
const { data: usage } = useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
const isBillingEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
const { credits, fetchCredits } = useCredits({ fetchInitialCredits: true });
const resetCost = usage?.reset_cost;
const hasInsufficientCredits =
credits !== null && resetCost != null && credits < resetCost;
if (!usage?.daily || !usage?.weekly) return null;
return (
<div className="py-2">
<div className="flex items-center gap-2">
<Text variant="h5" className="text-neutral-800">
Usage limits
</Text>
{usage.tier && (
<Badge variant="info" size="small" className="bg-[rgb(224,237,255)]">
{usage.tier.charAt(0) + usage.tier.slice(1).toLowerCase()} plan
</Badge>
)}
<div className="flex-1" />
{isBillingEnabled && (
<Link
href="/profile/credits"
className="text-sm text-blue-600 hover:underline"
>
Manage billing
</Link>
)}
</div>
<div className="mt-4 grid grid-cols-1 gap-6 sm:grid-cols-2">
{usage.daily.limit > 0 && (
<UsageMeter
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{usage.weekly.limit > 0 && (
<UsageMeter
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
</div>
<UsageFooter
usage={usage}
hasInsufficientCredits={hasInsufficientCredits}
onCreditChange={fetchCredits}
/>
</div>
);
}
const MAX_VISIBLE = 6;
function ExecutionListSection({
activeTab,
agents,
}: {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}) {
const allItems = useSitrepItems(agents, 50);
const [showAll, setShowAll] = useState(false);
const filtered = allItems.filter((item) => {
if (activeTab === "running") return item.priority === "running";
if (activeTab === "attention") return item.priority === "error";
if (activeTab === "completed") return item.priority === "success";
return false;
});
if (filtered.length === 0) {
return <EmptyMessage tab={activeTab} />;
}
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
const hasMore = filtered.length > MAX_VISIBLE;
return (
<div>
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
{visible.map((item) => (
<SitrepItem key={item.id} item={item} />
))}
</div>
{hasMore && (
<div className="mt-3 flex justify-center">
<Button
variant="secondary"
size="small"
onClick={() => setShowAll(!showAll)}
>
{showAll ? "Collapse" : `Show all (${filtered.length})`}
</Button>
</div>
)}
</div>
);
}
const TAB_STATUS_LABEL: Record<string, string> = {
listening: "Waiting for trigger event",
scheduled: "Has a scheduled run",
idle: "No recent activity",
};
function getAgentStatusLabel(tab: string, agent: LibraryAgent): string {
if (tab === "scheduled" && agent.next_scheduled_run) {
const diff = new Date(agent.next_scheduled_run).getTime() - Date.now();
const minutes = Math.round(diff / 60_000);
if (minutes <= 0) return "Scheduled to run soon";
if (minutes < 60) return `Scheduled to run in ${minutes}m`;
const hours = Math.round(minutes / 60);
if (hours < 24) return `Scheduled to run in ${hours}h`;
const days = Math.round(hours / 24);
return `Scheduled to run in ${days}d`;
}
return TAB_STATUS_LABEL[tab] ?? "";
}
function AgentListSection({
activeTab,
agents,
}: {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}) {
const [showAll, setShowAll] = useState(false);
const statusMap = useAgentStatusMap(agents);
const filtered = agents.filter((agent) => {
const status = statusMap.get(agent.graph_id)?.status;
if (activeTab === "listening") return status === "listening";
if (activeTab === "scheduled") return status === "scheduled";
if (activeTab === "idle") return status === "idle";
return false;
});
if (filtered.length === 0) {
return <EmptyMessage tab={activeTab} />;
}
const status =
activeTab === "listening"
? ("listening" as const)
: activeTab === "scheduled"
? ("scheduled" as const)
: ("idle" as const);
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
const hasMore = filtered.length > MAX_VISIBLE;
return (
<div>
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
{visible.map((agent) => (
<SitrepItem
key={agent.id}
item={{
id: agent.id,
agentID: agent.id,
agentName: agent.name,
agentImageUrl: agent.image_url,
priority: status,
message: getAgentStatusLabel(activeTab, agent),
status,
}}
/>
))}
</div>
{hasMore && (
<div className="mt-3 flex justify-center">
<Button
variant="secondary"
size="small"
onClick={() => setShowAll(!showAll)}
>
{showAll ? "Collapse" : `Show all (${filtered.length})`}
</Button>
</div>
)}
</div>
);
}
function UsageFooter({
usage,
hasInsufficientCredits,
onCreditChange,
}: {
usage: CoPilotUsageStatus;
hasInsufficientCredits: boolean;
onCreditChange?: () => void;
}) {
const isDailyExhausted =
usage.daily.limit > 0 && usage.daily.used >= usage.daily.limit;
const isWeeklyExhausted =
usage.weekly.limit > 0 && usage.weekly.used >= usage.weekly.limit;
const resetCost = usage.reset_cost ?? 0;
const { resetUsage, isPending } = useResetRateLimit({ onCreditChange });
const showReset =
isDailyExhausted &&
!isWeeklyExhausted &&
resetCost > 0 &&
!hasInsufficientCredits;
const showAddCredits =
isDailyExhausted && !isWeeklyExhausted && hasInsufficientCredits;
if (!showReset && !showAddCredits) return null;
return (
<div className="mt-4 flex items-center gap-3">
{showReset && (
<Button
variant="primary"
size="small"
onClick={() => resetUsage()}
loading={isPending}
>
{isPending
? "Resetting..."
: `Reset daily limit for ${formatCents(resetCost)}`}
</Button>
)}
{showAddCredits && (
<Link
href="/profile/credits"
className="inline-flex items-center justify-center rounded-md bg-primary px-3 py-1.5 text-sm font-medium text-primary-foreground hover:bg-primary/90"
>
Add credits to reset
</Link>
)}
</div>
);
}
function UsageMeter({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-2">
<div className="flex items-baseline justify-between">
<Text variant="body-medium" className="text-neutral-700">
{label}
</Text>
<Text variant="body" className="tabular-nums text-neutral-500">
{percentLabel}
</Text>
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
<div className="flex items-baseline justify-between">
<Text variant="small" className="tabular-nums text-neutral-500">
{used.toLocaleString()} / {limit.toLocaleString()}
</Text>
<Text variant="small" className="text-neutral-400">
Resets {formatResetTime(resetsAt)}
</Text>
</div>
</div>
);
}
const EMPTY_MESSAGES: Record<string, string> = {
running: "No agents running right now",
attention: "No agents that need attention",
completed: "No recently completed runs",
listening: "No agents listening for events",
scheduled: "No agents with scheduled runs",
idle: "No idle agents",
};
function EmptyMessage({ tab }: { tab: AgentStatusFilter }) {
return (
<div className="flex items-center justify-center pt-4">
<Text variant="body-medium" className="text-zinc-600">
{EMPTY_MESSAGES[tab] ?? "No agents in this category"}
</Text>
</div>
);
}

View File

@@ -1,102 +0,0 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { OverflowText } from "@/components/atoms/OverflowText/OverflowText";
import { Emoji } from "@/components/atoms/Emoji/Emoji";
import { cn } from "@/lib/utils";
import type { FleetSummary, AgentStatusFilter } from "../../types";
interface Props {
summary: FleetSummary;
activeTab: AgentStatusFilter;
onTabChange: (tab: AgentStatusFilter) => void;
}
const TILES: {
label: string;
key: keyof FleetSummary;
format?: (v: number) => string;
filter: AgentStatusFilter;
emoji: string;
color: string;
}[] = [
{
label: "Spent this month",
key: "monthlySpend",
format: (v) => `$${v.toLocaleString()}`,
filter: "all",
emoji: "💵",
color: "text-zinc-700",
},
{
label: "Running now",
key: "running",
filter: "running",
emoji: "🚩",
color: "text-blue-600",
},
{
label: "Recently completed",
key: "completed",
filter: "completed",
emoji: "🗃️",
color: "text-green-600",
},
{
label: "Needs attention",
key: "error",
filter: "attention",
emoji: "⚠️",
color: "text-red-500",
},
{
label: "Scheduled",
key: "scheduled",
filter: "scheduled",
emoji: "📅",
color: "text-yellow-600",
},
{
label: "Idle",
key: "idle",
filter: "idle",
emoji: "💤",
color: "text-zinc-400",
},
];
export function StatsGrid({ summary, activeTab, onTabChange }: Props) {
return (
<div className="grid grid-cols-1 gap-3 min-[450px]:grid-cols-2 sm:grid-cols-3 lg:grid-cols-6">
{TILES.map((tile) => {
const rawValue = summary[tile.key];
const value = tile.format ? tile.format(rawValue) : rawValue;
const isActive = activeTab === tile.filter;
return (
<button
key={tile.label}
type="button"
onClick={() => onTabChange(tile.filter)}
className={cn(
"flex min-w-0 flex-col gap-1 rounded-medium border p-3 text-left shadow-md transition-all hover:shadow-lg",
isActive
? "border-zinc-900 bg-zinc-50"
: "border-zinc-100 bg-white",
)}
>
<div className="flex min-w-0 items-center gap-1.5">
<Emoji text={tile.emoji} size={18} />
<OverflowText
value={tile.label}
variant="body"
className="text-zinc-800"
/>
</div>
<Text variant="h4">{value}</Text>
</button>
);
})}
</div>
);
}

View File

@@ -1,52 +0,0 @@
"use client";
import type { SelectOption } from "@/components/atoms/Select/Select";
import { Select } from "@/components/atoms/Select/Select";
import { FunnelIcon } from "@phosphor-icons/react";
import type { AgentStatusFilter, FleetSummary } from "../../types";
interface Props {
value: AgentStatusFilter;
onChange: (value: AgentStatusFilter) => void;
summary: FleetSummary;
}
function buildOptions(summary: FleetSummary): SelectOption[] {
return [
{ value: "all", label: "All Agents" },
{ value: "running", label: `Running (${summary.running})` },
{ value: "attention", label: `Needs Attention (${summary.error})` },
{ value: "listening", label: `Listening (${summary.listening})` },
{ value: "scheduled", label: `Scheduled (${summary.scheduled})` },
{ value: "idle", label: `Idle / Stale (${summary.idle})` },
{ value: "healthy", label: "Healthy" },
];
}
export function AgentFilterMenu({ value, onChange, summary }: Props) {
function handleChange(val: string) {
onChange(val as AgentStatusFilter);
}
const options = buildOptions(summary);
return (
<div className="flex items-center" data-testid="agent-filter-dropdown">
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
filter
</span>
<FunnelIcon className="ml-1 h-4 w-4 sm:hidden" />
<Select
id="agent-status-filter"
label="Filter agents"
hideLabel
value={value}
onValueChange={handleChange}
options={options}
size="small"
className="ml-1 w-fit border-none !bg-transparent text-sm underline underline-offset-4 shadow-none"
wrapperClassName="mb-0"
/>
</div>
);
}

View File

@@ -1,66 +0,0 @@
"use client";
import {
EyeIcon,
ArrowsClockwiseIcon,
MonitorPlayIcon,
} from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
import { useRouter } from "next/navigation";
import type { AgentStatus } from "../../types";
interface Props {
status: AgentStatus;
agentID: string;
executionID?: string;
className?: string;
}
export function ContextualActionButton({
status,
agentID,
executionID,
className,
}: Props) {
const router = useRouter();
const config = ACTION_CONFIG[status];
if (!config) return null;
const Icon = config.icon;
function handleClick(e: React.MouseEvent) {
e.preventDefault();
e.stopPropagation();
const params = new URLSearchParams();
if (executionID) params.set("activeItem", executionID);
const query = params.toString();
router.push(`/library/agents/${agentID}${query ? `?${query}` : ""}`);
}
return (
<button
type="button"
onClick={handleClick}
className={cn(
"inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800",
className,
)}
>
<Icon size={12} className="shrink-0" />
{config.label}
</button>
);
}
const ACTION_CONFIG: Record<
AgentStatus,
{ label: string; icon: typeof EyeIcon }
> = {
error: { label: "View error", icon: EyeIcon },
listening: { label: "Reconnect", icon: ArrowsClockwiseIcon },
running: { label: "Watch live", icon: MonitorPlayIcon },
idle: { label: "View", icon: EyeIcon },
scheduled: { label: "View", icon: EyeIcon },
};

View File

@@ -0,0 +1,46 @@
"use client";
import { ArrowRight, Lightning } from "@phosphor-icons/react";
import NextLink from "next/link";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { useJumpBackIn } from "./useJumpBackIn";
export function JumpBackIn() {
const { execution, isLoading } = useJumpBackIn();
if (isLoading || !execution) {
return null;
}
const href = execution.libraryAgentId
? `/library/agents/${execution.libraryAgentId}?activeTab=runs&activeItem=${execution.id}`
: "#";
return (
<div className="rounded-large bg-gradient-to-r from-zinc-200 via-zinc-200/60 to-indigo-200/50 p-[1px]">
<div className="flex items-center justify-between rounded-large bg-[#F6F7F8] px-5 py-4">
<div className="flex items-center gap-3">
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
<Lightning size={18} weight="fill" className="text-white" />
</div>
<div className="flex flex-col">
<Text variant="small" className="text-zinc-500">
{execution.statusLabel} · {execution.duration}
</Text>
<Text variant="body-medium" className="text-zinc-900">
{execution.agentName}
</Text>
</div>
</div>
<NextLink href={href}>
<Button variant="secondary" size="small" className="gap-1.5">
Jump Back In
<ArrowRight size={16} />
</Button>
</NextLink>
</div>
</div>
);
}

View File

@@ -0,0 +1,82 @@
"use client";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { okData } from "@/app/api/helpers";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
import { useMemo } from "react";
function isActive(status: AgentExecutionStatus) {
return (
status === AgentExecutionStatus.RUNNING ||
status === AgentExecutionStatus.QUEUED ||
status === AgentExecutionStatus.REVIEW
);
}
function formatDuration(startedAt: Date | string | null | undefined): string {
if (!startedAt) return "";
const start = new Date(startedAt);
if (isNaN(start.getTime())) return "";
const ms = Date.now() - start.getTime();
if (ms < 0) return "";
const sec = Math.floor(ms / 1000);
if (sec < 5) return "a few seconds";
if (sec < 60) return `${sec}s`;
const min = Math.floor(sec / 60);
if (min < 60) return `${min}m ${sec % 60}s`;
const hr = Math.floor(min / 60);
return `${hr}h ${min % 60}m`;
}
function getStatusLabel(status: AgentExecutionStatus) {
if (status === AgentExecutionStatus.RUNNING) return "Running";
if (status === AgentExecutionStatus.QUEUED) return "Queued";
if (status === AgentExecutionStatus.REVIEW) return "Awaiting approval";
return "";
}
export function useJumpBackIn() {
const { data: executions, isLoading: executionsLoading } =
useGetV1ListAllExecutions({
query: { select: okData },
});
const { agentInfoMap, isRefreshing: agentsLoading } = useLibraryAgents();
const activeExecution = useMemo(() => {
if (!executions) return null;
const active = executions
.filter((e) => isActive(e.status))
.sort((a, b) => {
const aTime = a.started_at ? new Date(a.started_at).getTime() : 0;
const bTime = b.started_at ? new Date(b.started_at).getTime() : 0;
return bTime - aTime;
});
return active[0] ?? null;
}, [executions]);
const enriched = useMemo(() => {
if (!activeExecution) return null;
const info = agentInfoMap.get(activeExecution.graph_id);
return {
id: activeExecution.id,
agentName: info?.name ?? "Unknown Agent",
libraryAgentId: info?.library_agent_id,
status: activeExecution.status,
statusLabel: getStatusLabel(activeExecution.status),
duration: formatDuration(activeExecution.started_at),
};
}, [activeExecution, agentInfoMap]);
return {
execution: enriched,
isLoading: executionsLoading || agentsLoading,
};
}

View File

@@ -8,7 +8,7 @@ interface Props {
export function LibraryActionHeader({ setSearchTerm }: Props) {
return (
<>
<div className="mb-7 hidden items-center justify-center gap-4 md:flex">
<div className="mb-[32px] hidden items-center justify-center gap-4 md:flex">
<LibrarySearchBar setSearchTerm={setSearchTerm} />
<LibraryImportDialog />
</div>

View File

@@ -1,40 +1,29 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { EyeIcon, ChatCircleDotsIcon } from "@phosphor-icons/react";
import { CaretCircleRightIcon } from "@phosphor-icons/react";
import Image from "next/image";
import NextLink from "next/link";
import { useRouter } from "next/navigation";
import { motion } from "framer-motion";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { cn } from "@/lib/utils";
import Avatar, {
AvatarFallback,
AvatarImage,
} from "@/components/atoms/Avatar/Avatar";
import { Link } from "@/components/atoms/Link/Link";
import { AgentCardMenu } from "./components/AgentCardMenu";
import { FavoriteButton } from "./components/FavoriteButton";
import { useLibraryAgentCard } from "./useLibraryAgentCard";
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
import { StatusBadge } from "../StatusBadge/StatusBadge";
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
import type { AgentStatusInfo } from "../../types";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
interface Props {
agent: LibraryAgent;
statusInfo: AgentStatusInfo;
draggable?: boolean;
}
export function LibraryAgentCard({
agent,
statusInfo,
draggable = true,
}: Props) {
const { id, name, image_url } = agent;
const router = useRouter();
export function LibraryAgentCard({ agent, draggable = true }: Props) {
const { id, name, graph_id, can_access_graph, image_url } = agent;
const { triggerFavoriteAnimation } = useFavoriteAnimation();
function handleDragStart(e: React.DragEvent<HTMLDivElement>) {
@@ -42,14 +31,18 @@ export function LibraryAgentCard({
e.dataTransfer.effectAllowed = "move";
}
const { isFavorite, handleToggleFavorite } = useLibraryAgentCard({
const {
isFromMarketplace,
isFavorite,
profile,
creator_image_url,
handleToggleFavorite,
} = useLibraryAgentCard({
agent,
onFavoriteAdd: triggerFavoriteAnimation,
});
const hasError = statusInfo.status === "error";
const card = (
return (
<div
draggable={draggable}
onDragStart={handleDragStart}
@@ -59,10 +52,7 @@ export function LibraryAgentCard({
layoutId={`agent-card-${id}`}
data-testid="library-agent-card"
data-agent-id={id}
className={cn(
"group relative inline-flex h-auto min-h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border bg-white hover:shadow-md",
hasError ? "border-red-400" : "border-zinc-100",
)}
className="group relative inline-flex h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border border-zinc-100 bg-white hover:shadow-md"
transition={{
type: "spring",
damping: 25,
@@ -71,10 +61,23 @@ export function LibraryAgentCard({
style={{ willChange: "transform" }}
>
<NextLink href={`/library/agents/${id}`} className="flex-shrink-0">
<div className="relative flex items-center gap-3 pl-2 pr-4 pt-3">
<StatusBadge status={statusInfo.status} />
<Text variant="small" className="text-zinc-400">
{statusInfo.totalRuns} tasks
<div className="relative flex items-center gap-2 px-4 pt-3">
<Avatar className="h-4 w-4 rounded-full">
<AvatarImage
src={
isFromMarketplace
? creator_image_url || "/avatar-placeholder.png"
: profile?.avatar_url || "/avatar-placeholder.png"
}
alt={`${name} creator avatar`}
/>
<AvatarFallback size={48}>{name.charAt(0)}</AvatarFallback>
</Avatar>
<Text
variant="small-medium"
className="uppercase tracking-wide text-zinc-400"
>
{isFromMarketplace ? "FROM MARKETPLACE" : "Built by you"}
</Text>
</div>
</NextLink>
@@ -86,7 +89,7 @@ export function LibraryAgentCard({
<AgentCardMenu agent={agent} />
<div className="flex w-full flex-1 flex-col px-4 pb-2">
<NextLink
<Link
href={`/library/agents/${id}`}
className="flex w-full items-start justify-between gap-2 no-underline hover:no-underline focus:ring-0"
>
@@ -123,52 +126,30 @@ export function LibraryAgentCard({
className="flex-shrink-0 rounded-small object-cover"
/>
)}
</NextLink>
</Link>
<div className="mt-4 flex w-full items-center justify-end gap-1 border-t border-zinc-100 pb-0 pt-2">
<button
type="button"
onClick={() => router.push(`/library/agents/${id}`)}
<div className="mt-auto flex w-full justify-start gap-6 border-t border-zinc-100 pb-1 pt-3">
<Link
href={`/library/agents/${id}`}
data-testid="library-agent-card-see-runs-link"
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
className="flex items-center gap-1 text-[13px]"
>
<EyeIcon size={14} className="shrink-0" />
See tasks
</button>
<ContextualActionButton
status={statusInfo.status}
agentID={id}
executionID={statusInfo.activeExecutionID ?? undefined}
/>
<button
type="button"
onClick={() => {
const prompt = encodeURIComponent(
`Tell me about ${name}, its current status, recent runs and how can I get the most out of it`,
);
router.push(`/copilot?autosubmit=true#prompt=${prompt}`);
}}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<ChatCircleDotsIcon size={14} className="shrink-0" />
Chat
</button>
See runs <CaretCircleRightIcon size={20} />
</Link>
{can_access_graph && (
<Link
href={`/build?flowID=${graph_id}`}
data-testid="library-agent-card-open-in-builder-link"
className="flex items-center gap-1 text-[13px]"
isExternal
>
Open in builder <CaretCircleRightIcon size={20} />
</Link>
)}
</div>
</div>
</motion.div>
</div>
);
if (hasError && statusInfo.lastError) {
return (
<Tooltip>
<TooltipTrigger asChild>{card}</TooltipTrigger>
<TooltipContent className="max-w-xs text-red-600">
{statusInfo.lastError}
</TooltipContent>
</Tooltip>
);
}
return card;
}

View File

@@ -169,7 +169,6 @@ export function AgentCardMenu({ agent }: AgentCardMenuProps) {
href={`/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`}
target="_blank"
className="flex items-center gap-2"
data-testid="library-agent-card-open-in-builder-link"
onClick={(e) => e.stopPropagation()}
>
Edit agent

View File

@@ -1,7 +1,6 @@
"use client";
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
import { LibraryAgentCard } from "../LibraryAgentCard/LibraryAgentCard";
@@ -17,11 +16,8 @@ import {
} from "framer-motion";
import { LibraryFolderEditDialog } from "../LibraryFolderEditDialog/LibraryFolderEditDialog";
import { LibraryFolderDeleteDialog } from "../LibraryFolderDeleteDialog/LibraryFolderDeleteDialog";
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
import { LibraryTab } from "../../types";
import { useLibraryAgentList } from "./useLibraryAgentList";
import { AgentBriefingPanel } from "../AgentBriefingPanel/AgentBriefingPanel";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useAgentStatusMap, getAgentStatus } from "../../hooks/useAgentStatus";
// cancels the current spring and starts a new one from current state.
const containerVariants = {
@@ -74,10 +70,6 @@ interface Props {
tabs: LibraryTab[];
activeTab: string;
onTabChange: (tabId: string) => void;
statusFilter?: AgentStatusFilter;
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
fleetSummary?: FleetSummary;
briefingAgents?: LibraryAgent[];
}
export function LibraryAgentList({
@@ -89,12 +81,7 @@ export function LibraryAgentList({
tabs,
activeTab,
onTabChange,
statusFilter = "all",
onStatusFilterChange,
fleetSummary,
briefingAgents,
}: Props) {
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const shouldReduceMotion = useReducedMotion();
const activeContainerVariants = shouldReduceMotion
? reducedContainerVariants
@@ -108,7 +95,7 @@ export function LibraryAgentList({
const {
isFavoritesTab,
agentLoading,
displayedCount,
allAgentsCount,
favoritesCount,
agents,
hasNextPage,
@@ -129,37 +116,18 @@ export function LibraryAgentList({
selectedFolderId,
onFolderSelect,
activeTab,
statusFilter,
});
const agentStatusMap = useAgentStatusMap(agents);
return (
<>
{isAgentBriefingEnabled &&
!selectedFolderId &&
fleetSummary &&
briefingAgents &&
briefingAgents.length > 0 && (
<div className="mb-4">
<AgentBriefingPanel
summary={fleetSummary}
agents={briefingAgents}
/>
</div>
)}
{!selectedFolderId && (
<LibrarySubSection
tabs={tabs}
activeTab={activeTab}
onTabChange={onTabChange}
allCount={displayedCount}
allCount={allAgentsCount}
favoritesCount={favoritesCount}
setLibrarySort={setLibrarySort}
statusFilter={statusFilter}
onStatusFilterChange={onStatusFilterChange}
fleetSummary={fleetSummary}
/>
)}
@@ -251,13 +219,7 @@ export function LibraryAgentList({
0.04,
}}
>
<LibraryAgentCard
agent={agent}
statusInfo={getAgentStatus(
agentStatusMap,
agent.graph_id,
)}
/>
<LibraryAgentCard agent={agent} />
</motion.div>
))}
</motion.div>

View File

@@ -21,12 +21,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
import { useFavoriteAgents } from "../../hooks/useFavoriteAgents";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo, useRef, useState } from "react";
import type { AgentStatusFilter } from "../../types";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
const FILTER_EXHAUST_THRESHOLD = 3;
import { useEffect, useRef, useState } from "react";
interface Props {
searchTerm: string;
@@ -34,7 +29,6 @@ interface Props {
selectedFolderId: string | null;
onFolderSelect: (folderId: string | null) => void;
activeTab: string;
statusFilter?: AgentStatusFilter;
}
export function useLibraryAgentList({
@@ -43,16 +37,12 @@ export function useLibraryAgentList({
selectedFolderId,
onFolderSelect,
activeTab,
statusFilter = "all",
}: Props) {
const isFavoritesTab = activeTab === "favorites";
const { toast } = useToast();
const stableQueryClient = getQueryClient();
const queryClient = useQueryClient();
const prevSortRef = useRef<LibraryAgentSort | null>(null);
const [consecutiveEmptyPages, setConsecutiveEmptyPages] = useState(0);
const prevFilteredLengthRef = useRef(0);
const prevAgentsLengthRef = useRef(0);
const [editingFolder, setEditingFolder] = useState<LibraryFolder | null>(
null,
@@ -209,90 +199,6 @@ export function useLibraryAgentList({
const showFolders = !isFavoritesTab;
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
const { activeGraphIds, errorGraphIds, completedGraphIds } = useMemo(() => {
const active = new Set<string>();
const errors = new Set<string>();
const completed = new Set<string>();
const cutoff = Date.now() - 72 * 60 * 60 * 1000;
for (const exec of executions ?? []) {
if (
exec.status === AgentExecutionStatus.RUNNING ||
exec.status === AgentExecutionStatus.QUEUED ||
exec.status === AgentExecutionStatus.REVIEW
) {
active.add(exec.graph_id);
}
const endedTs = exec.ended_at
? exec.ended_at instanceof Date
? exec.ended_at.getTime()
: new Date(String(exec.ended_at)).getTime()
: 0;
if (
(exec.status === AgentExecutionStatus.FAILED ||
exec.status === AgentExecutionStatus.TERMINATED) &&
endedTs > cutoff
) {
errors.add(exec.graph_id);
}
if (exec.status === AgentExecutionStatus.COMPLETED && endedTs > cutoff) {
completed.add(exec.graph_id);
}
}
return {
activeGraphIds: active,
errorGraphIds: errors,
completedGraphIds: completed,
};
}, [executions]);
const filteredAgents = filterAgentsByStatus(
agents,
statusFilter,
activeGraphIds,
errorGraphIds,
completedGraphIds,
);
useEffect(() => {
if (statusFilter === "all") {
setConsecutiveEmptyPages(0);
prevFilteredLengthRef.current = filteredAgents.length;
prevAgentsLengthRef.current = agents.length;
return;
}
if (agents.length > prevAgentsLengthRef.current) {
const newFilteredCount = filteredAgents.length;
const previousCount = prevFilteredLengthRef.current;
if (newFilteredCount > previousCount) {
setConsecutiveEmptyPages(0);
} else {
setConsecutiveEmptyPages((prev) => prev + 1);
}
}
prevAgentsLengthRef.current = agents.length;
prevFilteredLengthRef.current = filteredAgents.length;
}, [agents.length, filteredAgents.length, statusFilter]);
useEffect(() => {
setConsecutiveEmptyPages(0);
prevFilteredLengthRef.current = 0;
prevAgentsLengthRef.current = 0;
}, [statusFilter]);
const filteredExhausted =
statusFilter !== "all" && consecutiveEmptyPages >= FILTER_EXHAUST_THRESHOLD;
// When a filter is active, show the filtered count instead of the API total.
const displayedCount =
statusFilter === "all" ? allAgentsCount : filteredAgents.length;
function handleFolderDeleted() {
if (selectedFolderId === deletingFolder?.id) {
onFolderSelect(null);
@@ -304,10 +210,9 @@ export function useLibraryAgentList({
agentLoading,
agentCount,
allAgentsCount,
displayedCount,
favoritesCount: favoriteAgentsData.agentCount,
agents: filteredAgents,
hasNextPage: agentsHasNextPage && !filteredExhausted,
agents,
hasNextPage: agentsHasNextPage,
isFetchingNextPage: agentsIsFetchingNextPage,
fetchNextPage: agentsFetchNextPage,
foldersData,
@@ -321,46 +226,3 @@ export function useLibraryAgentList({
handleFolderDeleted,
};
}
function filterAgentsByStatus<
T extends {
graph_id: string;
has_external_trigger: boolean;
recommended_schedule_cron?: string | null;
},
>(
agents: T[],
statusFilter: AgentStatusFilter,
activeGraphIds: Set<string>,
errorGraphIds: Set<string>,
completedGraphIds: Set<string>,
): T[] {
if (statusFilter === "all") return agents;
return agents.filter((agent) => {
const isRunning = activeGraphIds.has(agent.graph_id);
const hasError = errorGraphIds.has(agent.graph_id);
if (statusFilter === "running") return isRunning;
if (statusFilter === "attention") return hasError && !isRunning;
if (statusFilter === "completed")
return completedGraphIds.has(agent.graph_id);
if (statusFilter === "listening")
return !isRunning && !hasError && agent.has_external_trigger;
if (statusFilter === "scheduled")
return (
!isRunning &&
!hasError &&
!agent.has_external_trigger &&
!!agent.recommended_schedule_cron
);
if (statusFilter === "idle")
return (
!isRunning &&
!hasError &&
!agent.has_external_trigger &&
!agent.recommended_schedule_cron
);
if (statusFilter === "healthy") return !hasError;
return true;
});
}

View File

@@ -2,11 +2,14 @@
import { Text } from "@/components/atoms/Text/Text";
import { Button } from "@/components/atoms/Button/Button";
import { FolderIcon, FolderColor } from "./FolderIcon";
import {
FolderIcon,
FolderColor,
folderCardStyles,
resolveColor,
} from "./FolderIcon";
import { useState } from "react";
import { PencilSimpleIcon, TrashIcon } from "@phosphor-icons/react";
import type { AgentStatus } from "../../types";
import { StatusBadge } from "../StatusBadge/StatusBadge";
interface Props {
id: string;
@@ -18,8 +21,6 @@ interface Props {
onDelete?: () => void;
onAgentDrop?: (agentId: string, folderId: string) => void;
onClick?: () => void;
/** Worst status among child agents (optional, for status aggregation). */
worstStatus?: AgentStatus;
}
export function LibraryFolder({
@@ -32,10 +33,11 @@ export function LibraryFolder({
onDelete,
onAgentDrop,
onClick,
worstStatus,
}: Props) {
const [isHovered, setIsHovered] = useState(false);
const [isDragOver, setIsDragOver] = useState(false);
const resolvedColor = resolveColor(color);
const cardStyle = folderCardStyles[resolvedColor];
function handleDragOver(e: React.DragEvent<HTMLDivElement>) {
if (e.dataTransfer.types.includes("application/agent-id")) {
@@ -62,10 +64,10 @@ export function LibraryFolder({
<div
data-testid="library-folder"
data-folder-id={id}
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 shadow-sm backdrop-blur-md transition-all duration-200 hover:shadow-md ${
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 transition-all duration-200 hover:shadow-md ${
isDragOver
? "border-blue-400 bg-blue-50 ring-2 ring-blue-200"
: "border-indigo-200/40 bg-gradient-to-br from-indigo-50/40 via-white/70 to-purple-50/30"
: `${cardStyle.border} ${cardStyle.bg}`
}`}
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
@@ -74,7 +76,7 @@ export function LibraryFolder({
onDrop={handleDrop}
onClick={onClick}
>
<div className="flex w-full items-center justify-between gap-4">
<div className="flex w-full items-start justify-between gap-4">
{/* Left side - Folder name and agent count */}
<div className="flex flex-1 flex-col gap-2">
<Text
@@ -84,22 +86,17 @@ export function LibraryFolder({
>
{name}
</Text>
<div className="flex items-center gap-2">
<Text
variant="small"
className="text-zinc-500"
data-testid="library-folder-agent-count"
>
{agentCount} {agentCount === 1 ? "agent" : "agents"}
</Text>
{worstStatus && worstStatus !== "idle" && (
<StatusBadge status={worstStatus} />
)}
</div>
<Text
variant="small"
className="text-zinc-500"
data-testid="library-folder-agent-count"
>
{agentCount} {agentCount === 1 ? "agent" : "agents"}
</Text>
</div>
{/* Right side - Custom folder icon */}
<div className="relative top-5 flex flex-shrink-0 items-center">
<div className="flex-shrink-0">
<FolderIcon isOpen={isHovered} color={color} icon={icon} />
</div>
</div>
@@ -117,7 +114,7 @@ export function LibraryFolder({
e.stopPropagation();
onEdit?.();
}}
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
>
<PencilSimpleIcon className="h-4 w-4" />
</Button>
@@ -129,7 +126,7 @@ export function LibraryFolder({
e.stopPropagation();
onDelete?.();
}}
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
>
<TrashIcon className="h-4 w-4" />
</Button>

View File

@@ -19,11 +19,11 @@ export function LibrarySortMenu({ setLibrarySort }: Props) {
const { handleSortChange } = useLibrarySortMenu({ setLibrarySort });
return (
<div className="flex items-center" data-testid="sort-by-dropdown">
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
<span className="hidden whitespace-nowrap text-sm sm:inline">
sort by
</span>
<Select onValueChange={handleSortChange}>
<SelectTrigger className="!m-0 ml-1 w-fit space-x-1 border-none !bg-transparent px-[1rem] text-sm underline underline-offset-4 !shadow-none !ring-offset-transparent">
<SelectTrigger className="ml-1 w-fit space-x-1 border-none px-0 text-sm underline underline-offset-4 shadow-none">
<ArrowDownNarrowWideIcon className="h-4 w-4 sm:hidden" />
<SelectValue placeholder="Last Modified" />
</SelectTrigger>

View File

@@ -6,10 +6,9 @@ import {
} from "@/components/molecules/TabsLine/TabsLine";
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
import { LibraryTab } from "../../types";
import LibraryFolderCreationDialog from "../LibraryFolderCreationDialog/LibraryFolderCreationDialog";
import { LibrarySortMenu } from "../LibrarySortMenu/LibrarySortMenu";
import { AgentFilterMenu } from "../AgentFilterMenu/AgentFilterMenu";
interface Props {
tabs: LibraryTab[];
@@ -18,9 +17,6 @@ interface Props {
allCount: number;
favoritesCount: number;
setLibrarySort: (value: LibraryAgentSort) => void;
statusFilter?: AgentStatusFilter;
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
fleetSummary?: FleetSummary;
}
export function LibrarySubSection({
@@ -30,9 +26,6 @@ export function LibrarySubSection({
allCount,
favoritesCount,
setLibrarySort,
statusFilter = "all",
onStatusFilterChange,
fleetSummary,
}: Props) {
const { registerFavoritesTabRef } = useFavoriteAnimation();
const favoritesRef = useRef<HTMLButtonElement>(null);
@@ -75,15 +68,8 @@ export function LibrarySubSection({
))}
</TabsLineList>
</TabsLine>
<div className="relative top-1.5 hidden items-center gap-6 md:flex">
<div className="hidden items-center gap-6 md:flex">
<LibraryFolderCreationDialog />
{fleetSummary && onStatusFilterChange && (
<AgentFilterMenu
value={statusFilter}
onChange={onStatusFilterChange}
summary={fleetSummary}
/>
)}
<LibrarySortMenu setLibrarySort={setLibrarySort} />
</div>
</div>

View File

@@ -1,17 +0,0 @@
.spinner {
aspect-ratio: 1;
border-radius: 50%;
background:
radial-gradient(farthest-side, currentColor 94%, #0000) top/3px 3px
no-repeat,
conic-gradient(#0000 30%, currentColor);
-webkit-mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
animation: spin 1s infinite linear;
}
@keyframes spin {
100% {
transform: rotate(1turn);
}
}

View File

@@ -1,175 +0,0 @@
"use client";
import { OverflowText } from "@/components/atoms/OverflowText/OverflowText";
import { Text } from "@/components/atoms/Text/Text";
import {
WarningCircleIcon,
ClockCountdownIcon,
CheckCircleIcon,
ChatCircleDotsIcon,
EarIcon,
CalendarDotsIcon,
MoonIcon,
EyeIcon,
} from "@phosphor-icons/react";
import NextLink from "next/link";
import { cn } from "@/lib/utils";
import { useRouter } from "next/navigation";
import type { SitrepItemData, SitrepPriority } from "../../types";
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
import styles from "./SitrepItem.module.css";
interface Props {
item: SitrepItemData;
}
const PRIORITY_CONFIG: Record<
SitrepPriority,
{
icon?: typeof WarningCircleIcon;
color: string;
bg: string;
cssSpinner?: boolean;
}
> = {
error: {
icon: WarningCircleIcon,
color: "text-red-500",
bg: "bg-red-50",
},
running: {
color: "text-zinc-800",
bg: "",
cssSpinner: true,
},
stale: {
icon: ClockCountdownIcon,
color: "text-yellow-600",
bg: "bg-yellow-50",
},
success: {
icon: CheckCircleIcon,
color: "text-green-600",
bg: "bg-green-50",
},
listening: {
icon: EarIcon,
color: "text-purple-500",
bg: "bg-purple-50",
},
scheduled: {
icon: CalendarDotsIcon,
color: "text-yellow-600",
bg: "bg-yellow-50",
},
idle: {
icon: MoonIcon,
color: "text-zinc-400",
bg: "bg-zinc-100",
},
};
export function SitrepItem({ item }: Props) {
const config = PRIORITY_CONFIG[item.priority];
const router = useRouter();
function handleAskAutoPilot() {
const prompt = buildAutoPilotPrompt(item);
const encoded = encodeURIComponent(prompt);
router.push(`/copilot?autosubmit=true#prompt=${encoded}`);
}
return (
<div
className={cn(
"flex flex-col gap-2 rounded-medium border border-zinc-200/50 bg-transparent p-2 sm:flex-row sm:items-center sm:gap-3",
)}
>
<div className="flex min-w-0 flex-1 items-center gap-3">
{item.agentImageUrl ? (
<img
src={item.agentImageUrl}
alt={item.agentName}
className="h-6 w-6 flex-shrink-0 rounded-full object-cover"
/>
) : (
<div
className={cn(
"flex h-6 w-6 flex-shrink-0 items-center justify-center rounded-full",
config.bg,
)}
>
{config.cssSpinner ? (
<div
className={cn(
styles.spinner,
"h-[21px] w-[21px] text-zinc-800",
)}
/>
) : (
config.icon && (
<config.icon size={14} className={config.color} weight="fill" />
)
)}
</div>
)}
<div className="min-w-0 flex-1">
<Text variant="body-medium" className="leading-tight text-zinc-900">
{item.agentName}
</Text>
<OverflowText
value={item.message}
variant="small"
className="leading-tight text-zinc-500"
/>
</div>
</div>
<div className="flex flex-shrink-0 flex-wrap items-center justify-center gap-1.5 sm:flex-nowrap sm:justify-end">
{item.priority === "success" ? (
<NextLink
href={`/library/agents/${item.agentID}${item.executionID ? `?activeItem=${item.executionID}` : ""}`}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<EyeIcon size={14} className="shrink-0" />
See task
</NextLink>
) : (
<ContextualActionButton
status={item.status}
agentID={item.agentID}
executionID={item.executionID}
/>
)}
<button
type="button"
onClick={handleAskAutoPilot}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<ChatCircleDotsIcon size={14} className="shrink-0" />
Ask AutoPilot
</button>
</div>
</div>
);
}
function buildAutoPilotPrompt(item: SitrepItemData): string {
switch (item.priority) {
case "error":
return `What happened with ${item.agentName}? It says "${item.message}" — can you check the logs and tell me what to fix?`;
case "running":
return `Give me a status update on the ${item.agentName} run — what has it found so far?`;
case "stale":
return `${item.agentName} hasn't run recently. Should I keep it or update and re-run it?`;
case "success":
return `Show me what ${item.agentName} found in its last run — summarize the results and any key takeaways.`;
case "listening":
return `What is ${item.agentName} listening for? Give me a summary of its trigger configuration.`;
case "scheduled":
return `When is ${item.agentName} scheduled to run next?`;
case "idle":
return `${item.agentName} has been idle. Should I keep it or update and re-run it?`;
}
}

View File

@@ -1,34 +0,0 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { ClockCounterClockwise } from "@phosphor-icons/react";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useSitrepItems } from "./useSitrepItems";
import { SitrepItem } from "./SitrepItem";
interface Props {
agents: LibraryAgent[];
maxItems?: number;
}
export function SitrepList({ agents, maxItems = 10 }: Props) {
const items = useSitrepItems(agents, maxItems);
if (items.length === 0) return null;
return (
<div>
<div className="mb-2 flex items-center gap-1.5">
<ClockCounterClockwise size={16} className="text-zinc-700" />
<Text variant="body-medium" className="text-zinc-700">
Recent tasks
</Text>
</div>
<div className="grid grid-cols-1 gap-1 lg:grid-cols-2">
{items.map((item) => (
<SitrepItem key={item.id} item={item} />
))}
</div>
</div>
);
}

View File

@@ -1,198 +0,0 @@
"use client";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { useMemo } from "react";
import type { SitrepItemData, SitrepPriority } from "../../types";
import {
isActive,
isFailed,
toEndTime,
endedAfter,
runningMessage,
SEVENTY_TWO_HOURS_MS,
} from "../../hooks/executionHelpers";
export function useSitrepItems(
agents: LibraryAgent[],
maxItems: number,
scheduledWithinMs?: number,
): SitrepItemData[] {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
if (agents.length === 0) return [];
const graphIdToAgent = new Map(agents.map((a) => [a.graph_id, a]));
const agentExecutions = groupByAgent(executions ?? [], graphIdToAgent);
const items: SitrepItemData[] = [];
const coveredAgentIds = new Set<string>();
for (const [agent, execs] of agentExecutions) {
const item = buildSitrepFromExecutions(agent, execs);
if (item) {
items.push(item);
coveredAgentIds.add(agent.id);
}
}
for (const agent of agents) {
if (coveredAgentIds.has(agent.id)) continue;
const configItem = buildSitrepFromConfig(agent, scheduledWithinMs);
if (configItem) items.push(configItem);
}
const order: Record<SitrepPriority, number> = {
error: 0,
running: 1,
stale: 2,
success: 3,
listening: 4,
scheduled: 5,
idle: 6,
};
items.sort((a, b) => order[a.priority] - order[b.priority]);
return items.slice(0, maxItems);
}, [agents, executions, maxItems, scheduledWithinMs]);
}
function groupByAgent(
executions: GraphExecutionMeta[],
graphIdToAgent: Map<string, LibraryAgent>,
): Map<LibraryAgent, GraphExecutionMeta[]> {
const map = new Map<LibraryAgent, GraphExecutionMeta[]>();
for (const exec of executions) {
const agent = graphIdToAgent.get(exec.graph_id);
if (!agent) continue;
const list = map.get(agent);
if (list) {
list.push(exec);
} else {
map.set(agent, [exec]);
}
}
return map;
}
function buildSitrepFromExecutions(
agent: LibraryAgent,
executions: GraphExecutionMeta[],
): SitrepItemData | null {
const active = executions.find((e) => isActive(e.status));
if (active) {
return {
id: `${agent.id}-${active.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: active.id,
priority: "running",
message:
active.stats?.activity_status ??
runningMessage(active.status, active.started_at),
status: "running",
};
}
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
const recent = executions
.filter((e) => endedAfter(e, cutoff))
.sort((a, b) => toEndTime(b) - toEndTime(a));
const lastFailed = recent.find((e) => isFailed(e.status));
if (lastFailed) {
const errorMsg =
lastFailed.stats?.error ??
lastFailed.stats?.activity_status ??
"Execution failed";
return {
id: `${agent.id}-${lastFailed.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: lastFailed.id,
priority: "error",
message: typeof errorMsg === "string" ? errorMsg : "Execution failed",
status: "error",
};
}
const lastCompleted = recent.find(
(e) => e.status === AgentExecutionStatus.COMPLETED,
);
if (lastCompleted) {
const summary =
lastCompleted.stats?.activity_status ?? "Completed successfully";
return {
id: `${agent.id}-${lastCompleted.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: lastCompleted.id,
priority: "success",
message: typeof summary === "string" ? summary : "Completed successfully",
status: "idle",
};
}
return null;
}
function buildSitrepFromConfig(
agent: LibraryAgent,
scheduledWithinMs?: number,
): SitrepItemData | null {
if (agent.has_external_trigger) {
return {
id: `${agent.id}-listening`,
agentID: agent.id,
agentName: agent.name,
priority: "listening",
message: "Waiting for trigger event",
status: "listening",
};
}
if (agent.is_scheduled || agent.recommended_schedule_cron) {
if (!isNextRunWithin(agent.next_scheduled_run, scheduledWithinMs)) {
return null;
}
return {
id: `${agent.id}-scheduled`,
agentID: agent.id,
agentName: agent.name,
priority: "scheduled",
message: formatNextRun(agent.next_scheduled_run),
status: "scheduled",
};
}
return null;
}
function isNextRunWithin(
iso: string | undefined | null,
windowMs: number | undefined,
): boolean {
if (windowMs === undefined) return true;
if (!iso) return false;
const diff = new Date(iso).getTime() - Date.now();
return diff <= windowMs;
}
function formatNextRun(iso: string | undefined | null): string {
if (!iso) return "Has a scheduled run";
const diff = new Date(iso).getTime() - Date.now();
const minutes = Math.round(diff / 60_000);
if (minutes <= 0) return "Scheduled to run soon";
if (minutes < 60) return `Scheduled to run in ${minutes}m`;
const hours = Math.round(minutes / 60);
if (hours < 24) return `Scheduled to run in ${hours}h`;
const days = Math.round(hours / 24);
return `Scheduled to run in ${days}d`;
}

View File

@@ -1,84 +0,0 @@
"use client";
import { cn } from "@/lib/utils";
import type { AgentStatus } from "../../types";
const STATUS_CONFIG: Record<
AgentStatus,
{ label: string; bg: string; text: string; pulse: boolean }
> = {
running: {
label: "Running",
bg: "",
text: "text-blue-600",
pulse: true,
},
error: {
label: "Error",
bg: "",
text: "text-red-500",
pulse: false,
},
listening: {
label: "Listening",
bg: "",
text: "text-purple-500",
pulse: true,
},
scheduled: {
label: "Scheduled",
bg: "",
text: "text-yellow-600",
pulse: false,
},
idle: {
label: "Idle",
bg: "",
text: "text-zinc-500",
pulse: false,
},
};
interface Props {
status: AgentStatus;
className?: string;
}
export function StatusBadge({ status, className }: Props) {
const config = STATUS_CONFIG[status];
return (
<span
className={cn(
"inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium",
config.bg,
config.text,
className,
)}
>
<span
className={cn(
"inline-block h-1.5 w-1.5 rounded-full",
config.pulse && "animate-pulse",
statusDotColor(status),
)}
/>
{config.label}
</span>
);
}
function statusDotColor(status: AgentStatus): string {
switch (status) {
case "running":
return "bg-blue-500";
case "error":
return "bg-red-500";
case "listening":
return "bg-purple-500";
case "scheduled":
return "bg-yellow-500";
case "idle":
return "bg-zinc-400";
}
}

View File

@@ -1,59 +0,0 @@
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
export const SEVENTY_TWO_HOURS_MS = 72 * 60 * 60 * 1000;
export function isActive(status: string): boolean {
return (
status === AgentExecutionStatus.RUNNING ||
status === AgentExecutionStatus.QUEUED ||
status === AgentExecutionStatus.REVIEW
);
}
export function isFailed(status: string): boolean {
return (
status === AgentExecutionStatus.FAILED ||
status === AgentExecutionStatus.TERMINATED
);
}
export function toEndTime(exec: GraphExecutionMeta): number {
if (!exec.ended_at) return 0;
return exec.ended_at instanceof Date
? exec.ended_at.getTime()
: new Date(exec.ended_at).getTime();
}
export function endedAfter(exec: GraphExecutionMeta, cutoff: number): boolean {
if (!exec.ended_at) return false;
return toEndTime(exec) > cutoff;
}
export function runningMessage(
status: string,
startedAt?: string | Date | null,
): string {
if (status === AgentExecutionStatus.QUEUED) return "Queued for execution";
if (status === AgentExecutionStatus.REVIEW) return "Awaiting review";
if (!startedAt) return "Currently executing";
const ms =
Date.now() -
(startedAt instanceof Date
? startedAt.getTime()
: new Date(startedAt).getTime());
return `Running for ${formatRelativeDuration(ms)}`;
}
export function formatRelativeDuration(ms: number): string {
const seconds = Math.floor(ms / 1000);
if (seconds < 60) return "a few seconds";
const minutes = Math.floor(seconds / 60);
if (minutes < 60) return `${minutes}m`;
const hours = Math.floor(minutes / 60);
const remainingMin = minutes % 60;
if (hours < 24)
return remainingMin > 0 ? `${hours}h ${remainingMin}m` : `${hours}h`;
const days = Math.floor(hours / 24);
return `${days}d ${hours % 24}h`;
}

View File

@@ -1,213 +0,0 @@
"use client";
import { useMemo } from "react";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import type {
AgentStatus,
AgentHealth,
AgentStatusInfo,
FleetSummary,
} from "../types";
import {
isActive,
isFailed,
toEndTime,
SEVENTY_TWO_HOURS_MS,
} from "./executionHelpers";
function deriveHealth(
status: AgentStatus,
lastRunAt: string | null,
): AgentHealth {
if (status === "error") return "attention";
if (status === "idle" && lastRunAt) {
const daysSince =
(Date.now() - new Date(lastRunAt).getTime()) / (1000 * 60 * 60 * 24);
if (daysSince > 14) return "stale";
}
return "good";
}
function computeAgentStatus(
agent: LibraryAgent,
agentExecutions: GraphExecutionMeta[],
): AgentStatusInfo {
const activeExec = agentExecutions.find((e) => isActive(e.status));
let status: AgentStatus;
let lastError: string | null = null;
let lastRunAt: string | null = null;
const activeExecutionID = activeExec?.id ?? null;
if (activeExec) {
status = "running";
} else {
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
const recentFailed = agentExecutions.find(
(e) =>
isFailed(e.status) &&
e.ended_at &&
new Date(
e.ended_at instanceof Date ? e.ended_at.getTime() : e.ended_at,
).getTime() > cutoff,
);
if (recentFailed) {
status = "error";
lastError =
(recentFailed.stats?.error as string) ??
(recentFailed.stats?.activity_status as string) ??
"Execution failed";
} else if (agent.has_external_trigger) {
status = "listening";
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
status = "scheduled";
} else {
status = "idle";
}
}
const completedExecs = agentExecutions.filter((e) => e.ended_at);
if (completedExecs.length > 0) {
const sorted = completedExecs.sort((a, b) => toEndTime(b) - toEndTime(a));
const endedAt = sorted[0].ended_at;
lastRunAt =
endedAt instanceof Date ? endedAt.toISOString() : String(endedAt);
}
const totalRuns = agent.execution_count ?? agentExecutions.length;
return {
status,
health: deriveHealth(status, lastRunAt),
progress: null,
totalRuns,
lastRunAt,
lastError,
activeExecutionID,
monthlySpend: 0,
nextScheduledRun: null,
triggerType: agent.has_external_trigger ? "webhook" : null,
};
}
export function useAgentStatusMap(
agents: LibraryAgent[],
): Map<string, AgentStatusInfo> {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
const map = new Map<string, AgentStatusInfo>();
const execsByGraph = new Map<string, GraphExecutionMeta[]>();
for (const exec of executions ?? []) {
const list = execsByGraph.get(exec.graph_id);
if (list) {
list.push(exec);
} else {
execsByGraph.set(exec.graph_id, [exec]);
}
}
for (const agent of agents) {
const agentExecs = execsByGraph.get(agent.graph_id) ?? [];
map.set(agent.graph_id, computeAgentStatus(agent, agentExecs));
}
return map;
}, [agents, executions]);
}
const DEFAULT_STATUS: AgentStatusInfo = {
status: "idle",
health: "good",
progress: null,
totalRuns: 0,
lastRunAt: null,
lastError: null,
activeExecutionID: null,
monthlySpend: 0,
nextScheduledRun: null,
triggerType: null,
};
export function getAgentStatus(
statusMap: Map<string, AgentStatusInfo>,
graphID: string,
): AgentStatusInfo {
return statusMap.get(graphID) ?? DEFAULT_STATUS;
}
export function useFleetSummary(agents: LibraryAgent[]): FleetSummary {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
const counts: FleetSummary = {
running: 0,
error: 0,
completed: 0,
listening: 0,
scheduled: 0,
idle: 0,
monthlySpend: 0,
};
const activeGraphIds = new Set<string>();
const errorGraphIds = new Set<string>();
const completedGraphIds = new Set<string>();
if (executions) {
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
for (const exec of executions) {
if (isActive(exec.status)) {
activeGraphIds.add(exec.graph_id);
}
const endedTs = exec.ended_at
? new Date(
exec.ended_at instanceof Date
? exec.ended_at.getTime()
: exec.ended_at,
).getTime()
: 0;
if (isFailed(exec.status) && endedTs > cutoff) {
errorGraphIds.add(exec.graph_id);
}
if (
exec.status === AgentExecutionStatus.COMPLETED &&
endedTs > cutoff
) {
completedGraphIds.add(exec.graph_id);
}
}
}
for (const agent of agents) {
if (activeGraphIds.has(agent.graph_id)) {
counts.running += 1;
} else if (errorGraphIds.has(agent.graph_id)) {
counts.error += 1;
} else if (agent.has_external_trigger) {
counts.listening += 1;
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
counts.scheduled += 1;
} else {
counts.idle += 1;
}
if (completedGraphIds.has(agent.graph_id)) {
counts.completed += 1;
}
}
return counts;
}, [agents, executions]);
}
export { deriveHealth };

View File

@@ -1,116 +0,0 @@
"use client";
import {
getGetV1ListAllExecutionsQueryKey,
useGetV1ListAllExecutions,
} from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { useExecutionEvents } from "@/hooks/useExecutionEvents";
import { useQueryClient } from "@tanstack/react-query";
import { useCallback, useMemo } from "react";
import type { FleetSummary } from "../types";
import { isActive, isFailed, SEVENTY_TWO_HOURS_MS } from "./executionHelpers";
function isRecentFailure(
status: string,
endedAt?: string | Date | null,
): boolean {
if (!isFailed(status)) return false;
if (!endedAt) return false;
const ts =
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
}
function isRecentCompletion(
status: string,
endedAt?: string | Date | null,
): boolean {
if (status !== AgentExecutionStatus.COMPLETED) return false;
if (!endedAt) return false;
const ts =
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
}
export function useLibraryFleetSummary(
agents: LibraryAgent[],
): FleetSummary | undefined {
const queryClient = useQueryClient();
const { data: executions, isSuccess } = useGetV1ListAllExecutions({
query: { select: okData },
});
const graphIDs = useMemo(() => agents.map((a) => a.graph_id), [agents]);
const handleExecutionUpdate = useCallback(() => {
queryClient.invalidateQueries({
queryKey: getGetV1ListAllExecutionsQueryKey(),
});
}, [queryClient]);
useExecutionEvents({
graphIds: graphIDs.length > 0 ? graphIDs : undefined,
enabled: graphIDs.length > 0,
onExecutionUpdate: handleExecutionUpdate,
});
return useMemo(() => {
if (!isSuccess || !executions) return undefined;
const agentsWithActiveExecution = new Set<string>();
const agentsWithRecentFailure = new Set<string>();
const agentsWithRecentCompletion = new Set<string>();
for (const exec of executions) {
if (isActive(exec.status)) {
agentsWithActiveExecution.add(exec.graph_id);
}
if (isRecentFailure(exec.status, exec.ended_at)) {
agentsWithRecentFailure.add(exec.graph_id);
}
if (isRecentCompletion(exec.status, exec.ended_at)) {
agentsWithRecentCompletion.add(exec.graph_id);
}
}
const summary: FleetSummary = {
running: 0,
error: 0,
completed: 0,
listening: 0,
scheduled: 0,
idle: 0,
monthlySpend: 0,
};
for (const agent of agents) {
if (agentsWithActiveExecution.has(agent.graph_id)) {
summary.running += 1;
} else if (agentsWithRecentFailure.has(agent.graph_id)) {
summary.error += 1;
} else if (agent.has_external_trigger) {
summary.listening += 1;
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
summary.scheduled += 1;
} else {
summary.idle += 1;
}
// Parallel counter: mutually exclusive with running/error (which match
// the sitrep priority order used by the "Recently completed" tab list)
// but orthogonal to listening/scheduled/idle.
if (
!agentsWithActiveExecution.has(agent.graph_id) &&
!agentsWithRecentFailure.has(agent.graph_id) &&
agentsWithRecentCompletion.has(agent.graph_id)
) {
summary.completed += 1;
}
}
return summary;
}, [agents, executions, isSuccess]);
}

View File

@@ -2,14 +2,12 @@
import { useEffect, useState, useCallback } from "react";
import { HeartIcon, ListIcon } from "@phosphor-icons/react";
import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn";
import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader";
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
import { useLibraryListPage } from "./components/useLibraryListPage";
import { FavoriteAnimationProvider } from "./context/FavoriteAnimationContext";
import type { LibraryTab, AgentStatusFilter } from "./types";
import { useLibraryFleetSummary } from "./hooks/useLibraryFleetSummary";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
import { LibraryTab } from "./types";
const LIBRARY_TABS: LibraryTab[] = [
{ id: "all", title: "All", icon: ListIcon },
@@ -21,10 +19,6 @@ export default function LibraryPage() {
useLibraryListPage();
const [selectedFolderId, setSelectedFolderId] = useState<string | null>(null);
const [activeTab, setActiveTab] = useState(LIBRARY_TABS[0].id);
const [statusFilter, setStatusFilter] = useState<AgentStatusFilter>("all");
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const { agents } = useLibraryAgents();
const fleetSummary = useLibraryFleetSummary(agents);
useEffect(() => {
document.title = "Library AutoGPT Platform";
@@ -46,6 +40,7 @@ export default function LibraryPage() {
>
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
<LibraryActionHeader setSearchTerm={setSearchTerm} />
<JumpBackIn />
<LibraryAgentList
searchTerm={searchTerm}
librarySort={librarySort}
@@ -55,10 +50,6 @@ export default function LibraryPage() {
tabs={LIBRARY_TABS}
activeTab={activeTab}
onTabChange={handleTabChange}
statusFilter={statusFilter}
onStatusFilterChange={setStatusFilter}
fleetSummary={isAgentBriefingEnabled ? fleetSummary : undefined}
briefingAgents={isAgentBriefingEnabled ? agents : undefined}
/>
</main>
</FavoriteAnimationProvider>

View File

@@ -1,76 +1,7 @@
import type { Icon } from "@phosphor-icons/react";
import { Icon } from "@phosphor-icons/react";
export interface LibraryTab {
id: string;
title: string;
icon: Icon;
}
/** Agent execution status — drives StatusBadge visuals & filtering. */
export type AgentStatus =
| "running"
| "error"
| "listening"
| "scheduled"
| "idle";
/** Derived health bucket for quick triage. */
export type AgentHealth = "good" | "attention" | "stale";
/** Real-time metadata that powers the Intelligence Layer features. */
export interface AgentStatusInfo {
status: AgentStatus;
health: AgentHealth;
/** 0-100 progress for currently running agents. */
progress: number | null;
totalRuns: number;
lastRunAt: string | null;
lastError: string | null;
/** ID of the currently active execution (when status is "running"). */
activeExecutionID: string | null;
monthlySpend: number;
nextScheduledRun: string | null;
triggerType: string | null;
}
/** Fleet-wide aggregate counts used by the Briefing Panel stats grid. */
export interface FleetSummary {
running: number;
error: number;
completed: number;
listening: number;
scheduled: number;
idle: number;
monthlySpend: number;
}
export type SitrepPriority =
| "error"
| "running"
| "stale"
| "success"
| "listening"
| "scheduled"
| "idle";
export interface SitrepItemData {
id: string;
agentID: string;
agentName: string;
agentImageUrl?: string | null;
executionID?: string;
priority: SitrepPriority;
message: string;
status: AgentStatus;
}
/** Filter options for the agent filter dropdown. */
export type AgentStatusFilter =
| "all"
| "running"
| "attention"
| "completed"
| "listening"
| "scheduled"
| "idle"
| "healthy";

Some files were not shown because too many files have changed in this diff Show More