mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
31 Commits
fix/artifa
...
test-scree
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce0bad96a3 | ||
|
|
51286cc0a9 | ||
|
|
b7d5a59f9d | ||
|
|
8c3bdb0315 | ||
|
|
8524091a5f | ||
|
|
22c5d6f86c | ||
|
|
2f6a02a7fa | ||
|
|
ddf8bb7d8b | ||
|
|
d635844412 | ||
|
|
6f46bce634 | ||
|
|
3056be165f | ||
|
|
7e8a68a5c0 | ||
|
|
54f585b8c4 | ||
|
|
9ea7e61652 | ||
|
|
6d1688b0f0 | ||
|
|
ce22b21824 | ||
|
|
c73c5b380c | ||
|
|
8f93942ee5 | ||
|
|
60b1aba221 | ||
|
|
f9a33f2aa6 | ||
|
|
df3c4b381c | ||
|
|
7bf5a8c226 | ||
|
|
a0f149fcb2 | ||
|
|
f35791170a | ||
|
|
3771bfad9c | ||
|
|
2e2f518c58 | ||
|
|
89f2dcc338 | ||
|
|
3d3aef58ac | ||
|
|
e85c042eb6 | ||
|
|
e7b621f0b0 | ||
|
|
e8c356a728 |
@@ -18,6 +18,7 @@ from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.message_dedup import acquire_dedup_lock
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -190,6 +191,8 @@ class SessionDetailResponse(BaseModel):
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
has_more_messages: bool = False
|
||||
oldest_sequence: int | None = None
|
||||
newest_sequence: int | None = None
|
||||
forward_paginated: bool = False
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
@@ -454,39 +457,113 @@ async def update_session_title_route(
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
before_sequence: int | None = Query(default=None, ge=0),
|
||||
limit: int = Query(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=200,
|
||||
description="Maximum number of messages to return.",
|
||||
),
|
||||
before_sequence: int | None = Query(
|
||||
default=None,
|
||||
ge=0,
|
||||
description=(
|
||||
"Backward pagination cursor. Return messages with sequence number "
|
||||
"strictly less than this value. Used by active-session load-more. "
|
||||
"Mutually exclusive with after_sequence."
|
||||
),
|
||||
),
|
||||
after_sequence: int | None = Query(
|
||||
default=None,
|
||||
ge=0,
|
||||
description=(
|
||||
"Forward pagination cursor. Return messages with sequence number "
|
||||
"strictly greater than this value. Used by completed-session load-more. "
|
||||
"Mutually exclusive with before_sequence."
|
||||
),
|
||||
),
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||
When no pagination params are provided, returns the most recent messages.
|
||||
Supports cursor-based pagination via ``limit``, ``before_sequence``, and
|
||||
``after_sequence``. The two cursor parameters are mutually exclusive.
|
||||
|
||||
On the initial load (no cursor provided) of a completed session, messages
|
||||
are returned in forward order starting from sequence 0 so the user always
|
||||
sees their initial prompt. Active sessions use the legacy newest-first
|
||||
order so streaming context is preserved.
|
||||
"""
|
||||
if before_sequence is not None and after_sequence is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="before_sequence and after_sequence are mutually exclusive",
|
||||
)
|
||||
|
||||
is_initial_load = before_sequence is None and after_sequence is None
|
||||
|
||||
# Check active stream before the DB query on initial loads so we can
|
||||
# choose the correct pagination direction (forward for completed sessions,
|
||||
# newest-first for active ones).
|
||||
active_session = None
|
||||
last_message_id = None
|
||||
if is_initial_load:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
# Completed sessions on initial load start from sequence 0 so the user's
|
||||
# initial prompt is always visible. Active sessions keep the legacy
|
||||
# newest-first behavior to preserve streaming context.
|
||||
from_start = is_initial_load and active_session is None
|
||||
forward_paginated = from_start or after_sequence is not None
|
||||
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id, limit, before_sequence, user_id=user_id
|
||||
session_id,
|
||||
limit,
|
||||
before_sequence=before_sequence,
|
||||
after_sequence=after_sequence,
|
||||
from_start=from_start,
|
||||
user_id=user_id,
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
# Close the TOCTOU window: if the session was active at pre-check, re-verify
|
||||
# after the DB fetch. The session may have completed between the two awaits,
|
||||
# which would have caused messages to be fetched newest-first even though the
|
||||
# session is now complete. Re-fetch from seq 0 so the initial prompt is
|
||||
# always visible.
|
||||
if is_initial_load and active_session is not None:
|
||||
post_active, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
if post_active is None:
|
||||
active_session = None
|
||||
last_message_id = None
|
||||
from_start = True
|
||||
forward_paginated = True
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id,
|
||||
limit,
|
||||
before_sequence=None,
|
||||
after_sequence=None,
|
||||
from_start=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [
|
||||
_strip_injected_context(message.model_dump()) for message in page.messages
|
||||
]
|
||||
|
||||
# Only check active stream on initial load (not on "load more" requests)
|
||||
active_stream_info = None
|
||||
if before_sequence is None:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
if active_session and last_message_id is not None:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
if active_session:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
if before_sequence is not None:
|
||||
if not is_initial_load:
|
||||
return SessionDetailResponse(
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
@@ -496,6 +573,8 @@ async def get_session(
|
||||
active_stream=None,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
newest_sequence=page.newest_sequence,
|
||||
forward_paginated=forward_paginated,
|
||||
total_prompt_tokens=0,
|
||||
total_completion_tokens=0,
|
||||
)
|
||||
@@ -512,6 +591,8 @@ async def get_session(
|
||||
active_stream=active_stream_info,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
newest_sequence=page.newest_sequence,
|
||||
forward_paginated=forward_paginated,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=page.session.metadata,
|
||||
@@ -832,6 +913,9 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
# Capture the original message text BEFORE any mutation (attachment enrichment)
|
||||
# so the idempotency hash is stable across retries.
|
||||
original_message = request.message
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
@@ -860,36 +944,58 @@ async def stream_chat_post(
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message returns None when a duplicate is
|
||||
# detected — in that case skip enqueue to avoid processing the message twice.
|
||||
is_duplicate_message = False
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
# ── Idempotency guard ────────────────────────────────────────────────────
|
||||
# Blocks duplicate executor tasks from concurrent/retried POSTs.
|
||||
# See backend/copilot/message_dedup.py for the full lifecycle description.
|
||||
dedup_lock = None
|
||||
if request.is_user_message:
|
||||
dedup_lock = await acquire_dedup_lock(
|
||||
session_id, original_message, sanitized_file_ids
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
is_duplicate_message = (
|
||||
await append_and_save_message(session_id, message)
|
||||
) is None
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
if not is_duplicate_message and request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
if dedup_lock is None and (original_message or sanitized_file_ids):
|
||||
|
||||
async def _empty_sse() -> AsyncGenerator[str, None]:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_empty_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support.
|
||||
# For duplicate messages, skip create_session entirely so the infra-retry
|
||||
# client subscribes to the *existing* turn's Redis stream and receives the
|
||||
# in-progress executor output rather than an empty stream.
|
||||
turn_id = ""
|
||||
if not is_duplicate_message:
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
#
|
||||
# If any of these operations raises, release the dedup lock before propagating
|
||||
# so subsequent retries are not blocked for 30 s.
|
||||
try:
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
@@ -907,6 +1013,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@@ -918,10 +1025,10 @@ async def stream_chat_post(
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
|
||||
)
|
||||
except Exception:
|
||||
if dedup_lock:
|
||||
await dedup_lock.release()
|
||||
raise
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
@@ -945,6 +1052,12 @@ async def stream_chat_post(
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
# True for every exit path except GeneratorExit (client disconnect).
|
||||
# On disconnect the backend turn is still running — releasing the lock
|
||||
# there would reopen the infra-retry duplicate window. The 30 s TTL
|
||||
# is the fallback. All other exits (normal finish, early return, error)
|
||||
# should release so the user can re-send the same message.
|
||||
release_dedup_lock_on_exit = True
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
@@ -956,7 +1069,7 @@ async def stream_chat_post(
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
return
|
||||
return # finally releases dedup_lock
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
@@ -998,7 +1111,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
break # finally releases dedup_lock
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
@@ -1014,6 +1127,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
release_dedup_lock_on_exit = False
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
@@ -1028,7 +1142,10 @@ async def stream_chat_post(
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
# finally releases dedup_lock
|
||||
finally:
|
||||
if dedup_lock and release_dedup_lock_on_exit:
|
||||
await dedup_lock.release()
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
|
||||
@@ -133,12 +133,21 @@ def test_stream_chat_rejects_too_many_file_ids():
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
def _mock_stream_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
redis_set_returns: object = True,
|
||||
):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing RabbitMQ.
|
||||
validation and enrichment logic without needing Redis/RabbitMQ.
|
||||
|
||||
Args:
|
||||
redis_set_returns: Value returned by the mocked Redis ``set`` call.
|
||||
``True`` (default) simulates a fresh key (new message);
|
||||
``None`` simulates a collision (duplicate blocked).
|
||||
|
||||
Returns:
|
||||
A namespace with ``save`` and ``enqueue`` mock objects so
|
||||
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
|
||||
callers can make additional assertions about side-effects.
|
||||
"""
|
||||
import types
|
||||
@@ -149,7 +158,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
)
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
@@ -165,9 +174,15 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
return types.SimpleNamespace(
|
||||
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=redis_set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
|
||||
return ns
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
@@ -196,29 +211,6 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── Duplicate message dedup ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_skips_enqueue_for_duplicate_message(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
"""When append_and_save_message returns None (duplicate detected),
|
||||
enqueue_copilot_turn and stream_registry.create_session must NOT be called
|
||||
to avoid double-processing and to prevent overwriting the active stream's
|
||||
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
|
||||
mocks = _mock_stream_internals(mocker)
|
||||
# Override save to return None — signalling a duplicate
|
||||
mocks.save.return_value = None
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mocks.enqueue.assert_not_called()
|
||||
mocks.registry.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -714,6 +706,237 @@ class TestStripInjectedContext:
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""A second POST with the same message within the 30-s window must return
|
||||
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
|
||||
turn complete without creating a ghost response."""
|
||||
# redis_set_returns=None simulates a collision: the NX key already exists.
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-dup/stream",
|
||||
json={"message": "duplicate message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
|
||||
assert '"finish"' in body
|
||||
assert "[DONE]" in body
|
||||
# The empty SSE response must include the AI SDK protocol header so the
|
||||
# frontend treats it as a valid stream and marks the turn complete.
|
||||
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
|
||||
# The duplicate guard must prevent save/enqueue side effects.
|
||||
ns.save.assert_not_called()
|
||||
ns.enqueue.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_first_post_proceeds_normally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The first POST (Redis NX key set successfully) must proceed through the
|
||||
normal streaming path — no early return."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-new/stream",
|
||||
json={"message": "first message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Redis set must have been called once with the NX flag.
|
||||
ns.redis.set.assert_called_once()
|
||||
call_kwargs = ns.redis.set.call_args
|
||||
assert call_kwargs.kwargs.get("nx") is True
|
||||
|
||||
|
||||
def test_stream_chat_dedup_skipped_for_non_user_messages(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""System/assistant messages (is_user_message=False) bypass the dedup
|
||||
guard — they are injected programmatically and must always be processed."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-sys/stream",
|
||||
json={"message": "system context", "is_user_message": False},
|
||||
)
|
||||
|
||||
# Even though redis_set_returns=None (would block a user message),
|
||||
# the endpoint must proceed because is_user_message=False.
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup hash must be computed from the original request message,
|
||||
not the mutated version that has the [Attached files] block appended.
|
||||
A file_id is sent so the route actually appends the [Attached files] block,
|
||||
exercising the mutation path — the hash must still match the original text."""
|
||||
import hashlib
|
||||
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
# Mock workspace + prisma so the attachment block is actually appended.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
fake_file = type(
|
||||
"F",
|
||||
(),
|
||||
{
|
||||
"id": file_id,
|
||||
"name": "doc.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"sizeBytes": 1024,
|
||||
},
|
||||
)()
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-hash/stream",
|
||||
json={
|
||||
"message": "plain message",
|
||||
"is_user_message": True,
|
||||
"file_ids": [file_id],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_called_once()
|
||||
call_args = ns.redis.set.call_args
|
||||
dedup_key = call_args.args[0]
|
||||
|
||||
# Hash must use the original message + sorted file IDs, not the mutated text.
|
||||
expected_hash = hashlib.sha256(
|
||||
f"sess-hash:plain message:{file_id}".encode()
|
||||
).hexdigest()[:16]
|
||||
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
|
||||
assert dedup_key == expected_key, (
|
||||
f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — "
|
||||
"hash may be using mutated message or wrong inputs"
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_after_stream_finish(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup Redis key must be deleted after the turn completes (when
|
||||
subscriber_queue is None the route yields StreamFinish immediately and
|
||||
should release the key so the user can re-send the same message)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
# Set up all internals manually so we can control subscribe_to_session.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-finish/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
assert '"finish"' in body
|
||||
# The dedup key must be released so intentional re-sends are allowed.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The route must not crash when the dedup Redis delete fails on the
|
||||
subscriber_queue-is-None early-finish path (except Exception: pass)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
# Make the delete raise so the except-pass branch is exercised.
|
||||
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
# Should not raise even though delete fails.
|
||||
response = client.post(
|
||||
"/sessions/sess-finish-err/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert '"finish"' in response.text
|
||||
# delete must have been attempted — the except-pass branch silenced the error.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
@@ -759,7 +982,7 @@ def test_disconnect_stream_returns_404_when_session_missing(
|
||||
mock_disconnect.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── GET /sessions/{session_id} — backward pagination ─────────────────────────
|
||||
# ─── GET /sessions/{session_id} — forward/backward pagination ──────────────────
|
||||
|
||||
|
||||
def _make_paginated_messages(
|
||||
@@ -784,6 +1007,7 @@ def _make_paginated_messages(
|
||||
messages=[ChatMessage(role="user", content="hello", sequence=0)],
|
||||
has_more=has_more,
|
||||
oldest_sequence=0,
|
||||
newest_sequence=0,
|
||||
session=session_info,
|
||||
)
|
||||
mock_paginate = mocker.patch(
|
||||
@@ -794,11 +1018,11 @@ def _make_paginated_messages(
|
||||
return page, mock_paginate
|
||||
|
||||
|
||||
def test_get_session_returns_backward_paginated(
|
||||
def test_get_session_completed_returns_forward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""All sessions use backward (newest-first) pagination."""
|
||||
"""Completed sessions (no active stream) return forward_paginated=True."""
|
||||
_make_paginated_messages(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
@@ -810,6 +1034,92 @@ def test_get_session_returns_backward_paginated(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["oldest_sequence"] == 0
|
||||
assert "forward_paginated" not in data
|
||||
assert "newest_sequence" not in data
|
||||
assert data["forward_paginated"] is True
|
||||
assert data["newest_sequence"] == 0
|
||||
|
||||
|
||||
def test_get_session_active_returns_backward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Active sessions (with running stream) return forward_paginated=False."""
|
||||
from backend.copilot.stream_registry import ActiveSession
|
||||
|
||||
_make_paginated_messages(mocker)
|
||||
active = MagicMock(spec=ActiveSession)
|
||||
active.turn_id = "turn-1"
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(active, "msg-1"),
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["forward_paginated"] is False
|
||||
assert data["active_stream"] is not None
|
||||
assert data["active_stream"]["turn_id"] == "turn-1"
|
||||
|
||||
|
||||
def test_get_session_after_sequence_returns_forward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""after_sequence param returns forward_paginated=True; no stream check needed."""
|
||||
_, mock_paginate = _make_paginated_messages(mocker)
|
||||
|
||||
response = client.get("/sessions/sess-1?after_sequence=10")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["forward_paginated"] is True
|
||||
call_kwargs = mock_paginate.call_args
|
||||
assert call_kwargs.kwargs.get("after_sequence") == 10
|
||||
assert call_kwargs.kwargs.get("before_sequence") is None
|
||||
|
||||
|
||||
def test_get_session_both_cursors_returns_400(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending both before_sequence and after_sequence returns 400."""
|
||||
response = client.get("/sessions/sess-1?before_sequence=5&after_sequence=10")
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_get_session_toctou_refetch_when_session_completes_mid_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Race condition: session was active at pre-check but completes before DB fetch.
|
||||
|
||||
The route should detect the race via a post-fetch re-check, then re-fetch
|
||||
from seq 0 so the initial prompt is always visible.
|
||||
"""
|
||||
from backend.copilot.stream_registry import ActiveSession
|
||||
|
||||
page, mock_paginate = _make_paginated_messages(mocker)
|
||||
active = MagicMock(spec=ActiveSession)
|
||||
active.turn_id = "turn-1"
|
||||
|
||||
# First call: session appears active. Second call: session has completed.
|
||||
mock_get_active = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[(active, "msg-1"), (None, None)],
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Post-race: session is now completed → forward_paginated=True, no stream
|
||||
assert data["forward_paginated"] is True
|
||||
assert data["active_stream"] is None
|
||||
# The DB was queried twice: once newest-first, once from_start=True
|
||||
assert mock_paginate.call_count == 2
|
||||
assert mock_get_active.call_count == 2
|
||||
second_call = mock_paginate.call_args_list[1]
|
||||
assert second_call.kwargs.get("from_start") is True
|
||||
|
||||
@@ -12,7 +12,6 @@ import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.api.features.library.db import _fetch_schedule_info
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -118,5 +117,4 @@ async def add_graph_to_library(
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
|
||||
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
|
||||
@@ -21,17 +21,13 @@ async def test_add_graph_to_library_create_new_agent() -> None:
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
|
||||
mock_from_db.assert_called_once_with(created_agent)
|
||||
# Verify create was called with correct data
|
||||
create_call = mock_prisma.return_value.create.call_args
|
||||
create_data = create_call.kwargs["data"]
|
||||
@@ -58,10 +54,6 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library._fetch_schedule_info",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=prisma.errors.UniqueViolationError(
|
||||
@@ -73,7 +65,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
|
||||
mock_from_db.assert_called_once_with(updated_agent)
|
||||
# Verify update was called with correct where and data
|
||||
update_call = mock_prisma.return_value.update.call_args
|
||||
assert update_call.kwargs["where"] == {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import fastapi
|
||||
@@ -44,65 +43,6 @@ config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
|
||||
"""Fetch execution counts per graph in a single batched query."""
|
||||
if not graph_ids:
|
||||
return {}
|
||||
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": {"in": graph_ids},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
return {
|
||||
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
|
||||
async def _fetch_schedule_info(
|
||||
user_id: str, graph_id: Optional[str] = None
|
||||
) -> dict[str, str]:
|
||||
"""Fetch a map of graph_id → earliest next_run_time ISO string.
|
||||
|
||||
When `graph_id` is provided, the scheduler query is narrowed to that graph,
|
||||
which is cheaper for single-agent lookups (detail page, post-update, etc.).
|
||||
"""
|
||||
try:
|
||||
scheduler_client = get_scheduler_client()
|
||||
schedules = await scheduler_client.get_execution_schedules(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
earliest: dict[str, tuple[datetime, str]] = {}
|
||||
for s in schedules:
|
||||
parsed = _parse_iso_datetime(s.next_run_time)
|
||||
if parsed is None:
|
||||
continue
|
||||
current = earliest.get(s.graph_id)
|
||||
if current is None or parsed < current[0]:
|
||||
earliest[s.graph_id] = (parsed, s.next_run_time)
|
||||
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
|
||||
except Exception:
|
||||
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_iso_datetime(value: str) -> Optional[datetime]:
|
||||
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
logger.warning("Failed to parse schedule next_run_time: %s", value)
|
||||
return None
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
user_id: str,
|
||||
search_term: Optional[str] = None,
|
||||
@@ -197,22 +137,12 @@ async def list_library_agents(
|
||||
|
||||
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts, schedule_info = await asyncio.gather(
|
||||
_fetch_execution_counts(user_id, graph_ids),
|
||||
_fetch_schedule_info(user_id),
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -284,22 +214,12 @@ async def list_favorite_library_agents(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts, schedule_info = await asyncio.gather(
|
||||
_fetch_execution_counts(user_id, graph_ids),
|
||||
_fetch_schedule_info(user_id),
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -365,12 +285,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
where={"userId": store_listing.owningUserId}
|
||||
)
|
||||
|
||||
schedule_info = (
|
||||
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
|
||||
if library_agent.AgentGraph
|
||||
else {}
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(
|
||||
library_agent,
|
||||
sub_graphs=(
|
||||
@@ -380,7 +294,6 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
),
|
||||
store_listing=store_listing,
|
||||
profile=profile,
|
||||
schedule_info=schedule_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -416,10 +329,7 @@ async def get_library_agent_by_store_version_id(
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
if not agent:
|
||||
return None
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
return library_model.LibraryAgent.from_db(agent) if agent else None
|
||||
|
||||
|
||||
async def get_library_agent_by_graph_id(
|
||||
@@ -448,10 +358,7 @@ async def get_library_agent_by_graph_id(
|
||||
assert agent.AgentGraph # make type checker happy
|
||||
# Include sub-graphs so we can make a full credentials input schema
|
||||
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
|
||||
return library_model.LibraryAgent.from_db(
|
||||
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
@@ -593,11 +500,7 @@ async def create_library_agent(
|
||||
for agent, graph in zip(library_agents, graph_entries):
|
||||
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
|
||||
|
||||
schedule_info = await _fetch_schedule_info(user_id)
|
||||
return [
|
||||
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
for agent in library_agents
|
||||
]
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
|
||||
|
||||
|
||||
async def update_agent_version_in_library(
|
||||
@@ -659,8 +562,7 @@ async def update_agent_version_in_library(
|
||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
|
||||
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
|
||||
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
|
||||
|
||||
async def create_graph_in_library(
|
||||
@@ -1565,11 +1467,7 @@ async def bulk_move_agents_to_folder(
|
||||
),
|
||||
)
|
||||
|
||||
schedule_info = await _fetch_schedule_info(user_id)
|
||||
return [
|
||||
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
|
||||
for agent in agents
|
||||
]
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
||||
|
||||
|
||||
def collect_tree_ids(
|
||||
|
||||
@@ -65,11 +65,6 @@ async def test_get_library_agents(mocker):
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
@@ -358,136 +353,3 @@ async def test_create_library_agent_uses_upsert():
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_favorite_library_agents(mocker):
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="fav1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-fav",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=True,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-fav",
|
||||
version=1,
|
||||
name="Favorite Agent",
|
||||
description="My Favorite",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
|
||||
)
|
||||
|
||||
result = await db.list_favorite_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 1
|
||||
assert result.agents[0].id == "fav1"
|
||||
assert result.agents[0].name == "Favorite Agent"
|
||||
assert result.agents[0].graph_id == "agent-fav"
|
||||
assert result.pagination.total_items == 1
|
||||
assert result.pagination.total_pages == 1
|
||||
assert result.pagination.current_page == 1
|
||||
assert result.pagination.page_size == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_skips_failed_agent(mocker):
|
||||
"""Agents that fail parsing should be skipped — covers the except branch."""
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua-bad",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-bad",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-bad",
|
||||
version=1,
|
||||
name="Bad Agent",
|
||||
description="",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
side_effect=Exception("parse error"),
|
||||
)
|
||||
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 0
|
||||
assert result.pagination.total_items == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_empty_graph_ids():
|
||||
result = await db._fetch_execution_counts("user-1", [])
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_uses_group_by(mocker):
|
||||
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
|
||||
mock_prisma.return_value.group_by = mocker.AsyncMock(
|
||||
return_value=[
|
||||
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
|
||||
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
|
||||
]
|
||||
)
|
||||
|
||||
result = await db._fetch_execution_counts(
|
||||
"user-1", ["graph-1", "graph-2", "graph-3"]
|
||||
)
|
||||
|
||||
assert result == {"graph-1": 5, "graph-2": 2}
|
||||
mock_prisma.return_value.group_by.assert_called_once_with(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": "user-1",
|
||||
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
|
||||
@@ -214,14 +214,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
folder_name: str | None = None # Denormalized for display
|
||||
|
||||
recommended_schedule_cron: str | None = None
|
||||
is_scheduled: bool = pydantic.Field(
|
||||
default=False,
|
||||
description="Whether this agent has active execution schedules",
|
||||
)
|
||||
next_scheduled_run: str | None = pydantic.Field(
|
||||
default=None,
|
||||
description="ISO 8601 timestamp of the next scheduled run, if any",
|
||||
)
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
|
||||
@@ -231,8 +223,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
|
||||
store_listing: Optional[prisma.models.StoreListing] = None,
|
||||
profile: Optional[prisma.models.Profile] = None,
|
||||
execution_count_override: Optional[int] = None,
|
||||
schedule_info: Optional[dict[str, str]] = None,
|
||||
) -> "LibraryAgent":
|
||||
"""
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
@@ -268,14 +258,10 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
execution_count = (
|
||||
execution_count_override
|
||||
if execution_count_override is not None
|
||||
else len(executions)
|
||||
)
|
||||
execution_count = len(executions)
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
if executions and execution_count > 0:
|
||||
if execution_count > 0:
|
||||
success_count = sum(
|
||||
1
|
||||
for e in executions
|
||||
@@ -368,10 +354,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
folder_id=agent.folderId,
|
||||
folder_name=agent.Folder.name if agent.Folder else None,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
|
||||
next_scheduled_run=(
|
||||
schedule_info.get(agent.agentGraphId) if schedule_info else None
|
||||
),
|
||||
settings=_parse_settings(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
)
|
||||
|
||||
@@ -1,66 +1,11 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
def _make_library_agent(
|
||||
*,
|
||||
graph_id: str = "g1",
|
||||
executions: list | None = None,
|
||||
) -> prisma.models.LibraryAgent:
|
||||
return prisma.models.LibraryAgent(
|
||||
id="la1",
|
||||
userId="u1",
|
||||
agentGraphId=graph_id,
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id=graph_id,
|
||||
version=1,
|
||||
name="Agent",
|
||||
description="Desc",
|
||||
userId="u1",
|
||||
isActive=True,
|
||||
createdAt=datetime.datetime.now(),
|
||||
Executions=executions,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_from_db_execution_count_override_covers_success_rate():
|
||||
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
exec1 = prisma.models.AgentGraphExecution(
|
||||
id="exec-1",
|
||||
agentGraphId="g1",
|
||||
agentGraphVersion=1,
|
||||
userId="u1",
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
agent = _make_library_agent(executions=[exec1])
|
||||
|
||||
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
|
||||
|
||||
assert result.execution_count == 1
|
||||
assert result.success_rate is not None
|
||||
assert result.success_rate == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_preset_from_db(test_user_id: str):
|
||||
# Create mock DB agent
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,8 +5,7 @@ import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Literal, Sequence, cast, get_args
|
||||
from urllib.parse import urlparse
|
||||
from typing import Annotated, Any, Literal, Sequence, get_args
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -55,11 +54,8 @@ from backend.data.credit import (
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_auto_top_up,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
modify_stripe_subscription_for_tier,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
@@ -703,72 +699,9 @@ class SubscriptionCheckoutResponse(BaseModel):
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
|
||||
|
||||
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
"""Return True if `url` matches the configured frontend origin.
|
||||
|
||||
Prevents open-redirect: attackers must not be able to supply arbitrary
|
||||
success_url/cancel_url that Stripe will redirect users to after checkout.
|
||||
|
||||
Pre-parse rejection rules (applied before urlparse):
|
||||
- Backslashes (``\\``) are normalised differently across parsers/browsers.
|
||||
- Control characters (U+0000–U+001F) are not valid in URLs and may confuse
|
||||
some URL-parsing implementations.
|
||||
"""
|
||||
# Reject characters that can confuse URL parsers before any parsing.
|
||||
if "\\" in url:
|
||||
return False
|
||||
if any(ord(c) < 0x20 for c in url):
|
||||
return False
|
||||
|
||||
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
if not allowed:
|
||||
# No configured origin — refuse to validate rather than allow arbitrary URLs.
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
allowed_parsed = urlparse(allowed)
|
||||
except ValueError:
|
||||
return False
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
|
||||
# can trick browsers into connecting to a different host than displayed.
|
||||
# ``@`` in query/fragment is harmless and must be allowed.
|
||||
if "@" in parsed.netloc:
|
||||
return False
|
||||
return (
|
||||
parsed.scheme == allowed_parsed.scheme
|
||||
and parsed.netloc == allowed_parsed.netloc
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
|
||||
async def _get_stripe_price_amount(price_id: str) -> int | None:
|
||||
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
|
||||
|
||||
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
|
||||
of caching the ``None`` sentinel so the next request retries Stripe instead
|
||||
of being served a stale "no price" for the rest of the TTL window. Callers
|
||||
should treat ``None`` as an unknown price and fall back to 0.
|
||||
|
||||
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
|
||||
every GET /credits/subscription page load and reduces quota consumption.
|
||||
"""
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
return price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"Failed to retrieve Stripe price %s — returning None (not cached)",
|
||||
price_id,
|
||||
)
|
||||
return None
|
||||
tier: str
|
||||
monthly_cost: int
|
||||
tier_costs: dict[str, int]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -789,26 +722,21 @@ async def get_subscription_status(
|
||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
||||
)
|
||||
|
||||
tier_costs: dict[str, int] = {
|
||||
SubscriptionTier.FREE.value: 0,
|
||||
SubscriptionTier.ENTERPRISE.value: 0,
|
||||
}
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
for t, cost in zip(paid_tiers, costs):
|
||||
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
|
||||
for t, price_id in zip(paid_tiers, price_ids):
|
||||
cost = 0
|
||||
if price_id:
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
cost = price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
pass
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
|
||||
return SubscriptionStatusResponse(
|
||||
tier=tier.value,
|
||||
monthly_cost=current_monthly_cost,
|
||||
monthly_cost=tier_costs.get(tier.value, 0),
|
||||
tier_costs=tier_costs,
|
||||
proration_credit_cents=proration_credit,
|
||||
)
|
||||
|
||||
|
||||
@@ -838,125 +766,24 @@ async def update_subscription_tier(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
|
||||
# keeps their tier for the time they already paid for. The DB tier is NOT
|
||||
# updated here when a subscription exists — the customer.subscription.deleted
|
||||
# webhook fires at period end and downgrades to FREE then.
|
||||
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
|
||||
# tier), cancel_stripe_subscription returns False and we update the DB tier
|
||||
# immediately since no webhook will ever fire.
|
||||
# When payment is disabled entirely, update the DB tier directly.
|
||||
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
|
||||
if tier == SubscriptionTier.FREE:
|
||||
if payment_enabled:
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
# Log full Stripe error server-side but return a generic message
|
||||
# to the client — raw Stripe errors can leak customer/sub IDs and
|
||||
# infrastructure config details.
|
||||
logger.exception(
|
||||
"Stripe error cancelling subscription for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
if not had_subscription:
|
||||
# No active Stripe subscription found — the user was on an
|
||||
# admin-granted tier. Update DB immediately since the
|
||||
# subscription.deleted webhook will never fire.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
await cancel_stripe_subscription(user_id)
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Paid tier changes require payment to be enabled — block self-service upgrades
|
||||
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
|
||||
# Beta users (payment not enabled) → update tier directly without Stripe.
|
||||
if not payment_enabled:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier}",
|
||||
)
|
||||
|
||||
# No-op short-circuit: if the user is already on the requested paid tier,
|
||||
# do NOT create a new Checkout Session. Without this guard, a duplicate
|
||||
# request (double-click, retried POST, stale page) creates a second
|
||||
# subscription for the same price; the user would be charged for both
|
||||
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
|
||||
# which only fires after the second charge has cleared.
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Paid→paid tier change: if the user already has a Stripe subscription,
|
||||
# modify it in-place with proration instead of creating a new Checkout
|
||||
# Session. This preserves remaining paid time and avoids double-charging.
|
||||
# The customer.subscription.updated webhook fires and updates the DB tier.
|
||||
current_tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
# modify_stripe_subscription_for_tier returns False when no active
|
||||
# Stripe subscription exists — i.e. the user has an admin-granted
|
||||
# paid tier with no Stripe record. In that case, update the DB
|
||||
# tier directly (same as the FREE-downgrade path for admin-granted
|
||||
# users) rather than sending them through a new Checkout Session.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# Paid upgrade from FREE → create Stripe Checkout Session.
|
||||
# Paid upgrade → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url are required for paid tier upgrades",
|
||||
)
|
||||
# Open-redirect protection: both URLs must point to the configured frontend
|
||||
# origin, otherwise an attacker could use our Stripe integration as a
|
||||
# redirector to arbitrary phishing sites.
|
||||
#
|
||||
# Fail early with a clear 503 if the server is misconfigured (neither
|
||||
# frontend_base_url nor platform_base_url set), so operators get an
|
||||
# actionable error instead of the misleading "must match the platform
|
||||
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
|
||||
# produce when `allowed` is empty.
|
||||
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
|
||||
logger.error(
|
||||
"update_subscription_tier: neither frontend_base_url nor "
|
||||
"platform_base_url is configured; cannot validate checkout redirect URLs"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Payment redirect URLs cannot be validated: "
|
||||
"frontend_base_url or platform_base_url must be set on the server."
|
||||
),
|
||||
)
|
||||
if not _validate_checkout_redirect_url(
|
||||
request.success_url
|
||||
) or not _validate_checkout_redirect_url(request.cancel_url):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url must match the platform frontend origin",
|
||||
)
|
||||
try:
|
||||
url = await create_subscription_checkout(
|
||||
user_id=user_id,
|
||||
@@ -964,19 +791,8 @@ async def update_subscription_tier(
|
||||
success_url=request.success_url,
|
||||
cancel_url=request.cancel_url,
|
||||
)
|
||||
except ValueError as e:
|
||||
except (ValueError, stripe.StripeError) as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error creating checkout session for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to start checkout right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
return SubscriptionCheckoutResponse(url=url)
|
||||
|
||||
@@ -985,78 +801,44 @@ async def update_subscription_tier(
|
||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||
)
|
||||
async def stripe_webhook(request: Request):
|
||||
webhook_secret = settings.secrets.stripe_webhook_secret
|
||||
if not webhook_secret:
|
||||
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
|
||||
# signature over the same empty key). Reject all webhook calls when unconfigured.
|
||||
logger.error(
|
||||
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
|
||||
"rejecting request to prevent signature bypass"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Webhook not configured")
|
||||
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
# Get the signature header
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
|
||||
except ValueError:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.SignatureVerificationError:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
# Defensive payload extraction. A malformed payload (missing/non-dict
|
||||
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
|
||||
# AFTER signature verification — which Stripe interprets as a delivery
|
||||
# failure and retries forever, while spamming Sentry with no useful info.
|
||||
# Acknowledge with 200 and a warning so Stripe stops retrying.
|
||||
event_type = event.get("type", "")
|
||||
event_data = event.get("data") or {}
|
||||
data_object = event_data.get("object") if isinstance(event_data, dict) else None
|
||||
if not isinstance(data_object, dict):
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing or non-dict data.object; ignoring",
|
||||
event_type,
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.secrets.stripe_webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
|
||||
)
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
|
||||
)
|
||||
return Response(status_code=200)
|
||||
|
||||
if event_type in (
|
||||
"checkout.session.completed",
|
||||
"checkout.session.async_payment_succeeded",
|
||||
if (
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
):
|
||||
session_id = data_object.get("id")
|
||||
if not session_id:
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing data.object.id; ignoring", event_type
|
||||
)
|
||||
return Response(status_code=200)
|
||||
await UserCredit().fulfill_checkout(session_id=session_id)
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
|
||||
if event_type in (
|
||||
if event["type"] in (
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
):
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
await sync_subscription_from_stripe(event["data"]["object"])
|
||||
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
|
||||
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
|
||||
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
|
||||
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
|
||||
# to satisfy the type checker without changing runtime behaviour.
|
||||
if event_type == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
|
||||
|
||||
if event_type == "refund.created" or event_type == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(
|
||||
cast("stripe.Refund | stripe.Dispute", data_object)
|
||||
)
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -106,6 +106,7 @@ class LlmModelMeta(EnumMeta):
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "LlmModel | None":
|
||||
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
|
||||
@@ -202,8 +203,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
GROK_4_20 = "x-ai/grok-4.20"
|
||||
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
@@ -628,18 +627,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_4_20: ModelMetadata(
|
||||
"open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3
|
||||
),
|
||||
LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata(
|
||||
"open_router",
|
||||
2000000,
|
||||
100000,
|
||||
"Grok 4.20 Multi-Agent",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
3,
|
||||
),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||
),
|
||||
@@ -1000,6 +987,7 @@ async def llm_call(
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
|
||||
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
|
||||
# Cache tool definitions alongside the system prompt.
|
||||
# Placing cache_control on the last tool caches all tool schemas as a
|
||||
|
||||
@@ -41,6 +41,7 @@ class PaginatedMessages(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
has_more: bool
|
||||
oldest_sequence: int | None
|
||||
newest_sequence: int | None
|
||||
session: ChatSessionInfo
|
||||
|
||||
|
||||
@@ -65,30 +66,48 @@ async def get_chat_messages_paginated(
|
||||
session_id: str,
|
||||
limit: int = 50,
|
||||
before_sequence: int | None = None,
|
||||
after_sequence: int | None = None,
|
||||
from_start: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> PaginatedMessages | None:
|
||||
"""Get paginated messages for a session, newest first.
|
||||
"""Get paginated messages for a session.
|
||||
|
||||
Verifies session existence (and ownership when ``user_id`` is provided)
|
||||
in parallel with the message query. Returns ``None`` when the session
|
||||
is not found or does not belong to the user.
|
||||
Three modes:
|
||||
|
||||
After fetching, a visibility guarantee ensures the page contains at least
|
||||
one user or assistant message. If the entire page is tool messages (which
|
||||
are hidden in the UI), it expands backward until a visible message is found
|
||||
so the chat never appears blank.
|
||||
- ``before_sequence`` set: backward pagination (DESC), returns messages
|
||||
with sequence < ``before_sequence``. Used for active sessions or manual
|
||||
backward navigation.
|
||||
- ``from_start=True`` or ``after_sequence`` set: forward pagination (ASC).
|
||||
Returns messages from sequence 0 (``from_start``) or after
|
||||
``after_sequence``. Used on initial load of completed sessions and for
|
||||
loading subsequent forward pages.
|
||||
- Both cursors ``None`` and ``from_start=False``: newest-first (DESC
|
||||
without filter). Used for active sessions on initial load.
|
||||
|
||||
Verifies session existence (and ownership when ``user_id`` is provided).
|
||||
Returns ``None`` when the session is not found or does not belong to the
|
||||
user.
|
||||
"""
|
||||
# Build session-existence / ownership check
|
||||
session_where: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
session_where["userId"] = user_id
|
||||
|
||||
# Build message include — fetch paginated messages in the same query
|
||||
forward = from_start or after_sequence is not None
|
||||
|
||||
# Build message include — fetch paginated messages in the same query.
|
||||
# Note: when both from_start=True and after_sequence is not None, the
|
||||
# after_sequence filter takes precedence (the elif branch below is skipped).
|
||||
# This combination is not reachable via the HTTP route (mutual exclusion is
|
||||
# enforced there), so we rely on the documented priority here without an
|
||||
# additional assertion.
|
||||
msg_include: FindManyChatMessageArgsFromChatSession = {
|
||||
"order_by": {"sequence": "desc"},
|
||||
"order_by": {"sequence": "asc" if forward else "desc"},
|
||||
"take": limit + 1,
|
||||
}
|
||||
if before_sequence is not None:
|
||||
if after_sequence is not None:
|
||||
msg_include["where"] = {"sequence": {"gt": after_sequence}}
|
||||
elif before_sequence is not None:
|
||||
msg_include["where"] = {"sequence": {"lt": before_sequence}}
|
||||
|
||||
# Single query: session existence/ownership + paginated messages
|
||||
@@ -106,129 +125,95 @@ async def get_chat_messages_paginated(
|
||||
has_more = len(results) > limit
|
||||
results = results[:limit]
|
||||
|
||||
# Reverse to ascending order
|
||||
results.reverse()
|
||||
if not forward:
|
||||
# Backward mode: DB returned DESC; reverse to ascending order.
|
||||
results.reverse()
|
||||
|
||||
# Tool-call boundary fix: if the oldest message is a tool message,
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
if results and results[0].role == "tool":
|
||||
results, has_more = await _expand_tool_boundary(
|
||||
session_id, results, has_more, user_id
|
||||
)
|
||||
# Tool-call boundary fix: if the oldest message is a tool message,
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
if results and results[0].role == "tool":
|
||||
boundary_where: ChatMessageWhereInput = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
# Only mark has_more if the expanded boundary isn't the
|
||||
# very start of the conversation (sequence 0).
|
||||
if boundary_msgs[0].sequence > 0:
|
||||
has_more = True
|
||||
else:
|
||||
# Forward mode: DB returned ASC.
|
||||
# Tool-call tail boundary fix: if the last message in this page is a
|
||||
# tool message, the NEXT forward page would start after it and begin
|
||||
# mid-tool-group — the owning assistant message is on this page but
|
||||
# the following tool results are on the next page.
|
||||
# Trim the current page so it ends on the owning assistant message,
|
||||
# which keeps tool groups intact across page boundaries.
|
||||
if results and results[-1].role == "tool":
|
||||
# Walk backward through results to find the last non-tool message.
|
||||
trim_idx = len(results) - 1
|
||||
while trim_idx >= 0 and results[trim_idx].role == "tool":
|
||||
trim_idx -= 1
|
||||
|
||||
# Visibility guarantee: if the entire page has no user/assistant messages
|
||||
# (all tool messages), the chat would appear blank. Expand backward
|
||||
# until we find at least one visible message.
|
||||
if results and not any(m.role in ("user", "assistant") for m in results):
|
||||
results, has_more = await _expand_for_visibility(
|
||||
session_id, results, has_more, user_id
|
||||
)
|
||||
if trim_idx >= 0:
|
||||
# Trim results so the page ends at the owning assistant.
|
||||
# Mark has_more=True so the client knows to fetch the rest.
|
||||
results = results[: trim_idx + 1]
|
||||
has_more = True
|
||||
else:
|
||||
# Entire page is tool messages with no visible owner — log and
|
||||
# keep as-is so the caller is not stuck with an empty page.
|
||||
logger.warning(
|
||||
"Forward tail boundary: entire page is tool messages "
|
||||
"for session=%s, no owning assistant found (%d msgs)",
|
||||
session_id,
|
||||
len(results),
|
||||
)
|
||||
|
||||
messages = [ChatMessage.from_db(m) for m in results]
|
||||
oldest_sequence = messages[0].sequence if messages else None
|
||||
# newest_sequence is only meaningful in forward mode; in backward mode it
|
||||
# points to the last message of the page (not the session's newest message)
|
||||
# which is not a valid forward cursor. Return None in backward mode so
|
||||
# clients don't accidentally use it as one.
|
||||
newest_sequence = messages[-1].sequence if (messages and forward) else None
|
||||
|
||||
return PaginatedMessages(
|
||||
messages=messages,
|
||||
has_more=has_more,
|
||||
oldest_sequence=oldest_sequence,
|
||||
newest_sequence=newest_sequence,
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
|
||||
async def _expand_tool_boundary(
|
||||
session_id: str,
|
||||
results: list[Any],
|
||||
has_more: bool,
|
||||
user_id: str | None,
|
||||
) -> tuple[list[Any], bool]:
|
||||
"""Expand backward from the oldest message to include the owning assistant
|
||||
message when the page starts mid-tool-group."""
|
||||
boundary_where: ChatMessageWhereInput = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
has_more = boundary_msgs[0].sequence > 0
|
||||
return results, has_more
|
||||
|
||||
|
||||
_VISIBILITY_EXPAND_LIMIT = 200
|
||||
|
||||
|
||||
async def _expand_for_visibility(
|
||||
session_id: str,
|
||||
results: list[Any],
|
||||
has_more: bool,
|
||||
user_id: str | None,
|
||||
) -> tuple[list[Any], bool]:
|
||||
"""Expand backward until the page contains at least one user or assistant
|
||||
message, so the chat is never blank."""
|
||||
expand_where: ChatMessageWhereInput = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
expand_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=expand_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_VISIBILITY_EXPAND_LIMIT,
|
||||
)
|
||||
if not extra:
|
||||
return results, has_more
|
||||
|
||||
# Collect messages until we find a visible one (user/assistant)
|
||||
prepend = []
|
||||
found_visible = False
|
||||
for msg in extra:
|
||||
prepend.append(msg)
|
||||
if msg.role in ("user", "assistant"):
|
||||
found_visible = True
|
||||
break
|
||||
|
||||
if not found_visible:
|
||||
logger.warning(
|
||||
"Visibility expansion did not find any user/assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
|
||||
prepend.reverse()
|
||||
if prepend:
|
||||
results = prepend + results
|
||||
has_more = prepend[0].sequence > 0
|
||||
return results, has_more
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
|
||||
@@ -175,136 +175,181 @@ async def test_no_where_on_messages_without_before_sequence(
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
# ---------- Visibility guarantee ----------
|
||||
# ---------- Forward pagination (from_start / after_sequence) ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visibility_expands_when_all_tool_messages(
|
||||
async def test_from_start_uses_asc_order_no_where(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When the entire page is tool messages, expand backward to find
|
||||
at least one visible (user/assistant) message so the chat isn't blank."""
|
||||
find_first, find_many = mock_db
|
||||
# Newest 3 messages are all tool messages (DESC → reversed to ASC)
|
||||
"""from_start=True queries messages in ASC order with no where filter."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(12, role="tool"),
|
||||
_make_msg(11, role="tool"),
|
||||
_make_msg(10, role="tool"),
|
||||
],
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
# Boundary expansion finds the owning assistant first (boundary fix),
|
||||
# then visibility expansion finds a user message further back
|
||||
find_many.side_effect = [
|
||||
# First call: boundary fix (oldest msg is tool → find owner)
|
||||
[_make_msg(9, role="tool"), _make_msg(8, role="tool")],
|
||||
# Second call: visibility expansion (still all tool → find visible)
|
||||
[_make_msg(7, role="tool"), _make_msg(6, role="assistant")],
|
||||
]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
# Should include the expanded messages + original tool messages
|
||||
roles = [m.role for m in page.messages]
|
||||
assert "assistant" in roles or "user" in roles
|
||||
assert page.has_more is True
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert include["Messages"]["order_by"] == {"sequence": "asc"}
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_visibility_expansion_when_visible_messages_present(
|
||||
async def test_from_start_returns_messages_ascending(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""No visibility expansion needed when page already has visible messages."""
|
||||
find_first, find_many = mock_db
|
||||
# Page has an assistant message among tool messages
|
||||
"""from_start=True returns messages in ascending sequence order."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(5, role="tool"),
|
||||
_make_msg(4, role="assistant"),
|
||||
_make_msg(3, role="user"),
|
||||
],
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
# Boundary expansion might fire (oldest is tool), but NOT visibility
|
||||
assert [m.sequence for m in page.messages][0] <= 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visibility_no_expansion_when_no_earlier_messages(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When the page is all tool messages but there are no earlier messages
|
||||
in the DB, visibility expansion returns early without changes."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(1, role="tool"), _make_msg(0, role="tool")],
|
||||
)
|
||||
# Boundary expansion: no earlier messages
|
||||
# Visibility expansion: no earlier messages
|
||||
find_many.side_effect = [[], []]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
|
||||
|
||||
assert page is not None
|
||||
assert all(m.role == "tool" for m in page.messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visibility_expansion_reaches_seq_zero(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""When visibility expansion finds a visible message at sequence 0,
|
||||
has_more should be False."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
|
||||
)
|
||||
find_many.side_effect = [
|
||||
# Boundary expansion
|
||||
[_make_msg(3, role="tool")],
|
||||
# Visibility expansion — finds user at seq 0
|
||||
[
|
||||
_make_msg(2, role="tool"),
|
||||
_make_msg(1, role="tool"),
|
||||
_make_msg(0, role="user"),
|
||||
],
|
||||
]
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
|
||||
|
||||
assert page is not None
|
||||
assert page.messages[0].role == "user"
|
||||
assert page.messages[0].sequence == 0
|
||||
assert [m.sequence for m in page.messages] == [0, 1, 2]
|
||||
assert page.oldest_sequence == 0
|
||||
assert page.newest_sequence == 2
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visibility_expansion_with_user_id(
|
||||
async def test_from_start_has_more_when_results_exceed_limit(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Visibility expansion passes user_id filter to the boundary query."""
|
||||
"""from_start=True sets has_more when DB returns more than limit items."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is True
|
||||
assert [m.sequence for m in page.messages] == [0, 1]
|
||||
assert page.newest_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_sequence_uses_gt_filter_asc_order(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""after_sequence adds a sequence > N where clause and uses ASC order."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(11), _make_msg(12)],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert include["Messages"]["order_by"] == {"sequence": "asc"}
|
||||
assert include["Messages"]["where"] == {"sequence": {"gt": 10}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_sequence_returns_messages_in_order(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""after_sequence returns only messages with sequence > cursor, ascending."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(11), _make_msg(12), _make_msg(13)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [11, 12, 13]
|
||||
assert page.oldest_sequence == 11
|
||||
assert page.newest_sequence == 13
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_newest_sequence_none_for_backward_mode(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""newest_sequence is None in backward mode — it is not a valid forward cursor."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(5), _make_msg(4), _make_msg(3)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is not None
|
||||
assert page.newest_sequence is None
|
||||
assert page.oldest_sequence == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_mode_no_boundary_expansion(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pagination never triggers backward boundary expansion."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(10, role="tool")],
|
||||
messages=[_make_msg(0, role="tool"), _make_msg(1, role="tool")],
|
||||
)
|
||||
find_many.side_effect = [
|
||||
# Boundary expansion
|
||||
[_make_msg(9, role="tool")],
|
||||
# Visibility expansion
|
||||
[_make_msg(8, role="assistant")],
|
||||
]
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=1, user_id="user-abc")
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
# Both find_many calls should include the user_id session filter
|
||||
for call in find_many.call_args_list:
|
||||
where = call.kwargs.get("where") or call[1].get("where")
|
||||
assert "Session" in where
|
||||
assert where["Session"] == {"is": {"userId": "user-abc"}}
|
||||
assert find_many.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_tail_boundary_trims_trailing_tool_messages(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pages that end with tool messages are trimmed to the owning
|
||||
assistant so the next after_sequence page doesn't start mid-tool-group."""
|
||||
find_first, _ = mock_db
|
||||
# DB returns 4 messages ASC: assistant at 0, tool at 1, tool at 2, tool at 3
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(0, role="assistant"),
|
||||
_make_msg(1, role="tool"),
|
||||
_make_msg(2, role="tool"),
|
||||
_make_msg(3, role="tool"),
|
||||
],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
# Page should be trimmed to end at the assistant message
|
||||
assert [m.sequence for m in page.messages] == [0]
|
||||
assert page.newest_sequence == 0
|
||||
# has_more must be True so the client fetches the tool messages on next page
|
||||
assert page.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_tail_boundary_no_trim_when_last_not_tool(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pages that end with a non-tool message are not trimmed."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(0, role="user"),
|
||||
_make_msg(1, role="assistant"),
|
||||
_make_msg(2, role="tool"),
|
||||
_make_msg(3, role="assistant"),
|
||||
],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [0, 1, 2, 3]
|
||||
assert page.newest_sequence == 3
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -461,8 +506,7 @@ async def test_boundary_expansion_warns_when_no_owner_found(
|
||||
|
||||
with patch("backend.copilot.db.logger") as mock_logger:
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
|
||||
# Two warnings: boundary expansion + visibility expansion (all tool msgs)
|
||||
assert mock_logger.warning.call_count == 2
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
assert page is not None
|
||||
assert page.messages[0].role == "tool"
|
||||
|
||||
71
autogpt_platform/backend/backend/copilot/message_dedup.py
Normal file
71
autogpt_platform/backend/backend/copilot/message_dedup.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Per-request idempotency lock for the /stream endpoint.
|
||||
|
||||
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
|
||||
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
|
||||
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
|
||||
``None`` when the key already exists (duplicate request).
|
||||
2. ``release()`` — deletes the key. Must be called on turn completion or turn
|
||||
error so the next legitimate send is never blocked.
|
||||
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
|
||||
the backend turn is still running, and releasing would reopen the duplicate
|
||||
window for infra-level retries. The 30 s TTL is the safety net.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_KEY_PREFIX = "chat:msg_dedup"
|
||||
_TTL_SECONDS = 30
|
||||
|
||||
|
||||
class _DedupLock:
|
||||
def __init__(self, key: str, redis) -> None:
|
||||
self._key = key
|
||||
self._redis = redis
|
||||
|
||||
async def release(self) -> None:
|
||||
"""Best-effort key deletion. The TTL handles failures silently."""
|
||||
try:
|
||||
await self._redis.delete(self._key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def acquire_dedup_lock(
|
||||
session_id: str,
|
||||
message: str | None,
|
||||
file_ids: list[str] | None,
|
||||
) -> _DedupLock | None:
|
||||
"""Acquire the idempotency lock for this (session, message, files) tuple.
|
||||
|
||||
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
|
||||
Returns ``None`` when a duplicate is detected (lock already held).
|
||||
Returns ``None`` when there is nothing to deduplicate (no message, no files).
|
||||
"""
|
||||
if not message and not file_ids:
|
||||
return None
|
||||
|
||||
sorted_ids = ":".join(sorted(file_ids or []))
|
||||
content_hash = hashlib.sha256(
|
||||
f"{session_id}:{message or ''}:{sorted_ids}".encode()
|
||||
).hexdigest()[:16]
|
||||
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
|
||||
|
||||
redis = await get_redis_async()
|
||||
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"[STREAM] Duplicate user message blocked for session {session_id}, "
|
||||
f"hash={content_hash} — returning empty SSE",
|
||||
)
|
||||
return None
|
||||
|
||||
return _DedupLock(key, redis)
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Unit tests for backend.copilot.message_dedup."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
|
||||
|
||||
|
||||
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
return mock_redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_when_no_message_no_files(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Nothing to deduplicate — no Redis call made, None returned."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
result = await acquire_dedup_lock("sess-1", None, None)
|
||||
assert result is None
|
||||
mock_redis.set.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_lock_on_first_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""First request acquires the lock and returns a _DedupLock."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
mock_redis.set.assert_called_once()
|
||||
key_arg = mock_redis.set.call_args.args[0]
|
||||
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_on_duplicate(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Duplicate request (NX fails) returns None to signal the caller."""
|
||||
_patch_redis(mocker, set_returns=None)
|
||||
result = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_key_stable_across_file_order(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""File IDs are sorted before hashing so order doesn't affect the key."""
|
||||
mock_redis_1 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
|
||||
key_ab = mock_redis_1.set.call_args.args[0]
|
||||
|
||||
mock_redis_2 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
|
||||
key_ba = mock_redis_2.set.call_args.args[0]
|
||||
|
||||
assert key_ab == key_ba
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_deletes_key(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() calls Redis delete exactly once."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release()
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_swallows_redis_error(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() must not raise even when Redis delete fails."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release() # must not raise
|
||||
mock_redis.delete.assert_called_once()
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, AsyncIterator, Self, cast
|
||||
from typing import Any, Self, cast
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
@@ -521,7 +522,10 @@ async def upsert_chat_session(
|
||||
callers are aware of the persistence failure.
|
||||
RedisError: If the cache write fails (after successful DB write).
|
||||
"""
|
||||
async with _get_session_lock(session.session_id) as _:
|
||||
# Acquire session-specific lock to prevent concurrent upserts
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
|
||||
@@ -647,50 +651,20 @@ async def _save_session_to_db(
|
||||
msg.sequence = existing_message_count + i
|
||||
|
||||
|
||||
async def append_and_save_message(
|
||||
session_id: str, message: ChatMessage
|
||||
) -> ChatSession | None:
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
"""Atomically append a message to a session and persist it.
|
||||
|
||||
Returns the updated session, or None if the message was detected as a
|
||||
duplicate (idempotency guard). Callers must check for None and skip any
|
||||
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
|
||||
|
||||
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
|
||||
The idempotency check below provides a last-resort guard when the lock degrades.
|
||||
Acquires the session lock, re-fetches the latest session state,
|
||||
appends the message, and saves — preventing message loss when
|
||||
concurrent requests modify the same session.
|
||||
"""
|
||||
async with _get_session_lock(session_id) as lock_acquired:
|
||||
# When the lock degraded (Redis down or 2s timeout), bypass cache for
|
||||
# the idempotency check. Stale cache could let two concurrent writers
|
||||
# both see the old state, pass the check, and write the same message.
|
||||
if lock_acquired:
|
||||
session = await get_chat_session(session_id)
|
||||
else:
|
||||
session = await _get_session_from_db(session_id)
|
||||
lock = await _get_session_lock(session_id)
|
||||
|
||||
async with lock:
|
||||
session = await get_chat_session(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# Idempotency: skip if the trailing block of same-role messages already
|
||||
# contains this content. Uses is_message_duplicate which checks all
|
||||
# consecutive trailing messages of the same role, not just [-1].
|
||||
#
|
||||
# This collapses infra/nginx retries whether they land on the same pod
|
||||
# (serialised by the Redis lock) or a different pod.
|
||||
#
|
||||
# Legit same-text messages are distinguished by the assistant turn
|
||||
# between them: if the user said "yes", got a response, and says
|
||||
# "yes" again, session.messages[-1] is the assistant reply, so the
|
||||
# role check fails and the second message goes through normally.
|
||||
#
|
||||
# Edge case: if a turn dies without writing any assistant message,
|
||||
# the user's next send of the same text is blocked here permanently.
|
||||
# The fix is to ensure failed turns always write an error/timeout
|
||||
# assistant message so the session always ends on an assistant turn.
|
||||
if message.content is not None and is_message_duplicate(
|
||||
session.messages, message.role, message.content
|
||||
):
|
||||
return None # duplicate — caller should skip enqueue
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
@@ -705,9 +679,6 @@ async def append_and_save_message(
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||
# Invalidate the stale entry so future reads fall back to DB,
|
||||
# preventing a retry from bypassing the idempotency check above.
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
return session
|
||||
|
||||
@@ -793,6 +764,10 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
@@ -857,38 +832,25 @@ async def update_session_title(
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
|
||||
"""Distributed Redis lock for a session, usable as an async context manager.
|
||||
|
||||
Yields True if the lock was acquired, False if it timed out or Redis was
|
||||
unavailable. Callers should treat False as a degraded mode and prefer fresh
|
||||
DB reads over cache to avoid acting on stale state.
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
|
||||
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
|
||||
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
_lock_key = f"copilot:session_lock:{session_id}"
|
||||
lock = None
|
||||
acquired = False
|
||||
try:
|
||||
_redis = await get_redis_async()
|
||||
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
|
||||
acquired = await lock.acquire(blocking=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
"Could not acquire session lock for %s within 2s", session_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
|
||||
|
||||
try:
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired and lock is not None:
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception:
|
||||
pass # TTL will expire the key
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
@@ -11,13 +11,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
append_and_save_message,
|
||||
get_chat_session,
|
||||
is_message_duplicate,
|
||||
maybe_append_user_message,
|
||||
@@ -576,345 +574,3 @@ def test_maybe_append_assistant_skips_duplicate():
|
||||
result = maybe_append_user_message(session, "dup", is_user_message=False)
|
||||
assert result is False
|
||||
assert len(session.messages) == 2
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# append_and_save_message #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
|
||||
s = ChatSession.new(user_id="u1", dry_run=False)
|
||||
s.messages = list(msgs)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_returns_none_for_duplicate(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message returns None when the trailing message is a duplicate."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="user", content="hello"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
|
||||
result = await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="hello")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_appends_new_message(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message appends a non-duplicate message and returns the session."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=2)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="second message")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
assert result is not None
|
||||
assert result.messages[-1].content == "second message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_raises_when_session_not_found(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message raises ValueError when the session does not exist."""
|
||||
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await append_and_save_message(
|
||||
"missing-session-id", ChatMessage(role="user", content="hi")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_uses_db_when_lock_degraded(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mock_get_from_db = mocker.patch(
|
||||
"backend.copilot.model._get_session_from_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
# DB path was used (not cache-first)
|
||||
mock_get_from_db.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_raises_database_error_on_save_failure(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=RuntimeError("db down"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="new msg")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=RuntimeError("redis write failed"),
|
||||
)
|
||||
mock_invalidate = mocker.patch(
|
||||
"backend.copilot.model.invalidate_session_cache",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
result = await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="new msg")
|
||||
)
|
||||
# DB write succeeded, cache invalidation was called
|
||||
mock_invalidate.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_uses_db_when_redis_unavailable(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=ConnectionError("redis down"),
|
||||
)
|
||||
mock_get_from_db = mocker.patch(
|
||||
"backend.copilot.model._get_session_from_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
mock_get_from_db.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_lock_release_failure_is_ignored(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock(
|
||||
side_effect=RuntimeError("release failed")
|
||||
)
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
assert result is not None
|
||||
|
||||
@@ -125,12 +125,7 @@ config = ChatConfig()
|
||||
|
||||
|
||||
class _SystemPromptPreset(SystemPromptPreset, total=False):
|
||||
"""Extends :class:`SystemPromptPreset` with ``exclude_dynamic_sections``.
|
||||
|
||||
The field was added to the upstream TypedDict in claude-agent-sdk 0.1.59.
|
||||
Until the package is pinned to that version we declare it locally so Pyright
|
||||
accepts the kwarg without a ``# type: ignore`` comment.
|
||||
"""
|
||||
"""Extends SystemPromptPreset with fields added in claude-agent-sdk 0.1.59."""
|
||||
|
||||
exclude_dynamic_sections: NotRequired[bool]
|
||||
|
||||
@@ -898,7 +893,7 @@ def _write_cli_session_to_disk(
|
||||
return False
|
||||
|
||||
|
||||
def read_cli_session_from_disk(
|
||||
def _read_cli_session_from_disk(
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
@@ -978,7 +973,7 @@ def read_cli_session_from_disk(
|
||||
return stripped_bytes
|
||||
|
||||
|
||||
def process_cli_restore(
|
||||
def _process_cli_restore(
|
||||
cli_restore: TranscriptDownload,
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
@@ -2494,7 +2489,9 @@ async def _restore_cli_session_for_turn(
|
||||
# session path, so we validate BEFORE any disk write.
|
||||
stripped = ""
|
||||
if cli_restore is not None and sdk_cwd:
|
||||
stripped, ok = process_cli_restore(cli_restore, sdk_cwd, session_id, log_prefix)
|
||||
stripped, ok = _process_cli_restore(
|
||||
cli_restore, sdk_cwd, session_id, log_prefix
|
||||
)
|
||||
if not ok:
|
||||
result.transcript_covers_prefix = False
|
||||
cli_restore = None
|
||||
@@ -3639,7 +3636,7 @@ async def stream_chat_completion_sdk(
|
||||
# this turn ran without --resume (restore failed or first T2+ on a new
|
||||
# pod), the T1 session file at the expected path may still be present
|
||||
# and should be re-uploaded so the next turn can resume from it.
|
||||
# read_cli_session_from_disk returns None when the file is absent, so
|
||||
# _read_cli_session_from_disk returns None when the file is absent, so
|
||||
# this is always safe.
|
||||
#
|
||||
# Intentionally NOT gated on skip_transcript_upload: that flag is set
|
||||
@@ -3668,7 +3665,7 @@ async def stream_chat_completion_sdk(
|
||||
try:
|
||||
# Read the CLI's native session file from disk (written by the CLI
|
||||
# after the turn), then upload the bytes to GCS.
|
||||
_cli_content = read_cli_session_from_disk(
|
||||
_cli_content = _read_cli_session_from_disk(
|
||||
sdk_cwd, session_id, log_prefix
|
||||
)
|
||||
if _cli_content:
|
||||
|
||||
@@ -1371,7 +1371,7 @@ class TestStripStaleThinkingBlocks:
|
||||
|
||||
|
||||
class TestProcessCliRestore:
|
||||
"""``process_cli_restore`` validates, strips, and writes CLI session to disk."""
|
||||
"""``_process_cli_restore`` validates, strips, and writes CLI session to disk."""
|
||||
|
||||
def test_writes_stripped_bytes_not_raw(self, tmp_path):
|
||||
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
|
||||
@@ -1380,7 +1380,7 @@ class TestProcessCliRestore:
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import process_cli_restore
|
||||
from backend.copilot.sdk.service import _process_cli_restore
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
|
||||
session_id = "12345678-0000-0000-0000-abcdef000001"
|
||||
@@ -1406,7 +1406,7 @@ class TestProcessCliRestore:
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
stripped_str, ok = process_cli_restore(
|
||||
stripped_str, ok = _process_cli_restore(
|
||||
restore, sdk_cwd, session_id, "[Test]"
|
||||
)
|
||||
|
||||
@@ -1433,7 +1433,7 @@ class TestProcessCliRestore:
|
||||
|
||||
def test_invalid_content_returns_false(self):
|
||||
"""Content that fails validation after strip returns (empty, False)."""
|
||||
from backend.copilot.sdk.service import process_cli_restore
|
||||
from backend.copilot.sdk.service import _process_cli_restore
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
|
||||
# A single progress-only entry — stripped result will be empty/invalid
|
||||
@@ -1442,7 +1442,7 @@ class TestProcessCliRestore:
|
||||
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
|
||||
)
|
||||
|
||||
stripped_str, ok = process_cli_restore(
|
||||
stripped_str, ok = _process_cli_restore(
|
||||
restore,
|
||||
"/tmp/nonexistent-sdk-cwd",
|
||||
"12345678-0000-0000-0000-000000000099",
|
||||
@@ -1454,7 +1454,7 @@ class TestProcessCliRestore:
|
||||
|
||||
|
||||
class TestReadCliSessionFromDisk:
|
||||
"""``read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
|
||||
"""``_read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
|
||||
|
||||
def _build_session_file(self, tmp_path, session_id: str):
|
||||
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
|
||||
@@ -1472,7 +1472,7 @@ class TestReadCliSessionFromDisk:
|
||||
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import read_cli_session_from_disk
|
||||
from backend.copilot.sdk.service import _read_cli_session_from_disk
|
||||
|
||||
session_id = "12345678-0000-0000-0000-aabbccdd0001"
|
||||
projects_base_dir = str(tmp_path)
|
||||
@@ -1491,7 +1491,7 @@ class TestReadCliSessionFromDisk:
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
|
||||
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
|
||||
assert result == b"\xff\xfe invalid utf-8\n"
|
||||
@@ -1500,7 +1500,7 @@ class TestReadCliSessionFromDisk:
|
||||
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import read_cli_session_from_disk
|
||||
from backend.copilot.sdk.service import _read_cli_session_from_disk
|
||||
|
||||
session_id = "12345678-0000-0000-0000-aabbccdd0002"
|
||||
projects_base_dir = str(tmp_path)
|
||||
@@ -1527,7 +1527,7 @@ class TestReadCliSessionFromDisk:
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
finally:
|
||||
session_file.chmod(0o644)
|
||||
|
||||
|
||||
@@ -423,33 +423,20 @@ async def subscribe_to_session(
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||
)
|
||||
|
||||
# RACE CONDITION FIX: If session not found, retry with backoff.
|
||||
# Duplicate requests skip create_session and subscribe immediately; the
|
||||
# original request's create_session (a Redis hset) may not have completed
|
||||
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
|
||||
# original request before the hset even starts.
|
||||
# RACE CONDITION FIX: If session not found, retry once after small delay
|
||||
# This handles the case where subscribe_to_session is called immediately
|
||||
# after create_session but before Redis propagates the write
|
||||
if not meta:
|
||||
_max_retries = 3
|
||||
_retry_delay = 0.1 # 100ms per attempt
|
||||
for attempt in range(_max_retries):
|
||||
logger.warning(
|
||||
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
|
||||
f"retrying after {int(_retry_delay * 1000)}ms",
|
||||
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
|
||||
)
|
||||
await asyncio.sleep(_retry_delay)
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if meta:
|
||||
logger.info(
|
||||
f"[TIMING] Session found after {attempt + 1} retries",
|
||||
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
await asyncio.sleep(0.05) # 50ms
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if not meta:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
|
||||
f"({elapsed:.1f}ms total)",
|
||||
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -459,6 +446,10 @@ async def subscribe_to_session(
|
||||
},
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"[TIMING] Session found after retry",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
session_status = meta.get("status", "")
|
||||
|
||||
@@ -143,8 +143,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
LlmModel.GROK_4_20: 5,
|
||||
LlmModel.GROK_4_20_MULTI_AGENT: 5,
|
||||
LlmModel.GROK_CODE_FAST_1: 1,
|
||||
LlmModel.KIMI_K2: 1,
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: 1,
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
@@ -34,7 +31,6 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.util.cache import cached
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
@@ -436,7 +432,7 @@ class UserCreditBase(ABC):
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
if current_balance >= ceiling_balance:
|
||||
raise ValueError(
|
||||
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
|
||||
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
)
|
||||
|
||||
# Single unified atomic operation for all transaction types using UserBalance
|
||||
@@ -575,7 +571,7 @@ class UserCreditBase(ABC):
|
||||
if amount < 0 and fail_insufficient_credits:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
|
||||
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=current_balance,
|
||||
amount=amount,
|
||||
@@ -586,6 +582,7 @@ class UserCreditBase(ABC):
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
|
||||
async def _send_refund_notification(
|
||||
self,
|
||||
notification_request: RefundRequestData,
|
||||
@@ -737,7 +734,7 @@ class UserCredit(UserCreditBase):
|
||||
)
|
||||
if request.amount <= 0 or request.amount > transaction.amount:
|
||||
raise AssertionError(
|
||||
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
|
||||
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
|
||||
)
|
||||
|
||||
balance, _ = await self._add_transaction(
|
||||
@@ -791,12 +788,12 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
# If the user has enough balance, just let them win the dispute.
|
||||
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
|
||||
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
|
||||
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
|
||||
dispute.close()
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
|
||||
f"Adding extra info for dispute from {user_id} for ${amount/100}"
|
||||
)
|
||||
# Retrieve recent transaction history to support our evidence.
|
||||
# This provides a concise timeline that shows service usage and proper credit application.
|
||||
@@ -1240,23 +1237,14 @@ async def get_stripe_customer_id(user_id: str) -> str:
|
||||
if user.stripe_customer_id:
|
||||
return user.stripe_customer_id
|
||||
|
||||
# Race protection: two concurrent calls (e.g. user double-clicks "Upgrade",
|
||||
# or any retried request) would each pass the check above and create their
|
||||
# own Stripe Customer, leaving an orphaned billable customer in Stripe.
|
||||
# Pass an idempotency_key so Stripe collapses concurrent + retried calls
|
||||
# into the same Customer object server-side. The 24h Stripe idempotency
|
||||
# window comfortably covers any realistic in-flight retry scenario.
|
||||
customer = await run_in_threadpool(
|
||||
stripe.Customer.create,
|
||||
customer = stripe.Customer.create(
|
||||
name=user.name or "",
|
||||
email=user.email,
|
||||
metadata={"user_id": user_id},
|
||||
idempotency_key=f"customer-create-{user_id}",
|
||||
)
|
||||
await User.prisma().update(
|
||||
where={"id": user_id}, data={"stripeCustomerId": customer.id}
|
||||
)
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
return customer.id
|
||||
|
||||
|
||||
@@ -1275,203 +1263,23 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
data={"subscriptionTier": tier},
|
||||
)
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
# Also invalidate the rate-limit tier cache so CoPilot picks up the new
|
||||
# tier immediately rather than waiting up to 5 minutes for the TTL to expire.
|
||||
from backend.copilot.rate_limit import get_user_tier # local import avoids circular
|
||||
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def _cancel_customer_subscriptions(
|
||||
customer_id: str,
|
||||
exclude_sub_id: str | None = None,
|
||||
at_period_end: bool = False,
|
||||
) -> int:
|
||||
"""Cancel all billable Stripe subscriptions for a customer, optionally excluding one.
|
||||
|
||||
Cancels both ``active`` and ``trialing`` subscriptions, since trialing subs will
|
||||
start billing once the trial ends and must be cleaned up on downgrade/upgrade to
|
||||
avoid double-charging or charging users who intended to cancel.
|
||||
|
||||
When ``at_period_end=True``, schedules cancellation at the end of the current
|
||||
billing period instead of cancelling immediately — the user keeps their tier
|
||||
until the period ends, then ``customer.subscription.deleted`` fires and the
|
||||
webhook downgrades them to FREE.
|
||||
|
||||
Wraps every synchronous Stripe SDK call with run_in_threadpool so the async event
|
||||
loop is never blocked. Raises stripe.StripeError on list/cancel failure so callers
|
||||
that need strict consistency can react; cleanup callers can catch and log instead.
|
||||
|
||||
Returns the number of subscriptions cancelled/scheduled for cancellation.
|
||||
"""
|
||||
# Query active and trialing separately; Stripe's list API accepts a single status
|
||||
# filter at a time (no OR), and we explicitly want to skip canceled/incomplete/
|
||||
# past_due subs rather than filter them out client-side via status="all".
|
||||
seen_ids: set[str] = set()
|
||||
for status in ("active", "trialing"):
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status=status, limit=10
|
||||
)
|
||||
# Iterate only the first page (up to 10); avoid auto_paging_iter which would
|
||||
# trigger additional sync HTTP calls inside the event loop.
|
||||
if subscriptions.has_more:
|
||||
logger.error(
|
||||
"_cancel_customer_subscriptions: customer %s has more than 10 %s"
|
||||
" subscriptions — only the first page was processed; remaining"
|
||||
" subscriptions were NOT cancelled",
|
||||
customer_id,
|
||||
status,
|
||||
async def cancel_stripe_subscription(user_id: str) -> None:
|
||||
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
subscriptions = stripe.Subscription.list(
|
||||
customer=customer_id, status="active", limit=10
|
||||
)
|
||||
for sub in subscriptions.auto_paging_iter():
|
||||
try:
|
||||
stripe.Subscription.cancel(sub["id"])
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"cancel_stripe_subscription: failed to cancel sub %s for user %s",
|
||||
sub["id"],
|
||||
user_id,
|
||||
)
|
||||
for sub in subscriptions.data:
|
||||
sub_id = sub["id"]
|
||||
if exclude_sub_id and sub_id == exclude_sub_id:
|
||||
continue
|
||||
if sub_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(sub_id)
|
||||
if at_period_end:
|
||||
await run_in_threadpool(
|
||||
stripe.Subscription.modify, sub_id, cancel_at_period_end=True
|
||||
)
|
||||
else:
|
||||
await run_in_threadpool(stripe.Subscription.cancel, sub_id)
|
||||
return len(seen_ids)
|
||||
|
||||
|
||||
async def cancel_stripe_subscription(user_id: str) -> bool:
|
||||
"""Schedule cancellation of all active/trialing Stripe subscriptions at period end.
|
||||
|
||||
The subscription stays active until the end of the billing period so the user
|
||||
keeps their tier for the time they already paid for. The ``customer.subscription.deleted``
|
||||
webhook fires at period end and downgrades the DB tier to FREE.
|
||||
|
||||
Returns True if at least one subscription was found and scheduled for cancellation,
|
||||
False if the customer had no active/trialing subscriptions (e.g., admin-granted tier
|
||||
with no associated Stripe subscription). When False, the caller should update the
|
||||
DB tier directly since no webhook will fire to do it.
|
||||
|
||||
Raises stripe.StripeError if any modification fails, so the caller can avoid
|
||||
updating the DB tier when Stripe is inconsistent.
|
||||
"""
|
||||
# Guard: only proceed if the user already has a Stripe customer ID. Calling
|
||||
# get_stripe_customer_id for a user who has never had a paid subscription would
|
||||
# create an orphaned, potentially-billable Stripe Customer object — we avoid that
|
||||
# by returning False early so the caller can downgrade the DB tier directly.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return False
|
||||
|
||||
customer_id = user.stripe_customer_id
|
||||
try:
|
||||
cancelled_count = await _cancel_customer_subscriptions(
|
||||
customer_id, at_period_end=True
|
||||
)
|
||||
return cancelled_count > 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"cancel_stripe_subscription: Stripe error while cancelling subs for user %s",
|
||||
user_id,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
|
||||
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
|
||||
|
||||
Fetches the user's active Stripe subscription to determine how many seconds
|
||||
remain in the current billing period, then calculates the unused portion of
|
||||
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
|
||||
is found.
|
||||
"""
|
||||
if monthly_cost_cents <= 0:
|
||||
return 0
|
||||
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
|
||||
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
|
||||
# orphaned customer on every billing-page load for those users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return 0
|
||||
try:
|
||||
customer_id = user.stripe_customer_id
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status="active", limit=1
|
||||
)
|
||||
if not subscriptions.data:
|
||||
return 0
|
||||
sub = subscriptions.data[0]
|
||||
period_start: int = sub["current_period_start"]
|
||||
period_end: int = sub["current_period_end"]
|
||||
now = int(time.time())
|
||||
total_seconds = period_end - period_start
|
||||
remaining_seconds = max(period_end - now, 0)
|
||||
if total_seconds <= 0:
|
||||
return 0
|
||||
return int(monthly_cost_cents * remaining_seconds / total_seconds)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"get_proration_credit_cents: failed to compute proration for user %s",
|
||||
user_id,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
async def modify_stripe_subscription_for_tier(
|
||||
user_id: str, tier: SubscriptionTier
|
||||
) -> bool:
|
||||
"""Modify an existing Stripe subscription to a new paid tier using proration.
|
||||
|
||||
For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing
|
||||
subscription is preferable to cancelling + creating a new one via Checkout:
|
||||
Stripe handles proration automatically, crediting unused time on the old plan
|
||||
and charging the pro-rated amount for the new plan in the same billing cycle.
|
||||
|
||||
Returns:
|
||||
True — a subscription was found and modified successfully.
|
||||
False — no active/trialing subscription exists (e.g. admin-granted tier or
|
||||
first-time paid signup); caller should fall back to Checkout.
|
||||
|
||||
Raises stripe.StripeError on API failures so callers can propagate a 502.
|
||||
Raises ValueError when no Stripe price ID is configured for the tier.
|
||||
"""
|
||||
price_id = await get_subscription_price_id(tier)
|
||||
if not price_id:
|
||||
raise ValueError(f"No Stripe price ID configured for tier {tier}")
|
||||
|
||||
# Guard: only proceed if the user already has a Stripe customer ID. Calling
|
||||
# get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier)
|
||||
# would create an orphaned customer object if the subsequent Subscription.list call
|
||||
# fails. Return False early so the API layer falls back to Checkout instead.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return False
|
||||
|
||||
customer_id = user.stripe_customer_id
|
||||
for status in ("active", "trialing"):
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status=status, limit=1
|
||||
)
|
||||
if not subscriptions.data:
|
||||
continue
|
||||
sub = subscriptions.data[0]
|
||||
sub_id = sub["id"]
|
||||
items = sub.get("items", {}).get("data", [])
|
||||
if not items:
|
||||
continue
|
||||
item_id = items[0]["id"]
|
||||
await run_in_threadpool(
|
||||
stripe.Subscription.modify,
|
||||
sub_id,
|
||||
items=[{"id": item_id, "price": price_id}],
|
||||
proration_behavior="create_prorations",
|
||||
)
|
||||
logger.info(
|
||||
"modify_stripe_subscription_for_tier: modified sub %s for user %s → %s",
|
||||
sub_id,
|
||||
user_id,
|
||||
tier,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
@@ -1483,19 +1291,8 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
return AutoTopUpConfig.model_validate(user.top_up_config)
|
||||
|
||||
|
||||
@cached(ttl_seconds=60, maxsize=8, cache_none=False)
|
||||
async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
|
||||
"""Return Stripe Price ID for a tier from LaunchDarkly, cached for 60 seconds.
|
||||
|
||||
Price IDs are LaunchDarkly flag values that change only at deploy time.
|
||||
Caching for 60 seconds avoids hitting the LD SDK on every webhook delivery
|
||||
and every GET /credits/subscription page load (called 2x per request).
|
||||
|
||||
``cache_none=False`` prevents a transient LD failure from caching ``None``
|
||||
and blocking subscription upgrades for the full 60-second TTL window.
|
||||
A tier with no configured flag (FREE, ENTERPRISE) returns ``None`` from an
|
||||
O(1) dict lookup before hitting LD, so the extra LD call is never made.
|
||||
"""
|
||||
"""Return Stripe Price ID for a tier from LaunchDarkly. None = not configured."""
|
||||
flag_map = {
|
||||
SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO,
|
||||
SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS,
|
||||
@@ -1503,7 +1300,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
|
||||
flag = flag_map.get(tier)
|
||||
if flag is None:
|
||||
return None
|
||||
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
|
||||
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
|
||||
return price_id if isinstance(price_id, str) and price_id else None
|
||||
|
||||
|
||||
@@ -1518,8 +1315,7 @@ async def create_subscription_checkout(
|
||||
if not price_id:
|
||||
raise ValueError(f"Subscription not available for tier {tier.value}")
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
session = await run_in_threadpool(
|
||||
stripe.checkout.Session.create,
|
||||
session = stripe.checkout.Session.create(
|
||||
customer=customer_id,
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
@@ -1527,111 +1323,26 @@ async def create_subscription_checkout(
|
||||
cancel_url=cancel_url,
|
||||
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
|
||||
)
|
||||
if not session.url:
|
||||
# An empty checkout URL for a paid upgrade is always an error; surfacing it
|
||||
# as ValueError means the API handler returns 422 instead of silently
|
||||
# redirecting the client to an empty URL.
|
||||
raise ValueError("Stripe did not return a checkout session URL")
|
||||
return session.url
|
||||
|
||||
|
||||
async def _cleanup_stale_subscriptions(customer_id: str, new_sub_id: str) -> None:
|
||||
"""Best-effort cancel of any active subs for the customer other than new_sub_id.
|
||||
|
||||
Called from the webhook handler after a new subscription becomes active. Failures
|
||||
are logged but not raised so a transient Stripe error doesn't crash the webhook —
|
||||
a periodic reconciliation job is the intended backstop for persistent drift.
|
||||
|
||||
NOTE: until that reconcile job lands, a failure here means the user is silently
|
||||
billed for two simultaneous subscriptions. The error log below is intentionally
|
||||
`logger.exception` so it surfaces in Sentry with the customer/sub IDs needed to
|
||||
manually reconcile, and the metric `stripe_stale_subscription_cleanup_failed`
|
||||
is bumped so on-call can alert on persistent drift.
|
||||
TODO(#stripe-reconcile-job): replace this best-effort cleanup with a periodic
|
||||
reconciliation job that queries Stripe for customers with >1 active sub.
|
||||
"""
|
||||
try:
|
||||
await _cancel_customer_subscriptions(customer_id, exclude_sub_id=new_sub_id)
|
||||
except stripe.StripeError:
|
||||
# Use exception() (not warning) so this surfaces as an error in Sentry —
|
||||
# any failure here means a paid-to-paid upgrade may have left the user
|
||||
# with two simultaneous active subscriptions.
|
||||
logger.exception(
|
||||
"stripe_stale_subscription_cleanup_failed: customer=%s new_sub=%s —"
|
||||
" user may be billed for two simultaneous subscriptions; manual"
|
||||
" reconciliation required",
|
||||
customer_id,
|
||||
new_sub_id,
|
||||
)
|
||||
return session.url or ""
|
||||
|
||||
|
||||
async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
"""Update User.subscriptionTier from a Stripe subscription object.
|
||||
|
||||
Expected shape of stripe_subscription (subset of Stripe's Subscription object):
|
||||
customer: str — Stripe customer ID
|
||||
status: str — "active" | "trialing" | "canceled" | ...
|
||||
id: str — Stripe subscription ID
|
||||
items.data[].price.id: str — Stripe price ID identifying the tier
|
||||
"""
|
||||
customer_id = stripe_subscription.get("customer")
|
||||
if not customer_id:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: missing 'customer' field in event, "
|
||||
"skipping (keys: %s)",
|
||||
list(stripe_subscription.keys()),
|
||||
)
|
||||
return
|
||||
"""Update User.subscriptionTier from a Stripe subscription object."""
|
||||
customer_id = stripe_subscription["customer"]
|
||||
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
|
||||
if not user:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: no user for customer %s", customer_id
|
||||
)
|
||||
return
|
||||
# Cross-check: if the subscription carries a metadata.user_id (set during
|
||||
# Checkout Session creation), verify it matches the user we found via
|
||||
# stripeCustomerId. A mismatch indicates a customer↔user mapping
|
||||
# inconsistency — updating the wrong user's tier would be a data-corruption
|
||||
# bug, so we log loudly and bail out. Absence of metadata.user_id (e.g.
|
||||
# subscriptions created outside the Checkout flow) is not an error — we
|
||||
# simply skip the check and proceed with the customer-ID-based lookup.
|
||||
metadata = stripe_subscription.get("metadata") or {}
|
||||
metadata_user_id = metadata.get("user_id") if isinstance(metadata, dict) else None
|
||||
if metadata_user_id and metadata_user_id != user.id:
|
||||
logger.error(
|
||||
"sync_subscription_from_stripe: metadata.user_id=%s does not match"
|
||||
" user.id=%s found via stripeCustomerId=%s — refusing to update tier"
|
||||
" to avoid corrupting the wrong user's subscription state",
|
||||
metadata_user_id,
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
# ENTERPRISE tiers are admin-managed. Never let a Stripe webhook flip an
|
||||
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
|
||||
# a self-service Stripe sub, it's a data-consistency issue for an operator,
|
||||
# not something the webhook should automatically "fix".
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.FREE
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
|
||||
" for user %s (customer %s); event status=%s",
|
||||
user.id,
|
||||
customer_id,
|
||||
stripe_subscription.get("status", ""),
|
||||
)
|
||||
return
|
||||
status = stripe_subscription.get("status", "")
|
||||
new_sub_id = stripe_subscription.get("id", "")
|
||||
if status in ("active", "trialing"):
|
||||
price_id = ""
|
||||
items = stripe_subscription.get("items", {}).get("data", [])
|
||||
if items:
|
||||
price_id = items[0].get("price", {}).get("id", "")
|
||||
pro_price, biz_price = await asyncio.gather(
|
||||
get_subscription_price_id(SubscriptionTier.PRO),
|
||||
get_subscription_price_id(SubscriptionTier.BUSINESS),
|
||||
)
|
||||
pro_price = await get_subscription_price_id(SubscriptionTier.PRO)
|
||||
biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS)
|
||||
if price_id and pro_price and price_id == pro_price:
|
||||
tier = SubscriptionTier.PRO
|
||||
elif price_id and biz_price and price_id == biz_price:
|
||||
@@ -1648,206 +1359,10 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
)
|
||||
return
|
||||
else:
|
||||
# A subscription was cancelled or ended. DO NOT unconditionally downgrade
|
||||
# to FREE — Stripe does not guarantee webhook delivery order, so a
|
||||
# `customer.subscription.deleted` for the OLD sub can arrive after we've
|
||||
# already processed `customer.subscription.created` for a new paid sub.
|
||||
# Ask Stripe whether any OTHER active/trialing subs exist for this
|
||||
# customer; if they do, keep the user's current tier (the other sub's
|
||||
# own event will/has already set the correct tier).
|
||||
try:
|
||||
other_subs_active, other_subs_trialing = await asyncio.gather(
|
||||
run_in_threadpool(
|
||||
stripe.Subscription.list,
|
||||
customer=customer_id,
|
||||
status="active",
|
||||
limit=10,
|
||||
),
|
||||
run_in_threadpool(
|
||||
stripe.Subscription.list,
|
||||
customer=customer_id,
|
||||
status="trialing",
|
||||
limit=10,
|
||||
),
|
||||
)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: could not verify other active"
|
||||
" subs for customer %s on cancel event %s; preserving current"
|
||||
" tier to avoid an unsafe downgrade",
|
||||
customer_id,
|
||||
new_sub_id,
|
||||
)
|
||||
return
|
||||
# Filter out the cancelled subscription to check if other active subs
|
||||
# exist. When new_sub_id is empty (malformed event with no 'id' field),
|
||||
# we cannot safely exclude any sub — preserve current tier to avoid
|
||||
# an unsafe downgrade on a malformed webhook payload.
|
||||
if not new_sub_id:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: cancel event missing 'id' field"
|
||||
" for customer %s; preserving current tier",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
other_active_ids = {sub["id"] for sub in other_subs_active.data} - {new_sub_id}
|
||||
other_trialing_ids = {sub["id"] for sub in other_subs_trialing.data} - {
|
||||
new_sub_id
|
||||
}
|
||||
still_has_active_sub = bool(other_active_ids or other_trialing_ids)
|
||||
if still_has_active_sub:
|
||||
logger.info(
|
||||
"sync_subscription_from_stripe: sub %s cancelled but customer %s"
|
||||
" still has another active sub; keeping tier %s",
|
||||
new_sub_id,
|
||||
customer_id,
|
||||
current_tier.value,
|
||||
)
|
||||
return
|
||||
tier = SubscriptionTier.FREE
|
||||
# Idempotency: Stripe retries webhooks on delivery failure, and several event
|
||||
# types map to the same final tier. Skip the DB write + cache invalidation
|
||||
# when the tier is already correct to avoid redundant writes on replay.
|
||||
if current_tier == tier:
|
||||
return
|
||||
# When a new subscription becomes active (e.g. paid-to-paid tier upgrade
|
||||
# via a fresh Checkout Session), cancel any OTHER active subscriptions for
|
||||
# the same customer so the user isn't billed twice. We do this in the
|
||||
# webhook rather than the API handler so that abandoning the checkout
|
||||
# doesn't leave the user without a subscription.
|
||||
# IMPORTANT: this runs AFTER the idempotency check above so that webhook
|
||||
# replays for an already-applied event do NOT trigger another cleanup round
|
||||
# (which could otherwise cancel a legitimately new subscription the user
|
||||
# signed up for between the original event and its replay).
|
||||
if status in ("active", "trialing") and new_sub_id:
|
||||
# NOTE: paid-to-paid upgrade race (e.g. PRO → BUSINESS):
|
||||
# _cleanup_stale_subscriptions cancels the old PRO sub before
|
||||
# set_subscription_tier writes BUSINESS to the DB. If Stripe delivers
|
||||
# the PRO `customer.subscription.deleted` event concurrently and it
|
||||
# processes after the PRO cancel but before set_subscription_tier
|
||||
# commits, the user could momentarily appear as FREE in the DB.
|
||||
# This window is very short in practice (two sequential awaits),
|
||||
# but is a known limitation of the current webhook-driven approach.
|
||||
# A future improvement would be to write the new tier first, then
|
||||
# cancel the old sub.
|
||||
await _cleanup_stale_subscriptions(customer_id, new_sub_id)
|
||||
await set_subscription_tier(user.id, tier)
|
||||
|
||||
|
||||
async def handle_subscription_payment_failure(invoice: dict) -> None:
|
||||
"""Handle a failed Stripe subscription payment.
|
||||
|
||||
Tries to cover the invoice amount from the user's credit balance.
|
||||
|
||||
- Balance sufficient → deduct from balance, then pay the Stripe invoice so
|
||||
Stripe stops retrying it. The sub stays intact and the user keeps their tier.
|
||||
- Balance insufficient → cancel Stripe sub immediately, downgrade to FREE.
|
||||
Cancelling here avoids further Stripe retries on an invoice we cannot cover.
|
||||
"""
|
||||
customer_id = invoice.get("customer")
|
||||
if not customer_id:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: missing customer in invoice; skipping"
|
||||
)
|
||||
return
|
||||
|
||||
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
|
||||
if not user:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: no user found for customer %s",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.FREE
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
|
||||
" (customer %s) — tier is admin-managed",
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
amount_due: int = invoice.get("amount_due", 0)
|
||||
sub_id: str = invoice.get("subscription", "")
|
||||
invoice_id: str = invoice.get("id", "")
|
||||
|
||||
if amount_due <= 0:
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: amount_due=%d for user %s;"
|
||||
" nothing to deduct",
|
||||
amount_due,
|
||||
user.id,
|
||||
)
|
||||
return
|
||||
|
||||
credit_model = UserCredit()
|
||||
try:
|
||||
await credit_model._add_transaction(
|
||||
user_id=user.id,
|
||||
amount=-amount_due,
|
||||
transaction_type=CreditTransactionType.SUBSCRIPTION,
|
||||
fail_insufficient_credits=True,
|
||||
# Use invoice_id as the idempotency key so that Stripe webhook retries
|
||||
# (e.g. on a transient stripe.Invoice.pay failure) do not double-charge.
|
||||
transaction_key=invoice_id or None,
|
||||
metadata=SafeJson(
|
||||
{
|
||||
"stripe_customer_id": customer_id,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"reason": "subscription_payment_failure_covered_by_balance",
|
||||
}
|
||||
),
|
||||
)
|
||||
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
|
||||
# system stops retrying it — without this call Stripe would retry automatically
|
||||
# and re-trigger this webhook, causing double-deductions each retry cycle.
|
||||
if invoice_id:
|
||||
try:
|
||||
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: balance deducted for user"
|
||||
" %s but failed to mark invoice %s as paid; Stripe may retry",
|
||||
user.id,
|
||||
invoice_id,
|
||||
)
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: deducted %d cents from balance"
|
||||
" for user %s; Stripe invoice %s paid, sub %s intact, tier preserved",
|
||||
amount_due,
|
||||
user.id,
|
||||
invoice_id,
|
||||
sub_id,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# Balance insufficient — cancel Stripe subscription first, then downgrade DB.
|
||||
# Order matters: if we downgrade the DB first and the Stripe cancel fails, the
|
||||
# user is permanently stuck on FREE while Stripe continues billing them.
|
||||
# Cancelling Stripe first is safe: if the DB write then fails, the webhook
|
||||
# customer.subscription.deleted will fire and correct the tier eventually.
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: insufficient balance for user %s;"
|
||||
" cancelling Stripe sub %s then downgrading to FREE",
|
||||
user.id,
|
||||
sub_id,
|
||||
)
|
||||
try:
|
||||
await _cancel_customer_subscriptions(customer_id)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
|
||||
" for user %s (customer %s); skipping tier downgrade to avoid"
|
||||
" inconsistency — Stripe may continue retrying the invoice",
|
||||
sub_id,
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
await set_subscription_tier(user.id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -73,31 +73,6 @@ def _get_redis() -> Redis:
|
||||
return r
|
||||
|
||||
|
||||
class _MissingType:
|
||||
"""Singleton sentinel type — distinct from ``None`` (a valid cached value).
|
||||
|
||||
Using a dedicated class (instead of ``Any = object()``) lets mypy prove
|
||||
that comparisons ``result is _MISSING`` narrow the type correctly and
|
||||
prevents accidental use of the sentinel where a real value is expected.
|
||||
"""
|
||||
|
||||
_instance: "_MissingType | None" = None
|
||||
|
||||
def __new__(cls) -> "_MissingType":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<MISSING>"
|
||||
|
||||
|
||||
# Sentinel returned by ``_get_from_memory`` / ``_get_from_redis`` to mean
|
||||
# "no entry exists" — distinct from a cached ``None`` value, which is a
|
||||
# valid result for callers that opt into caching it.
|
||||
_MISSING = _MissingType()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedValue:
|
||||
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
|
||||
@@ -185,7 +160,6 @@ def cached(
|
||||
ttl_seconds: int,
|
||||
shared_cache: bool = False,
|
||||
refresh_ttl_on_get: bool = False,
|
||||
cache_none: bool = True,
|
||||
) -> Callable[[Callable[P, R]], CachedFunction[P, R]]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
@@ -198,10 +172,6 @@ def cached(
|
||||
ttl_seconds: Time to live in seconds. Required - entries must expire.
|
||||
shared_cache: If True, use Redis for cross-process caching
|
||||
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
|
||||
cache_none: If True (default) ``None`` is cached like any other value.
|
||||
Set to ``False`` for functions that return ``None`` to signal a
|
||||
transient error and should be re-tried on the next call without
|
||||
poisoning the cache (e.g. external API calls that may fail).
|
||||
|
||||
Returns:
|
||||
Decorated function with caching capabilities
|
||||
@@ -214,12 +184,6 @@ def cached(
|
||||
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
async def fetch_external(id: str) -> dict | None:
|
||||
# Returns None on transient error — won't be stored,
|
||||
# next call retries instead of returning the stale None.
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
|
||||
@@ -227,14 +191,9 @@ def cached(
|
||||
cache_storage: dict[tuple, CachedValue] = {}
|
||||
_event_loop_locks: dict[Any, asyncio.Lock] = {}
|
||||
|
||||
def _get_from_redis(redis_key: str) -> Any:
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
"""Get value from Redis, optionally refreshing TTL.
|
||||
|
||||
Returns the cached value (which may be ``None``) on a hit, or the
|
||||
module-level ``_MISSING`` sentinel on a miss / corrupt entry.
|
||||
Callers must compare with ``is _MISSING`` so cached ``None`` values
|
||||
are not mistaken for misses.
|
||||
|
||||
Values are expected to carry an HMAC-SHA256 prefix for integrity
|
||||
verification. Unsigned (legacy) or tampered entries are silently
|
||||
discarded and treated as cache misses, so the caller recomputes and
|
||||
@@ -254,11 +213,11 @@ def cached(
|
||||
f"for {func_name}, discarding entry: "
|
||||
"possible tampering or legacy unsigned value"
|
||||
)
|
||||
return _MISSING
|
||||
return None
|
||||
return pickle.loads(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis error during cache check for {func_name}: {e}")
|
||||
return _MISSING
|
||||
return None
|
||||
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set HMAC-signed pickled value in Redis with TTL."""
|
||||
@@ -268,13 +227,8 @@ def cached(
|
||||
except Exception as e:
|
||||
logger.error(f"Redis error storing cache for {func_name}: {e}")
|
||||
|
||||
def _get_from_memory(key: tuple) -> Any:
|
||||
"""Get value from in-memory cache, checking TTL.
|
||||
|
||||
Returns the cached value (which may be ``None``) on a hit, or the
|
||||
``_MISSING`` sentinel on a miss / TTL expiry. See
|
||||
``_get_from_redis`` for the rationale.
|
||||
"""
|
||||
def _get_from_memory(key: tuple) -> Any | None:
|
||||
"""Get value from in-memory cache, checking TTL."""
|
||||
if key in cache_storage:
|
||||
cached_data = cache_storage[key]
|
||||
if time.time() - cached_data.timestamp < ttl_seconds:
|
||||
@@ -282,7 +236,7 @@ def cached(
|
||||
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
|
||||
)
|
||||
return cached_data.result
|
||||
return _MISSING
|
||||
return None
|
||||
|
||||
def _set_to_memory(key: tuple, value: Any) -> None:
|
||||
"""Set value in in-memory cache with timestamp."""
|
||||
@@ -316,11 +270,11 @@ def cached(
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not _MISSING:
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not _MISSING:
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
@@ -328,24 +282,22 @@ def cached(
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not _MISSING:
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not _MISSING:
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result (skip ``None`` if the caller opted out of
|
||||
# caching it — used for transient-error sentinels).
|
||||
if cache_none or result is not None:
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
@@ -363,11 +315,11 @@ def cached(
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not _MISSING:
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not _MISSING:
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
@@ -375,24 +327,22 @@ def cached(
|
||||
# Double-check: another thread might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not _MISSING:
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not _MISSING:
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result (skip ``None`` if the caller opted out of
|
||||
# caching it — used for transient-error sentinels).
|
||||
if cache_none or result is not None:
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1223,123 +1223,3 @@ class TestCacheHMAC:
|
||||
assert call_count == 2
|
||||
|
||||
legacy_test_fn.cache_clear()
|
||||
|
||||
|
||||
class TestCacheNoneHandling:
|
||||
"""Tests for the ``cache_none`` parameter on the @cached decorator.
|
||||
|
||||
Sentry bug PRRT_kwDOJKSTjM56RTEu (HIGH): the cache previously could not
|
||||
distinguish "no entry" from "entry is None", so any function returning
|
||||
``None`` was effectively re-executed on every call. The fix is a
|
||||
sentinel-based check inside the wrappers, plus an opt-out
|
||||
``cache_none=False`` flag for callers that *want* errors to retry.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_none_is_cached_by_default(self):
|
||||
"""With ``cache_none=True`` (default), cached ``None`` is returned
|
||||
from the cache instead of triggering re-execution."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
async def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Second call should hit the cache, not re-execute.
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Different argument is a different cache key — re-executes.
|
||||
assert await maybe_none(2) is None
|
||||
assert call_count == 2
|
||||
|
||||
def test_sync_none_is_cached_by_default(self):
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
assert maybe_none(1) is None
|
||||
assert maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_none_false_skips_storing_none(self):
|
||||
"""``cache_none=False`` skips storing ``None`` so transient errors
|
||||
are retried on the next call instead of poisoning the cache."""
|
||||
call_count = 0
|
||||
results: list[int | None] = [None, None, 42]
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
async def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
result = results[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
# First call: returns None, NOT stored.
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same key: re-executes (None wasn't cached).
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 2
|
||||
|
||||
# Third call: returns 42, this time it IS stored.
|
||||
assert await maybe_none(1) == 42
|
||||
assert call_count == 3
|
||||
|
||||
# Fourth call: cache hit on the stored 42.
|
||||
assert await maybe_none(1) == 42
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_cache_none_false_skips_storing_none(self):
|
||||
call_count = 0
|
||||
results: list[int | None] = [None, 99]
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
result = results[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
assert maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# None was not stored — re-executes.
|
||||
assert maybe_none(1) == 99
|
||||
assert call_count == 2
|
||||
|
||||
# 99 IS stored — no re-execution.
|
||||
assert maybe_none(1) == 99
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_none_is_cached_by_default(self):
|
||||
"""Shared (Redis) cache also properly returns cached ``None`` values."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def maybe_none_redis(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
maybe_none_redis.cache_clear()
|
||||
|
||||
assert await maybe_none_redis(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
assert await maybe_none_redis(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
maybe_none_redis.cache_clear()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
@@ -102,12 +101,6 @@ async def _fetch_user_context_data(user_id: str) -> Context:
|
||||
"""
|
||||
builder = Context.builder(user_id).kind("user").anonymous(True)
|
||||
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
|
||||
return builder.build()
|
||||
|
||||
try:
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
|
||||
@@ -40,8 +40,6 @@
|
||||
"folder_id": null,
|
||||
"folder_name": null,
|
||||
"recommended_schedule_cron": null,
|
||||
"is_scheduled": false,
|
||||
"next_scheduled_run": null,
|
||||
"settings": {
|
||||
"human_in_the_loop_safe_mode": true,
|
||||
"sensitive_action_safe_mode": false
|
||||
@@ -88,8 +86,6 @@
|
||||
"folder_id": null,
|
||||
"folder_name": null,
|
||||
"recommended_schedule_cron": null,
|
||||
"is_scheduled": false,
|
||||
"next_scheduled_run": null,
|
||||
"settings": {
|
||||
"human_in_the_loop_safe_mode": true,
|
||||
"sensitive_action_safe_mode": false
|
||||
|
||||
@@ -155,7 +155,6 @@
|
||||
"@types/twemoji": "13.1.2",
|
||||
"@vitejs/plugin-react": "5.1.2",
|
||||
"@vitest/coverage-v8": "4.0.17",
|
||||
"agentation": "3.0.2",
|
||||
"axe-playwright": "2.2.2",
|
||||
"chromatic": "13.3.3",
|
||||
"concurrently": "9.2.1",
|
||||
|
||||
19
autogpt_platform/frontend/pnpm-lock.yaml
generated
19
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -376,9 +376,6 @@ importers:
|
||||
'@vitest/coverage-v8':
|
||||
specifier: 4.0.17
|
||||
version: 4.0.17(vitest@4.0.17(@opentelemetry/api@1.9.0)(@types/node@24.10.0)(happy-dom@20.3.4)(jiti@2.6.1)(jsdom@27.4.0)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(terser@5.44.1)(yaml@2.8.2))
|
||||
agentation:
|
||||
specifier: 3.0.2
|
||||
version: 3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
axe-playwright:
|
||||
specifier: 2.2.2
|
||||
version: 2.2.2(playwright@1.56.1)
|
||||
@@ -4122,17 +4119,6 @@ packages:
|
||||
resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==}
|
||||
engines: {node: '>= 14'}
|
||||
|
||||
agentation@3.0.2:
|
||||
resolution: {integrity: sha512-iGzBxFVTuZEIKzLY6AExSLAQH6i6SwxV4pAu7v7m3X6bInZ7qlZXAwrEqyc4+EfP4gM7z2RXBF6SF4DeH0f2lA==}
|
||||
peerDependencies:
|
||||
react: '>=18.0.0'
|
||||
react-dom: '>=18.0.0'
|
||||
peerDependenciesMeta:
|
||||
react:
|
||||
optional: true
|
||||
react-dom:
|
||||
optional: true
|
||||
|
||||
ai@6.0.134:
|
||||
resolution: {integrity: sha512-YalNEaavld/kE444gOcsMKXdVVRGEe0SK77fAFcWYcqLg+a7xKnEet8bdfrEAJTfnMjj01rhgrIL10903w1a5Q==}
|
||||
engines: {node: '>=18'}
|
||||
@@ -13133,11 +13119,6 @@ snapshots:
|
||||
agent-base@7.1.4:
|
||||
optional: true
|
||||
|
||||
agentation@3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
optionalDependencies:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
ai@6.0.134(zod@3.25.76):
|
||||
dependencies:
|
||||
'@ai-sdk/gateway': 3.0.77(zod@3.25.76)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { getNodeDisplayName, serializeGraphForChat } from "../helpers";
|
||||
import { serializeGraphForChat } from "../helpers";
|
||||
import type { CustomNode } from "../../FlowEditor/nodes/CustomNode/CustomNode";
|
||||
|
||||
describe("serializeGraphForChat – XML injection prevention", () => {
|
||||
@@ -53,53 +53,3 @@ describe("serializeGraphForChat – XML injection prevention", () => {
|
||||
expect(result).toContain("<injection>");
|
||||
});
|
||||
});
|
||||
|
||||
function makeNode(overrides: Partial<CustomNode["data"]> = {}): CustomNode {
|
||||
return {
|
||||
id: "node-1",
|
||||
data: {
|
||||
title: "AgentExecutorBlock",
|
||||
description: "",
|
||||
hardcodedValues: {},
|
||||
inputSchema: {},
|
||||
outputSchema: {},
|
||||
uiType: "agent",
|
||||
block_id: "b1",
|
||||
costs: [],
|
||||
categories: [],
|
||||
...overrides,
|
||||
},
|
||||
type: "custom" as const,
|
||||
position: { x: 0, y: 0 },
|
||||
} as unknown as CustomNode;
|
||||
}
|
||||
|
||||
describe("getNodeDisplayName", () => {
|
||||
it("returns fallback when node is undefined", () => {
|
||||
expect(getNodeDisplayName(undefined, "fallback-id")).toBe("fallback-id");
|
||||
});
|
||||
|
||||
it("returns customized_name when set", () => {
|
||||
const node = makeNode({
|
||||
metadata: { customized_name: "My Agent" } as any,
|
||||
});
|
||||
expect(getNodeDisplayName(node, "fallback")).toBe("My Agent");
|
||||
});
|
||||
|
||||
it("returns agent_name with version via getNodeDisplayTitle delegation", () => {
|
||||
const node = makeNode({
|
||||
hardcodedValues: { agent_name: "Researcher", graph_version: 3 },
|
||||
});
|
||||
expect(getNodeDisplayName(node, "fallback")).toBe("Researcher v3");
|
||||
});
|
||||
|
||||
it("returns block title when no custom or agent name", () => {
|
||||
const node = makeNode({ title: "SomeBlock" });
|
||||
expect(getNodeDisplayName(node, "fallback")).toBe("SomeBlock");
|
||||
});
|
||||
|
||||
it("returns fallback when title is empty", () => {
|
||||
const node = makeNode({ title: "" });
|
||||
expect(getNodeDisplayName(node, "fallback")).toBe("fallback");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { CustomNode } from "../FlowEditor/nodes/CustomNode/CustomNode";
|
||||
import type { CustomEdge } from "../FlowEditor/edges/CustomEdge";
|
||||
import { getNodeDisplayTitle } from "../FlowEditor/nodes/CustomNode/helpers";
|
||||
|
||||
/** Maximum nodes serialized into the AI context to prevent token overruns. */
|
||||
const MAX_NODES = 100;
|
||||
@@ -145,16 +144,18 @@ export function getActionKey(action: GraphAction): string {
|
||||
|
||||
/**
|
||||
* Resolves the display name for a node: prefers the user-customized name,
|
||||
* then agent name from hardcodedValues, then block title, then fallback ID.
|
||||
* Delegates to `getNodeDisplayTitle` for the 3-tier resolution logic.
|
||||
* falls back to the block title, then to the raw ID.
|
||||
* Shared between `serializeGraphForChat` and `ActionItem` to avoid duplication.
|
||||
*/
|
||||
export function getNodeDisplayName(
|
||||
node: CustomNode | undefined,
|
||||
fallback: string,
|
||||
): string {
|
||||
if (!node) return fallback;
|
||||
return getNodeDisplayTitle(node.data) || fallback;
|
||||
return (
|
||||
(node?.data.metadata?.customized_name as string | undefined) ||
|
||||
node?.data.title ||
|
||||
fallback
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { getNodeDisplayTitle, formatNodeDisplayTitle } from "../helpers";
|
||||
import { CustomNodeData } from "../CustomNode";
|
||||
|
||||
function makeNodeData(overrides: Partial<CustomNodeData> = {}): CustomNodeData {
|
||||
return {
|
||||
title: "AgentExecutorBlock",
|
||||
description: "",
|
||||
hardcodedValues: {},
|
||||
inputSchema: {},
|
||||
outputSchema: {},
|
||||
uiType: "agent",
|
||||
block_id: "block-1",
|
||||
costs: [],
|
||||
categories: [],
|
||||
...overrides,
|
||||
} as CustomNodeData;
|
||||
}
|
||||
|
||||
describe("getNodeDisplayTitle", () => {
|
||||
it("returns customized_name when set (tier 1)", () => {
|
||||
const data = makeNodeData({
|
||||
metadata: { customized_name: "My Custom Agent" } as any,
|
||||
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
|
||||
});
|
||||
expect(getNodeDisplayTitle(data)).toBe("My Custom Agent");
|
||||
});
|
||||
|
||||
it("returns agent_name with version when no customized_name (tier 2)", () => {
|
||||
const data = makeNodeData({
|
||||
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
|
||||
});
|
||||
expect(getNodeDisplayTitle(data)).toBe("Researcher v2");
|
||||
});
|
||||
|
||||
it("returns agent_name without version when graph_version is undefined (tier 2)", () => {
|
||||
const data = makeNodeData({
|
||||
hardcodedValues: { agent_name: "Researcher" },
|
||||
});
|
||||
expect(getNodeDisplayTitle(data)).toBe("Researcher");
|
||||
});
|
||||
|
||||
it("returns agent_name with version 0 (tier 2)", () => {
|
||||
const data = makeNodeData({
|
||||
hardcodedValues: { agent_name: "Researcher", graph_version: 0 },
|
||||
});
|
||||
expect(getNodeDisplayTitle(data)).toBe("Researcher v0");
|
||||
});
|
||||
|
||||
it("returns generic block title when no custom or agent name (tier 3)", () => {
|
||||
const data = makeNodeData({ title: "AgentExecutorBlock" });
|
||||
expect(getNodeDisplayTitle(data)).toBe("AgentExecutorBlock");
|
||||
});
|
||||
|
||||
it("prioritizes customized_name over agent_name", () => {
|
||||
const data = makeNodeData({
|
||||
metadata: { customized_name: "Renamed" } as any,
|
||||
hardcodedValues: { agent_name: "Original Agent", graph_version: 1 },
|
||||
});
|
||||
expect(getNodeDisplayTitle(data)).toBe("Renamed");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatNodeDisplayTitle", () => {
|
||||
it("returns custom name as-is without beautifying", () => {
|
||||
const data = makeNodeData({
|
||||
metadata: { customized_name: "my_custom_name" } as any,
|
||||
});
|
||||
expect(formatNodeDisplayTitle(data)).toBe("my_custom_name");
|
||||
});
|
||||
|
||||
it("returns agent name as-is without beautifying", () => {
|
||||
const data = makeNodeData({
|
||||
hardcodedValues: { agent_name: "Blockchain Agent", graph_version: 1 },
|
||||
});
|
||||
expect(formatNodeDisplayTitle(data)).toBe("Blockchain Agent v1");
|
||||
});
|
||||
|
||||
it("beautifies generic block title and strips Block suffix", () => {
|
||||
const data = makeNodeData({ title: "AgentExecutorBlock" });
|
||||
const result = formatNodeDisplayTitle(data);
|
||||
expect(result).not.toContain("Block");
|
||||
expect(result).toBe("Agent Executor");
|
||||
});
|
||||
|
||||
it("does not corrupt agent names containing 'Block'", () => {
|
||||
const data = makeNodeData({
|
||||
hardcodedValues: { agent_name: "Blockchain Agent", graph_version: 2 },
|
||||
});
|
||||
expect(formatNodeDisplayTitle(data)).toBe("Blockchain Agent v2");
|
||||
});
|
||||
});
|
||||
@@ -6,10 +6,9 @@ import {
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useEffect, useState } from "react";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import { useState } from "react";
|
||||
import { CustomNodeData } from "../CustomNode";
|
||||
import { formatNodeDisplayTitle, getNodeDisplayTitle } from "../helpers";
|
||||
import { NodeBadges } from "./NodeBadges";
|
||||
import { NodeContextMenu } from "./NodeContextMenu";
|
||||
import { NodeCost } from "./NodeCost";
|
||||
@@ -22,24 +21,15 @@ type Props = {
|
||||
export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
|
||||
const title = getNodeDisplayTitle(data);
|
||||
const displayTitle = formatNodeDisplayTitle(data);
|
||||
const title = (data.metadata?.customized_name as string) || data.title;
|
||||
|
||||
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
||||
const [editedTitle, setEditedTitle] = useState(title);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isEditingTitle) {
|
||||
setEditedTitle(title);
|
||||
}
|
||||
}, [title, isEditingTitle]);
|
||||
|
||||
const handleTitleEdit = () => {
|
||||
if (editedTitle !== title) {
|
||||
updateNodeData(nodeId, {
|
||||
metadata: { ...data.metadata, customized_name: editedTitle },
|
||||
});
|
||||
}
|
||||
updateNodeData(nodeId, {
|
||||
metadata: { ...data.metadata, customized_name: editedTitle },
|
||||
});
|
||||
setIsEditingTitle(false);
|
||||
};
|
||||
|
||||
@@ -82,12 +72,12 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||
variant="large-semibold"
|
||||
className="line-clamp-1 hover:cursor-text"
|
||||
>
|
||||
{displayTitle}
|
||||
{beautifyString(title).replace("Block", "").trim()}
|
||||
</Text>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p>{displayTitle}</p>
|
||||
<p>{beautifyString(title).replace("Block", "").trim()}</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { render, screen, fireEvent } from "@/tests/integrations/test-utils";
|
||||
import { NodeHeader } from "../NodeHeader";
|
||||
import { CustomNodeData } from "../../CustomNode";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
|
||||
vi.mock("../NodeCost", () => ({
|
||||
NodeCost: () => <div data-testid="node-cost" />,
|
||||
}));
|
||||
|
||||
vi.mock("../NodeContextMenu", () => ({
|
||||
NodeContextMenu: () => <div data-testid="node-context-menu" />,
|
||||
}));
|
||||
|
||||
vi.mock("../NodeBadges", () => ({
|
||||
NodeBadges: () => <div data-testid="node-badges" />,
|
||||
}));
|
||||
|
||||
function makeData(overrides: Partial<CustomNodeData> = {}): CustomNodeData {
|
||||
return {
|
||||
title: "AgentExecutorBlock",
|
||||
description: "",
|
||||
hardcodedValues: {},
|
||||
inputSchema: {},
|
||||
outputSchema: {},
|
||||
uiType: "agent",
|
||||
block_id: "block-1",
|
||||
costs: [],
|
||||
categories: [],
|
||||
...overrides,
|
||||
} as CustomNodeData;
|
||||
}
|
||||
|
||||
describe("NodeHeader", () => {
|
||||
const mockUpdateNodeData = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
useNodeStore.setState({ updateNodeData: mockUpdateNodeData } as any);
|
||||
});
|
||||
|
||||
it("renders beautified generic block title", () => {
|
||||
render(<NodeHeader data={makeData()} nodeId="abc-123" />);
|
||||
expect(screen.getByText("Agent Executor")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("renders agent name with version from hardcodedValues", () => {
|
||||
const data = makeData({
|
||||
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
|
||||
});
|
||||
render(<NodeHeader data={data} nodeId="abc-123" />);
|
||||
expect(screen.getByText("Researcher v2")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("renders customized_name over agent name", () => {
|
||||
const data = makeData({
|
||||
metadata: { customized_name: "My Custom Node" } as any,
|
||||
hardcodedValues: { agent_name: "Researcher", graph_version: 1 },
|
||||
});
|
||||
render(<NodeHeader data={data} nodeId="abc-123" />);
|
||||
expect(screen.getByText("My Custom Node")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("shows node ID prefix", () => {
|
||||
render(<NodeHeader data={makeData()} nodeId="abc-123" />);
|
||||
expect(screen.getByText("#abc")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("enters edit mode on double-click and saves on blur", () => {
|
||||
render(<NodeHeader data={makeData()} nodeId="node-1" />);
|
||||
const titleEl = screen.getByText("Agent Executor");
|
||||
fireEvent.doubleClick(titleEl);
|
||||
|
||||
const input = screen.getByDisplayValue("AgentExecutorBlock");
|
||||
fireEvent.change(input, { target: { value: "New Name" } });
|
||||
fireEvent.blur(input);
|
||||
|
||||
expect(mockUpdateNodeData).toHaveBeenCalledWith("node-1", {
|
||||
metadata: { customized_name: "New Name" },
|
||||
});
|
||||
});
|
||||
|
||||
it("does not save when title is unchanged on blur", () => {
|
||||
const data = makeData({
|
||||
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
|
||||
});
|
||||
render(<NodeHeader data={data} nodeId="node-1" />);
|
||||
const titleEl = screen.getByText("Researcher v2");
|
||||
fireEvent.doubleClick(titleEl);
|
||||
|
||||
const input = screen.getByDisplayValue("Researcher v2");
|
||||
fireEvent.blur(input);
|
||||
|
||||
expect(mockUpdateNodeData).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("saves on Enter key", () => {
|
||||
render(<NodeHeader data={makeData()} nodeId="node-1" />);
|
||||
fireEvent.doubleClick(screen.getByText("Agent Executor"));
|
||||
|
||||
const input = screen.getByDisplayValue("AgentExecutorBlock");
|
||||
fireEvent.change(input, { target: { value: "Renamed" } });
|
||||
fireEvent.keyDown(input, { key: "Enter" });
|
||||
|
||||
expect(mockUpdateNodeData).toHaveBeenCalledWith("node-1", {
|
||||
metadata: { customized_name: "Renamed" },
|
||||
});
|
||||
});
|
||||
|
||||
it("cancels edit on Escape key", () => {
|
||||
render(<NodeHeader data={makeData()} nodeId="node-1" />);
|
||||
fireEvent.doubleClick(screen.getByText("Agent Executor"));
|
||||
|
||||
const input = screen.getByDisplayValue("AgentExecutorBlock");
|
||||
fireEvent.change(input, { target: { value: "Changed" } });
|
||||
fireEvent.keyDown(input, { key: "Escape" });
|
||||
|
||||
expect(mockUpdateNodeData).not.toHaveBeenCalled();
|
||||
expect(screen.getByText("Agent Executor")).toBeTruthy();
|
||||
});
|
||||
});
|
||||
@@ -1,55 +1,6 @@
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { NodeResolutionData } from "@/app/(platform)/build/stores/types";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { CustomNodeData } from "./CustomNode";
|
||||
|
||||
/**
|
||||
* Resolves the display title for a node using a 3-tier fallback:
|
||||
*
|
||||
* 1. `customized_name` — the user's manual rename (highest priority)
|
||||
* 2. `agent_name` (+ version) from `hardcodedValues` — the selected agent's
|
||||
* display name, persisted by blocks like AgentExecutorBlock
|
||||
* 3. `data.title` — the generic block name (e.g. "Agent Executor")
|
||||
*
|
||||
* `customized_name` is the user's explicit rename via double-click; it lives in
|
||||
* node metadata. `agent_name` is the programmatic name of the agent graph
|
||||
* selected in the block's input form; it lives in `hardcodedValues` alongside
|
||||
* `graph_version`. These are distinct sources of truth — customized_name always
|
||||
* wins because it reflects deliberate user intent.
|
||||
*/
|
||||
export function getNodeDisplayTitle(data: CustomNodeData): string {
|
||||
if (data.metadata?.customized_name) {
|
||||
return data.metadata.customized_name as string;
|
||||
}
|
||||
|
||||
const agentName = data.hardcodedValues?.agent_name as string | undefined;
|
||||
const graphVersion = data.hardcodedValues?.graph_version as
|
||||
| number
|
||||
| undefined;
|
||||
if (agentName) {
|
||||
return graphVersion != null ? `${agentName} v${graphVersion}` : agentName;
|
||||
}
|
||||
|
||||
return data.title;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the formatted display title for rendering.
|
||||
* Agent names and custom names are shown as-is; generic block names get
|
||||
* beautified and have the trailing " Block" suffix stripped.
|
||||
*/
|
||||
export function formatNodeDisplayTitle(data: CustomNodeData): string {
|
||||
const title = getNodeDisplayTitle(data);
|
||||
const isAgentOrCustom = !!(
|
||||
data.metadata?.customized_name || data.hardcodedValues?.agent_name
|
||||
);
|
||||
return isAgentOrCustom
|
||||
? title
|
||||
: beautifyString(title)
|
||||
.replace(/ Block$/, "")
|
||||
.trim();
|
||||
}
|
||||
|
||||
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
|
||||
INCOMPLETE: "ring-slate-300 bg-slate-300",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { formatNodeDisplayTitle } from "@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/helpers";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
@@ -59,7 +58,9 @@ export function GraphSearchContent({
|
||||
filteredNodes.map((node, index) => {
|
||||
if (!node?.data) return null;
|
||||
|
||||
const nodeTitle = formatNodeDisplayTitle(node.data);
|
||||
const nodeTitle =
|
||||
(node.data.metadata?.customized_name as string) ||
|
||||
beautifyString(node.data.title || "").replace(/ Block$/, "");
|
||||
const nodeType = beautifyString(node.data.title || "").replace(
|
||||
/ Block$/,
|
||||
"",
|
||||
@@ -69,10 +70,7 @@ export function GraphSearchContent({
|
||||
node.data.description ||
|
||||
"";
|
||||
|
||||
const hasCustomName = !!(
|
||||
node.data.metadata?.customized_name ||
|
||||
node.data.hardcodedValues?.agent_name
|
||||
);
|
||||
const hasCustomName = !!node.data.metadata?.customized_name;
|
||||
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -69,9 +69,6 @@ function calculateNodeScore(
|
||||
const customizedName = String(
|
||||
node.data?.metadata?.customized_name || "",
|
||||
).toLowerCase();
|
||||
const agentName = String(
|
||||
node.data?.hardcodedValues?.agent_name || "",
|
||||
).toLowerCase();
|
||||
|
||||
// Get input and output names with defensive checks
|
||||
const inputNames = Object.keys(node.data?.inputSchema?.properties || {}).map(
|
||||
@@ -84,7 +81,6 @@ function calculateNodeScore(
|
||||
// 1. Check exact match in customized name, title (includes ID), node ID, or block type (highest priority)
|
||||
if (
|
||||
customizedName.includes(query) ||
|
||||
agentName.includes(query) ||
|
||||
nodeTitle.includes(query) ||
|
||||
nodeID.includes(query) ||
|
||||
blockType.includes(query) ||
|
||||
@@ -99,7 +95,6 @@ function calculateNodeScore(
|
||||
queryWords.every(
|
||||
(word) =>
|
||||
customizedName.includes(word) ||
|
||||
agentName.includes(word) ||
|
||||
nodeTitle.includes(word) ||
|
||||
beautifiedBlockType.includes(word),
|
||||
)
|
||||
|
||||
@@ -93,6 +93,7 @@ export function CopilotPage() {
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
forwardPaginated,
|
||||
// Mobile drawer
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
@@ -217,6 +218,7 @@ export function CopilotPage() {
|
||||
hasMoreMessages={hasMoreMessages}
|
||||
isLoadingMore={isLoadingMore}
|
||||
onLoadMore={loadMore}
|
||||
forwardPaginated={forwardPaginated}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
historicalDurations={historicalDurations}
|
||||
|
||||
@@ -48,40 +48,75 @@ function makeQueryResult(data: object | null) {
|
||||
};
|
||||
}
|
||||
|
||||
describe("useChatSession — pagination metadata", () => {
|
||||
describe("useChatSession — newestSequence and forwardPaginated", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("returns null for oldestSequence when no session data", () => {
|
||||
it("returns null / false when no session data", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(makeQueryResult(null));
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.oldestSequence).toBeNull();
|
||||
expect(result.current.newestSequence).toBeNull();
|
||||
expect(result.current.forwardPaginated).toBe(false);
|
||||
});
|
||||
|
||||
it("returns oldestSequence from session data", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 50,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.oldestSequence).toBe(50);
|
||||
});
|
||||
|
||||
it("returns hasMoreMessages from session data", () => {
|
||||
it("returns newestSequence from session data", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: 99,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.hasMoreMessages).toBe(true);
|
||||
expect(result.current.newestSequence).toBe(99);
|
||||
});
|
||||
|
||||
it("returns null for newestSequence when field is missing", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: false,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: null,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.newestSequence).toBeNull();
|
||||
});
|
||||
|
||||
it("returns forwardPaginated=true when session is forward-paginated", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: 49,
|
||||
forward_paginated: true,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.forwardPaginated).toBe(true);
|
||||
});
|
||||
|
||||
it("returns forwardPaginated=false when session is backward-paginated", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 50,
|
||||
newest_sequence: 99,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.forwardPaginated).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useCopilotPage } from "../useCopilotPage";
|
||||
|
||||
const mockUseChatSession = vi.fn();
|
||||
const mockUseCopilotStream = vi.fn();
|
||||
const mockUseLoadMoreMessages = vi.fn();
|
||||
|
||||
vi.mock("../useChatSession", () => ({
|
||||
useChatSession: (...args: unknown[]) => mockUseChatSession(...args),
|
||||
}));
|
||||
vi.mock("../useCopilotStream", () => ({
|
||||
useCopilotStream: (...args: unknown[]) => mockUseCopilotStream(...args),
|
||||
}));
|
||||
vi.mock("../useLoadMoreMessages", () => ({
|
||||
useLoadMoreMessages: (...args: unknown[]) => mockUseLoadMoreMessages(...args),
|
||||
}));
|
||||
vi.mock("../useCopilotNotifications", () => ({
|
||||
useCopilotNotifications: () => undefined,
|
||||
}));
|
||||
vi.mock("../useWorkflowImportAutoSubmit", () => ({
|
||||
useWorkflowImportAutoSubmit: () => undefined,
|
||||
}));
|
||||
vi.mock("../store", () => ({
|
||||
useCopilotUIStore: () => ({
|
||||
sessionToDelete: null,
|
||||
setSessionToDelete: vi.fn(),
|
||||
isDrawerOpen: false,
|
||||
setDrawerOpen: vi.fn(),
|
||||
copilotChatMode: "chat",
|
||||
copilotLlmModel: null,
|
||||
isDryRun: false,
|
||||
}),
|
||||
}));
|
||||
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
|
||||
concatWithAssistantMerge: (a: unknown[], b: unknown[]) => [...a, ...b],
|
||||
}));
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
useDeleteV2DeleteSession: () => ({ mutate: vi.fn(), isPending: false }),
|
||||
useGetV2ListSessions: () => ({ data: undefined, isLoading: false }),
|
||||
getGetV2ListSessionsQueryKey: () => ["sessions"],
|
||||
}));
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
toast: vi.fn(),
|
||||
}));
|
||||
vi.mock("@/lib/direct-upload", () => ({
|
||||
uploadFileDirect: vi.fn(),
|
||||
}));
|
||||
vi.mock("@/lib/hooks/useBreakpoint", () => ({
|
||||
useBreakpoint: () => "lg",
|
||||
}));
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: () => ({ isUserLoading: false, isLoggedIn: true }),
|
||||
}));
|
||||
vi.mock("@tanstack/react-query", () => ({
|
||||
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
|
||||
}));
|
||||
vi.mock("@/services/feature-flags/use-get-flag", () => ({
|
||||
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
|
||||
useGetFlag: () => false,
|
||||
}));
|
||||
|
||||
function makeBaseChatSession(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
sessionId: "sess-1",
|
||||
setSessionId: vi.fn(),
|
||||
hydratedMessages: [],
|
||||
rawSessionMessages: [],
|
||||
historicalDurations: new Map(),
|
||||
hasActiveStream: false,
|
||||
hasMoreMessages: false,
|
||||
oldestSequence: null,
|
||||
isLoadingSession: false,
|
||||
isSessionError: false,
|
||||
createSession: vi.fn(),
|
||||
isCreatingSession: false,
|
||||
refetchSession: vi.fn(),
|
||||
sessionDryRun: false,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeBaseCopilotStream(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
messages: [],
|
||||
sendMessage: vi.fn(),
|
||||
stop: vi.fn(),
|
||||
status: "ready",
|
||||
error: undefined,
|
||||
isReconnecting: false,
|
||||
isSyncing: false,
|
||||
isUserStoppingRef: { current: false },
|
||||
rateLimitMessage: null,
|
||||
dismissRateLimit: vi.fn(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeBaseLoadMore(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
pagedMessages: [],
|
||||
hasMore: false,
|
||||
isLoadingMore: false,
|
||||
loadMore: vi.fn(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("useCopilotPage — backward pagination message ordering", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("prepends pagedMessages before currentMessages", () => {
|
||||
const pagedMsg = { id: "paged", role: "user" };
|
||||
const currentMsg = { id: "current", role: "assistant" };
|
||||
mockUseChatSession.mockReturnValue(makeBaseChatSession());
|
||||
mockUseCopilotStream.mockReturnValue(
|
||||
makeBaseCopilotStream({ messages: [currentMsg] }),
|
||||
);
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useCopilotPage());
|
||||
|
||||
// Backward: pagedMessages (older) come first
|
||||
expect(result.current.messages[0]).toEqual(pagedMsg);
|
||||
expect(result.current.messages[1]).toEqual(currentMsg);
|
||||
});
|
||||
});
|
||||
@@ -15,8 +15,10 @@ vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
|
||||
|
||||
const BASE_ARGS = {
|
||||
sessionId: "sess-1",
|
||||
initialOldestSequence: 50,
|
||||
initialOldestSequence: 0,
|
||||
initialNewestSequence: 49,
|
||||
initialHasMore: true,
|
||||
forwardPaginated: true,
|
||||
initialPageRawMessages: [],
|
||||
};
|
||||
|
||||
@@ -24,6 +26,7 @@ function makeSuccessResponse(overrides: {
|
||||
messages?: unknown[];
|
||||
has_more_messages?: boolean;
|
||||
oldest_sequence?: number;
|
||||
newest_sequence?: number;
|
||||
}) {
|
||||
return {
|
||||
status: 200,
|
||||
@@ -31,6 +34,7 @@ function makeSuccessResponse(overrides: {
|
||||
messages: overrides.messages ?? [],
|
||||
has_more_messages: overrides.has_more_messages ?? false,
|
||||
oldest_sequence: overrides.oldest_sequence ?? 0,
|
||||
newest_sequence: overrides.newest_sequence ?? 49,
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -47,6 +51,30 @@ describe("useLoadMoreMessages", () => {
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("resetPaged clears paged state and sets hasMore=false during transition", () => {
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
// hasMore must be false during transition to prevent forward loadMore
|
||||
// from firing on the now-active session before forwardPaginated updates.
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("resetPaged exposes a fresh loadMore via incremented epoch", () => {
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
// Just verify resetPaged is callable and doesn't throw.
|
||||
expect(() => {
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
it("resets all state on sessionId change", () => {
|
||||
const { result, rerender } = renderHook(
|
||||
(props) => useLoadMoreMessages(props),
|
||||
@@ -57,6 +85,7 @@ describe("useLoadMoreMessages", () => {
|
||||
...BASE_ARGS,
|
||||
sessionId: "sess-2",
|
||||
initialOldestSequence: 10,
|
||||
initialNewestSequence: 59,
|
||||
initialHasMore: false,
|
||||
});
|
||||
|
||||
@@ -65,6 +94,66 @@ describe("useLoadMoreMessages", () => {
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
describe("loadMore — forward pagination", () => {
|
||||
it("calls getV2GetSession with after_sequence and updates newestSequence", async () => {
|
||||
const rawMsg = { role: "user", content: "hi", sequence: 50 };
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [rawMsg],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 49 }),
|
||||
);
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("sets hasMore=false when response has no more messages", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false }),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("is a no-op when hasMore is false", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
initialHasMore: false,
|
||||
forwardPaginated: true,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — backward pagination", () => {
|
||||
it("calls getV2GetSession with before_sequence", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
@@ -75,7 +164,13 @@ describe("useLoadMoreMessages", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: 50,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
@@ -87,30 +182,6 @@ describe("useLoadMoreMessages", () => {
|
||||
);
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("is a no-op when hasMore is false", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, initialHasMore: false }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("is a no-op when oldestSequence is null", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, initialOldestSequence: null }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — error handling", () => {
|
||||
@@ -123,6 +194,7 @@ describe("useLoadMoreMessages", () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// First error — hasMore still true
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
@@ -136,6 +208,7 @@ describe("useLoadMoreMessages", () => {
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
// Reset the in-flight guard between calls
|
||||
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
|
||||
}
|
||||
|
||||
@@ -151,18 +224,122 @@ describe("useLoadMoreMessages", () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// One error, not yet at threshold — hasMore still true
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
|
||||
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
|
||||
describe("loadMore — forward pagination cursor advancement", () => {
|
||||
it("advances newestSequence after a successful forward load", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 2001 }, (_, i) => ({
|
||||
messages: [{ role: "user", content: "hi", sequence: 50 }],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// A second loadMore should use after_sequence: 99 (advanced cursor)
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 99 }),
|
||||
);
|
||||
});
|
||||
|
||||
it("does not regress newestSequence when parent refetches after pages loaded", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [{ role: "user", content: "msg", sequence: 50 }],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
(props) => useLoadMoreMessages(props),
|
||||
{ initialProps: { ...BASE_ARGS, forwardPaginated: true } },
|
||||
);
|
||||
|
||||
// Load one page — newestSequence advances to 99
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Parent refetches with a lower newest_sequence (49) — should NOT regress cursor
|
||||
rerender({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: 49,
|
||||
});
|
||||
|
||||
// Next loadMore should still use the advanced cursor (99)
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 99 }),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
|
||||
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
|
||||
// Simulate being near the limit — 1990 existing paged messages
|
||||
const nearLimitArgs = {
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: 1990,
|
||||
};
|
||||
|
||||
// Return 20 messages to push total past 2000
|
||||
const newMessages = Array.from({ length: 20 }, (_, i) => ({
|
||||
role: "user",
|
||||
content: `msg ${i}`,
|
||||
sequence: i,
|
||||
}));
|
||||
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: newMessages,
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook((props) => useLoadMoreMessages(props), {
|
||||
initialProps: nearLimitArgs,
|
||||
});
|
||||
|
||||
// Pre-fill pagedRawMessages to near limit by doing a successful load first
|
||||
// then checking hasMore is set to false when limit reached
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 1990 }, (_, i) => ({
|
||||
role: "user",
|
||||
content: `msg ${i}`,
|
||||
content: `old ${i}`,
|
||||
sequence: i,
|
||||
})),
|
||||
has_more_messages: true,
|
||||
@@ -170,18 +347,109 @@ describe("useLoadMoreMessages", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Now add 20 more to exceed 2000 — hasMore should be forced false
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("forward truncation keeps first MAX_OLDER_MESSAGES items (not last)", async () => {
|
||||
// 1990 messages already paged; load 20 more forward — total 2010 > 2000.
|
||||
// Forward truncation must keep slice(0, 2000), not slice(-2000),
|
||||
// to preserve the beginning of the conversation.
|
||||
const forwardNearLimitArgs = {
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: 49,
|
||||
initialOldestSequence: 0,
|
||||
initialHasMore: true,
|
||||
};
|
||||
|
||||
const { result } = renderHook((props) => useLoadMoreMessages(props), {
|
||||
initialProps: forwardNearLimitArgs,
|
||||
});
|
||||
|
||||
// First load: 1990 messages — advances newestSequence to 2039
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 1990 }, (_, i) => ({
|
||||
role: "assistant",
|
||||
content: `msg ${i + 50}`,
|
||||
sequence: i + 50,
|
||||
})),
|
||||
has_more_messages: true,
|
||||
newest_sequence: 2039,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Second load: 20 more messages pushes total to 2010 > 2000 — hasMore=false
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 20 }, (_, i) => ({
|
||||
role: "assistant",
|
||||
content: `msg ${i + 2040}`,
|
||||
sequence: i + 2040,
|
||||
})),
|
||||
has_more_messages: false,
|
||||
newest_sequence: 2059,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// After truncation, hasMore is forced false (total ≥ MAX_OLDER_MESSAGES).
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — null cursor guard", () => {
|
||||
it("is a no-op when newestSequence is null (forwardPaginated=true)", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: null,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("is a no-op when oldestSequence is null (forwardPaginated=false)", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: null,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("pagedMessages — initialPageRawMessages extraToolOutputs", () => {
|
||||
it("calls extractToolOutputsFromRaw with non-empty initialPageRawMessages", async () => {
|
||||
it("calls extractToolOutputsFromRaw for backward pagination with non-empty initialPageRawMessages", async () => {
|
||||
const { extractToolOutputsFromRaw } = await import(
|
||||
"../helpers/convertChatSessionToUiMessages"
|
||||
);
|
||||
@@ -198,6 +466,8 @@ describe("useLoadMoreMessages", () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: 50,
|
||||
initialPageRawMessages: [{ role: "assistant", content: "response" }],
|
||||
}),
|
||||
);
|
||||
@@ -208,5 +478,68 @@ describe("useLoadMoreMessages", () => {
|
||||
|
||||
expect(extractToolOutputsFromRaw).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("does NOT call extractToolOutputsFromRaw for forward pagination", async () => {
|
||||
const { extractToolOutputsFromRaw } = await import(
|
||||
"../helpers/convertChatSessionToUiMessages"
|
||||
);
|
||||
|
||||
const rawMsg = { role: "assistant", content: "hi", sequence: 50 };
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [rawMsg],
|
||||
has_more_messages: false,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialPageRawMessages: [{ role: "user", content: "hello" }],
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(extractToolOutputsFromRaw).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — epoch / stale-response guard", () => {
|
||||
it("discards response when epoch changes during flight (resetPaged called)", async () => {
|
||||
let resolveRequest!: (v: unknown) => void;
|
||||
mockGetV2GetSession.mockReturnValueOnce(
|
||||
new Promise((res) => {
|
||||
resolveRequest = res;
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
// Start the request without awaiting
|
||||
act(() => {
|
||||
result.current.loadMore();
|
||||
});
|
||||
|
||||
// Reset epoch mid-flight
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
|
||||
// Now resolve the in-flight request
|
||||
await act(async () => {
|
||||
resolveRequest(
|
||||
makeSuccessResponse({ messages: [{ role: "user", content: "hi" }] }),
|
||||
);
|
||||
});
|
||||
|
||||
// Response discarded — pagedMessages stays empty, isLoadingMore stays false
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,11 +6,9 @@ import { Suspense, useState } from "react";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
import { ArtifactErrorBoundary } from "./ArtifactErrorBoundary";
|
||||
import { ArtifactReactPreview } from "./ArtifactReactPreview";
|
||||
import { ArtifactSkeleton } from "./ArtifactSkeleton";
|
||||
import {
|
||||
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
TAILWIND_CDN_URL,
|
||||
wrapWithHeadInjection,
|
||||
} from "@/lib/iframe-sandbox-csp";
|
||||
@@ -55,18 +53,13 @@ function ArtifactContentLoader({
|
||||
|
||||
return (
|
||||
<div ref={scrollRef} className="flex-1 overflow-y-auto">
|
||||
<ArtifactErrorBoundary
|
||||
artifactTitle={artifact.title}
|
||||
artifactType={classification.type}
|
||||
>
|
||||
<ArtifactRenderer
|
||||
artifact={artifact}
|
||||
content={content}
|
||||
pdfUrl={pdfUrl}
|
||||
isSourceView={isSourceView}
|
||||
classification={classification}
|
||||
/>
|
||||
</ArtifactErrorBoundary>
|
||||
<ArtifactRenderer
|
||||
artifact={artifact}
|
||||
content={content}
|
||||
pdfUrl={pdfUrl}
|
||||
isSourceView={isSourceView}
|
||||
classification={classification}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -207,10 +200,7 @@ function ArtifactRenderer({
|
||||
if (classification.type === "html") {
|
||||
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
|
||||
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
|
||||
const wrapped = wrapWithHeadInjection(
|
||||
content,
|
||||
tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
);
|
||||
const wrapped = wrapWithHeadInjection(content, tailwindScript);
|
||||
return (
|
||||
<iframe
|
||||
sandbox="allow-scripts"
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { Component, type ErrorInfo, type ReactNode } from "react";
|
||||
|
||||
interface Props {
|
||||
children: ReactNode;
|
||||
artifactTitle: string;
|
||||
artifactType: string;
|
||||
}
|
||||
|
||||
interface State {
|
||||
error: Error | null;
|
||||
}
|
||||
|
||||
export class ArtifactErrorBoundary extends Component<Props, State> {
|
||||
state: State = { error: null };
|
||||
|
||||
static getDerivedStateFromError(error: Error): State {
|
||||
return { error };
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, errorInfo: ErrorInfo) {
|
||||
Sentry.captureException(error, {
|
||||
contexts: {
|
||||
react: { componentStack: errorInfo.componentStack },
|
||||
},
|
||||
tags: { errorBoundary: "true", context: "copilot-artifact" },
|
||||
extra: {
|
||||
artifactTitle: this.props.artifactTitle,
|
||||
artifactType: this.props.artifactType,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
componentDidUpdate(prevProps: Props) {
|
||||
if (
|
||||
this.state.error &&
|
||||
(prevProps.artifactTitle !== this.props.artifactTitle ||
|
||||
prevProps.artifactType !== this.props.artifactType)
|
||||
) {
|
||||
this.setState({ error: null });
|
||||
}
|
||||
}
|
||||
|
||||
handleCopy = () => {
|
||||
const { error } = this.state;
|
||||
if (!error) return;
|
||||
const details = [
|
||||
`Artifact: ${this.props.artifactTitle}`,
|
||||
`Type: ${this.props.artifactType}`,
|
||||
`Error: ${error.message}`,
|
||||
error.stack ? `Stack:\n${error.stack}` : "",
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
navigator.clipboard?.writeText(details).catch(() => {});
|
||||
};
|
||||
|
||||
render() {
|
||||
const { error } = this.state;
|
||||
if (!error) return this.props.children;
|
||||
|
||||
const message = error.message || "Unknown rendering error";
|
||||
|
||||
return (
|
||||
<div
|
||||
role="alert"
|
||||
className="flex h-full flex-col items-center justify-center gap-3 p-8 text-center"
|
||||
>
|
||||
<p className="text-sm font-medium text-zinc-700">
|
||||
This artifact couldn't be rendered
|
||||
</p>
|
||||
<p className="max-w-md break-words text-xs text-zinc-500">
|
||||
Something in{" "}
|
||||
<span className="font-mono">{this.props.artifactTitle}</span> threw an
|
||||
error while rendering. The chat and sidebar are still working.
|
||||
</p>
|
||||
<pre className="max-h-32 max-w-md overflow-auto whitespace-pre-wrap break-words rounded-md bg-zinc-100 px-3 py-2 text-left text-xs text-zinc-700">
|
||||
{message}
|
||||
</pre>
|
||||
<button
|
||||
type="button"
|
||||
onClick={this.handleCopy}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
Copy error details
|
||||
</button>
|
||||
<p className="max-w-md text-xs text-zinc-400">
|
||||
Paste this into the chat so the agent can regenerate a working
|
||||
version.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -412,41 +412,6 @@ describe("ArtifactContent", () => {
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
});
|
||||
|
||||
it("injects the fragment-link interceptor into HTML artifact iframes (regression)", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () =>
|
||||
Promise.resolve(
|
||||
'<html><head></head><body><a href="#x">x</a><div id="x">x</div></body></html>',
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={makeArtifact({
|
||||
id: "html-frag",
|
||||
title: "page.html",
|
||||
mimeType: "text/html",
|
||||
})}
|
||||
isSourceView={false}
|
||||
classification={makeClassification({ type: "html" })}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTitle("page.html");
|
||||
const srcdoc = container.querySelector("iframe")?.getAttribute("srcdoc");
|
||||
expect(srcdoc).toBeTruthy();
|
||||
// Markers unique to FRAGMENT_LINK_INTERCEPTOR_SCRIPT — if any of these
|
||||
// disappear, the interceptor is no longer being injected and fragment
|
||||
// links will navigate the parent URL again.
|
||||
expect(srcdoc).toContain("__fragmentLinkInterceptor");
|
||||
expect(srcdoc).toContain('a[href^="#"]');
|
||||
expect(srcdoc).toContain("scrollIntoView");
|
||||
});
|
||||
|
||||
// ── Source view ───────────────────────────────────────────────────
|
||||
|
||||
it("renders source view as pre tag", async () => {
|
||||
@@ -958,164 +923,6 @@ describe("ArtifactContent", () => {
|
||||
},
|
||||
);
|
||||
|
||||
// ── Error boundary ────────────────────────────────────────────────
|
||||
|
||||
it("shows a visible error instead of crashing when the renderer throws", async () => {
|
||||
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
|
||||
const originalImpl = vi
|
||||
.mocked(ArtifactReactPreview)
|
||||
.getMockImplementation();
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
|
||||
throw new Error("boom in renderer");
|
||||
});
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("source"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "crash-001",
|
||||
title: "broken.tsx",
|
||||
mimeType: "text/tsx",
|
||||
});
|
||||
const classification = makeClassification({ type: "react" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(
|
||||
await screen.findByText(/This artifact couldn't be rendered/i),
|
||||
).toBeTruthy();
|
||||
expect(screen.getByText(/boom in renderer/)).toBeTruthy();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /copy error details/i }),
|
||||
).toBeTruthy();
|
||||
|
||||
if (originalImpl) {
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
|
||||
}
|
||||
consoleErr.mockRestore();
|
||||
});
|
||||
|
||||
it("copies artifact title, type, and error to the clipboard", async () => {
|
||||
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
|
||||
const writeText = vi.fn().mockResolvedValue(undefined);
|
||||
Object.defineProperty(navigator, "clipboard", {
|
||||
value: { writeText },
|
||||
writable: true,
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
const originalImpl = vi
|
||||
.mocked(ArtifactReactPreview)
|
||||
.getMockImplementation();
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
|
||||
throw new Error("jsx parse failed at line 42");
|
||||
});
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("source"),
|
||||
}),
|
||||
);
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={makeArtifact({
|
||||
id: "crash-002",
|
||||
title: "report.tsx",
|
||||
mimeType: "text/tsx",
|
||||
})}
|
||||
isSourceView={false}
|
||||
classification={makeClassification({ type: "react" })}
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.click(
|
||||
await screen.findByRole("button", { name: /copy error details/i }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(writeText).toHaveBeenCalled();
|
||||
});
|
||||
const payload = writeText.mock.calls[0]![0] as string;
|
||||
expect(payload).toContain("report.tsx");
|
||||
expect(payload).toContain("react");
|
||||
expect(payload).toContain("jsx parse failed at line 42");
|
||||
|
||||
if (originalImpl) {
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
|
||||
}
|
||||
consoleErr.mockRestore();
|
||||
});
|
||||
|
||||
it("renders the user-reported plotly HTML artifact into a sandboxed iframe", async () => {
|
||||
const html = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>AutoGPT Beta Launch Interactive Report</title>
|
||||
<script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
|
||||
<style>
|
||||
:root { --bg: #f8f9fa; --primary: #6c5ce7; }
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body { font-family: 'Segoe UI', system-ui, sans-serif; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header><h1>\u{1F4CA} AutoGPT Beta Launch Interactive Report</h1></header>
|
||||
<div class="chart-container" id="globalActivationChart"></div>
|
||||
<script>
|
||||
function showTab(tabId, groupId) {
|
||||
const group = document.getElementById(groupId);
|
||||
group.querySelectorAll('.tab-content').forEach(t => t.classList.remove('active'));
|
||||
document.getElementById(tabId).classList.add('active');
|
||||
}
|
||||
Plotly.newPlot('globalActivationChart', [{ type: 'pie', values: [1, 2] }], {});
|
||||
</script>
|
||||
</body>
|
||||
</html>`;
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(html),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "html-big-report",
|
||||
title: "report.html",
|
||||
mimeType: "text/html",
|
||||
});
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={makeClassification({ type: "html" })}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTitle("report.html");
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
expect(screen.queryByText(/couldn't be rendered/i)).toBeNull();
|
||||
});
|
||||
|
||||
it("falls back to pre tag when no renderer matches", async () => {
|
||||
const { globalRegistry } = await import(
|
||||
"@/components/contextual/OutputRenderers"
|
||||
|
||||
@@ -116,11 +116,4 @@ describe("buildReactArtifactSrcDoc", () => {
|
||||
expect(doc).toContain("/^[A-Z]/.test(name)");
|
||||
expect(doc).toContain("wrapWithProviders");
|
||||
});
|
||||
|
||||
it("injects the fragment-link interceptor so #anchor clicks stay inside the iframe (regression)", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("__fragmentLinkInterceptor");
|
||||
expect(doc).toContain('a[href^="#"]');
|
||||
expect(doc).toContain("scrollIntoView");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -19,10 +19,7 @@
|
||||
* React is loaded from unpkg with pinned version and SRI integrity hashes.
|
||||
*/
|
||||
|
||||
import {
|
||||
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
TAILWIND_CDN_URL,
|
||||
} from "@/lib/iframe-sandbox-csp";
|
||||
import { TAILWIND_CDN_URL } from "@/lib/iframe-sandbox-csp";
|
||||
|
||||
export { transpileReactArtifactSource } from "./transpileReactArtifact";
|
||||
|
||||
@@ -98,7 +95,6 @@ export function buildReactArtifactSrcDoc(
|
||||
}
|
||||
</style>
|
||||
<script src="${TAILWIND_CDN_URL}"></script>
|
||||
${FRAGMENT_LINK_INTERCEPTOR_SCRIPT}
|
||||
<script crossorigin="anonymous" src="https://unpkg.com/react@18.3.1/umd/react.production.min.js" integrity="sha384-DGyLxAyjq0f9SPpVevD6IgztCFlnMF6oW/XQGmfe+IsZ8TqEiDrcHkMLKI6fiB/Z"></script><!-- pragma: allowlist secret -->
|
||||
<script crossorigin="anonymous" src="https://unpkg.com/react-dom@18.3.1/umd/react-dom.production.min.js" integrity="sha384-gTGxhz21lVGYNMcdJOyq01Edg0jhn/c22nsx0kyqP0TxaV5WVdsSH1fSDUf5YJj1"></script><!-- pragma: allowlist secret -->
|
||||
</head>
|
||||
|
||||
@@ -30,6 +30,7 @@ export interface ChatContainerProps {
|
||||
hasMoreMessages?: boolean;
|
||||
isLoadingMore?: boolean;
|
||||
onLoadMore?: () => void;
|
||||
forwardPaginated?: boolean;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
@@ -54,6 +55,7 @@ export const ChatContainer = ({
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
onLoadMore,
|
||||
forwardPaginated,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
historicalDurations,
|
||||
@@ -108,6 +110,7 @@ export const ChatContainer = ({
|
||||
hasMoreMessages={hasMoreMessages}
|
||||
isLoadingMore={isLoadingMore}
|
||||
onLoadMore={onLoadMore}
|
||||
forwardPaginated={forwardPaginated}
|
||||
onRetry={handleRetry}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
|
||||
@@ -86,11 +86,11 @@ export function ChatInput({
|
||||
title:
|
||||
next === "advanced"
|
||||
? "Switched to Advanced model"
|
||||
: "Switched to Balanced model",
|
||||
: "Switched to Standard model",
|
||||
description:
|
||||
next === "advanced"
|
||||
? "Using the highest-capability model."
|
||||
: "Using the balanced default model.",
|
||||
: "Using the balanced standard model.",
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -162,15 +162,10 @@ describe("ChatInput mode toggle", () => {
|
||||
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
|
||||
});
|
||||
|
||||
it("hides toggle buttons when streaming", () => {
|
||||
it("hides toggle button when streaming", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} isStreaming />);
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (fast|extended thinking) mode/i),
|
||||
).toBeNull();
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (advanced|balanced|standard) model/i),
|
||||
).toBeNull();
|
||||
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
|
||||
});
|
||||
|
||||
it("shows mode toggle when hasSession is true and not streaming", () => {
|
||||
@@ -239,7 +234,7 @@ describe("ChatInput model toggle", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "advanced";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
|
||||
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
|
||||
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
|
||||
});
|
||||
|
||||
@@ -293,10 +288,10 @@ describe("ChatInput model toggle", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "advanced";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
|
||||
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to balanced model/i),
|
||||
title: expect.stringMatching(/switched to standard model/i),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Flask } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
|
||||
// This button is only rendered on NEW chats (no active session).
|
||||
// Once a session exists, it is hidden — the session's dry_run flag is
|
||||
@@ -19,31 +14,27 @@ interface Props {
|
||||
|
||||
export function DryRunToggleButton({ isDryRun, onToggle }: Props) {
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isDryRun}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
|
||||
isDryRun
|
||||
? "text-amber-900"
|
||||
: "text-neutral-500 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={isDryRun ? "Test mode active" : "Enable Test mode"}
|
||||
>
|
||||
<Flask size={14} />
|
||||
<span className="hidden sm:inline">
|
||||
{isDryRun ? "Test mode enabled" : "Enable test mode"}
|
||||
</span>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{isDryRun
|
||||
? "Test mode on — new sessions run without performing real actions (click to turn off)."
|
||||
: "Turn on test mode to try prompts without performing real actions."}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isDryRun}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isDryRun
|
||||
? "bg-amber-100 text-amber-900 hover:bg-amber-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isDryRun ? "Test mode active — click to disable" : "Enable Test mode"
|
||||
}
|
||||
title={
|
||||
isDryRun
|
||||
? "Test mode ON — new chats run agents as simulation (click to disable)"
|
||||
: "Enable Test mode — new chats will run agents as simulation"
|
||||
}
|
||||
>
|
||||
<Flask size={14} />
|
||||
{isDryRun && "Test"}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Brain, Lightning } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import type { CopilotMode } from "../../../store";
|
||||
|
||||
interface Props {
|
||||
@@ -16,42 +11,37 @@ interface Props {
|
||||
|
||||
export function ModeToggleButton({ mode, onToggle }: Props) {
|
||||
const isExtended = mode === "extended_thinking";
|
||||
|
||||
const tooltipText = isExtended
|
||||
? "Extended Thinking — deeper reasoning (click to switch to Fast)"
|
||||
: "Fast mode — quicker responses (click to switch to Thinking)";
|
||||
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isExtended}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"ml-2 inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
|
||||
isExtended ? "text-purple-900" : "text-amber-900",
|
||||
)}
|
||||
aria-label={
|
||||
isExtended
|
||||
? "Switch to Fast mode"
|
||||
: "Switch to Extended Thinking mode"
|
||||
}
|
||||
>
|
||||
{isExtended ? (
|
||||
<>
|
||||
<Brain size={14} />
|
||||
Thinking
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Lightning size={14} />
|
||||
Fast
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{tooltipText}</TooltipContent>
|
||||
</Tooltip>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isExtended}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isExtended
|
||||
? "bg-purple-100 text-purple-900 hover:bg-purple-200"
|
||||
: "bg-amber-100 text-amber-900 hover:bg-amber-200",
|
||||
)}
|
||||
aria-label={
|
||||
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
|
||||
}
|
||||
title={
|
||||
isExtended
|
||||
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
|
||||
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
|
||||
}
|
||||
>
|
||||
{isExtended ? (
|
||||
<>
|
||||
<Brain size={14} />
|
||||
Thinking
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Lightning size={14} />
|
||||
Fast
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Cpu } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import type { CopilotLlmModel } from "../../../store";
|
||||
|
||||
interface Props {
|
||||
@@ -17,33 +12,27 @@ interface Props {
|
||||
export function ModelToggleButton({ model, onToggle }: Props) {
|
||||
const isAdvanced = model === "advanced";
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isAdvanced}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
|
||||
isAdvanced
|
||||
? "text-sky-900"
|
||||
: "text-neutral-500 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isAdvanced ? "Switch to Balanced model" : "Switch to Advanced model"
|
||||
}
|
||||
>
|
||||
<Cpu size={14} />
|
||||
<span className="hidden sm:inline">
|
||||
{isAdvanced ? "Advanced" : "Balanced"}
|
||||
</span>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{isAdvanced
|
||||
? "Using the highest-capability model (click to switch to Balanced)."
|
||||
: "Using the balanced default model (click to switch to Advanced)."}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isAdvanced}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isAdvanced
|
||||
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
|
||||
}
|
||||
title={
|
||||
isAdvanced
|
||||
? "Advanced model — highest capability (click to switch to Standard)"
|
||||
: "Standard model — click to switch to Advanced"
|
||||
}
|
||||
>
|
||||
<Cpu size={14} />
|
||||
{isAdvanced && "Advanced"}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,32 +1,21 @@
|
||||
import {
|
||||
render as rtlRender,
|
||||
screen,
|
||||
fireEvent,
|
||||
cleanup,
|
||||
} from "@testing-library/react";
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import type { ReactElement } from "react";
|
||||
import { TooltipProvider } from "@/components/ui/tooltip";
|
||||
import { DryRunToggleButton } from "../DryRunToggleButton";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
function render(ui: ReactElement) {
|
||||
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
|
||||
}
|
||||
|
||||
// DryRunToggleButton only appears on new chats (no active session).
|
||||
// It has no readOnly/isStreaming props — those scenarios are handled by hiding
|
||||
// the button entirely at the ChatInput level when hasSession is true.
|
||||
describe("DryRunToggleButton", () => {
|
||||
it("shows enabled label when isDryRun is true", () => {
|
||||
it("shows Test label when isDryRun is true", () => {
|
||||
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
|
||||
expect(screen.getByText("Test mode enabled")).toBeTruthy();
|
||||
expect(screen.getByText("Test")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("shows enable label when isDryRun is false", () => {
|
||||
it("shows no text label when isDryRun is false", () => {
|
||||
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
|
||||
expect(screen.getByText("Enable test mode")).toBeTruthy();
|
||||
expect(screen.queryByText("Test")).toBeNull();
|
||||
});
|
||||
|
||||
it("calls onToggle when clicked", () => {
|
||||
|
||||
@@ -1,20 +1,9 @@
|
||||
import {
|
||||
render as rtlRender,
|
||||
screen,
|
||||
fireEvent,
|
||||
cleanup,
|
||||
} from "@testing-library/react";
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import type { ReactElement } from "react";
|
||||
import { TooltipProvider } from "@/components/ui/tooltip";
|
||||
import { ModelToggleButton } from "../ModelToggleButton";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
function render(ui: ReactElement) {
|
||||
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
|
||||
}
|
||||
|
||||
describe("ModelToggleButton", () => {
|
||||
it("shows no text label when model is standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
@@ -42,7 +31,7 @@ describe("ModelToggleButton", () => {
|
||||
|
||||
it("sets aria-pressed=true for advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
|
||||
const btn = screen.getByLabelText("Switch to Balanced model");
|
||||
const btn = screen.getByLabelText("Switch to Standard model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -43,6 +43,10 @@ interface Props {
|
||||
hasMoreMessages?: boolean;
|
||||
isLoadingMore?: boolean;
|
||||
onLoadMore?: () => void;
|
||||
/** When true the load-more sentinel is placed at the bottom (forward
|
||||
* pagination for completed sessions). When false it is at the top
|
||||
* (backward pagination for active sessions). */
|
||||
forwardPaginated?: boolean;
|
||||
onRetry?: () => void;
|
||||
historicalDurations?: Map<string, number>;
|
||||
}
|
||||
@@ -136,11 +140,25 @@ export function LoadMoreSentinel({
|
||||
isLoading,
|
||||
messageCount,
|
||||
onLoadMore,
|
||||
rootMargin = "200px 0px 0px 0px",
|
||||
adjustScroll = true,
|
||||
forwardPaginated = false,
|
||||
}: {
|
||||
hasMore: boolean;
|
||||
isLoading: boolean;
|
||||
messageCount: number;
|
||||
onLoadMore: () => void;
|
||||
/** IntersectionObserver rootMargin. Top sentinel uses "200px 0px 0px 0px"
|
||||
* (pre-trigger when approaching from above); bottom sentinel should use
|
||||
* "0px 0px 200px 0px" (pre-trigger when approaching from below). */
|
||||
rootMargin?: string;
|
||||
/** Whether to adjust scrollTop after load to preserve visual position.
|
||||
* True for backward pagination (prepend above); false for forward
|
||||
* pagination (append below) where no adjustment is needed. */
|
||||
adjustScroll?: boolean;
|
||||
/** When true the button reads "Load newer messages" (forward pagination).
|
||||
* When false (default) it reads "Load older messages". */
|
||||
forwardPaginated?: boolean;
|
||||
}) {
|
||||
const sentinelRef = useRef<HTMLDivElement>(null);
|
||||
const onLoadMoreRef = useRef(onLoadMore);
|
||||
@@ -185,11 +203,11 @@ export function LoadMoreSentinel({
|
||||
if (autoFillRoundsRef.current >= MAX_AUTO_FILL_ROUNDS) return;
|
||||
captureAndLoad(true);
|
||||
},
|
||||
{ rootMargin: "200px 0px 0px 0px" },
|
||||
{ rootMargin },
|
||||
);
|
||||
observer.observe(sentinelRef.current);
|
||||
return () => observer.disconnect();
|
||||
}, [hasMore, isLoading, scrollRef]);
|
||||
}, [hasMore, isLoading, rootMargin, scrollRef]);
|
||||
|
||||
// After React commits new DOM nodes (prepended messages), adjust
|
||||
// scrollTop so the user stays at the same visual position.
|
||||
@@ -202,7 +220,9 @@ export function LoadMoreSentinel({
|
||||
scrollSnapshotRef.current;
|
||||
if (!el || prevHeight === 0) return;
|
||||
const delta = el.scrollHeight - prevHeight;
|
||||
if (delta > 0) {
|
||||
// Only restore scroll position for backward pagination (content prepended
|
||||
// above). Forward pagination appends below — no adjustment needed.
|
||||
if (adjustScroll && delta > 0) {
|
||||
el.scrollTop = prevTop + delta;
|
||||
}
|
||||
// Reset the auto-fill backoff whenever the container becomes
|
||||
@@ -216,7 +236,7 @@ export function LoadMoreSentinel({
|
||||
}
|
||||
scrollSnapshotRef.current = { scrollHeight: 0, scrollTop: 0 };
|
||||
autoTriggeredRef.current = false;
|
||||
}, [messageCount, scrollRef]);
|
||||
}, [adjustScroll, messageCount, scrollRef]);
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -235,7 +255,7 @@ export function LoadMoreSentinel({
|
||||
size="small"
|
||||
onClick={() => captureAndLoad(false)}
|
||||
>
|
||||
Load older messages
|
||||
{forwardPaginated ? "Load newer messages" : "Load older messages"}
|
||||
</Button>
|
||||
)
|
||||
)}
|
||||
@@ -252,6 +272,7 @@ export function ChatMessagesContainer({
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
onLoadMore,
|
||||
forwardPaginated,
|
||||
onRetry,
|
||||
historicalDurations,
|
||||
}: Props) {
|
||||
@@ -330,7 +351,7 @@ export function ChatMessagesContainer({
|
||||
}
|
||||
>
|
||||
<ConversationContent className="flex min-h-full flex-1 flex-col gap-6 px-3 py-6">
|
||||
{hasMoreMessages && onLoadMore && (
|
||||
{hasMoreMessages && onLoadMore && !forwardPaginated && (
|
||||
<LoadMoreSentinel
|
||||
hasMore={hasMoreMessages}
|
||||
isLoading={!!isLoadingMore}
|
||||
@@ -489,6 +510,17 @@ export function ChatMessagesContainer({
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
{hasMoreMessages && onLoadMore && forwardPaginated && (
|
||||
<LoadMoreSentinel
|
||||
hasMore={hasMoreMessages}
|
||||
isLoading={!!isLoadingMore}
|
||||
messageCount={messages.length}
|
||||
onLoadMore={onLoadMore}
|
||||
rootMargin="0px 0px 200px 0px"
|
||||
adjustScroll={false}
|
||||
forwardPaginated
|
||||
/>
|
||||
)}
|
||||
</ConversationContent>
|
||||
<ConversationScrollButton />
|
||||
</Conversation>
|
||||
|
||||
@@ -124,22 +124,48 @@ describe("ChatMessagesContainer", () => {
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("renders top sentinel for backward pagination", () => {
|
||||
it("renders top sentinel when forwardPaginated is false (backward pagination)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={false} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load older messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders top sentinel when forwardPaginated is undefined (default, backward)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load older messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders bottom sentinel when forwardPaginated is true (forward pagination)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={true} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load newer messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("hides sentinel when hasMoreMessages is false", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} hasMoreMessages={false} />);
|
||||
render(
|
||||
<ChatMessagesContainer
|
||||
{...BASE_PROPS}
|
||||
hasMoreMessages={false}
|
||||
forwardPaginated={true}
|
||||
/>,
|
||||
);
|
||||
expect(
|
||||
screen.queryByRole("button", { name: /load older messages/i }),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("hides sentinel when onLoadMore is not provided", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} onLoadMore={undefined} />);
|
||||
render(
|
||||
<ChatMessagesContainer
|
||||
{...BASE_PROPS}
|
||||
onLoadMore={undefined}
|
||||
forwardPaginated={true}
|
||||
/>,
|
||||
);
|
||||
expect(
|
||||
screen.queryByRole("button", { name: /load older messages/i }),
|
||||
).toBeNull();
|
||||
|
||||
@@ -172,6 +172,36 @@ describe("LoadMoreSentinel", () => {
|
||||
expect(mockScrollEl.scrollTop).toBe(200);
|
||||
});
|
||||
|
||||
it("does NOT adjust scroll when adjustScroll=false (forward pagination)", () => {
|
||||
mockScrollEl.scrollHeight = 100;
|
||||
mockScrollEl.scrollTop = 50;
|
||||
const onLoadMore = vi.fn();
|
||||
const { rerender } = render(
|
||||
<LoadMoreSentinel
|
||||
hasMore={true}
|
||||
isLoading={false}
|
||||
messageCount={5}
|
||||
onLoadMore={onLoadMore}
|
||||
adjustScroll={false}
|
||||
/>,
|
||||
);
|
||||
// Fire observer to capture snapshot.
|
||||
MockIntersectionObserver.lastCallback?.([{ isIntersecting: true }]);
|
||||
// Simulate DOM growing from appended newer messages (forward load-more).
|
||||
mockScrollEl.scrollHeight = 300;
|
||||
rerender(
|
||||
<LoadMoreSentinel
|
||||
hasMore={true}
|
||||
isLoading={false}
|
||||
messageCount={10}
|
||||
onLoadMore={onLoadMore}
|
||||
adjustScroll={false}
|
||||
/>,
|
||||
);
|
||||
// scrollTop should remain unchanged — no jump for forward pagination.
|
||||
expect(mockScrollEl.scrollTop).toBe(50);
|
||||
});
|
||||
|
||||
it("ignores same-frame duplicate triggers until isLoading transitions", () => {
|
||||
const onLoadMore = vi.fn();
|
||||
const { rerender } = render(
|
||||
|
||||
@@ -246,7 +246,7 @@ export function ChatSidebar() {
|
||||
</SidebarHeader>
|
||||
)}
|
||||
{!isCollapsed && (
|
||||
<SidebarHeader className="shrink-0 px-4 pb-3 pt-3 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<SidebarHeader className="shrink-0 px-4 pb-4 pt-4 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
|
||||
@@ -13,10 +13,6 @@ import {
|
||||
getSuggestionThemes,
|
||||
} from "./helpers";
|
||||
import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes";
|
||||
import { PulseChips } from "../PulseChips/PulseChips";
|
||||
import { usePulseChips } from "../PulseChips/usePulseChips";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { EditNameDialog } from "./components/EditNameDialog/EditNameDialog";
|
||||
|
||||
interface Props {
|
||||
inputLayoutId: string;
|
||||
@@ -38,8 +34,6 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
|
||||
const pulseChips = usePulseChips();
|
||||
|
||||
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
|
||||
useGetV2GetSuggestedPrompts({
|
||||
@@ -81,16 +75,11 @@ export function EmptySession({
|
||||
<div className="mx-auto max-w-[52rem]">
|
||||
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
|
||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||
<EditNameDialog currentName={greetingName} />
|
||||
</Text>
|
||||
<Text variant="h3" className="mb-8 !font-normal">
|
||||
Tell me about your work — I'll find what to automate.
|
||||
</Text>
|
||||
|
||||
{isAgentBriefingEnabled && (
|
||||
<PulseChips chips={pulseChips} onChipClick={onSend} />
|
||||
)}
|
||||
|
||||
<div className="mb-6">
|
||||
<motion.div
|
||||
layoutId={inputLayoutId}
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { PencilSimpleIcon } from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Props {
|
||||
currentName: string;
|
||||
}
|
||||
|
||||
export function EditNameDialog({ currentName }: Props) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [name, setName] = useState(currentName);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const { refreshSession } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
|
||||
function handleOpenChange(open: boolean) {
|
||||
if (open) setName(currentName);
|
||||
setIsOpen(open);
|
||||
}
|
||||
|
||||
async function handleSave() {
|
||||
const trimmed = name.trim();
|
||||
if (!trimmed) return;
|
||||
|
||||
setIsSaving(true);
|
||||
try {
|
||||
const res = await fetch("/api/auth/user", {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ full_name: trimmed }),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const body = await res.json();
|
||||
toast({
|
||||
title: "Failed to update name",
|
||||
description: body.error ?? "Unknown error",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const session = await refreshSession();
|
||||
if (session?.error) {
|
||||
toast({
|
||||
title: "Name saved, but session refresh failed",
|
||||
description: session.error,
|
||||
variant: "destructive",
|
||||
});
|
||||
setIsOpen(false);
|
||||
return;
|
||||
}
|
||||
setIsOpen(false);
|
||||
toast({ title: "Name updated" });
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit display name"
|
||||
styling={{ maxWidth: "24rem" }}
|
||||
controlled={{ isOpen, set: handleOpenChange }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
className="ml-1 inline-flex items-center text-violet-500 transition-colors hover:text-violet-700"
|
||||
>
|
||||
<PencilSimpleIcon size={16} />
|
||||
</button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col gap-4 px-1">
|
||||
<Input
|
||||
id="display-name"
|
||||
label="Display name"
|
||||
placeholder="Your name"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
handleSave();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
variant="primary"
|
||||
onClick={handleSave}
|
||||
disabled={!name.trim() || isSaving}
|
||||
loading={isSaving}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
import {
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { EditNameDialog } from "../EditNameDialog";
|
||||
|
||||
const mockToast = vi.hoisted(() => vi.fn());
|
||||
const mockRefreshSession = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
useToast: () => ({ toast: mockToast }),
|
||||
}));
|
||||
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: () => ({
|
||||
refreshSession: mockRefreshSession,
|
||||
}),
|
||||
}));
|
||||
|
||||
function mockUpdateNameSuccess() {
|
||||
server.use(
|
||||
http.put("/api/auth/user", () => {
|
||||
return HttpResponse.json({ user: { id: "u1" } });
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
function mockUpdateNameError(message = "Network error") {
|
||||
server.use(
|
||||
http.put("/api/auth/user", () => {
|
||||
return HttpResponse.json({ error: message }, { status: 400 });
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
async function openDialogAndGetInput() {
|
||||
const trigger = screen.getByRole("button");
|
||||
fireEvent.click(trigger);
|
||||
await screen.findAllByLabelText(/display name/i);
|
||||
const inputs =
|
||||
document.querySelectorAll<HTMLInputElement>("input#display-name");
|
||||
return inputs[0];
|
||||
}
|
||||
|
||||
function getSaveButton() {
|
||||
const saves = screen.getAllByRole("button", { name: /save/i });
|
||||
return saves[0] as HTMLButtonElement;
|
||||
}
|
||||
|
||||
describe("EditNameDialog", () => {
|
||||
beforeEach(() => {
|
||||
mockToast.mockReset();
|
||||
mockRefreshSession.mockReset();
|
||||
mockRefreshSession.mockResolvedValue({ user: { id: "u1" } });
|
||||
});
|
||||
|
||||
test("opens dialog with current name prefilled", async () => {
|
||||
mockUpdateNameSuccess();
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
expect(input.value).toBe("Alice");
|
||||
});
|
||||
|
||||
test("saves name via API route and closes dialog", async () => {
|
||||
mockUpdateNameSuccess();
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
fireEvent.change(input, { target: { value: "Bob" } });
|
||||
fireEvent.click(getSaveButton());
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRefreshSession).toHaveBeenCalled();
|
||||
});
|
||||
expect(mockToast).toHaveBeenCalledWith({ title: "Name updated" });
|
||||
});
|
||||
|
||||
test("shows error toast when API returns error", async () => {
|
||||
mockUpdateNameError("Network error");
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
fireEvent.change(input, { target: { value: "Bob" } });
|
||||
fireEvent.click(getSaveButton());
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: "Failed to update name",
|
||||
description: "Network error",
|
||||
variant: "destructive",
|
||||
}),
|
||||
);
|
||||
});
|
||||
expect(mockRefreshSession).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("shows warning toast when refreshSession returns an error", async () => {
|
||||
mockUpdateNameSuccess();
|
||||
mockRefreshSession.mockResolvedValue({ error: "refresh failed" });
|
||||
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
fireEvent.change(input, { target: { value: "Bob" } });
|
||||
fireEvent.click(getSaveButton());
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: "Name saved, but session refresh failed",
|
||||
description: "refresh failed",
|
||||
variant: "destructive",
|
||||
}),
|
||||
);
|
||||
});
|
||||
expect(mockToast).not.toHaveBeenCalledWith({ title: "Name updated" });
|
||||
});
|
||||
|
||||
test("disables Save button while input is empty", async () => {
|
||||
mockUpdateNameSuccess();
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
fireEvent.change(input, { target: { value: " " } });
|
||||
|
||||
expect(getSaveButton().disabled).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,93 +0,0 @@
|
||||
.glassPanel {
|
||||
position: relative;
|
||||
isolation: isolate;
|
||||
}
|
||||
|
||||
.glassPanel::before {
|
||||
content: "";
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
border-radius: inherit;
|
||||
padding: 1px;
|
||||
background: conic-gradient(
|
||||
from var(--border-angle, 0deg),
|
||||
rgba(129, 120, 228, 0.08),
|
||||
rgba(129, 120, 228, 0.28),
|
||||
rgba(168, 130, 255, 0.18),
|
||||
rgba(129, 120, 228, 0.08),
|
||||
rgba(99, 102, 241, 0.24),
|
||||
rgba(129, 120, 228, 0.08)
|
||||
);
|
||||
-webkit-mask:
|
||||
linear-gradient(#000 0 0) content-box,
|
||||
linear-gradient(#000 0 0);
|
||||
mask:
|
||||
linear-gradient(#000 0 0) content-box,
|
||||
linear-gradient(#000 0 0);
|
||||
-webkit-mask-composite: xor;
|
||||
mask-composite: exclude;
|
||||
animation: rotate-border 6s linear infinite;
|
||||
pointer-events: none;
|
||||
z-index: -1;
|
||||
}
|
||||
|
||||
@property --border-angle {
|
||||
syntax: "<angle>";
|
||||
initial-value: 0deg;
|
||||
inherits: false;
|
||||
}
|
||||
|
||||
@keyframes rotate-border {
|
||||
to {
|
||||
--border-angle: 360deg;
|
||||
}
|
||||
}
|
||||
|
||||
.chip {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
@media (hover: hover) {
|
||||
.chip {
|
||||
padding-bottom: 0.9rem;
|
||||
}
|
||||
}
|
||||
|
||||
@media (hover: none) {
|
||||
.chip {
|
||||
padding-bottom: 2.25rem;
|
||||
}
|
||||
}
|
||||
|
||||
.chipActions {
|
||||
position: absolute;
|
||||
inset-inline: 0;
|
||||
bottom: 0;
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
backdrop-filter: blur(4px);
|
||||
-webkit-backdrop-filter: blur(4px);
|
||||
}
|
||||
|
||||
@media (hover: hover) {
|
||||
.chipActions {
|
||||
opacity: 0;
|
||||
transform: translateY(100%);
|
||||
transition:
|
||||
opacity 0.2s ease-out,
|
||||
transform 0.2s ease-out;
|
||||
}
|
||||
|
||||
.chip:hover .chipActions {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.chipContent {
|
||||
transition: filter 0.2s ease-out;
|
||||
}
|
||||
|
||||
.chip:hover .chipContent {
|
||||
filter: blur(2px);
|
||||
opacity: 0.5;
|
||||
}
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
ArrowRightIcon,
|
||||
EyeIcon,
|
||||
ChatCircleDotsIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
import { StatusBadge } from "@/app/(platform)/library/components/StatusBadge/StatusBadge";
|
||||
import styles from "./PulseChips.module.css";
|
||||
import type { PulseChipData } from "./types";
|
||||
|
||||
interface Props {
|
||||
chips: PulseChipData[];
|
||||
onChipClick?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
export function PulseChips({ chips, onChipClick }: Props) {
|
||||
if (chips.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${styles.glassPanel} mx-[0.6875rem] mb-5 rounded-large p-5`}
|
||||
>
|
||||
<div className="mb-3 flex items-center gap-3">
|
||||
<Text variant="body-medium" className="text-zinc-600">
|
||||
What's happening with your agents
|
||||
</Text>
|
||||
<NextLink
|
||||
href="/library"
|
||||
className="flex items-center gap-1 text-xs text-zinc-500 hover:text-zinc-700"
|
||||
>
|
||||
View all <ArrowRightIcon size={12} />
|
||||
</NextLink>
|
||||
</div>
|
||||
<div className="flex gap-2 overflow-x-auto pb-1 scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300">
|
||||
{chips.map((chip) => (
|
||||
<PulseChip key={chip.id} chip={chip} onAsk={onChipClick} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ChipProps {
|
||||
chip: PulseChipData;
|
||||
onAsk?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
function PulseChip({ chip, onAsk }: ChipProps) {
|
||||
function handleAsk() {
|
||||
const prompt = buildChipPrompt(chip);
|
||||
onAsk?.(prompt);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${styles.chip} relative flex w-[15rem] shrink-0 flex-col items-start gap-2 rounded-medium border border-zinc-100 bg-white px-3 py-2`}
|
||||
>
|
||||
<div className={`${styles.chipContent} w-full text-left`}>
|
||||
{chip.priority === "success" ? (
|
||||
<span className="inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium text-emerald-600">
|
||||
<span className="h-1.5 w-1.5 rounded-full bg-emerald-500" />
|
||||
Completed
|
||||
</span>
|
||||
) : (
|
||||
<StatusBadge status={chip.status} />
|
||||
)}
|
||||
<div className="mt-2 min-w-0">
|
||||
<Text variant="small-medium" className="truncate text-zinc-900">
|
||||
{chip.name}
|
||||
</Text>
|
||||
<Text variant="small" className="truncate text-zinc-500">
|
||||
{chip.shortMessage}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.chipActions} flex items-center justify-center gap-1.5 rounded-b-medium px-3 py-1.5`}
|
||||
>
|
||||
<NextLink
|
||||
href={`/library/agents/${chip.agentID}`}
|
||||
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
|
||||
>
|
||||
<EyeIcon size={14} />
|
||||
See
|
||||
</NextLink>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleAsk}
|
||||
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
|
||||
>
|
||||
<ChatCircleDotsIcon size={14} />
|
||||
Ask
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function buildChipPrompt(chip: PulseChipData): string {
|
||||
if (chip.priority === "success") {
|
||||
return `${chip.name} just finished a run — can you summarize what it did?`;
|
||||
}
|
||||
switch (chip.status) {
|
||||
case "error":
|
||||
return `What happened with ${chip.name}? It has an error — can you check?`;
|
||||
case "running":
|
||||
return `Give me a status update on ${chip.name} — what has it done so far?`;
|
||||
case "idle":
|
||||
return `${chip.name} hasn't run recently. Should I keep it or update and re-run it?`;
|
||||
default:
|
||||
return `Tell me about ${chip.name} — what's its current status?`;
|
||||
}
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
import { describe, expect, test, vi } from "vitest";
|
||||
import { render, screen, fireEvent } from "@/tests/integrations/test-utils";
|
||||
import { PulseChips } from "../PulseChips";
|
||||
import type { PulseChipData } from "../types";
|
||||
|
||||
function makeChip(overrides: Partial<PulseChipData> = {}): PulseChipData {
|
||||
return {
|
||||
id: "chip-1",
|
||||
agentID: "agent-1",
|
||||
name: "Test Agent",
|
||||
status: "running",
|
||||
priority: "running",
|
||||
shortMessage: "Doing work…",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("PulseChips", () => {
|
||||
test("renders nothing when chips array is empty", () => {
|
||||
const { container } = render(<PulseChips chips={[]} />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
test("renders chip names and messages", () => {
|
||||
const chips = [
|
||||
makeChip({ id: "1", name: "Alpha Bot", shortMessage: "Running task A" }),
|
||||
makeChip({ id: "2", name: "Beta Bot", shortMessage: "Running task B" }),
|
||||
];
|
||||
|
||||
render(<PulseChips chips={chips} />);
|
||||
|
||||
expect(screen.getByText("Alpha Bot")).toBeDefined();
|
||||
expect(screen.getByText("Running task A")).toBeDefined();
|
||||
expect(screen.getByText("Beta Bot")).toBeDefined();
|
||||
expect(screen.getByText("Running task B")).toBeDefined();
|
||||
});
|
||||
|
||||
test("renders section heading and View all link", () => {
|
||||
render(<PulseChips chips={[makeChip()]} />);
|
||||
|
||||
expect(screen.getByText("What's happening with your agents")).toBeDefined();
|
||||
expect(screen.getByText("View all")).toBeDefined();
|
||||
});
|
||||
|
||||
test("shows Completed badge for success priority chips", () => {
|
||||
render(
|
||||
<PulseChips
|
||||
chips={[makeChip({ priority: "success", status: "idle" })]}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByText("Completed")).toBeDefined();
|
||||
});
|
||||
|
||||
test("calls onChipClick with generated prompt when Ask is clicked", () => {
|
||||
const onChipClick = vi.fn();
|
||||
render(
|
||||
<PulseChips
|
||||
chips={[
|
||||
makeChip({
|
||||
name: "Error Agent",
|
||||
status: "error",
|
||||
priority: "error",
|
||||
}),
|
||||
]}
|
||||
onChipClick={onChipClick}
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByText("Ask"));
|
||||
|
||||
expect(onChipClick).toHaveBeenCalledWith(
|
||||
"What happened with Error Agent? It has an error — can you check?",
|
||||
);
|
||||
});
|
||||
|
||||
test("generates success prompt for completed chips", () => {
|
||||
const onChipClick = vi.fn();
|
||||
render(
|
||||
<PulseChips
|
||||
chips={[
|
||||
makeChip({
|
||||
name: "Done Agent",
|
||||
priority: "success",
|
||||
status: "idle",
|
||||
}),
|
||||
]}
|
||||
onChipClick={onChipClick}
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByText("Ask"));
|
||||
|
||||
expect(onChipClick).toHaveBeenCalledWith(
|
||||
"Done Agent just finished a run — can you summarize what it did?",
|
||||
);
|
||||
});
|
||||
|
||||
test("renders See link pointing to agent detail page", () => {
|
||||
render(<PulseChips chips={[makeChip({ agentID: "agent-xyz" })]} />);
|
||||
|
||||
const seeLink = screen.getByText("See").closest("a");
|
||||
expect(seeLink?.getAttribute("href")).toBe("/library/agents/agent-xyz");
|
||||
});
|
||||
});
|
||||
@@ -1,13 +0,0 @@
|
||||
import type {
|
||||
AgentStatus,
|
||||
SitrepPriority,
|
||||
} from "@/app/(platform)/library/types";
|
||||
|
||||
export interface PulseChipData {
|
||||
id: string;
|
||||
agentID: string;
|
||||
name: string;
|
||||
status: AgentStatus;
|
||||
priority: SitrepPriority;
|
||||
shortMessage: string;
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
|
||||
import { useSitrepItems } from "@/app/(platform)/library/components/SitrepItem/useSitrepItems";
|
||||
import type { PulseChipData } from "./types";
|
||||
import { useMemo } from "react";
|
||||
|
||||
const THREE_DAYS_MS = 3 * 24 * 60 * 60 * 1000;
|
||||
|
||||
export function usePulseChips(): PulseChipData[] {
|
||||
const { agents } = useLibraryAgents();
|
||||
|
||||
const sitrepItems = useSitrepItems(agents, 5, THREE_DAYS_MS);
|
||||
|
||||
return useMemo(() => {
|
||||
return sitrepItems.map((item) => ({
|
||||
id: item.id,
|
||||
agentID: item.agentID,
|
||||
name: item.agentName,
|
||||
status: item.status,
|
||||
priority: item.priority,
|
||||
shortMessage: item.message,
|
||||
}));
|
||||
}, [sitrepItems]);
|
||||
}
|
||||
@@ -6,9 +6,6 @@ import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
|
||||
import { formatCents } from "../usageHelpers";
|
||||
|
||||
export { formatCents };
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
@@ -21,6 +18,10 @@ interface Props {
|
||||
onCreditChange?: () => void;
|
||||
}
|
||||
|
||||
export function formatCents(cents: number): string {
|
||||
return `$${(cents / 100).toFixed(2)}`;
|
||||
}
|
||||
|
||||
export function RateLimitResetDialog({
|
||||
isOpen,
|
||||
onClose,
|
||||
|
||||
@@ -1,10 +1,35 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import Link from "next/link";
|
||||
import { formatCents, formatResetTime } from "../usageHelpers";
|
||||
import { formatCents } from "../RateLimitResetDialog/RateLimitResetDialog";
|
||||
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
|
||||
|
||||
export { formatResetTime };
|
||||
export function formatResetTime(
|
||||
resetsAt: Date | string,
|
||||
now: Date = new Date(),
|
||||
): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / (1000 * 60 * 60));
|
||||
|
||||
// Under 24h: show relative time ("in 4h 23m")
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
|
||||
function UsageBar({
|
||||
label,
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
export function formatCents(cents: number): string {
|
||||
return `$${(cents / 100).toFixed(2)}`;
|
||||
}
|
||||
|
||||
export function formatResetTime(
|
||||
resetsAt: Date | string,
|
||||
now: Date = new Date(),
|
||||
): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / (1000 * 60 * 60));
|
||||
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
@@ -86,6 +86,16 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
|
||||
return sessionQuery.data.data.oldest_sequence ?? null;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const newestSequence = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200) return null;
|
||||
return sessionQuery.data.data.newest_sequence ?? null;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const forwardPaginated = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200) return false;
|
||||
return !!sessionQuery.data.data.forward_paginated;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
|
||||
// array reference every render. Re-derives only when query data changes.
|
||||
// When the session is complete (no active stream), mark dangling tool
|
||||
@@ -185,6 +195,8 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
|
||||
hasActiveStream,
|
||||
hasMoreMessages,
|
||||
oldestSequence,
|
||||
newestSequence,
|
||||
forwardPaginated,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
isSessionError: sessionQuery.isError,
|
||||
createSession,
|
||||
|
||||
@@ -56,6 +56,8 @@ export function useCopilotPage() {
|
||||
hasActiveStream,
|
||||
hasMoreMessages,
|
||||
oldestSequence,
|
||||
newestSequence,
|
||||
forwardPaginated,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
createSession,
|
||||
@@ -84,19 +86,26 @@ export function useCopilotPage() {
|
||||
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
|
||||
});
|
||||
|
||||
const { pagedMessages, hasMore, isLoadingMore, loadMore } =
|
||||
const { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged } =
|
||||
useLoadMoreMessages({
|
||||
sessionId,
|
||||
initialOldestSequence: oldestSequence,
|
||||
initialNewestSequence: newestSequence,
|
||||
initialHasMore: hasMoreMessages,
|
||||
forwardPaginated,
|
||||
initialPageRawMessages: rawSessionMessages,
|
||||
});
|
||||
|
||||
// Combine paginated messages with current page messages, merging consecutive
|
||||
// assistant UIMessages at the page boundary so reasoning + response parts
|
||||
// stay in a single bubble. Paged messages are older history prepended before
|
||||
// the current page.
|
||||
const messages = concatWithAssistantMerge(pagedMessages, currentMessages);
|
||||
// stay in a single bubble.
|
||||
// Forward pagination (completed sessions): current page is the beginning,
|
||||
// paged messages are newer pages appended after.
|
||||
// Backward pagination (active sessions): paged messages are older history
|
||||
// prepended before the current page.
|
||||
const messages = forwardPaginated
|
||||
? concatWithAssistantMerge(currentMessages, pagedMessages)
|
||||
: concatWithAssistantMerge(pagedMessages, currentMessages);
|
||||
|
||||
useCopilotNotifications(sessionId);
|
||||
|
||||
@@ -252,6 +261,15 @@ export function useCopilotPage() {
|
||||
isUserStoppingRef.current = false;
|
||||
|
||||
if (sessionId) {
|
||||
// When continuing a completed session that had forward-paginated history
|
||||
// loaded, the paged messages would appear in wrong position relative to
|
||||
// the new streaming turn (pagedMessages are newer pages, so they'd end
|
||||
// up after the streaming turn). Reset paged state so ordering is correct
|
||||
// during streaming; the user can reload history afterward if needed.
|
||||
if (forwardPaginated && pagedMessages.length > 0) {
|
||||
resetPaged();
|
||||
}
|
||||
|
||||
if (files && files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
try {
|
||||
@@ -398,6 +416,7 @@ export function useCopilotPage() {
|
||||
hasMoreMessages: hasMore,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
forwardPaginated,
|
||||
// Mobile drawer
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -9,7 +9,11 @@ import {
|
||||
interface UseLoadMoreMessagesArgs {
|
||||
sessionId: string | null;
|
||||
initialOldestSequence: number | null;
|
||||
initialNewestSequence: number | null;
|
||||
initialHasMore: boolean;
|
||||
/** True when the initial page was loaded from sequence 0 forward (completed
|
||||
* sessions). False when loaded newest-first (active sessions). */
|
||||
forwardPaginated: boolean;
|
||||
/** Raw messages from the initial page, used for cross-page tool output matching. */
|
||||
initialPageRawMessages: unknown[];
|
||||
}
|
||||
@@ -20,7 +24,9 @@ const MAX_OLDER_MESSAGES = 2000;
|
||||
export function useLoadMoreMessages({
|
||||
sessionId,
|
||||
initialOldestSequence,
|
||||
initialNewestSequence,
|
||||
initialHasMore,
|
||||
forwardPaginated,
|
||||
initialPageRawMessages,
|
||||
}: UseLoadMoreMessagesArgs) {
|
||||
// Accumulated raw messages from all extra pages (ascending order).
|
||||
@@ -30,6 +36,9 @@ export function useLoadMoreMessages({
|
||||
const [oldestSequence, setOldestSequence] = useState<number | null>(
|
||||
initialOldestSequence,
|
||||
);
|
||||
const [newestSequence, setNewestSequence] = useState<number | null>(
|
||||
initialNewestSequence,
|
||||
);
|
||||
const [hasMore, setHasMore] = useState(initialHasMore);
|
||||
const [isLoadingMore, setIsLoadingMore] = useState(false);
|
||||
const isLoadingMoreRef = useRef(false);
|
||||
@@ -37,7 +46,9 @@ export function useLoadMoreMessages({
|
||||
// Epoch counter to discard stale loadMore responses after a reset
|
||||
const epochRef = useRef(0);
|
||||
|
||||
// Track the sessionId and initial cursor to reset state on change
|
||||
const prevSessionIdRef = useRef(sessionId);
|
||||
const prevInitialOldestRef = useRef(initialOldestSequence);
|
||||
|
||||
// Sync initial values from parent when they change.
|
||||
//
|
||||
@@ -60,8 +71,10 @@ export function useLoadMoreMessages({
|
||||
if (prevSessionIdRef.current !== sessionId) {
|
||||
// Session changed — full reset
|
||||
prevSessionIdRef.current = sessionId;
|
||||
prevInitialOldestRef.current = initialOldestSequence;
|
||||
setPagedRawMessages([]);
|
||||
setOldestSequence(initialOldestSequence);
|
||||
setNewestSequence(initialNewestSequence);
|
||||
setHasMore(initialHasMore);
|
||||
setIsLoadingMore(false);
|
||||
isLoadingMoreRef.current = false;
|
||||
@@ -70,23 +83,34 @@ export function useLoadMoreMessages({
|
||||
return;
|
||||
}
|
||||
|
||||
prevInitialOldestRef.current = initialOldestSequence;
|
||||
|
||||
// If we haven't paged yet, mirror the parent so the first
|
||||
// `loadMore` starts from the correct cursor.
|
||||
if (pagedRawMessages.length === 0) {
|
||||
setOldestSequence(initialOldestSequence);
|
||||
// Only regress the forward cursor if we haven't paged ahead yet —
|
||||
// otherwise a parent refetch would reset a cursor we already advanced.
|
||||
setNewestSequence((prev) =>
|
||||
prev !== null && prev > (initialNewestSequence ?? -1)
|
||||
? prev
|
||||
: initialNewestSequence,
|
||||
);
|
||||
setHasMore(initialHasMore);
|
||||
}
|
||||
}, [sessionId, initialOldestSequence, initialHasMore]);
|
||||
}, [sessionId, initialOldestSequence, initialNewestSequence, initialHasMore]);
|
||||
|
||||
// Convert all accumulated raw messages in one pass so tool outputs
|
||||
// are matched across inter-page boundaries.
|
||||
// Include initial page tool outputs so older paged pages can match
|
||||
// tool calls whose outputs landed in the initial page.
|
||||
// For backward pagination only: include initial page tool outputs so older
|
||||
// paged pages can match tool calls whose outputs landed in the initial page.
|
||||
// For forward pagination this is unnecessary — tool calls in newer paged
|
||||
// pages cannot have their outputs in the older initial page.
|
||||
const pagedMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
|
||||
useMemo(() => {
|
||||
if (!sessionId || pagedRawMessages.length === 0) return [];
|
||||
const extraToolOutputs =
|
||||
initialPageRawMessages.length > 0
|
||||
!forwardPaginated && initialPageRawMessages.length > 0
|
||||
? extractToolOutputsFromRaw(initialPageRawMessages)
|
||||
: undefined;
|
||||
return convertChatSessionMessagesToUiMessages(
|
||||
@@ -94,20 +118,22 @@ export function useLoadMoreMessages({
|
||||
pagedRawMessages,
|
||||
{ isComplete: true, extraToolOutputs },
|
||||
).messages;
|
||||
}, [sessionId, pagedRawMessages, initialPageRawMessages]);
|
||||
}, [sessionId, pagedRawMessages, initialPageRawMessages, forwardPaginated]);
|
||||
|
||||
async function loadMore() {
|
||||
if (!sessionId || !hasMore || isLoadingMoreRef.current) return;
|
||||
if (oldestSequence === null) return;
|
||||
|
||||
const cursor = forwardPaginated ? newestSequence : oldestSequence;
|
||||
if (cursor === null) return;
|
||||
|
||||
const requestEpoch = epochRef.current;
|
||||
isLoadingMoreRef.current = true;
|
||||
setIsLoadingMore(true);
|
||||
try {
|
||||
const response = await getV2GetSession(sessionId, {
|
||||
limit: 50,
|
||||
before_sequence: oldestSequence,
|
||||
});
|
||||
const params = forwardPaginated
|
||||
? { limit: 50, after_sequence: cursor }
|
||||
: { limit: 50, before_sequence: cursor };
|
||||
const response = await getV2GetSession(sessionId, params);
|
||||
|
||||
// Discard response if session/pagination was reset while awaiting
|
||||
if (epochRef.current !== requestEpoch) return;
|
||||
@@ -126,20 +152,40 @@ export function useLoadMoreMessages({
|
||||
consecutiveErrorsRef.current = 0;
|
||||
|
||||
const newRaw = (response.data.messages ?? []) as unknown[];
|
||||
const estimatedTotal = pagedRawMessages.length + newRaw.length;
|
||||
setPagedRawMessages((prev) => {
|
||||
const merged = [...newRaw, ...prev];
|
||||
// Forward: append to end. Backward: prepend to start.
|
||||
const merged = forwardPaginated
|
||||
? [...prev, ...newRaw]
|
||||
: [...newRaw, ...prev];
|
||||
if (merged.length > MAX_OLDER_MESSAGES) {
|
||||
return merged.slice(merged.length - MAX_OLDER_MESSAGES);
|
||||
// Backward: discard the oldest (front) items — user has scrolled far
|
||||
// back and we shed the furthest history.
|
||||
// Forward: discard the newest (tail) items — we only ever fetch
|
||||
// forward, so the tail is the most recently appended page; shedding
|
||||
// it means the sentinel stalls, which is safer than discarding the
|
||||
// beginning of the conversation the user is here to read.
|
||||
return forwardPaginated
|
||||
? merged.slice(0, MAX_OLDER_MESSAGES)
|
||||
: merged.slice(merged.length - MAX_OLDER_MESSAGES);
|
||||
}
|
||||
return merged;
|
||||
});
|
||||
|
||||
// Note: after truncation, oldest_sequence may reference a dropped
|
||||
// message. This is safe because we also set hasMore=false below,
|
||||
// preventing further loads with the stale cursor.
|
||||
setOldestSequence(response.data.oldest_sequence ?? null);
|
||||
if (estimatedTotal >= MAX_OLDER_MESSAGES) {
|
||||
if (forwardPaginated) {
|
||||
setNewestSequence(response.data.newest_sequence ?? null);
|
||||
} else {
|
||||
setOldestSequence(response.data.oldest_sequence ?? null);
|
||||
}
|
||||
|
||||
const totalAfterMerge = newRaw.length + pagedRawMessages.length;
|
||||
if (forwardPaginated) {
|
||||
// Forward: truncation sheds the newest tail but the cursor
|
||||
// (newestSequence) still advances, so the sentinel can keep
|
||||
// fetching. Only stop when the server reports no more messages.
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
} else if (totalAfterMerge >= MAX_OLDER_MESSAGES) {
|
||||
// Backward: we've accumulated MAX_OLDER_MESSAGES of history —
|
||||
// stop to avoid unbounded memory growth.
|
||||
setHasMore(false);
|
||||
} else {
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
@@ -159,5 +205,22 @@ export function useLoadMoreMessages({
|
||||
}
|
||||
}
|
||||
|
||||
return { pagedMessages, hasMore, isLoadingMore, loadMore };
|
||||
function resetPaged() {
|
||||
setPagedRawMessages([]);
|
||||
setOldestSequence(initialOldestSequence);
|
||||
setNewestSequence(initialNewestSequence);
|
||||
// Set hasMore=false during the session-transition window so no loadMore
|
||||
// fires with forward pagination (after_sequence) on the now-active session.
|
||||
// The useEffect will restore hasMore from the parent after the refetch
|
||||
// completes and forwardPaginated switches to false.
|
||||
setHasMore(false);
|
||||
// Clear the loading state so the spinner doesn't stay stuck if a loadMore
|
||||
// was in flight when resetPaged was called.
|
||||
setIsLoadingMore(false);
|
||||
isLoadingMoreRef.current = false;
|
||||
consecutiveErrorsRef.current = 0;
|
||||
epochRef.current += 1;
|
||||
}
|
||||
|
||||
return { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged };
|
||||
}
|
||||
|
||||
@@ -2,17 +2,14 @@ import { Navbar } from "@/components/layout/Navbar/Navbar";
|
||||
import { NetworkStatusMonitor } from "@/services/network-status/NetworkStatusMonitor";
|
||||
import { ReactNode } from "react";
|
||||
import { AdminImpersonationBanner } from "./admin/components/AdminImpersonationBanner";
|
||||
import { AutoPilotBridgeProvider } from "@/contexts/AutoPilotBridgeContext";
|
||||
|
||||
export default function PlatformLayout({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<AutoPilotBridgeProvider>
|
||||
<main className="flex h-screen w-full flex-col">
|
||||
<NetworkStatusMonitor />
|
||||
<Navbar />
|
||||
<AdminImpersonationBanner />
|
||||
<section className="flex-1">{children}</section>
|
||||
</main>
|
||||
</AutoPilotBridgeProvider>
|
||||
<main className="flex h-screen w-full flex-col">
|
||||
<NetworkStatusMonitor />
|
||||
<Navbar />
|
||||
<AdminImpersonationBanner />
|
||||
<section className="flex-1">{children}</section>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -137,10 +137,8 @@ describe("LibraryPage", () => {
|
||||
user_id: "test-user",
|
||||
name: "Work Agents",
|
||||
agent_count: 3,
|
||||
subfolder_count: 0,
|
||||
color: null,
|
||||
icon: null,
|
||||
parent_id: null,
|
||||
created_at: new Date(),
|
||||
updated_at: new Date(),
|
||||
},
|
||||
@@ -149,10 +147,8 @@ describe("LibraryPage", () => {
|
||||
user_id: "test-user",
|
||||
name: "Personal",
|
||||
agent_count: 1,
|
||||
subfolder_count: 0,
|
||||
color: null,
|
||||
icon: null,
|
||||
parent_id: null,
|
||||
created_at: new Date(),
|
||||
updated_at: new Date(),
|
||||
},
|
||||
@@ -162,14 +158,12 @@ describe("LibraryPage", () => {
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
await waitForAgentsToLoad();
|
||||
|
||||
expect(await screen.findByText("Work Agents")).toBeDefined();
|
||||
expect(screen.getByText("Personal")).toBeDefined();
|
||||
expect(screen.getAllByTestId("library-folder")).toHaveLength(2);
|
||||
});
|
||||
|
||||
test("shows See tasks link on agent card", async () => {
|
||||
test("shows See runs link on agent card", async () => {
|
||||
setupHandlers({
|
||||
agents: [makeAgent({ name: "Linked Agent", can_access_graph: true })],
|
||||
});
|
||||
@@ -178,7 +172,7 @@ describe("LibraryPage", () => {
|
||||
|
||||
await screen.findByText("Linked Agent");
|
||||
|
||||
const runLinks = screen.getAllByText("See tasks");
|
||||
const runLinks = screen.getAllByText("See runs");
|
||||
expect(runLinks.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
@@ -196,7 +190,7 @@ describe("LibraryPage", () => {
|
||||
expect(importButtons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders running agent card when execution is active", async () => {
|
||||
test("renders Jump Back In when there is an active execution", async () => {
|
||||
const agent = makeAgent({
|
||||
id: "lib-1",
|
||||
graph_id: "g-1",
|
||||
@@ -224,6 +218,6 @@ describe("LibraryPage", () => {
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText("Running Agent")).toBeDefined();
|
||||
expect(await screen.findByText("Jump Back In")).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
.glassPanel {
|
||||
position: relative;
|
||||
isolation: isolate;
|
||||
}
|
||||
|
||||
.glassPanel::before {
|
||||
content: "";
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
border-radius: inherit;
|
||||
padding: 1px;
|
||||
background: conic-gradient(
|
||||
from var(--border-angle, 0deg),
|
||||
rgba(129, 120, 228, 0.04),
|
||||
rgba(129, 120, 228, 0.14),
|
||||
rgba(168, 130, 255, 0.09),
|
||||
rgba(129, 120, 228, 0.04),
|
||||
rgba(99, 102, 241, 0.12),
|
||||
rgba(129, 120, 228, 0.04)
|
||||
);
|
||||
-webkit-mask:
|
||||
linear-gradient(#000 0 0) content-box,
|
||||
linear-gradient(#000 0 0);
|
||||
mask:
|
||||
linear-gradient(#000 0 0) content-box,
|
||||
linear-gradient(#000 0 0);
|
||||
-webkit-mask-composite: xor;
|
||||
mask-composite: exclude;
|
||||
animation: rotate-border 6s linear infinite;
|
||||
pointer-events: none;
|
||||
z-index: -1;
|
||||
}
|
||||
|
||||
@property --border-angle {
|
||||
syntax: "<angle>";
|
||||
initial-value: 0deg;
|
||||
inherits: false;
|
||||
}
|
||||
|
||||
@keyframes rotate-border {
|
||||
to {
|
||||
--border-angle: 360deg;
|
||||
}
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useState } from "react";
|
||||
import type { FleetSummary, AgentStatusFilter } from "../../types";
|
||||
import { BriefingTabContent } from "./BriefingTabContent";
|
||||
import { StatsGrid } from "./StatsGrid";
|
||||
import styles from "./AgentBriefingPanel.module.css";
|
||||
|
||||
interface Props {
|
||||
summary: FleetSummary;
|
||||
agents: LibraryAgent[];
|
||||
}
|
||||
|
||||
export function AgentBriefingPanel({ summary, agents }: Props) {
|
||||
const [userTab, setUserTab] = useState<AgentStatusFilter | null>(null);
|
||||
const activeTab: AgentStatusFilter =
|
||||
userTab ?? (summary.running > 0 ? "running" : "all");
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${styles.glassPanel} min-h-[14.75rem] rounded-large bg-gradient-to-br from-indigo-50/30 via-white/90 to-purple-50/25 px-5 pb-5 pt-[1.125rem] shadow-sm backdrop-blur-md`}
|
||||
>
|
||||
<Text variant="h5">Agent Briefing</Text>
|
||||
<div className="mt-4 space-y-5">
|
||||
<StatsGrid
|
||||
summary={summary}
|
||||
activeTab={activeTab}
|
||||
onTabChange={setUserTab}
|
||||
/>
|
||||
<BriefingTabContent activeTab={activeTab} agents={agents} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,361 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import {
|
||||
formatResetTime,
|
||||
formatCents,
|
||||
} from "@/app/(platform)/copilot/components/usageHelpers";
|
||||
import { useResetRateLimit } from "@/app/(platform)/copilot/hooks/useResetRateLimit";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useSitrepItems } from "../SitrepItem/useSitrepItems";
|
||||
import { SitrepItem } from "../SitrepItem/SitrepItem";
|
||||
import { useAgentStatusMap } from "../../hooks/useAgentStatus";
|
||||
import type { AgentStatusFilter } from "../../types";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Props {
|
||||
activeTab: AgentStatusFilter;
|
||||
agents: LibraryAgent[];
|
||||
}
|
||||
|
||||
export function BriefingTabContent({ activeTab, agents }: Props) {
|
||||
if (activeTab === "all") {
|
||||
return <UsageSection />;
|
||||
}
|
||||
|
||||
if (
|
||||
activeTab === "running" ||
|
||||
activeTab === "attention" ||
|
||||
activeTab === "completed"
|
||||
) {
|
||||
return <ExecutionListSection activeTab={activeTab} agents={agents} />;
|
||||
}
|
||||
|
||||
return <AgentListSection activeTab={activeTab} agents={agents} />;
|
||||
}
|
||||
|
||||
function UsageSection() {
|
||||
const { data: usage } = useGetV2GetCopilotUsage({
|
||||
query: {
|
||||
select: (res) => res.data as CoPilotUsageStatus,
|
||||
refetchInterval: 30000,
|
||||
staleTime: 10000,
|
||||
},
|
||||
});
|
||||
|
||||
const isBillingEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
|
||||
const { credits, fetchCredits } = useCredits({ fetchInitialCredits: true });
|
||||
const resetCost = usage?.reset_cost;
|
||||
const hasInsufficientCredits =
|
||||
credits !== null && resetCost != null && credits < resetCost;
|
||||
|
||||
if (!usage?.daily || !usage?.weekly) return null;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text variant="h5" className="text-neutral-800">
|
||||
Usage limits
|
||||
</Text>
|
||||
{usage.tier && (
|
||||
<Badge variant="info" size="small" className="bg-[rgb(224,237,255)]">
|
||||
{usage.tier.charAt(0) + usage.tier.slice(1).toLowerCase()} plan
|
||||
</Badge>
|
||||
)}
|
||||
<div className="flex-1" />
|
||||
{isBillingEnabled && (
|
||||
<Link
|
||||
href="/profile/credits"
|
||||
className="text-sm text-blue-600 hover:underline"
|
||||
>
|
||||
Manage billing
|
||||
</Link>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-4 grid grid-cols-1 gap-6 sm:grid-cols-2">
|
||||
{usage.daily.limit > 0 && (
|
||||
<UsageMeter
|
||||
label="Today"
|
||||
used={usage.daily.used}
|
||||
limit={usage.daily.limit}
|
||||
resetsAt={usage.daily.resets_at}
|
||||
/>
|
||||
)}
|
||||
{usage.weekly.limit > 0 && (
|
||||
<UsageMeter
|
||||
label="This week"
|
||||
used={usage.weekly.used}
|
||||
limit={usage.weekly.limit}
|
||||
resetsAt={usage.weekly.resets_at}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<UsageFooter
|
||||
usage={usage}
|
||||
hasInsufficientCredits={hasInsufficientCredits}
|
||||
onCreditChange={fetchCredits}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const MAX_VISIBLE = 6;
|
||||
|
||||
function ExecutionListSection({
|
||||
activeTab,
|
||||
agents,
|
||||
}: {
|
||||
activeTab: AgentStatusFilter;
|
||||
agents: LibraryAgent[];
|
||||
}) {
|
||||
const allItems = useSitrepItems(agents, 50);
|
||||
const [showAll, setShowAll] = useState(false);
|
||||
|
||||
const filtered = allItems.filter((item) => {
|
||||
if (activeTab === "running") return item.priority === "running";
|
||||
if (activeTab === "attention") return item.priority === "error";
|
||||
if (activeTab === "completed") return item.priority === "success";
|
||||
return false;
|
||||
});
|
||||
|
||||
if (filtered.length === 0) {
|
||||
return <EmptyMessage tab={activeTab} />;
|
||||
}
|
||||
|
||||
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
|
||||
const hasMore = filtered.length > MAX_VISIBLE;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
|
||||
{visible.map((item) => (
|
||||
<SitrepItem key={item.id} item={item} />
|
||||
))}
|
||||
</div>
|
||||
{hasMore && (
|
||||
<div className="mt-3 flex justify-center">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => setShowAll(!showAll)}
|
||||
>
|
||||
{showAll ? "Collapse" : `Show all (${filtered.length})`}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const TAB_STATUS_LABEL: Record<string, string> = {
|
||||
listening: "Waiting for trigger event",
|
||||
scheduled: "Has a scheduled run",
|
||||
idle: "No recent activity",
|
||||
};
|
||||
|
||||
function getAgentStatusLabel(tab: string, agent: LibraryAgent): string {
|
||||
if (tab === "scheduled" && agent.next_scheduled_run) {
|
||||
const diff = new Date(agent.next_scheduled_run).getTime() - Date.now();
|
||||
const minutes = Math.round(diff / 60_000);
|
||||
if (minutes <= 0) return "Scheduled to run soon";
|
||||
if (minutes < 60) return `Scheduled to run in ${minutes}m`;
|
||||
const hours = Math.round(minutes / 60);
|
||||
if (hours < 24) return `Scheduled to run in ${hours}h`;
|
||||
const days = Math.round(hours / 24);
|
||||
return `Scheduled to run in ${days}d`;
|
||||
}
|
||||
return TAB_STATUS_LABEL[tab] ?? "";
|
||||
}
|
||||
|
||||
function AgentListSection({
|
||||
activeTab,
|
||||
agents,
|
||||
}: {
|
||||
activeTab: AgentStatusFilter;
|
||||
agents: LibraryAgent[];
|
||||
}) {
|
||||
const [showAll, setShowAll] = useState(false);
|
||||
const statusMap = useAgentStatusMap(agents);
|
||||
|
||||
const filtered = agents.filter((agent) => {
|
||||
const status = statusMap.get(agent.graph_id)?.status;
|
||||
if (activeTab === "listening") return status === "listening";
|
||||
if (activeTab === "scheduled") return status === "scheduled";
|
||||
if (activeTab === "idle") return status === "idle";
|
||||
return false;
|
||||
});
|
||||
|
||||
if (filtered.length === 0) {
|
||||
return <EmptyMessage tab={activeTab} />;
|
||||
}
|
||||
|
||||
const status =
|
||||
activeTab === "listening"
|
||||
? ("listening" as const)
|
||||
: activeTab === "scheduled"
|
||||
? ("scheduled" as const)
|
||||
: ("idle" as const);
|
||||
|
||||
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
|
||||
const hasMore = filtered.length > MAX_VISIBLE;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
|
||||
{visible.map((agent) => (
|
||||
<SitrepItem
|
||||
key={agent.id}
|
||||
item={{
|
||||
id: agent.id,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
agentImageUrl: agent.image_url,
|
||||
priority: status,
|
||||
message: getAgentStatusLabel(activeTab, agent),
|
||||
status,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
{hasMore && (
|
||||
<div className="mt-3 flex justify-center">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => setShowAll(!showAll)}
|
||||
>
|
||||
{showAll ? "Collapse" : `Show all (${filtered.length})`}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function UsageFooter({
|
||||
usage,
|
||||
hasInsufficientCredits,
|
||||
onCreditChange,
|
||||
}: {
|
||||
usage: CoPilotUsageStatus;
|
||||
hasInsufficientCredits: boolean;
|
||||
onCreditChange?: () => void;
|
||||
}) {
|
||||
const isDailyExhausted =
|
||||
usage.daily.limit > 0 && usage.daily.used >= usage.daily.limit;
|
||||
const isWeeklyExhausted =
|
||||
usage.weekly.limit > 0 && usage.weekly.used >= usage.weekly.limit;
|
||||
const resetCost = usage.reset_cost ?? 0;
|
||||
const { resetUsage, isPending } = useResetRateLimit({ onCreditChange });
|
||||
|
||||
const showReset =
|
||||
isDailyExhausted &&
|
||||
!isWeeklyExhausted &&
|
||||
resetCost > 0 &&
|
||||
!hasInsufficientCredits;
|
||||
|
||||
const showAddCredits =
|
||||
isDailyExhausted && !isWeeklyExhausted && hasInsufficientCredits;
|
||||
|
||||
if (!showReset && !showAddCredits) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-4 flex items-center gap-3">
|
||||
{showReset && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={() => resetUsage()}
|
||||
loading={isPending}
|
||||
>
|
||||
{isPending
|
||||
? "Resetting..."
|
||||
: `Reset daily limit for ${formatCents(resetCost)}`}
|
||||
</Button>
|
||||
)}
|
||||
{showAddCredits && (
|
||||
<Link
|
||||
href="/profile/credits"
|
||||
className="inline-flex items-center justify-center rounded-md bg-primary px-3 py-1.5 text-sm font-medium text-primary-foreground hover:bg-primary/90"
|
||||
>
|
||||
Add credits to reset
|
||||
</Link>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function UsageMeter({
|
||||
label,
|
||||
used,
|
||||
limit,
|
||||
resetsAt,
|
||||
}: {
|
||||
label: string;
|
||||
used: number;
|
||||
limit: number;
|
||||
resetsAt: Date | string;
|
||||
}) {
|
||||
if (limit <= 0) return null;
|
||||
|
||||
const rawPercent = (used / limit) * 100;
|
||||
const percent = Math.min(100, Math.round(rawPercent));
|
||||
const isHigh = percent >= 80;
|
||||
const percentLabel =
|
||||
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex items-baseline justify-between">
|
||||
<Text variant="body-medium" className="text-neutral-700">
|
||||
{label}
|
||||
</Text>
|
||||
<Text variant="body" className="tabular-nums text-neutral-500">
|
||||
{percentLabel}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||
<div
|
||||
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||
}`}
|
||||
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-baseline justify-between">
|
||||
<Text variant="small" className="tabular-nums text-neutral-500">
|
||||
{used.toLocaleString()} / {limit.toLocaleString()}
|
||||
</Text>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
Resets {formatResetTime(resetsAt)}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const EMPTY_MESSAGES: Record<string, string> = {
|
||||
running: "No agents running right now",
|
||||
attention: "No agents that need attention",
|
||||
completed: "No recently completed runs",
|
||||
listening: "No agents listening for events",
|
||||
scheduled: "No agents with scheduled runs",
|
||||
idle: "No idle agents",
|
||||
};
|
||||
|
||||
function EmptyMessage({ tab }: { tab: AgentStatusFilter }) {
|
||||
return (
|
||||
<div className="flex items-center justify-center pt-4">
|
||||
<Text variant="body-medium" className="text-zinc-600">
|
||||
{EMPTY_MESSAGES[tab] ?? "No agents in this category"}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { OverflowText } from "@/components/atoms/OverflowText/OverflowText";
|
||||
import { Emoji } from "@/components/atoms/Emoji/Emoji";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { FleetSummary, AgentStatusFilter } from "../../types";
|
||||
|
||||
interface Props {
|
||||
summary: FleetSummary;
|
||||
activeTab: AgentStatusFilter;
|
||||
onTabChange: (tab: AgentStatusFilter) => void;
|
||||
}
|
||||
|
||||
const TILES: {
|
||||
label: string;
|
||||
key: keyof FleetSummary;
|
||||
format?: (v: number) => string;
|
||||
filter: AgentStatusFilter;
|
||||
emoji: string;
|
||||
color: string;
|
||||
}[] = [
|
||||
{
|
||||
label: "Spent this month",
|
||||
key: "monthlySpend",
|
||||
format: (v) => `$${v.toLocaleString()}`,
|
||||
filter: "all",
|
||||
emoji: "💵",
|
||||
color: "text-zinc-700",
|
||||
},
|
||||
{
|
||||
label: "Running now",
|
||||
key: "running",
|
||||
filter: "running",
|
||||
emoji: "🚩",
|
||||
color: "text-blue-600",
|
||||
},
|
||||
{
|
||||
label: "Recently completed",
|
||||
key: "completed",
|
||||
filter: "completed",
|
||||
emoji: "🗃️",
|
||||
color: "text-green-600",
|
||||
},
|
||||
{
|
||||
label: "Needs attention",
|
||||
key: "error",
|
||||
filter: "attention",
|
||||
emoji: "⚠️",
|
||||
color: "text-red-500",
|
||||
},
|
||||
{
|
||||
label: "Scheduled",
|
||||
key: "scheduled",
|
||||
filter: "scheduled",
|
||||
emoji: "📅",
|
||||
color: "text-yellow-600",
|
||||
},
|
||||
{
|
||||
label: "Idle",
|
||||
key: "idle",
|
||||
filter: "idle",
|
||||
emoji: "💤",
|
||||
color: "text-zinc-400",
|
||||
},
|
||||
];
|
||||
|
||||
export function StatsGrid({ summary, activeTab, onTabChange }: Props) {
|
||||
return (
|
||||
<div className="grid grid-cols-1 gap-3 min-[450px]:grid-cols-2 sm:grid-cols-3 lg:grid-cols-6">
|
||||
{TILES.map((tile) => {
|
||||
const rawValue = summary[tile.key];
|
||||
const value = tile.format ? tile.format(rawValue) : rawValue;
|
||||
const isActive = activeTab === tile.filter;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={tile.label}
|
||||
type="button"
|
||||
onClick={() => onTabChange(tile.filter)}
|
||||
className={cn(
|
||||
"flex min-w-0 flex-col gap-1 rounded-medium border p-3 text-left shadow-md transition-all hover:shadow-lg",
|
||||
isActive
|
||||
? "border-zinc-900 bg-zinc-50"
|
||||
: "border-zinc-100 bg-white",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 items-center gap-1.5">
|
||||
<Emoji text={tile.emoji} size={18} />
|
||||
<OverflowText
|
||||
value={tile.label}
|
||||
variant="body"
|
||||
className="text-zinc-800"
|
||||
/>
|
||||
</div>
|
||||
<Text variant="h4">{value}</Text>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { SelectOption } from "@/components/atoms/Select/Select";
|
||||
import { Select } from "@/components/atoms/Select/Select";
|
||||
import { FunnelIcon } from "@phosphor-icons/react";
|
||||
import type { AgentStatusFilter, FleetSummary } from "../../types";
|
||||
|
||||
interface Props {
|
||||
value: AgentStatusFilter;
|
||||
onChange: (value: AgentStatusFilter) => void;
|
||||
summary: FleetSummary;
|
||||
}
|
||||
|
||||
function buildOptions(summary: FleetSummary): SelectOption[] {
|
||||
return [
|
||||
{ value: "all", label: "All Agents" },
|
||||
{ value: "running", label: `Running (${summary.running})` },
|
||||
{ value: "attention", label: `Needs Attention (${summary.error})` },
|
||||
{ value: "listening", label: `Listening (${summary.listening})` },
|
||||
{ value: "scheduled", label: `Scheduled (${summary.scheduled})` },
|
||||
{ value: "idle", label: `Idle / Stale (${summary.idle})` },
|
||||
{ value: "healthy", label: "Healthy" },
|
||||
];
|
||||
}
|
||||
|
||||
export function AgentFilterMenu({ value, onChange, summary }: Props) {
|
||||
function handleChange(val: string) {
|
||||
onChange(val as AgentStatusFilter);
|
||||
}
|
||||
|
||||
const options = buildOptions(summary);
|
||||
|
||||
return (
|
||||
<div className="flex items-center" data-testid="agent-filter-dropdown">
|
||||
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
|
||||
filter
|
||||
</span>
|
||||
<FunnelIcon className="ml-1 h-4 w-4 sm:hidden" />
|
||||
<Select
|
||||
id="agent-status-filter"
|
||||
label="Filter agents"
|
||||
hideLabel
|
||||
value={value}
|
||||
onValueChange={handleChange}
|
||||
options={options}
|
||||
size="small"
|
||||
className="ml-1 w-fit border-none !bg-transparent text-sm underline underline-offset-4 shadow-none"
|
||||
wrapperClassName="mb-0"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
EyeIcon,
|
||||
ArrowsClockwiseIcon,
|
||||
MonitorPlayIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { AgentStatus } from "../../types";
|
||||
|
||||
interface Props {
|
||||
status: AgentStatus;
|
||||
agentID: string;
|
||||
executionID?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ContextualActionButton({
|
||||
status,
|
||||
agentID,
|
||||
executionID,
|
||||
className,
|
||||
}: Props) {
|
||||
const router = useRouter();
|
||||
|
||||
const config = ACTION_CONFIG[status];
|
||||
if (!config) return null;
|
||||
|
||||
const Icon = config.icon;
|
||||
|
||||
function handleClick(e: React.MouseEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
|
||||
const params = new URLSearchParams();
|
||||
if (executionID) params.set("activeItem", executionID);
|
||||
const query = params.toString();
|
||||
router.push(`/library/agents/${agentID}${query ? `?${query}` : ""}`);
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleClick}
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<Icon size={12} className="shrink-0" />
|
||||
{config.label}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
const ACTION_CONFIG: Record<
|
||||
AgentStatus,
|
||||
{ label: string; icon: typeof EyeIcon }
|
||||
> = {
|
||||
error: { label: "View error", icon: EyeIcon },
|
||||
listening: { label: "Reconnect", icon: ArrowsClockwiseIcon },
|
||||
running: { label: "Watch live", icon: MonitorPlayIcon },
|
||||
idle: { label: "View", icon: EyeIcon },
|
||||
scheduled: { label: "View", icon: EyeIcon },
|
||||
};
|
||||
@@ -0,0 +1,46 @@
|
||||
"use client";
|
||||
|
||||
import { ArrowRight, Lightning } from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useJumpBackIn } from "./useJumpBackIn";
|
||||
|
||||
export function JumpBackIn() {
|
||||
const { execution, isLoading } = useJumpBackIn();
|
||||
|
||||
if (isLoading || !execution) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const href = execution.libraryAgentId
|
||||
? `/library/agents/${execution.libraryAgentId}?activeTab=runs&activeItem=${execution.id}`
|
||||
: "#";
|
||||
|
||||
return (
|
||||
<div className="rounded-large bg-gradient-to-r from-zinc-200 via-zinc-200/60 to-indigo-200/50 p-[1px]">
|
||||
<div className="flex items-center justify-between rounded-large bg-[#F6F7F8] px-5 py-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
|
||||
<Lightning size={18} weight="fill" className="text-white" />
|
||||
</div>
|
||||
<div className="flex flex-col">
|
||||
<Text variant="small" className="text-zinc-500">
|
||||
{execution.statusLabel} · {execution.duration}
|
||||
</Text>
|
||||
<Text variant="body-medium" className="text-zinc-900">
|
||||
{execution.agentName}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<NextLink href={href}>
|
||||
<Button variant="secondary" size="small" className="gap-1.5">
|
||||
Jump Back In
|
||||
<ArrowRight size={16} />
|
||||
</Button>
|
||||
</NextLink>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
|
||||
import { useMemo } from "react";
|
||||
|
||||
function isActive(status: AgentExecutionStatus) {
|
||||
return (
|
||||
status === AgentExecutionStatus.RUNNING ||
|
||||
status === AgentExecutionStatus.QUEUED ||
|
||||
status === AgentExecutionStatus.REVIEW
|
||||
);
|
||||
}
|
||||
|
||||
function formatDuration(startedAt: Date | string | null | undefined): string {
|
||||
if (!startedAt) return "";
|
||||
|
||||
const start = new Date(startedAt);
|
||||
if (isNaN(start.getTime())) return "";
|
||||
|
||||
const ms = Date.now() - start.getTime();
|
||||
if (ms < 0) return "";
|
||||
|
||||
const sec = Math.floor(ms / 1000);
|
||||
if (sec < 5) return "a few seconds";
|
||||
if (sec < 60) return `${sec}s`;
|
||||
const min = Math.floor(sec / 60);
|
||||
if (min < 60) return `${min}m ${sec % 60}s`;
|
||||
const hr = Math.floor(min / 60);
|
||||
return `${hr}h ${min % 60}m`;
|
||||
}
|
||||
|
||||
function getStatusLabel(status: AgentExecutionStatus) {
|
||||
if (status === AgentExecutionStatus.RUNNING) return "Running";
|
||||
if (status === AgentExecutionStatus.QUEUED) return "Queued";
|
||||
if (status === AgentExecutionStatus.REVIEW) return "Awaiting approval";
|
||||
return "";
|
||||
}
|
||||
|
||||
export function useJumpBackIn() {
|
||||
const { data: executions, isLoading: executionsLoading } =
|
||||
useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const { agentInfoMap, isRefreshing: agentsLoading } = useLibraryAgents();
|
||||
|
||||
const activeExecution = useMemo(() => {
|
||||
if (!executions) return null;
|
||||
|
||||
const active = executions
|
||||
.filter((e) => isActive(e.status))
|
||||
.sort((a, b) => {
|
||||
const aTime = a.started_at ? new Date(a.started_at).getTime() : 0;
|
||||
const bTime = b.started_at ? new Date(b.started_at).getTime() : 0;
|
||||
return bTime - aTime;
|
||||
});
|
||||
|
||||
return active[0] ?? null;
|
||||
}, [executions]);
|
||||
|
||||
const enriched = useMemo(() => {
|
||||
if (!activeExecution) return null;
|
||||
|
||||
const info = agentInfoMap.get(activeExecution.graph_id);
|
||||
return {
|
||||
id: activeExecution.id,
|
||||
agentName: info?.name ?? "Unknown Agent",
|
||||
libraryAgentId: info?.library_agent_id,
|
||||
status: activeExecution.status,
|
||||
statusLabel: getStatusLabel(activeExecution.status),
|
||||
duration: formatDuration(activeExecution.started_at),
|
||||
};
|
||||
}, [activeExecution, agentInfoMap]);
|
||||
|
||||
return {
|
||||
execution: enriched,
|
||||
isLoading: executionsLoading || agentsLoading,
|
||||
};
|
||||
}
|
||||
@@ -8,7 +8,7 @@ interface Props {
|
||||
export function LibraryActionHeader({ setSearchTerm }: Props) {
|
||||
return (
|
||||
<>
|
||||
<div className="mb-7 hidden items-center justify-center gap-4 md:flex">
|
||||
<div className="mb-[32px] hidden items-center justify-center gap-4 md:flex">
|
||||
<LibrarySearchBar setSearchTerm={setSearchTerm} />
|
||||
<LibraryImportDialog />
|
||||
</div>
|
||||
|
||||
@@ -1,40 +1,29 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { EyeIcon, ChatCircleDotsIcon } from "@phosphor-icons/react";
|
||||
import { CaretCircleRightIcon } from "@phosphor-icons/react";
|
||||
import Image from "next/image";
|
||||
import NextLink from "next/link";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { motion } from "framer-motion";
|
||||
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Avatar, {
|
||||
AvatarFallback,
|
||||
AvatarImage,
|
||||
} from "@/components/atoms/Avatar/Avatar";
|
||||
import { Link } from "@/components/atoms/Link/Link";
|
||||
import { AgentCardMenu } from "./components/AgentCardMenu";
|
||||
import { FavoriteButton } from "./components/FavoriteButton";
|
||||
import { useLibraryAgentCard } from "./useLibraryAgentCard";
|
||||
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
|
||||
import { StatusBadge } from "../StatusBadge/StatusBadge";
|
||||
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
|
||||
import type { AgentStatusInfo } from "../../types";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
|
||||
interface Props {
|
||||
agent: LibraryAgent;
|
||||
statusInfo: AgentStatusInfo;
|
||||
draggable?: boolean;
|
||||
}
|
||||
|
||||
export function LibraryAgentCard({
|
||||
agent,
|
||||
statusInfo,
|
||||
draggable = true,
|
||||
}: Props) {
|
||||
const { id, name, image_url } = agent;
|
||||
const router = useRouter();
|
||||
export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
const { id, name, graph_id, can_access_graph, image_url } = agent;
|
||||
const { triggerFavoriteAnimation } = useFavoriteAnimation();
|
||||
|
||||
function handleDragStart(e: React.DragEvent<HTMLDivElement>) {
|
||||
@@ -42,14 +31,18 @@ export function LibraryAgentCard({
|
||||
e.dataTransfer.effectAllowed = "move";
|
||||
}
|
||||
|
||||
const { isFavorite, handleToggleFavorite } = useLibraryAgentCard({
|
||||
const {
|
||||
isFromMarketplace,
|
||||
isFavorite,
|
||||
profile,
|
||||
creator_image_url,
|
||||
handleToggleFavorite,
|
||||
} = useLibraryAgentCard({
|
||||
agent,
|
||||
onFavoriteAdd: triggerFavoriteAnimation,
|
||||
});
|
||||
|
||||
const hasError = statusInfo.status === "error";
|
||||
|
||||
const card = (
|
||||
return (
|
||||
<div
|
||||
draggable={draggable}
|
||||
onDragStart={handleDragStart}
|
||||
@@ -59,10 +52,7 @@ export function LibraryAgentCard({
|
||||
layoutId={`agent-card-${id}`}
|
||||
data-testid="library-agent-card"
|
||||
data-agent-id={id}
|
||||
className={cn(
|
||||
"group relative inline-flex h-auto min-h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border bg-white hover:shadow-md",
|
||||
hasError ? "border-red-400" : "border-zinc-100",
|
||||
)}
|
||||
className="group relative inline-flex h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border border-zinc-100 bg-white hover:shadow-md"
|
||||
transition={{
|
||||
type: "spring",
|
||||
damping: 25,
|
||||
@@ -71,10 +61,23 @@ export function LibraryAgentCard({
|
||||
style={{ willChange: "transform" }}
|
||||
>
|
||||
<NextLink href={`/library/agents/${id}`} className="flex-shrink-0">
|
||||
<div className="relative flex items-center gap-3 pl-2 pr-4 pt-3">
|
||||
<StatusBadge status={statusInfo.status} />
|
||||
<Text variant="small" className="text-zinc-400">
|
||||
{statusInfo.totalRuns} tasks
|
||||
<div className="relative flex items-center gap-2 px-4 pt-3">
|
||||
<Avatar className="h-4 w-4 rounded-full">
|
||||
<AvatarImage
|
||||
src={
|
||||
isFromMarketplace
|
||||
? creator_image_url || "/avatar-placeholder.png"
|
||||
: profile?.avatar_url || "/avatar-placeholder.png"
|
||||
}
|
||||
alt={`${name} creator avatar`}
|
||||
/>
|
||||
<AvatarFallback size={48}>{name.charAt(0)}</AvatarFallback>
|
||||
</Avatar>
|
||||
<Text
|
||||
variant="small-medium"
|
||||
className="uppercase tracking-wide text-zinc-400"
|
||||
>
|
||||
{isFromMarketplace ? "FROM MARKETPLACE" : "Built by you"}
|
||||
</Text>
|
||||
</div>
|
||||
</NextLink>
|
||||
@@ -86,7 +89,7 @@ export function LibraryAgentCard({
|
||||
<AgentCardMenu agent={agent} />
|
||||
|
||||
<div className="flex w-full flex-1 flex-col px-4 pb-2">
|
||||
<NextLink
|
||||
<Link
|
||||
href={`/library/agents/${id}`}
|
||||
className="flex w-full items-start justify-between gap-2 no-underline hover:no-underline focus:ring-0"
|
||||
>
|
||||
@@ -123,52 +126,30 @@ export function LibraryAgentCard({
|
||||
className="flex-shrink-0 rounded-small object-cover"
|
||||
/>
|
||||
)}
|
||||
</NextLink>
|
||||
</Link>
|
||||
|
||||
<div className="mt-4 flex w-full items-center justify-end gap-1 border-t border-zinc-100 pb-0 pt-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => router.push(`/library/agents/${id}`)}
|
||||
<div className="mt-auto flex w-full justify-start gap-6 border-t border-zinc-100 pb-1 pt-3">
|
||||
<Link
|
||||
href={`/library/agents/${id}`}
|
||||
data-testid="library-agent-card-see-runs-link"
|
||||
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
|
||||
className="flex items-center gap-1 text-[13px]"
|
||||
>
|
||||
<EyeIcon size={14} className="shrink-0" />
|
||||
See tasks
|
||||
</button>
|
||||
<ContextualActionButton
|
||||
status={statusInfo.status}
|
||||
agentID={id}
|
||||
executionID={statusInfo.activeExecutionID ?? undefined}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
const prompt = encodeURIComponent(
|
||||
`Tell me about ${name}, its current status, recent runs and how can I get the most out of it`,
|
||||
);
|
||||
router.push(`/copilot?autosubmit=true#prompt=${prompt}`);
|
||||
}}
|
||||
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
|
||||
>
|
||||
<ChatCircleDotsIcon size={14} className="shrink-0" />
|
||||
Chat
|
||||
</button>
|
||||
See runs <CaretCircleRightIcon size={20} />
|
||||
</Link>
|
||||
|
||||
{can_access_graph && (
|
||||
<Link
|
||||
href={`/build?flowID=${graph_id}`}
|
||||
data-testid="library-agent-card-open-in-builder-link"
|
||||
className="flex items-center gap-1 text-[13px]"
|
||||
isExternal
|
||||
>
|
||||
Open in builder <CaretCircleRightIcon size={20} />
|
||||
</Link>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
|
||||
if (hasError && statusInfo.lastError) {
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>{card}</TooltipTrigger>
|
||||
<TooltipContent className="max-w-xs text-red-600">
|
||||
{statusInfo.lastError}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
return card;
|
||||
}
|
||||
|
||||
@@ -169,7 +169,6 @@ export function AgentCardMenu({ agent }: AgentCardMenuProps) {
|
||||
href={`/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`}
|
||||
target="_blank"
|
||||
className="flex items-center gap-2"
|
||||
data-testid="library-agent-card-open-in-builder-link"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
Edit agent
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
|
||||
import { LibraryAgentCard } from "../LibraryAgentCard/LibraryAgentCard";
|
||||
@@ -17,11 +16,8 @@ import {
|
||||
} from "framer-motion";
|
||||
import { LibraryFolderEditDialog } from "../LibraryFolderEditDialog/LibraryFolderEditDialog";
|
||||
import { LibraryFolderDeleteDialog } from "../LibraryFolderDeleteDialog/LibraryFolderDeleteDialog";
|
||||
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
|
||||
import { LibraryTab } from "../../types";
|
||||
import { useLibraryAgentList } from "./useLibraryAgentList";
|
||||
import { AgentBriefingPanel } from "../AgentBriefingPanel/AgentBriefingPanel";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useAgentStatusMap, getAgentStatus } from "../../hooks/useAgentStatus";
|
||||
|
||||
// cancels the current spring and starts a new one from current state.
|
||||
const containerVariants = {
|
||||
@@ -74,10 +70,6 @@ interface Props {
|
||||
tabs: LibraryTab[];
|
||||
activeTab: string;
|
||||
onTabChange: (tabId: string) => void;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
|
||||
fleetSummary?: FleetSummary;
|
||||
briefingAgents?: LibraryAgent[];
|
||||
}
|
||||
|
||||
export function LibraryAgentList({
|
||||
@@ -89,12 +81,7 @@ export function LibraryAgentList({
|
||||
tabs,
|
||||
activeTab,
|
||||
onTabChange,
|
||||
statusFilter = "all",
|
||||
onStatusFilterChange,
|
||||
fleetSummary,
|
||||
briefingAgents,
|
||||
}: Props) {
|
||||
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
|
||||
const shouldReduceMotion = useReducedMotion();
|
||||
const activeContainerVariants = shouldReduceMotion
|
||||
? reducedContainerVariants
|
||||
@@ -108,7 +95,7 @@ export function LibraryAgentList({
|
||||
const {
|
||||
isFavoritesTab,
|
||||
agentLoading,
|
||||
displayedCount,
|
||||
allAgentsCount,
|
||||
favoritesCount,
|
||||
agents,
|
||||
hasNextPage,
|
||||
@@ -129,37 +116,18 @@ export function LibraryAgentList({
|
||||
selectedFolderId,
|
||||
onFolderSelect,
|
||||
activeTab,
|
||||
statusFilter,
|
||||
});
|
||||
|
||||
const agentStatusMap = useAgentStatusMap(agents);
|
||||
|
||||
return (
|
||||
<>
|
||||
{isAgentBriefingEnabled &&
|
||||
!selectedFolderId &&
|
||||
fleetSummary &&
|
||||
briefingAgents &&
|
||||
briefingAgents.length > 0 && (
|
||||
<div className="mb-4">
|
||||
<AgentBriefingPanel
|
||||
summary={fleetSummary}
|
||||
agents={briefingAgents}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!selectedFolderId && (
|
||||
<LibrarySubSection
|
||||
tabs={tabs}
|
||||
activeTab={activeTab}
|
||||
onTabChange={onTabChange}
|
||||
allCount={displayedCount}
|
||||
allCount={allAgentsCount}
|
||||
favoritesCount={favoritesCount}
|
||||
setLibrarySort={setLibrarySort}
|
||||
statusFilter={statusFilter}
|
||||
onStatusFilterChange={onStatusFilterChange}
|
||||
fleetSummary={fleetSummary}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -251,13 +219,7 @@ export function LibraryAgentList({
|
||||
0.04,
|
||||
}}
|
||||
>
|
||||
<LibraryAgentCard
|
||||
agent={agent}
|
||||
statusInfo={getAgentStatus(
|
||||
agentStatusMap,
|
||||
agent.graph_id,
|
||||
)}
|
||||
/>
|
||||
<LibraryAgentCard agent={agent} />
|
||||
</motion.div>
|
||||
))}
|
||||
</motion.div>
|
||||
|
||||
@@ -21,12 +21,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useFavoriteAgents } from "../../hooks/useFavoriteAgents";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import type { AgentStatusFilter } from "../../types";
|
||||
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
|
||||
const FILTER_EXHAUST_THRESHOLD = 3;
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
interface Props {
|
||||
searchTerm: string;
|
||||
@@ -34,7 +29,6 @@ interface Props {
|
||||
selectedFolderId: string | null;
|
||||
onFolderSelect: (folderId: string | null) => void;
|
||||
activeTab: string;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
}
|
||||
|
||||
export function useLibraryAgentList({
|
||||
@@ -43,16 +37,12 @@ export function useLibraryAgentList({
|
||||
selectedFolderId,
|
||||
onFolderSelect,
|
||||
activeTab,
|
||||
statusFilter = "all",
|
||||
}: Props) {
|
||||
const isFavoritesTab = activeTab === "favorites";
|
||||
const { toast } = useToast();
|
||||
const stableQueryClient = getQueryClient();
|
||||
const queryClient = useQueryClient();
|
||||
const prevSortRef = useRef<LibraryAgentSort | null>(null);
|
||||
const [consecutiveEmptyPages, setConsecutiveEmptyPages] = useState(0);
|
||||
const prevFilteredLengthRef = useRef(0);
|
||||
const prevAgentsLengthRef = useRef(0);
|
||||
|
||||
const [editingFolder, setEditingFolder] = useState<LibraryFolder | null>(
|
||||
null,
|
||||
@@ -209,90 +199,6 @@ export function useLibraryAgentList({
|
||||
|
||||
const showFolders = !isFavoritesTab;
|
||||
|
||||
const { data: executions } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const { activeGraphIds, errorGraphIds, completedGraphIds } = useMemo(() => {
|
||||
const active = new Set<string>();
|
||||
const errors = new Set<string>();
|
||||
const completed = new Set<string>();
|
||||
const cutoff = Date.now() - 72 * 60 * 60 * 1000;
|
||||
for (const exec of executions ?? []) {
|
||||
if (
|
||||
exec.status === AgentExecutionStatus.RUNNING ||
|
||||
exec.status === AgentExecutionStatus.QUEUED ||
|
||||
exec.status === AgentExecutionStatus.REVIEW
|
||||
) {
|
||||
active.add(exec.graph_id);
|
||||
}
|
||||
const endedTs = exec.ended_at
|
||||
? exec.ended_at instanceof Date
|
||||
? exec.ended_at.getTime()
|
||||
: new Date(String(exec.ended_at)).getTime()
|
||||
: 0;
|
||||
if (
|
||||
(exec.status === AgentExecutionStatus.FAILED ||
|
||||
exec.status === AgentExecutionStatus.TERMINATED) &&
|
||||
endedTs > cutoff
|
||||
) {
|
||||
errors.add(exec.graph_id);
|
||||
}
|
||||
if (exec.status === AgentExecutionStatus.COMPLETED && endedTs > cutoff) {
|
||||
completed.add(exec.graph_id);
|
||||
}
|
||||
}
|
||||
return {
|
||||
activeGraphIds: active,
|
||||
errorGraphIds: errors,
|
||||
completedGraphIds: completed,
|
||||
};
|
||||
}, [executions]);
|
||||
|
||||
const filteredAgents = filterAgentsByStatus(
|
||||
agents,
|
||||
statusFilter,
|
||||
activeGraphIds,
|
||||
errorGraphIds,
|
||||
completedGraphIds,
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (statusFilter === "all") {
|
||||
setConsecutiveEmptyPages(0);
|
||||
prevFilteredLengthRef.current = filteredAgents.length;
|
||||
prevAgentsLengthRef.current = agents.length;
|
||||
return;
|
||||
}
|
||||
|
||||
if (agents.length > prevAgentsLengthRef.current) {
|
||||
const newFilteredCount = filteredAgents.length;
|
||||
const previousCount = prevFilteredLengthRef.current;
|
||||
|
||||
if (newFilteredCount > previousCount) {
|
||||
setConsecutiveEmptyPages(0);
|
||||
} else {
|
||||
setConsecutiveEmptyPages((prev) => prev + 1);
|
||||
}
|
||||
}
|
||||
|
||||
prevAgentsLengthRef.current = agents.length;
|
||||
prevFilteredLengthRef.current = filteredAgents.length;
|
||||
}, [agents.length, filteredAgents.length, statusFilter]);
|
||||
|
||||
useEffect(() => {
|
||||
setConsecutiveEmptyPages(0);
|
||||
prevFilteredLengthRef.current = 0;
|
||||
prevAgentsLengthRef.current = 0;
|
||||
}, [statusFilter]);
|
||||
|
||||
const filteredExhausted =
|
||||
statusFilter !== "all" && consecutiveEmptyPages >= FILTER_EXHAUST_THRESHOLD;
|
||||
|
||||
// When a filter is active, show the filtered count instead of the API total.
|
||||
const displayedCount =
|
||||
statusFilter === "all" ? allAgentsCount : filteredAgents.length;
|
||||
|
||||
function handleFolderDeleted() {
|
||||
if (selectedFolderId === deletingFolder?.id) {
|
||||
onFolderSelect(null);
|
||||
@@ -304,10 +210,9 @@ export function useLibraryAgentList({
|
||||
agentLoading,
|
||||
agentCount,
|
||||
allAgentsCount,
|
||||
displayedCount,
|
||||
favoritesCount: favoriteAgentsData.agentCount,
|
||||
agents: filteredAgents,
|
||||
hasNextPage: agentsHasNextPage && !filteredExhausted,
|
||||
agents,
|
||||
hasNextPage: agentsHasNextPage,
|
||||
isFetchingNextPage: agentsIsFetchingNextPage,
|
||||
fetchNextPage: agentsFetchNextPage,
|
||||
foldersData,
|
||||
@@ -321,46 +226,3 @@ export function useLibraryAgentList({
|
||||
handleFolderDeleted,
|
||||
};
|
||||
}
|
||||
|
||||
function filterAgentsByStatus<
|
||||
T extends {
|
||||
graph_id: string;
|
||||
has_external_trigger: boolean;
|
||||
recommended_schedule_cron?: string | null;
|
||||
},
|
||||
>(
|
||||
agents: T[],
|
||||
statusFilter: AgentStatusFilter,
|
||||
activeGraphIds: Set<string>,
|
||||
errorGraphIds: Set<string>,
|
||||
completedGraphIds: Set<string>,
|
||||
): T[] {
|
||||
if (statusFilter === "all") return agents;
|
||||
return agents.filter((agent) => {
|
||||
const isRunning = activeGraphIds.has(agent.graph_id);
|
||||
const hasError = errorGraphIds.has(agent.graph_id);
|
||||
|
||||
if (statusFilter === "running") return isRunning;
|
||||
if (statusFilter === "attention") return hasError && !isRunning;
|
||||
if (statusFilter === "completed")
|
||||
return completedGraphIds.has(agent.graph_id);
|
||||
if (statusFilter === "listening")
|
||||
return !isRunning && !hasError && agent.has_external_trigger;
|
||||
if (statusFilter === "scheduled")
|
||||
return (
|
||||
!isRunning &&
|
||||
!hasError &&
|
||||
!agent.has_external_trigger &&
|
||||
!!agent.recommended_schedule_cron
|
||||
);
|
||||
if (statusFilter === "idle")
|
||||
return (
|
||||
!isRunning &&
|
||||
!hasError &&
|
||||
!agent.has_external_trigger &&
|
||||
!agent.recommended_schedule_cron
|
||||
);
|
||||
if (statusFilter === "healthy") return !hasError;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { FolderIcon, FolderColor } from "./FolderIcon";
|
||||
import {
|
||||
FolderIcon,
|
||||
FolderColor,
|
||||
folderCardStyles,
|
||||
resolveColor,
|
||||
} from "./FolderIcon";
|
||||
import { useState } from "react";
|
||||
import { PencilSimpleIcon, TrashIcon } from "@phosphor-icons/react";
|
||||
import type { AgentStatus } from "../../types";
|
||||
import { StatusBadge } from "../StatusBadge/StatusBadge";
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
@@ -18,8 +21,6 @@ interface Props {
|
||||
onDelete?: () => void;
|
||||
onAgentDrop?: (agentId: string, folderId: string) => void;
|
||||
onClick?: () => void;
|
||||
/** Worst status among child agents (optional, for status aggregation). */
|
||||
worstStatus?: AgentStatus;
|
||||
}
|
||||
|
||||
export function LibraryFolder({
|
||||
@@ -32,10 +33,11 @@ export function LibraryFolder({
|
||||
onDelete,
|
||||
onAgentDrop,
|
||||
onClick,
|
||||
worstStatus,
|
||||
}: Props) {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const [isDragOver, setIsDragOver] = useState(false);
|
||||
const resolvedColor = resolveColor(color);
|
||||
const cardStyle = folderCardStyles[resolvedColor];
|
||||
|
||||
function handleDragOver(e: React.DragEvent<HTMLDivElement>) {
|
||||
if (e.dataTransfer.types.includes("application/agent-id")) {
|
||||
@@ -62,10 +64,10 @@ export function LibraryFolder({
|
||||
<div
|
||||
data-testid="library-folder"
|
||||
data-folder-id={id}
|
||||
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 shadow-sm backdrop-blur-md transition-all duration-200 hover:shadow-md ${
|
||||
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 transition-all duration-200 hover:shadow-md ${
|
||||
isDragOver
|
||||
? "border-blue-400 bg-blue-50 ring-2 ring-blue-200"
|
||||
: "border-indigo-200/40 bg-gradient-to-br from-indigo-50/40 via-white/70 to-purple-50/30"
|
||||
: `${cardStyle.border} ${cardStyle.bg}`
|
||||
}`}
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
@@ -74,7 +76,7 @@ export function LibraryFolder({
|
||||
onDrop={handleDrop}
|
||||
onClick={onClick}
|
||||
>
|
||||
<div className="flex w-full items-center justify-between gap-4">
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
{/* Left side - Folder name and agent count */}
|
||||
<div className="flex flex-1 flex-col gap-2">
|
||||
<Text
|
||||
@@ -84,22 +86,17 @@ export function LibraryFolder({
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-zinc-500"
|
||||
data-testid="library-folder-agent-count"
|
||||
>
|
||||
{agentCount} {agentCount === 1 ? "agent" : "agents"}
|
||||
</Text>
|
||||
{worstStatus && worstStatus !== "idle" && (
|
||||
<StatusBadge status={worstStatus} />
|
||||
)}
|
||||
</div>
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-zinc-500"
|
||||
data-testid="library-folder-agent-count"
|
||||
>
|
||||
{agentCount} {agentCount === 1 ? "agent" : "agents"}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{/* Right side - Custom folder icon */}
|
||||
<div className="relative top-5 flex flex-shrink-0 items-center">
|
||||
<div className="flex-shrink-0">
|
||||
<FolderIcon isOpen={isHovered} color={color} icon={icon} />
|
||||
</div>
|
||||
</div>
|
||||
@@ -117,7 +114,7 @@ export function LibraryFolder({
|
||||
e.stopPropagation();
|
||||
onEdit?.();
|
||||
}}
|
||||
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
|
||||
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
|
||||
>
|
||||
<PencilSimpleIcon className="h-4 w-4" />
|
||||
</Button>
|
||||
@@ -129,7 +126,7 @@ export function LibraryFolder({
|
||||
e.stopPropagation();
|
||||
onDelete?.();
|
||||
}}
|
||||
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
|
||||
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
|
||||
>
|
||||
<TrashIcon className="h-4 w-4" />
|
||||
</Button>
|
||||
|
||||
@@ -19,11 +19,11 @@ export function LibrarySortMenu({ setLibrarySort }: Props) {
|
||||
const { handleSortChange } = useLibrarySortMenu({ setLibrarySort });
|
||||
return (
|
||||
<div className="flex items-center" data-testid="sort-by-dropdown">
|
||||
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
|
||||
<span className="hidden whitespace-nowrap text-sm sm:inline">
|
||||
sort by
|
||||
</span>
|
||||
<Select onValueChange={handleSortChange}>
|
||||
<SelectTrigger className="!m-0 ml-1 w-fit space-x-1 border-none !bg-transparent px-[1rem] text-sm underline underline-offset-4 !shadow-none !ring-offset-transparent">
|
||||
<SelectTrigger className="ml-1 w-fit space-x-1 border-none px-0 text-sm underline underline-offset-4 shadow-none">
|
||||
<ArrowDownNarrowWideIcon className="h-4 w-4 sm:hidden" />
|
||||
<SelectValue placeholder="Last Modified" />
|
||||
</SelectTrigger>
|
||||
|
||||
@@ -6,10 +6,9 @@ import {
|
||||
} from "@/components/molecules/TabsLine/TabsLine";
|
||||
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
|
||||
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
|
||||
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
|
||||
import { LibraryTab } from "../../types";
|
||||
import LibraryFolderCreationDialog from "../LibraryFolderCreationDialog/LibraryFolderCreationDialog";
|
||||
import { LibrarySortMenu } from "../LibrarySortMenu/LibrarySortMenu";
|
||||
import { AgentFilterMenu } from "../AgentFilterMenu/AgentFilterMenu";
|
||||
|
||||
interface Props {
|
||||
tabs: LibraryTab[];
|
||||
@@ -18,9 +17,6 @@ interface Props {
|
||||
allCount: number;
|
||||
favoritesCount: number;
|
||||
setLibrarySort: (value: LibraryAgentSort) => void;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
|
||||
fleetSummary?: FleetSummary;
|
||||
}
|
||||
|
||||
export function LibrarySubSection({
|
||||
@@ -30,9 +26,6 @@ export function LibrarySubSection({
|
||||
allCount,
|
||||
favoritesCount,
|
||||
setLibrarySort,
|
||||
statusFilter = "all",
|
||||
onStatusFilterChange,
|
||||
fleetSummary,
|
||||
}: Props) {
|
||||
const { registerFavoritesTabRef } = useFavoriteAnimation();
|
||||
const favoritesRef = useRef<HTMLButtonElement>(null);
|
||||
@@ -75,15 +68,8 @@ export function LibrarySubSection({
|
||||
))}
|
||||
</TabsLineList>
|
||||
</TabsLine>
|
||||
<div className="relative top-1.5 hidden items-center gap-6 md:flex">
|
||||
<div className="hidden items-center gap-6 md:flex">
|
||||
<LibraryFolderCreationDialog />
|
||||
{fleetSummary && onStatusFilterChange && (
|
||||
<AgentFilterMenu
|
||||
value={statusFilter}
|
||||
onChange={onStatusFilterChange}
|
||||
summary={fleetSummary}
|
||||
/>
|
||||
)}
|
||||
<LibrarySortMenu setLibrarySort={setLibrarySort} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
.spinner {
|
||||
aspect-ratio: 1;
|
||||
border-radius: 50%;
|
||||
background:
|
||||
radial-gradient(farthest-side, currentColor 94%, #0000) top/3px 3px
|
||||
no-repeat,
|
||||
conic-gradient(#0000 30%, currentColor);
|
||||
-webkit-mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
|
||||
mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
|
||||
animation: spin 1s infinite linear;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
100% {
|
||||
transform: rotate(1turn);
|
||||
}
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { OverflowText } from "@/components/atoms/OverflowText/OverflowText";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
WarningCircleIcon,
|
||||
ClockCountdownIcon,
|
||||
CheckCircleIcon,
|
||||
ChatCircleDotsIcon,
|
||||
EarIcon,
|
||||
CalendarDotsIcon,
|
||||
MoonIcon,
|
||||
EyeIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { SitrepItemData, SitrepPriority } from "../../types";
|
||||
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
|
||||
import styles from "./SitrepItem.module.css";
|
||||
|
||||
interface Props {
|
||||
item: SitrepItemData;
|
||||
}
|
||||
|
||||
const PRIORITY_CONFIG: Record<
|
||||
SitrepPriority,
|
||||
{
|
||||
icon?: typeof WarningCircleIcon;
|
||||
color: string;
|
||||
bg: string;
|
||||
cssSpinner?: boolean;
|
||||
}
|
||||
> = {
|
||||
error: {
|
||||
icon: WarningCircleIcon,
|
||||
color: "text-red-500",
|
||||
bg: "bg-red-50",
|
||||
},
|
||||
running: {
|
||||
color: "text-zinc-800",
|
||||
bg: "",
|
||||
cssSpinner: true,
|
||||
},
|
||||
stale: {
|
||||
icon: ClockCountdownIcon,
|
||||
color: "text-yellow-600",
|
||||
bg: "bg-yellow-50",
|
||||
},
|
||||
success: {
|
||||
icon: CheckCircleIcon,
|
||||
color: "text-green-600",
|
||||
bg: "bg-green-50",
|
||||
},
|
||||
listening: {
|
||||
icon: EarIcon,
|
||||
color: "text-purple-500",
|
||||
bg: "bg-purple-50",
|
||||
},
|
||||
scheduled: {
|
||||
icon: CalendarDotsIcon,
|
||||
color: "text-yellow-600",
|
||||
bg: "bg-yellow-50",
|
||||
},
|
||||
idle: {
|
||||
icon: MoonIcon,
|
||||
color: "text-zinc-400",
|
||||
bg: "bg-zinc-100",
|
||||
},
|
||||
};
|
||||
|
||||
export function SitrepItem({ item }: Props) {
|
||||
const config = PRIORITY_CONFIG[item.priority];
|
||||
const router = useRouter();
|
||||
|
||||
function handleAskAutoPilot() {
|
||||
const prompt = buildAutoPilotPrompt(item);
|
||||
const encoded = encodeURIComponent(prompt);
|
||||
router.push(`/copilot?autosubmit=true#prompt=${encoded}`);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-2 rounded-medium border border-zinc-200/50 bg-transparent p-2 sm:flex-row sm:items-center sm:gap-3",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 flex-1 items-center gap-3">
|
||||
{item.agentImageUrl ? (
|
||||
<img
|
||||
src={item.agentImageUrl}
|
||||
alt={item.agentName}
|
||||
className="h-6 w-6 flex-shrink-0 rounded-full object-cover"
|
||||
/>
|
||||
) : (
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-6 w-6 flex-shrink-0 items-center justify-center rounded-full",
|
||||
config.bg,
|
||||
)}
|
||||
>
|
||||
{config.cssSpinner ? (
|
||||
<div
|
||||
className={cn(
|
||||
styles.spinner,
|
||||
"h-[21px] w-[21px] text-zinc-800",
|
||||
)}
|
||||
/>
|
||||
) : (
|
||||
config.icon && (
|
||||
<config.icon size={14} className={config.color} weight="fill" />
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="min-w-0 flex-1">
|
||||
<Text variant="body-medium" className="leading-tight text-zinc-900">
|
||||
{item.agentName}
|
||||
</Text>
|
||||
<OverflowText
|
||||
value={item.message}
|
||||
variant="small"
|
||||
className="leading-tight text-zinc-500"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-shrink-0 flex-wrap items-center justify-center gap-1.5 sm:flex-nowrap sm:justify-end">
|
||||
{item.priority === "success" ? (
|
||||
<NextLink
|
||||
href={`/library/agents/${item.agentID}${item.executionID ? `?activeItem=${item.executionID}` : ""}`}
|
||||
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
|
||||
>
|
||||
<EyeIcon size={14} className="shrink-0" />
|
||||
See task
|
||||
</NextLink>
|
||||
) : (
|
||||
<ContextualActionButton
|
||||
status={item.status}
|
||||
agentID={item.agentID}
|
||||
executionID={item.executionID}
|
||||
/>
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleAskAutoPilot}
|
||||
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
|
||||
>
|
||||
<ChatCircleDotsIcon size={14} className="shrink-0" />
|
||||
Ask AutoPilot
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function buildAutoPilotPrompt(item: SitrepItemData): string {
|
||||
switch (item.priority) {
|
||||
case "error":
|
||||
return `What happened with ${item.agentName}? It says "${item.message}" — can you check the logs and tell me what to fix?`;
|
||||
case "running":
|
||||
return `Give me a status update on the ${item.agentName} run — what has it found so far?`;
|
||||
case "stale":
|
||||
return `${item.agentName} hasn't run recently. Should I keep it or update and re-run it?`;
|
||||
case "success":
|
||||
return `Show me what ${item.agentName} found in its last run — summarize the results and any key takeaways.`;
|
||||
case "listening":
|
||||
return `What is ${item.agentName} listening for? Give me a summary of its trigger configuration.`;
|
||||
case "scheduled":
|
||||
return `When is ${item.agentName} scheduled to run next?`;
|
||||
case "idle":
|
||||
return `${item.agentName} has been idle. Should I keep it or update and re-run it?`;
|
||||
}
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ClockCounterClockwise } from "@phosphor-icons/react";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useSitrepItems } from "./useSitrepItems";
|
||||
import { SitrepItem } from "./SitrepItem";
|
||||
|
||||
interface Props {
|
||||
agents: LibraryAgent[];
|
||||
maxItems?: number;
|
||||
}
|
||||
|
||||
export function SitrepList({ agents, maxItems = 10 }: Props) {
|
||||
const items = useSitrepItems(agents, maxItems);
|
||||
|
||||
if (items.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="mb-2 flex items-center gap-1.5">
|
||||
<ClockCounterClockwise size={16} className="text-zinc-700" />
|
||||
<Text variant="body-medium" className="text-zinc-700">
|
||||
Recent tasks
|
||||
</Text>
|
||||
</div>
|
||||
<div className="grid grid-cols-1 gap-1 lg:grid-cols-2">
|
||||
{items.map((item) => (
|
||||
<SitrepItem key={item.id} item={item} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useMemo } from "react";
|
||||
import type { SitrepItemData, SitrepPriority } from "../../types";
|
||||
import {
|
||||
isActive,
|
||||
isFailed,
|
||||
toEndTime,
|
||||
endedAfter,
|
||||
runningMessage,
|
||||
SEVENTY_TWO_HOURS_MS,
|
||||
} from "../../hooks/executionHelpers";
|
||||
|
||||
export function useSitrepItems(
|
||||
agents: LibraryAgent[],
|
||||
maxItems: number,
|
||||
scheduledWithinMs?: number,
|
||||
): SitrepItemData[] {
|
||||
const { data: executions } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
return useMemo(() => {
|
||||
if (agents.length === 0) return [];
|
||||
|
||||
const graphIdToAgent = new Map(agents.map((a) => [a.graph_id, a]));
|
||||
const agentExecutions = groupByAgent(executions ?? [], graphIdToAgent);
|
||||
const items: SitrepItemData[] = [];
|
||||
const coveredAgentIds = new Set<string>();
|
||||
|
||||
for (const [agent, execs] of agentExecutions) {
|
||||
const item = buildSitrepFromExecutions(agent, execs);
|
||||
if (item) {
|
||||
items.push(item);
|
||||
coveredAgentIds.add(agent.id);
|
||||
}
|
||||
}
|
||||
|
||||
for (const agent of agents) {
|
||||
if (coveredAgentIds.has(agent.id)) continue;
|
||||
const configItem = buildSitrepFromConfig(agent, scheduledWithinMs);
|
||||
if (configItem) items.push(configItem);
|
||||
}
|
||||
|
||||
const order: Record<SitrepPriority, number> = {
|
||||
error: 0,
|
||||
running: 1,
|
||||
stale: 2,
|
||||
success: 3,
|
||||
listening: 4,
|
||||
scheduled: 5,
|
||||
idle: 6,
|
||||
};
|
||||
items.sort((a, b) => order[a.priority] - order[b.priority]);
|
||||
|
||||
return items.slice(0, maxItems);
|
||||
}, [agents, executions, maxItems, scheduledWithinMs]);
|
||||
}
|
||||
|
||||
function groupByAgent(
|
||||
executions: GraphExecutionMeta[],
|
||||
graphIdToAgent: Map<string, LibraryAgent>,
|
||||
): Map<LibraryAgent, GraphExecutionMeta[]> {
|
||||
const map = new Map<LibraryAgent, GraphExecutionMeta[]>();
|
||||
|
||||
for (const exec of executions) {
|
||||
const agent = graphIdToAgent.get(exec.graph_id);
|
||||
if (!agent) continue;
|
||||
const list = map.get(agent);
|
||||
if (list) {
|
||||
list.push(exec);
|
||||
} else {
|
||||
map.set(agent, [exec]);
|
||||
}
|
||||
}
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
function buildSitrepFromExecutions(
|
||||
agent: LibraryAgent,
|
||||
executions: GraphExecutionMeta[],
|
||||
): SitrepItemData | null {
|
||||
const active = executions.find((e) => isActive(e.status));
|
||||
if (active) {
|
||||
return {
|
||||
id: `${agent.id}-${active.id}`,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
executionID: active.id,
|
||||
priority: "running",
|
||||
message:
|
||||
active.stats?.activity_status ??
|
||||
runningMessage(active.status, active.started_at),
|
||||
status: "running",
|
||||
};
|
||||
}
|
||||
|
||||
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
|
||||
const recent = executions
|
||||
.filter((e) => endedAfter(e, cutoff))
|
||||
.sort((a, b) => toEndTime(b) - toEndTime(a));
|
||||
|
||||
const lastFailed = recent.find((e) => isFailed(e.status));
|
||||
if (lastFailed) {
|
||||
const errorMsg =
|
||||
lastFailed.stats?.error ??
|
||||
lastFailed.stats?.activity_status ??
|
||||
"Execution failed";
|
||||
return {
|
||||
id: `${agent.id}-${lastFailed.id}`,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
executionID: lastFailed.id,
|
||||
priority: "error",
|
||||
message: typeof errorMsg === "string" ? errorMsg : "Execution failed",
|
||||
status: "error",
|
||||
};
|
||||
}
|
||||
|
||||
const lastCompleted = recent.find(
|
||||
(e) => e.status === AgentExecutionStatus.COMPLETED,
|
||||
);
|
||||
if (lastCompleted) {
|
||||
const summary =
|
||||
lastCompleted.stats?.activity_status ?? "Completed successfully";
|
||||
return {
|
||||
id: `${agent.id}-${lastCompleted.id}`,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
executionID: lastCompleted.id,
|
||||
priority: "success",
|
||||
message: typeof summary === "string" ? summary : "Completed successfully",
|
||||
status: "idle",
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function buildSitrepFromConfig(
|
||||
agent: LibraryAgent,
|
||||
scheduledWithinMs?: number,
|
||||
): SitrepItemData | null {
|
||||
if (agent.has_external_trigger) {
|
||||
return {
|
||||
id: `${agent.id}-listening`,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
priority: "listening",
|
||||
message: "Waiting for trigger event",
|
||||
status: "listening",
|
||||
};
|
||||
}
|
||||
|
||||
if (agent.is_scheduled || agent.recommended_schedule_cron) {
|
||||
if (!isNextRunWithin(agent.next_scheduled_run, scheduledWithinMs)) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
id: `${agent.id}-scheduled`,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
priority: "scheduled",
|
||||
message: formatNextRun(agent.next_scheduled_run),
|
||||
status: "scheduled",
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function isNextRunWithin(
|
||||
iso: string | undefined | null,
|
||||
windowMs: number | undefined,
|
||||
): boolean {
|
||||
if (windowMs === undefined) return true;
|
||||
if (!iso) return false;
|
||||
const diff = new Date(iso).getTime() - Date.now();
|
||||
return diff <= windowMs;
|
||||
}
|
||||
|
||||
function formatNextRun(iso: string | undefined | null): string {
|
||||
if (!iso) return "Has a scheduled run";
|
||||
const diff = new Date(iso).getTime() - Date.now();
|
||||
const minutes = Math.round(diff / 60_000);
|
||||
if (minutes <= 0) return "Scheduled to run soon";
|
||||
if (minutes < 60) return `Scheduled to run in ${minutes}m`;
|
||||
const hours = Math.round(minutes / 60);
|
||||
if (hours < 24) return `Scheduled to run in ${hours}h`;
|
||||
const days = Math.round(hours / 24);
|
||||
return `Scheduled to run in ${days}d`;
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { AgentStatus } from "../../types";
|
||||
|
||||
const STATUS_CONFIG: Record<
|
||||
AgentStatus,
|
||||
{ label: string; bg: string; text: string; pulse: boolean }
|
||||
> = {
|
||||
running: {
|
||||
label: "Running",
|
||||
bg: "",
|
||||
text: "text-blue-600",
|
||||
pulse: true,
|
||||
},
|
||||
error: {
|
||||
label: "Error",
|
||||
bg: "",
|
||||
text: "text-red-500",
|
||||
pulse: false,
|
||||
},
|
||||
listening: {
|
||||
label: "Listening",
|
||||
bg: "",
|
||||
text: "text-purple-500",
|
||||
pulse: true,
|
||||
},
|
||||
scheduled: {
|
||||
label: "Scheduled",
|
||||
bg: "",
|
||||
text: "text-yellow-600",
|
||||
pulse: false,
|
||||
},
|
||||
idle: {
|
||||
label: "Idle",
|
||||
bg: "",
|
||||
text: "text-zinc-500",
|
||||
pulse: false,
|
||||
},
|
||||
};
|
||||
|
||||
interface Props {
|
||||
status: AgentStatus;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function StatusBadge({ status, className }: Props) {
|
||||
const config = STATUS_CONFIG[status];
|
||||
|
||||
return (
|
||||
<span
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium",
|
||||
config.bg,
|
||||
config.text,
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
"inline-block h-1.5 w-1.5 rounded-full",
|
||||
config.pulse && "animate-pulse",
|
||||
statusDotColor(status),
|
||||
)}
|
||||
/>
|
||||
{config.label}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
function statusDotColor(status: AgentStatus): string {
|
||||
switch (status) {
|
||||
case "running":
|
||||
return "bg-blue-500";
|
||||
case "error":
|
||||
return "bg-red-500";
|
||||
case "listening":
|
||||
return "bg-purple-500";
|
||||
case "scheduled":
|
||||
return "bg-yellow-500";
|
||||
case "idle":
|
||||
return "bg-zinc-400";
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||
|
||||
export const SEVENTY_TWO_HOURS_MS = 72 * 60 * 60 * 1000;
|
||||
|
||||
export function isActive(status: string): boolean {
|
||||
return (
|
||||
status === AgentExecutionStatus.RUNNING ||
|
||||
status === AgentExecutionStatus.QUEUED ||
|
||||
status === AgentExecutionStatus.REVIEW
|
||||
);
|
||||
}
|
||||
|
||||
export function isFailed(status: string): boolean {
|
||||
return (
|
||||
status === AgentExecutionStatus.FAILED ||
|
||||
status === AgentExecutionStatus.TERMINATED
|
||||
);
|
||||
}
|
||||
|
||||
export function toEndTime(exec: GraphExecutionMeta): number {
|
||||
if (!exec.ended_at) return 0;
|
||||
return exec.ended_at instanceof Date
|
||||
? exec.ended_at.getTime()
|
||||
: new Date(exec.ended_at).getTime();
|
||||
}
|
||||
|
||||
export function endedAfter(exec: GraphExecutionMeta, cutoff: number): boolean {
|
||||
if (!exec.ended_at) return false;
|
||||
return toEndTime(exec) > cutoff;
|
||||
}
|
||||
|
||||
export function runningMessage(
|
||||
status: string,
|
||||
startedAt?: string | Date | null,
|
||||
): string {
|
||||
if (status === AgentExecutionStatus.QUEUED) return "Queued for execution";
|
||||
if (status === AgentExecutionStatus.REVIEW) return "Awaiting review";
|
||||
if (!startedAt) return "Currently executing";
|
||||
const ms =
|
||||
Date.now() -
|
||||
(startedAt instanceof Date
|
||||
? startedAt.getTime()
|
||||
: new Date(startedAt).getTime());
|
||||
return `Running for ${formatRelativeDuration(ms)}`;
|
||||
}
|
||||
|
||||
export function formatRelativeDuration(ms: number): string {
|
||||
const seconds = Math.floor(ms / 1000);
|
||||
if (seconds < 60) return "a few seconds";
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
if (minutes < 60) return `${minutes}m`;
|
||||
const hours = Math.floor(minutes / 60);
|
||||
const remainingMin = minutes % 60;
|
||||
if (hours < 24)
|
||||
return remainingMin > 0 ? `${hours}h ${remainingMin}m` : `${hours}h`;
|
||||
const days = Math.floor(hours / 24);
|
||||
return `${days}d ${hours % 24}h`;
|
||||
}
|
||||
@@ -1,213 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import type {
|
||||
AgentStatus,
|
||||
AgentHealth,
|
||||
AgentStatusInfo,
|
||||
FleetSummary,
|
||||
} from "../types";
|
||||
import {
|
||||
isActive,
|
||||
isFailed,
|
||||
toEndTime,
|
||||
SEVENTY_TWO_HOURS_MS,
|
||||
} from "./executionHelpers";
|
||||
|
||||
function deriveHealth(
|
||||
status: AgentStatus,
|
||||
lastRunAt: string | null,
|
||||
): AgentHealth {
|
||||
if (status === "error") return "attention";
|
||||
if (status === "idle" && lastRunAt) {
|
||||
const daysSince =
|
||||
(Date.now() - new Date(lastRunAt).getTime()) / (1000 * 60 * 60 * 24);
|
||||
if (daysSince > 14) return "stale";
|
||||
}
|
||||
return "good";
|
||||
}
|
||||
|
||||
function computeAgentStatus(
|
||||
agent: LibraryAgent,
|
||||
agentExecutions: GraphExecutionMeta[],
|
||||
): AgentStatusInfo {
|
||||
const activeExec = agentExecutions.find((e) => isActive(e.status));
|
||||
|
||||
let status: AgentStatus;
|
||||
let lastError: string | null = null;
|
||||
let lastRunAt: string | null = null;
|
||||
const activeExecutionID = activeExec?.id ?? null;
|
||||
|
||||
if (activeExec) {
|
||||
status = "running";
|
||||
} else {
|
||||
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
|
||||
const recentFailed = agentExecutions.find(
|
||||
(e) =>
|
||||
isFailed(e.status) &&
|
||||
e.ended_at &&
|
||||
new Date(
|
||||
e.ended_at instanceof Date ? e.ended_at.getTime() : e.ended_at,
|
||||
).getTime() > cutoff,
|
||||
);
|
||||
|
||||
if (recentFailed) {
|
||||
status = "error";
|
||||
lastError =
|
||||
(recentFailed.stats?.error as string) ??
|
||||
(recentFailed.stats?.activity_status as string) ??
|
||||
"Execution failed";
|
||||
} else if (agent.has_external_trigger) {
|
||||
status = "listening";
|
||||
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
|
||||
status = "scheduled";
|
||||
} else {
|
||||
status = "idle";
|
||||
}
|
||||
}
|
||||
|
||||
const completedExecs = agentExecutions.filter((e) => e.ended_at);
|
||||
if (completedExecs.length > 0) {
|
||||
const sorted = completedExecs.sort((a, b) => toEndTime(b) - toEndTime(a));
|
||||
const endedAt = sorted[0].ended_at;
|
||||
lastRunAt =
|
||||
endedAt instanceof Date ? endedAt.toISOString() : String(endedAt);
|
||||
}
|
||||
|
||||
const totalRuns = agent.execution_count ?? agentExecutions.length;
|
||||
|
||||
return {
|
||||
status,
|
||||
health: deriveHealth(status, lastRunAt),
|
||||
progress: null,
|
||||
totalRuns,
|
||||
lastRunAt,
|
||||
lastError,
|
||||
activeExecutionID,
|
||||
monthlySpend: 0,
|
||||
nextScheduledRun: null,
|
||||
triggerType: agent.has_external_trigger ? "webhook" : null,
|
||||
};
|
||||
}
|
||||
|
||||
export function useAgentStatusMap(
|
||||
agents: LibraryAgent[],
|
||||
): Map<string, AgentStatusInfo> {
|
||||
const { data: executions } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
return useMemo(() => {
|
||||
const map = new Map<string, AgentStatusInfo>();
|
||||
const execsByGraph = new Map<string, GraphExecutionMeta[]>();
|
||||
|
||||
for (const exec of executions ?? []) {
|
||||
const list = execsByGraph.get(exec.graph_id);
|
||||
if (list) {
|
||||
list.push(exec);
|
||||
} else {
|
||||
execsByGraph.set(exec.graph_id, [exec]);
|
||||
}
|
||||
}
|
||||
|
||||
for (const agent of agents) {
|
||||
const agentExecs = execsByGraph.get(agent.graph_id) ?? [];
|
||||
map.set(agent.graph_id, computeAgentStatus(agent, agentExecs));
|
||||
}
|
||||
|
||||
return map;
|
||||
}, [agents, executions]);
|
||||
}
|
||||
|
||||
const DEFAULT_STATUS: AgentStatusInfo = {
|
||||
status: "idle",
|
||||
health: "good",
|
||||
progress: null,
|
||||
totalRuns: 0,
|
||||
lastRunAt: null,
|
||||
lastError: null,
|
||||
activeExecutionID: null,
|
||||
monthlySpend: 0,
|
||||
nextScheduledRun: null,
|
||||
triggerType: null,
|
||||
};
|
||||
|
||||
export function getAgentStatus(
|
||||
statusMap: Map<string, AgentStatusInfo>,
|
||||
graphID: string,
|
||||
): AgentStatusInfo {
|
||||
return statusMap.get(graphID) ?? DEFAULT_STATUS;
|
||||
}
|
||||
|
||||
export function useFleetSummary(agents: LibraryAgent[]): FleetSummary {
|
||||
const { data: executions } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
return useMemo(() => {
|
||||
const counts: FleetSummary = {
|
||||
running: 0,
|
||||
error: 0,
|
||||
completed: 0,
|
||||
listening: 0,
|
||||
scheduled: 0,
|
||||
idle: 0,
|
||||
monthlySpend: 0,
|
||||
};
|
||||
|
||||
const activeGraphIds = new Set<string>();
|
||||
const errorGraphIds = new Set<string>();
|
||||
const completedGraphIds = new Set<string>();
|
||||
|
||||
if (executions) {
|
||||
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
|
||||
for (const exec of executions) {
|
||||
if (isActive(exec.status)) {
|
||||
activeGraphIds.add(exec.graph_id);
|
||||
}
|
||||
const endedTs = exec.ended_at
|
||||
? new Date(
|
||||
exec.ended_at instanceof Date
|
||||
? exec.ended_at.getTime()
|
||||
: exec.ended_at,
|
||||
).getTime()
|
||||
: 0;
|
||||
if (isFailed(exec.status) && endedTs > cutoff) {
|
||||
errorGraphIds.add(exec.graph_id);
|
||||
}
|
||||
if (
|
||||
exec.status === AgentExecutionStatus.COMPLETED &&
|
||||
endedTs > cutoff
|
||||
) {
|
||||
completedGraphIds.add(exec.graph_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const agent of agents) {
|
||||
if (activeGraphIds.has(agent.graph_id)) {
|
||||
counts.running += 1;
|
||||
} else if (errorGraphIds.has(agent.graph_id)) {
|
||||
counts.error += 1;
|
||||
} else if (agent.has_external_trigger) {
|
||||
counts.listening += 1;
|
||||
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
|
||||
counts.scheduled += 1;
|
||||
} else {
|
||||
counts.idle += 1;
|
||||
}
|
||||
if (completedGraphIds.has(agent.graph_id)) {
|
||||
counts.completed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return counts;
|
||||
}, [agents, executions]);
|
||||
}
|
||||
|
||||
export { deriveHealth };
|
||||
@@ -1,116 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
getGetV1ListAllExecutionsQueryKey,
|
||||
useGetV1ListAllExecutions,
|
||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useExecutionEvents } from "@/hooks/useExecutionEvents";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import type { FleetSummary } from "../types";
|
||||
import { isActive, isFailed, SEVENTY_TWO_HOURS_MS } from "./executionHelpers";
|
||||
|
||||
function isRecentFailure(
|
||||
status: string,
|
||||
endedAt?: string | Date | null,
|
||||
): boolean {
|
||||
if (!isFailed(status)) return false;
|
||||
if (!endedAt) return false;
|
||||
const ts =
|
||||
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
|
||||
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
|
||||
}
|
||||
|
||||
function isRecentCompletion(
|
||||
status: string,
|
||||
endedAt?: string | Date | null,
|
||||
): boolean {
|
||||
if (status !== AgentExecutionStatus.COMPLETED) return false;
|
||||
if (!endedAt) return false;
|
||||
const ts =
|
||||
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
|
||||
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
|
||||
}
|
||||
|
||||
export function useLibraryFleetSummary(
|
||||
agents: LibraryAgent[],
|
||||
): FleetSummary | undefined {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { data: executions, isSuccess } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const graphIDs = useMemo(() => agents.map((a) => a.graph_id), [agents]);
|
||||
|
||||
const handleExecutionUpdate = useCallback(() => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV1ListAllExecutionsQueryKey(),
|
||||
});
|
||||
}, [queryClient]);
|
||||
|
||||
useExecutionEvents({
|
||||
graphIds: graphIDs.length > 0 ? graphIDs : undefined,
|
||||
enabled: graphIDs.length > 0,
|
||||
onExecutionUpdate: handleExecutionUpdate,
|
||||
});
|
||||
|
||||
return useMemo(() => {
|
||||
if (!isSuccess || !executions) return undefined;
|
||||
|
||||
const agentsWithActiveExecution = new Set<string>();
|
||||
const agentsWithRecentFailure = new Set<string>();
|
||||
const agentsWithRecentCompletion = new Set<string>();
|
||||
|
||||
for (const exec of executions) {
|
||||
if (isActive(exec.status)) {
|
||||
agentsWithActiveExecution.add(exec.graph_id);
|
||||
}
|
||||
if (isRecentFailure(exec.status, exec.ended_at)) {
|
||||
agentsWithRecentFailure.add(exec.graph_id);
|
||||
}
|
||||
if (isRecentCompletion(exec.status, exec.ended_at)) {
|
||||
agentsWithRecentCompletion.add(exec.graph_id);
|
||||
}
|
||||
}
|
||||
|
||||
const summary: FleetSummary = {
|
||||
running: 0,
|
||||
error: 0,
|
||||
completed: 0,
|
||||
listening: 0,
|
||||
scheduled: 0,
|
||||
idle: 0,
|
||||
monthlySpend: 0,
|
||||
};
|
||||
|
||||
for (const agent of agents) {
|
||||
if (agentsWithActiveExecution.has(agent.graph_id)) {
|
||||
summary.running += 1;
|
||||
} else if (agentsWithRecentFailure.has(agent.graph_id)) {
|
||||
summary.error += 1;
|
||||
} else if (agent.has_external_trigger) {
|
||||
summary.listening += 1;
|
||||
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
|
||||
summary.scheduled += 1;
|
||||
} else {
|
||||
summary.idle += 1;
|
||||
}
|
||||
// Parallel counter: mutually exclusive with running/error (which match
|
||||
// the sitrep priority order used by the "Recently completed" tab list)
|
||||
// but orthogonal to listening/scheduled/idle.
|
||||
if (
|
||||
!agentsWithActiveExecution.has(agent.graph_id) &&
|
||||
!agentsWithRecentFailure.has(agent.graph_id) &&
|
||||
agentsWithRecentCompletion.has(agent.graph_id)
|
||||
) {
|
||||
summary.completed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return summary;
|
||||
}, [agents, executions, isSuccess]);
|
||||
}
|
||||
@@ -2,14 +2,12 @@
|
||||
|
||||
import { useEffect, useState, useCallback } from "react";
|
||||
import { HeartIcon, ListIcon } from "@phosphor-icons/react";
|
||||
import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn";
|
||||
import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader";
|
||||
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
|
||||
import { useLibraryListPage } from "./components/useLibraryListPage";
|
||||
import { FavoriteAnimationProvider } from "./context/FavoriteAnimationContext";
|
||||
import type { LibraryTab, AgentStatusFilter } from "./types";
|
||||
import { useLibraryFleetSummary } from "./hooks/useLibraryFleetSummary";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
|
||||
import { LibraryTab } from "./types";
|
||||
|
||||
const LIBRARY_TABS: LibraryTab[] = [
|
||||
{ id: "all", title: "All", icon: ListIcon },
|
||||
@@ -21,10 +19,6 @@ export default function LibraryPage() {
|
||||
useLibraryListPage();
|
||||
const [selectedFolderId, setSelectedFolderId] = useState<string | null>(null);
|
||||
const [activeTab, setActiveTab] = useState(LIBRARY_TABS[0].id);
|
||||
const [statusFilter, setStatusFilter] = useState<AgentStatusFilter>("all");
|
||||
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
|
||||
const { agents } = useLibraryAgents();
|
||||
const fleetSummary = useLibraryFleetSummary(agents);
|
||||
|
||||
useEffect(() => {
|
||||
document.title = "Library – AutoGPT Platform";
|
||||
@@ -46,6 +40,7 @@ export default function LibraryPage() {
|
||||
>
|
||||
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
|
||||
<LibraryActionHeader setSearchTerm={setSearchTerm} />
|
||||
<JumpBackIn />
|
||||
<LibraryAgentList
|
||||
searchTerm={searchTerm}
|
||||
librarySort={librarySort}
|
||||
@@ -55,10 +50,6 @@ export default function LibraryPage() {
|
||||
tabs={LIBRARY_TABS}
|
||||
activeTab={activeTab}
|
||||
onTabChange={handleTabChange}
|
||||
statusFilter={statusFilter}
|
||||
onStatusFilterChange={setStatusFilter}
|
||||
fleetSummary={isAgentBriefingEnabled ? fleetSummary : undefined}
|
||||
briefingAgents={isAgentBriefingEnabled ? agents : undefined}
|
||||
/>
|
||||
</main>
|
||||
</FavoriteAnimationProvider>
|
||||
|
||||
@@ -1,76 +1,7 @@
|
||||
import type { Icon } from "@phosphor-icons/react";
|
||||
import { Icon } from "@phosphor-icons/react";
|
||||
|
||||
export interface LibraryTab {
|
||||
id: string;
|
||||
title: string;
|
||||
icon: Icon;
|
||||
}
|
||||
|
||||
/** Agent execution status — drives StatusBadge visuals & filtering. */
|
||||
export type AgentStatus =
|
||||
| "running"
|
||||
| "error"
|
||||
| "listening"
|
||||
| "scheduled"
|
||||
| "idle";
|
||||
|
||||
/** Derived health bucket for quick triage. */
|
||||
export type AgentHealth = "good" | "attention" | "stale";
|
||||
|
||||
/** Real-time metadata that powers the Intelligence Layer features. */
|
||||
export interface AgentStatusInfo {
|
||||
status: AgentStatus;
|
||||
health: AgentHealth;
|
||||
/** 0-100 progress for currently running agents. */
|
||||
progress: number | null;
|
||||
totalRuns: number;
|
||||
lastRunAt: string | null;
|
||||
lastError: string | null;
|
||||
/** ID of the currently active execution (when status is "running"). */
|
||||
activeExecutionID: string | null;
|
||||
monthlySpend: number;
|
||||
nextScheduledRun: string | null;
|
||||
triggerType: string | null;
|
||||
}
|
||||
|
||||
/** Fleet-wide aggregate counts used by the Briefing Panel stats grid. */
|
||||
export interface FleetSummary {
|
||||
running: number;
|
||||
error: number;
|
||||
completed: number;
|
||||
listening: number;
|
||||
scheduled: number;
|
||||
idle: number;
|
||||
monthlySpend: number;
|
||||
}
|
||||
|
||||
export type SitrepPriority =
|
||||
| "error"
|
||||
| "running"
|
||||
| "stale"
|
||||
| "success"
|
||||
| "listening"
|
||||
| "scheduled"
|
||||
| "idle";
|
||||
|
||||
export interface SitrepItemData {
|
||||
id: string;
|
||||
agentID: string;
|
||||
agentName: string;
|
||||
agentImageUrl?: string | null;
|
||||
executionID?: string;
|
||||
priority: SitrepPriority;
|
||||
message: string;
|
||||
status: AgentStatus;
|
||||
}
|
||||
|
||||
/** Filter options for the agent filter dropdown. */
|
||||
export type AgentStatusFilter =
|
||||
| "all"
|
||||
| "running"
|
||||
| "attention"
|
||||
| "completed"
|
||||
| "listening"
|
||||
| "scheduled"
|
||||
| "idle"
|
||||
| "healthy";
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user