mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge remote-tracking branch 'origin/dev' into feat/task-decomposition-copilot
# Conflicts: # autogpt_platform/backend/backend/copilot/model.py
This commit is contained in:
@@ -18,7 +18,6 @@ from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.message_dedup import acquire_dedup_lock
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -192,6 +191,8 @@ class SessionDetailResponse(BaseModel):
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
has_more_messages: bool = False
|
||||
oldest_sequence: int | None = None
|
||||
newest_sequence: int | None = None
|
||||
forward_paginated: bool = False
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
@@ -456,52 +457,113 @@ async def update_session_title_route(
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
before_sequence: int | None = Query(default=None, ge=0),
|
||||
limit: int = Query(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=200,
|
||||
description="Maximum number of messages to return.",
|
||||
),
|
||||
before_sequence: int | None = Query(
|
||||
default=None,
|
||||
ge=0,
|
||||
description=(
|
||||
"Backward pagination cursor. Return messages with sequence number "
|
||||
"strictly less than this value. Used by active-session load-more. "
|
||||
"Mutually exclusive with after_sequence."
|
||||
),
|
||||
),
|
||||
after_sequence: int | None = Query(
|
||||
default=None,
|
||||
ge=0,
|
||||
description=(
|
||||
"Forward pagination cursor. Return messages with sequence number "
|
||||
"strictly greater than this value. Used by completed-session load-more. "
|
||||
"Mutually exclusive with before_sequence."
|
||||
),
|
||||
),
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||
When no pagination params are provided, returns the most recent messages.
|
||||
Supports cursor-based pagination via ``limit``, ``before_sequence``, and
|
||||
``after_sequence``. The two cursor parameters are mutually exclusive.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of messages to return (1-200, default 50).
|
||||
before_sequence: Return messages with sequence < this value (cursor).
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including
|
||||
active_stream info and pagination metadata.
|
||||
On the initial load (no cursor provided) of a completed session, messages
|
||||
are returned in forward order starting from sequence 0 so the user always
|
||||
sees their initial prompt. Active sessions use the legacy newest-first
|
||||
order so streaming context is preserved.
|
||||
"""
|
||||
if before_sequence is not None and after_sequence is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="before_sequence and after_sequence are mutually exclusive",
|
||||
)
|
||||
|
||||
is_initial_load = before_sequence is None and after_sequence is None
|
||||
|
||||
# Check active stream before the DB query on initial loads so we can
|
||||
# choose the correct pagination direction (forward for completed sessions,
|
||||
# newest-first for active ones).
|
||||
active_session = None
|
||||
last_message_id = None
|
||||
if is_initial_load:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
# Completed sessions on initial load start from sequence 0 so the user's
|
||||
# initial prompt is always visible. Active sessions keep the legacy
|
||||
# newest-first behavior to preserve streaming context.
|
||||
from_start = is_initial_load and active_session is None
|
||||
forward_paginated = from_start or after_sequence is not None
|
||||
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id, limit, before_sequence, user_id=user_id
|
||||
session_id,
|
||||
limit,
|
||||
before_sequence=before_sequence,
|
||||
after_sequence=after_sequence,
|
||||
from_start=from_start,
|
||||
user_id=user_id,
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
# Close the TOCTOU window: if the session was active at pre-check, re-verify
|
||||
# after the DB fetch. The session may have completed between the two awaits,
|
||||
# which would have caused messages to be fetched newest-first even though the
|
||||
# session is now complete. Re-fetch from seq 0 so the initial prompt is
|
||||
# always visible.
|
||||
if is_initial_load and active_session is not None:
|
||||
post_active, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
if post_active is None:
|
||||
active_session = None
|
||||
last_message_id = None
|
||||
from_start = True
|
||||
forward_paginated = True
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id,
|
||||
limit,
|
||||
before_sequence=None,
|
||||
after_sequence=None,
|
||||
from_start=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [
|
||||
_strip_injected_context(message.model_dump()) for message in page.messages
|
||||
]
|
||||
|
||||
# Only check active stream on initial load (not on "load more" requests)
|
||||
active_stream_info = None
|
||||
if before_sequence is None:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
if active_session and last_message_id is not None:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
if before_sequence is not None:
|
||||
if not is_initial_load:
|
||||
return SessionDetailResponse(
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
@@ -511,6 +573,8 @@ async def get_session(
|
||||
active_stream=None,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
newest_sequence=page.newest_sequence,
|
||||
forward_paginated=forward_paginated,
|
||||
total_prompt_tokens=0,
|
||||
total_completion_tokens=0,
|
||||
)
|
||||
@@ -527,6 +591,8 @@ async def get_session(
|
||||
active_stream=active_stream_info,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
newest_sequence=page.newest_sequence,
|
||||
forward_paginated=forward_paginated,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=page.session.metadata,
|
||||
@@ -872,9 +938,6 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
# Capture the original message text BEFORE any mutation (attachment enrichment)
|
||||
# so the idempotency hash is stable across retries.
|
||||
original_message = request.message
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
@@ -903,58 +966,36 @@ async def stream_chat_post(
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# ── Idempotency guard ────────────────────────────────────────────────────
|
||||
# Blocks duplicate executor tasks from concurrent/retried POSTs.
|
||||
# See backend/copilot/message_dedup.py for the full lifecycle description.
|
||||
dedup_lock = None
|
||||
if request.is_user_message:
|
||||
dedup_lock = await acquire_dedup_lock(
|
||||
session_id, original_message, sanitized_file_ids
|
||||
)
|
||||
if dedup_lock is None and (original_message or sanitized_file_ids):
|
||||
|
||||
async def _empty_sse() -> AsyncGenerator[str, None]:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_empty_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
#
|
||||
# If any of these operations raises, release the dedup lock before propagating
|
||||
# so subsequent retries are not blocked for 30 s.
|
||||
try:
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
# saved yet. append_and_save_message returns None when a duplicate is
|
||||
# detected — in that case skip enqueue to avoid processing the message twice.
|
||||
is_duplicate_message = False
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
is_duplicate_message = (
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
) is None
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
if not is_duplicate_message and request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
# Create a task in the stream registry for reconnection support.
|
||||
# For duplicate messages, skip create_session entirely so the infra-retry
|
||||
# client subscribes to the *existing* turn's Redis stream and receives the
|
||||
# in-progress executor output rather than an empty stream.
|
||||
turn_id = ""
|
||||
if not is_duplicate_message:
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
@@ -972,7 +1013,6 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@@ -984,10 +1024,10 @@ async def stream_chat_post(
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
)
|
||||
except Exception:
|
||||
if dedup_lock:
|
||||
await dedup_lock.release()
|
||||
raise
|
||||
else:
|
||||
logger.info(
|
||||
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
@@ -1011,12 +1051,6 @@ async def stream_chat_post(
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
# True for every exit path except GeneratorExit (client disconnect).
|
||||
# On disconnect the backend turn is still running — releasing the lock
|
||||
# there would reopen the infra-retry duplicate window. The 30 s TTL
|
||||
# is the fallback. All other exits (normal finish, early return, error)
|
||||
# should release so the user can re-send the same message.
|
||||
release_dedup_lock_on_exit = True
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
@@ -1028,7 +1062,7 @@ async def stream_chat_post(
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
return # finally releases dedup_lock
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
@@ -1070,7 +1104,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
break # finally releases dedup_lock
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
@@ -1086,7 +1120,6 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
release_dedup_lock_on_exit = False
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
@@ -1101,10 +1134,7 @@ async def stream_chat_post(
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
# finally releases dedup_lock
|
||||
finally:
|
||||
if dedup_lock and release_dedup_lock_on_exit:
|
||||
await dedup_lock.release()
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
|
||||
@@ -133,21 +133,12 @@ def test_stream_chat_rejects_too_many_file_ids():
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
redis_set_returns: object = True,
|
||||
):
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ.
|
||||
|
||||
Args:
|
||||
redis_set_returns: Value returned by the mocked Redis ``set`` call.
|
||||
``True`` (default) simulates a fresh key (new message);
|
||||
``None`` simulates a collision (duplicate blocked).
|
||||
validation and enrichment logic without needing RabbitMQ.
|
||||
|
||||
Returns:
|
||||
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
|
||||
A namespace with ``save`` and ``enqueue`` mock objects so
|
||||
callers can make additional assertions about side-effects.
|
||||
"""
|
||||
import types
|
||||
@@ -158,7 +149,7 @@ def _mock_stream_internals(
|
||||
)
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
@@ -174,15 +165,9 @@ def _mock_stream_internals(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=redis_set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
return types.SimpleNamespace(
|
||||
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
|
||||
)
|
||||
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
|
||||
return ns
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
@@ -211,6 +196,29 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── Duplicate message dedup ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_skips_enqueue_for_duplicate_message(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
"""When append_and_save_message returns None (duplicate detected),
|
||||
enqueue_copilot_turn and stream_registry.create_session must NOT be called
|
||||
to avoid double-processing and to prevent overwriting the active stream's
|
||||
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
|
||||
mocks = _mock_stream_internals(mocker)
|
||||
# Override save to return None — signalling a duplicate
|
||||
mocks.save.return_value = None
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mocks.enqueue.assert_not_called()
|
||||
mocks.registry.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -706,237 +714,6 @@ class TestStripInjectedContext:
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""A second POST with the same message within the 30-s window must return
|
||||
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
|
||||
turn complete without creating a ghost response."""
|
||||
# redis_set_returns=None simulates a collision: the NX key already exists.
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-dup/stream",
|
||||
json={"message": "duplicate message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
|
||||
assert '"finish"' in body
|
||||
assert "[DONE]" in body
|
||||
# The empty SSE response must include the AI SDK protocol header so the
|
||||
# frontend treats it as a valid stream and marks the turn complete.
|
||||
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
|
||||
# The duplicate guard must prevent save/enqueue side effects.
|
||||
ns.save.assert_not_called()
|
||||
ns.enqueue.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_first_post_proceeds_normally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The first POST (Redis NX key set successfully) must proceed through the
|
||||
normal streaming path — no early return."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-new/stream",
|
||||
json={"message": "first message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Redis set must have been called once with the NX flag.
|
||||
ns.redis.set.assert_called_once()
|
||||
call_kwargs = ns.redis.set.call_args
|
||||
assert call_kwargs.kwargs.get("nx") is True
|
||||
|
||||
|
||||
def test_stream_chat_dedup_skipped_for_non_user_messages(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""System/assistant messages (is_user_message=False) bypass the dedup
|
||||
guard — they are injected programmatically and must always be processed."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-sys/stream",
|
||||
json={"message": "system context", "is_user_message": False},
|
||||
)
|
||||
|
||||
# Even though redis_set_returns=None (would block a user message),
|
||||
# the endpoint must proceed because is_user_message=False.
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup hash must be computed from the original request message,
|
||||
not the mutated version that has the [Attached files] block appended.
|
||||
A file_id is sent so the route actually appends the [Attached files] block,
|
||||
exercising the mutation path — the hash must still match the original text."""
|
||||
import hashlib
|
||||
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
# Mock workspace + prisma so the attachment block is actually appended.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
fake_file = type(
|
||||
"F",
|
||||
(),
|
||||
{
|
||||
"id": file_id,
|
||||
"name": "doc.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"sizeBytes": 1024,
|
||||
},
|
||||
)()
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-hash/stream",
|
||||
json={
|
||||
"message": "plain message",
|
||||
"is_user_message": True,
|
||||
"file_ids": [file_id],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_called_once()
|
||||
call_args = ns.redis.set.call_args
|
||||
dedup_key = call_args.args[0]
|
||||
|
||||
# Hash must use the original message + sorted file IDs, not the mutated text.
|
||||
expected_hash = hashlib.sha256(
|
||||
f"sess-hash:plain message:{file_id}".encode()
|
||||
).hexdigest()[:16]
|
||||
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
|
||||
assert dedup_key == expected_key, (
|
||||
f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — "
|
||||
"hash may be using mutated message or wrong inputs"
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_after_stream_finish(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup Redis key must be deleted after the turn completes (when
|
||||
subscriber_queue is None the route yields StreamFinish immediately and
|
||||
should release the key so the user can re-send the same message)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
# Set up all internals manually so we can control subscribe_to_session.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-finish/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
assert '"finish"' in body
|
||||
# The dedup key must be released so intentional re-sends are allowed.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The route must not crash when the dedup Redis delete fails on the
|
||||
subscriber_queue-is-None early-finish path (except Exception: pass)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
# Make the delete raise so the except-pass branch is exercised.
|
||||
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
# Should not raise even though delete fails.
|
||||
response = client.post(
|
||||
"/sessions/sess-finish-err/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert '"finish"' in response.text
|
||||
# delete must have been attempted — the except-pass branch silenced the error.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
@@ -980,3 +757,146 @@ def test_disconnect_stream_returns_404_when_session_missing(
|
||||
|
||||
assert response.status_code == 404
|
||||
mock_disconnect.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── GET /sessions/{session_id} — forward/backward pagination ──────────────────
|
||||
|
||||
|
||||
def _make_paginated_messages(
|
||||
mocker: pytest_mock.MockerFixture, *, has_more: bool = False
|
||||
):
|
||||
"""Return a mock PaginatedMessages and configure the DB patch."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.db import PaginatedMessages
|
||||
from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata
|
||||
|
||||
now = datetime.now(UTC)
|
||||
session_info = ChatSessionInfo(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
metadata=ChatSessionMetadata(),
|
||||
)
|
||||
page = PaginatedMessages(
|
||||
messages=[ChatMessage(role="user", content="hello", sequence=0)],
|
||||
has_more=has_more,
|
||||
oldest_sequence=0,
|
||||
newest_sequence=0,
|
||||
session=session_info,
|
||||
)
|
||||
mock_paginate = mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_messages_paginated",
|
||||
new_callable=AsyncMock,
|
||||
return_value=page,
|
||||
)
|
||||
return page, mock_paginate
|
||||
|
||||
|
||||
def test_get_session_completed_returns_forward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Completed sessions (no active stream) return forward_paginated=True."""
|
||||
_make_paginated_messages(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(None, None),
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["forward_paginated"] is True
|
||||
assert data["newest_sequence"] == 0
|
||||
|
||||
|
||||
def test_get_session_active_returns_backward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Active sessions (with running stream) return forward_paginated=False."""
|
||||
from backend.copilot.stream_registry import ActiveSession
|
||||
|
||||
_make_paginated_messages(mocker)
|
||||
active = MagicMock(spec=ActiveSession)
|
||||
active.turn_id = "turn-1"
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(active, "msg-1"),
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["forward_paginated"] is False
|
||||
assert data["active_stream"] is not None
|
||||
assert data["active_stream"]["turn_id"] == "turn-1"
|
||||
|
||||
|
||||
def test_get_session_after_sequence_returns_forward_paginated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""after_sequence param returns forward_paginated=True; no stream check needed."""
|
||||
_, mock_paginate = _make_paginated_messages(mocker)
|
||||
|
||||
response = client.get("/sessions/sess-1?after_sequence=10")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["forward_paginated"] is True
|
||||
call_kwargs = mock_paginate.call_args
|
||||
assert call_kwargs.kwargs.get("after_sequence") == 10
|
||||
assert call_kwargs.kwargs.get("before_sequence") is None
|
||||
|
||||
|
||||
def test_get_session_both_cursors_returns_400(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending both before_sequence and after_sequence returns 400."""
|
||||
response = client.get("/sessions/sess-1?before_sequence=5&after_sequence=10")
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_get_session_toctou_refetch_when_session_completes_mid_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Race condition: session was active at pre-check but completes before DB fetch.
|
||||
|
||||
The route should detect the race via a post-fetch re-check, then re-fetch
|
||||
from seq 0 so the initial prompt is always visible.
|
||||
"""
|
||||
from backend.copilot.stream_registry import ActiveSession
|
||||
|
||||
page, mock_paginate = _make_paginated_messages(mocker)
|
||||
active = MagicMock(spec=ActiveSession)
|
||||
active.turn_id = "turn-1"
|
||||
|
||||
# First call: session appears active. Second call: session has completed.
|
||||
mock_get_active = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.get_active_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[(active, "msg-1"), (None, None)],
|
||||
)
|
||||
|
||||
response = client.get("/sessions/sess-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Post-race: session is now completed → forward_paginated=True, no stream
|
||||
assert data["forward_paginated"] is True
|
||||
assert data["active_stream"] is None
|
||||
# The DB was queried twice: once newest-first, once from_start=True
|
||||
assert mock_paginate.call_count == 2
|
||||
assert mock_get_active.call_count == 2
|
||||
second_call = mock_paginate.call_args_list[1]
|
||||
assert second_call.kwargs.get("from_start") is True
|
||||
|
||||
@@ -43,6 +43,25 @@ config = Config()
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
|
||||
"""Fetch execution counts per graph in a single batched query."""
|
||||
if not graph_ids:
|
||||
return {}
|
||||
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": {"in": graph_ids},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
return {
|
||||
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
user_id: str,
|
||||
search_term: Optional[str] = None,
|
||||
@@ -137,12 +156,18 @@ async def list_library_agents(
|
||||
|
||||
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts = await _fetch_execution_counts(user_id, graph_ids)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -214,12 +239,18 @@ async def list_favorite_library_agents(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
|
||||
execution_counts = await _fetch_execution_counts(user_id, graph_ids)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent,
|
||||
execution_count_override=execution_counts.get(agent.agentGraphId),
|
||||
)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
|
||||
@@ -65,6 +65,11 @@ async def test_get_library_agents(mocker):
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
@@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert():
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_favorite_library_agents(mocker):
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="fav1",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-fav",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=True,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-fav",
|
||||
version=1,
|
||||
name="Favorite Agent",
|
||||
description="My Favorite",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
|
||||
)
|
||||
|
||||
result = await db.list_favorite_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 1
|
||||
assert result.agents[0].id == "fav1"
|
||||
assert result.agents[0].name == "Favorite Agent"
|
||||
assert result.agents[0].graph_id == "agent-fav"
|
||||
assert result.pagination.total_items == 1
|
||||
assert result.pagination.total_pages == 1
|
||||
assert result.pagination.current_page == 1
|
||||
assert result.pagination.page_size == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_library_agents_skips_failed_agent(mocker):
|
||||
"""Agents that fail parsing should be skipped — covers the except branch."""
|
||||
mock_library_agents = [
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua-bad",
|
||||
userId="test-user",
|
||||
agentGraphId="agent-bad",
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id="agent-bad",
|
||||
version=1,
|
||||
name="Bad Agent",
|
||||
description="",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db._fetch_execution_counts",
|
||||
new=mocker.AsyncMock(return_value={}),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
side_effect=Exception("parse error"),
|
||||
)
|
||||
|
||||
result = await db.list_library_agents("test-user")
|
||||
|
||||
assert len(result.agents) == 0
|
||||
assert result.pagination.total_items == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_empty_graph_ids():
|
||||
result = await db._fetch_execution_counts("user-1", [])
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_execution_counts_uses_group_by(mocker):
|
||||
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
|
||||
mock_prisma.return_value.group_by = mocker.AsyncMock(
|
||||
return_value=[
|
||||
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
|
||||
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
|
||||
]
|
||||
)
|
||||
|
||||
result = await db._fetch_execution_counts(
|
||||
"user-1", ["graph-1", "graph-2", "graph-3"]
|
||||
)
|
||||
|
||||
assert result == {"graph-1": 5, "graph-2": 2}
|
||||
mock_prisma.return_value.group_by.assert_called_once_with(
|
||||
by=["agentGraphId"],
|
||||
where={
|
||||
"userId": "user-1",
|
||||
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
|
||||
"isDeleted": False,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
|
||||
@@ -223,6 +223,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
|
||||
store_listing: Optional[prisma.models.StoreListing] = None,
|
||||
profile: Optional[prisma.models.Profile] = None,
|
||||
execution_count_override: Optional[int] = None,
|
||||
) -> "LibraryAgent":
|
||||
"""
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
@@ -258,10 +259,14 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
execution_count = len(executions)
|
||||
execution_count = (
|
||||
execution_count_override
|
||||
if execution_count_override is not None
|
||||
else len(executions)
|
||||
)
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
if execution_count > 0:
|
||||
if executions and execution_count > 0:
|
||||
success_count = sum(
|
||||
1
|
||||
for e in executions
|
||||
|
||||
@@ -1,11 +1,66 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
from . import model as library_model
|
||||
|
||||
|
||||
def _make_library_agent(
|
||||
*,
|
||||
graph_id: str = "g1",
|
||||
executions: list | None = None,
|
||||
) -> prisma.models.LibraryAgent:
|
||||
return prisma.models.LibraryAgent(
|
||||
id="la1",
|
||||
userId="u1",
|
||||
agentGraphId=graph_id,
|
||||
settings="{}", # type: ignore
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=True,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
AgentGraph=prisma.models.AgentGraph(
|
||||
id=graph_id,
|
||||
version=1,
|
||||
name="Agent",
|
||||
description="Desc",
|
||||
userId="u1",
|
||||
isActive=True,
|
||||
createdAt=datetime.datetime.now(),
|
||||
Executions=executions,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_from_db_execution_count_override_covers_success_rate():
|
||||
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
exec1 = prisma.models.AgentGraphExecution(
|
||||
id="exec-1",
|
||||
agentGraphId="g1",
|
||||
agentGraphVersion=1,
|
||||
userId="u1",
|
||||
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
isShared=False,
|
||||
)
|
||||
agent = _make_library_agent(executions=[exec1])
|
||||
|
||||
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
|
||||
|
||||
assert result.execution_count == 1
|
||||
assert result.success_rate is not None
|
||||
assert result.success_rate == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_preset_from_db(test_user_id: str):
|
||||
# Create mock DB agent
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,8 @@ import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Literal, Sequence, get_args
|
||||
from typing import Annotated, Any, Literal, Sequence, cast, get_args
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -54,8 +55,11 @@ from backend.data.credit import (
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_auto_top_up,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
modify_stripe_subscription_for_tier,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
@@ -699,9 +703,72 @@ class SubscriptionCheckoutResponse(BaseModel):
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: str
|
||||
monthly_cost: int
|
||||
tier_costs: dict[str, int]
|
||||
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
|
||||
|
||||
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
"""Return True if `url` matches the configured frontend origin.
|
||||
|
||||
Prevents open-redirect: attackers must not be able to supply arbitrary
|
||||
success_url/cancel_url that Stripe will redirect users to after checkout.
|
||||
|
||||
Pre-parse rejection rules (applied before urlparse):
|
||||
- Backslashes (``\\``) are normalised differently across parsers/browsers.
|
||||
- Control characters (U+0000–U+001F) are not valid in URLs and may confuse
|
||||
some URL-parsing implementations.
|
||||
"""
|
||||
# Reject characters that can confuse URL parsers before any parsing.
|
||||
if "\\" in url:
|
||||
return False
|
||||
if any(ord(c) < 0x20 for c in url):
|
||||
return False
|
||||
|
||||
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
if not allowed:
|
||||
# No configured origin — refuse to validate rather than allow arbitrary URLs.
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
allowed_parsed = urlparse(allowed)
|
||||
except ValueError:
|
||||
return False
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
|
||||
# can trick browsers into connecting to a different host than displayed.
|
||||
# ``@`` in query/fragment is harmless and must be allowed.
|
||||
if "@" in parsed.netloc:
|
||||
return False
|
||||
return (
|
||||
parsed.scheme == allowed_parsed.scheme
|
||||
and parsed.netloc == allowed_parsed.netloc
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
|
||||
async def _get_stripe_price_amount(price_id: str) -> int | None:
|
||||
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
|
||||
|
||||
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
|
||||
of caching the ``None`` sentinel so the next request retries Stripe instead
|
||||
of being served a stale "no price" for the rest of the TTL window. Callers
|
||||
should treat ``None`` as an unknown price and fall back to 0.
|
||||
|
||||
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
|
||||
every GET /credits/subscription page load and reduces quota consumption.
|
||||
"""
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
return price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"Failed to retrieve Stripe price %s — returning None (not cached)",
|
||||
price_id,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -722,21 +789,26 @@ async def get_subscription_status(
|
||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
||||
)
|
||||
|
||||
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
|
||||
for t, price_id in zip(paid_tiers, price_ids):
|
||||
cost = 0
|
||||
if price_id:
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
cost = price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
pass
|
||||
tier_costs: dict[str, int] = {
|
||||
SubscriptionTier.FREE.value: 0,
|
||||
SubscriptionTier.ENTERPRISE.value: 0,
|
||||
}
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
for t, cost in zip(paid_tiers, costs):
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
|
||||
return SubscriptionStatusResponse(
|
||||
tier=tier.value,
|
||||
monthly_cost=tier_costs.get(tier.value, 0),
|
||||
monthly_cost=current_monthly_cost,
|
||||
tier_costs=tier_costs,
|
||||
proration_credit_cents=proration_credit,
|
||||
)
|
||||
|
||||
|
||||
@@ -766,24 +838,125 @@ async def update_subscription_tier(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
|
||||
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
|
||||
# keeps their tier for the time they already paid for. The DB tier is NOT
|
||||
# updated here when a subscription exists — the customer.subscription.deleted
|
||||
# webhook fires at period end and downgrades to FREE then.
|
||||
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
|
||||
# tier), cancel_stripe_subscription returns False and we update the DB tier
|
||||
# immediately since no webhook will ever fire.
|
||||
# When payment is disabled entirely, update the DB tier directly.
|
||||
if tier == SubscriptionTier.FREE:
|
||||
if payment_enabled:
|
||||
await cancel_stripe_subscription(user_id)
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
# Log full Stripe error server-side but return a generic message
|
||||
# to the client — raw Stripe errors can leak customer/sub IDs and
|
||||
# infrastructure config details.
|
||||
logger.exception(
|
||||
"Stripe error cancelling subscription for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
if not had_subscription:
|
||||
# No active Stripe subscription found — the user was on an
|
||||
# admin-granted tier. Update DB immediately since the
|
||||
# subscription.deleted webhook will never fire.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Beta users (payment not enabled) → update tier directly without Stripe.
|
||||
# Paid tier changes require payment to be enabled — block self-service upgrades
|
||||
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
|
||||
if not payment_enabled:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier}",
|
||||
)
|
||||
|
||||
# No-op short-circuit: if the user is already on the requested paid tier,
|
||||
# do NOT create a new Checkout Session. Without this guard, a duplicate
|
||||
# request (double-click, retried POST, stale page) creates a second
|
||||
# subscription for the same price; the user would be charged for both
|
||||
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
|
||||
# which only fires after the second charge has cleared.
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Paid upgrade → create Stripe Checkout Session.
|
||||
# Paid→paid tier change: if the user already has a Stripe subscription,
|
||||
# modify it in-place with proration instead of creating a new Checkout
|
||||
# Session. This preserves remaining paid time and avoids double-charging.
|
||||
# The customer.subscription.updated webhook fires and updates the DB tier.
|
||||
current_tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
# modify_stripe_subscription_for_tier returns False when no active
|
||||
# Stripe subscription exists — i.e. the user has an admin-granted
|
||||
# paid tier with no Stripe record. In that case, update the DB
|
||||
# tier directly (same as the FREE-downgrade path for admin-granted
|
||||
# users) rather than sending them through a new Checkout Session.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# Paid upgrade from FREE → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url are required for paid tier upgrades",
|
||||
)
|
||||
# Open-redirect protection: both URLs must point to the configured frontend
|
||||
# origin, otherwise an attacker could use our Stripe integration as a
|
||||
# redirector to arbitrary phishing sites.
|
||||
#
|
||||
# Fail early with a clear 503 if the server is misconfigured (neither
|
||||
# frontend_base_url nor platform_base_url set), so operators get an
|
||||
# actionable error instead of the misleading "must match the platform
|
||||
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
|
||||
# produce when `allowed` is empty.
|
||||
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
|
||||
logger.error(
|
||||
"update_subscription_tier: neither frontend_base_url nor "
|
||||
"platform_base_url is configured; cannot validate checkout redirect URLs"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Payment redirect URLs cannot be validated: "
|
||||
"frontend_base_url or platform_base_url must be set on the server."
|
||||
),
|
||||
)
|
||||
if not _validate_checkout_redirect_url(
|
||||
request.success_url
|
||||
) or not _validate_checkout_redirect_url(request.cancel_url):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url must match the platform frontend origin",
|
||||
)
|
||||
try:
|
||||
url = await create_subscription_checkout(
|
||||
user_id=user_id,
|
||||
@@ -791,8 +964,19 @@ async def update_subscription_tier(
|
||||
success_url=request.success_url,
|
||||
cancel_url=request.cancel_url,
|
||||
)
|
||||
except (ValueError, stripe.StripeError) as e:
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error creating checkout session for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to start checkout right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
return SubscriptionCheckoutResponse(url=url)
|
||||
|
||||
@@ -801,44 +985,78 @@ async def update_subscription_tier(
|
||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||
)
|
||||
async def stripe_webhook(request: Request):
|
||||
webhook_secret = settings.secrets.stripe_webhook_secret
|
||||
if not webhook_secret:
|
||||
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
|
||||
# signature over the same empty key). Reject all webhook calls when unconfigured.
|
||||
logger.error(
|
||||
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
|
||||
"rejecting request to prevent signature bypass"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Webhook not configured")
|
||||
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
# Get the signature header
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.secrets.stripe_webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
|
||||
except ValueError:
|
||||
# Invalid payload
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
|
||||
)
|
||||
except stripe.SignatureVerificationError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.SignatureVerificationError:
|
||||
# Invalid signature
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
# Defensive payload extraction. A malformed payload (missing/non-dict
|
||||
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
|
||||
# AFTER signature verification — which Stripe interprets as a delivery
|
||||
# failure and retries forever, while spamming Sentry with no useful info.
|
||||
# Acknowledge with 200 and a warning so Stripe stops retrying.
|
||||
event_type = event.get("type", "")
|
||||
event_data = event.get("data") or {}
|
||||
data_object = event_data.get("object") if isinstance(event_data, dict) else None
|
||||
if not isinstance(data_object, dict):
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing or non-dict data.object; ignoring",
|
||||
event_type,
|
||||
)
|
||||
return Response(status_code=200)
|
||||
|
||||
if (
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
if event_type in (
|
||||
"checkout.session.completed",
|
||||
"checkout.session.async_payment_succeeded",
|
||||
):
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
session_id = data_object.get("id")
|
||||
if not session_id:
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing data.object.id; ignoring", event_type
|
||||
)
|
||||
return Response(status_code=200)
|
||||
await UserCredit().fulfill_checkout(session_id=session_id)
|
||||
|
||||
if event["type"] in (
|
||||
if event_type in (
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
):
|
||||
await sync_subscription_from_stripe(event["data"]["object"])
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
|
||||
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
|
||||
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
|
||||
# to satisfy the type checker without changing runtime behaviour.
|
||||
if event_type == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
|
||||
|
||||
if event_type == "refund.created" or event_type == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(
|
||||
cast("stripe.Refund | stripe.Dispute", data_object)
|
||||
)
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -106,7 +106,6 @@ class LlmModelMeta(EnumMeta):
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "LlmModel | None":
|
||||
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
|
||||
@@ -203,6 +202,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
GROK_4_20 = "x-ai/grok-4.20"
|
||||
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
@@ -627,6 +628,18 @@ MODEL_METADATA = {
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_4_20: ModelMetadata(
|
||||
"open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3
|
||||
),
|
||||
LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata(
|
||||
"open_router",
|
||||
2000000,
|
||||
100000,
|
||||
"Grok 4.20 Multi-Agent",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
3,
|
||||
),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||
),
|
||||
@@ -987,7 +1000,6 @@ async def llm_call(
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
|
||||
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
|
||||
# Cache tool definitions alongside the system prompt.
|
||||
# Placing cache_control on the last tool caches all tool schemas as a
|
||||
|
||||
@@ -67,11 +67,15 @@ from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
detect_gap,
|
||||
download_transcript,
|
||||
extract_context_messages,
|
||||
strip_for_upload,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util import json as util_json
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
@@ -699,81 +703,147 @@ async def _compress_session_messages(
|
||||
return messages
|
||||
|
||||
|
||||
def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool:
|
||||
"""Return ``True`` when a download doesn't cover the current session.
|
||||
|
||||
A transcript is stale when it has a known ``message_count`` and that
|
||||
count doesn't reach ``session_msg_count - 1`` (i.e. the session has
|
||||
already advanced beyond what the stored transcript captures).
|
||||
Loading a stale transcript would silently drop intermediate turns,
|
||||
so callers should treat stale as "skip load, skip upload".
|
||||
|
||||
An unknown ``message_count`` (``0``) is treated as **not stale**
|
||||
because older transcripts uploaded before msg_count tracking
|
||||
existed must still be usable.
|
||||
"""
|
||||
if dl is None:
|
||||
return False
|
||||
if not dl.message_count:
|
||||
return False
|
||||
return dl.message_count < session_msg_count - 1
|
||||
|
||||
|
||||
def should_upload_transcript(
|
||||
user_id: str | None, transcript_covers_prefix: bool
|
||||
) -> bool:
|
||||
def should_upload_transcript(user_id: str | None, upload_safe: bool) -> bool:
|
||||
"""Return ``True`` when the caller should upload the final transcript.
|
||||
|
||||
Uploads require a logged-in user (for the storage key) *and* a
|
||||
transcript that covered the session prefix when loaded — otherwise
|
||||
we'd be overwriting a more complete version in storage with a
|
||||
partial one built from just the current turn.
|
||||
Uploads require a logged-in user (for the storage key) *and* a safe
|
||||
upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a
|
||||
newer version that we'd be overwriting.
|
||||
"""
|
||||
return bool(user_id) and transcript_covers_prefix
|
||||
return bool(user_id) and upload_safe
|
||||
|
||||
|
||||
def _append_gap_to_builder(
|
||||
gap: list[ChatMessage],
|
||||
builder: TranscriptBuilder,
|
||||
) -> None:
|
||||
"""Append gap messages from chat-db into the TranscriptBuilder.
|
||||
|
||||
Converts ChatMessage (OpenAI format) to TranscriptBuilder entries
|
||||
(Claude CLI JSONL format) so the uploaded transcript covers all turns.
|
||||
|
||||
Pre-condition: ``gap`` always starts at a user or assistant boundary
|
||||
(never mid-turn at a ``tool`` role), because ``detect_gap`` enforces
|
||||
``session_messages[wm-1].role == 'assistant'`` before returning a non-empty
|
||||
gap. Any ``tool`` role messages within the gap always follow an assistant
|
||||
entry that already exists in the builder or in the gap itself.
|
||||
"""
|
||||
for msg in gap:
|
||||
if msg.role == "user":
|
||||
builder.append_user(msg.content or "")
|
||||
elif msg.role == "assistant":
|
||||
content_blocks: list[dict] = []
|
||||
if msg.content:
|
||||
content_blocks.append({"type": "text", "text": msg.content})
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
input_data = util_json.loads(fn.get("arguments", "{}"), fallback={})
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", "") if isinstance(tc, dict) else "",
|
||||
"name": fn.get("name", "unknown"),
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
if not content_blocks:
|
||||
# Fallback: ensure every assistant gap message produces an entry
|
||||
# so the builder's entry count matches the gap length.
|
||||
content_blocks.append({"type": "text", "text": ""})
|
||||
builder.append_assistant(content_blocks=content_blocks)
|
||||
elif msg.role == "tool":
|
||||
if msg.tool_call_id:
|
||||
builder.append_tool_result(
|
||||
tool_use_id=msg.tool_call_id,
|
||||
content=msg.content or "",
|
||||
)
|
||||
else:
|
||||
# Malformed tool message — no tool_call_id to link to an
|
||||
# assistant tool_use block. Skip to avoid an unmatched
|
||||
# tool_result entry in the builder (which would confuse --resume).
|
||||
logger.warning(
|
||||
"[Baseline] Skipping tool gap message with no tool_call_id"
|
||||
)
|
||||
|
||||
|
||||
async def _load_prior_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
session_msg_count: int,
|
||||
session_messages: list[ChatMessage],
|
||||
transcript_builder: TranscriptBuilder,
|
||||
) -> bool:
|
||||
"""Download and load the prior transcript into ``transcript_builder``.
|
||||
) -> tuple[bool, "TranscriptDownload | None"]:
|
||||
"""Download and load the prior CLI session into ``transcript_builder``.
|
||||
|
||||
Returns ``True`` when the loaded transcript fully covers the session
|
||||
prefix; ``False`` otherwise (stale, missing, invalid, or download
|
||||
error). Callers should suppress uploads when this returns ``False``
|
||||
to avoid overwriting a more complete version in storage.
|
||||
Returns a tuple of (upload_safe, transcript_download):
|
||||
- ``upload_safe`` is ``True`` when it is safe to upload at the end of this
|
||||
turn. Upload is suppressed only for **download errors** (unknown GCS
|
||||
state) — missing and invalid files return ``True`` because there is
|
||||
nothing in GCS worth protecting against overwriting.
|
||||
- ``transcript_download`` is a ``TranscriptDownload`` with str content
|
||||
(pre-decoded and stripped) when available, or ``None`` when no valid
|
||||
transcript could be loaded. Callers pass this to
|
||||
``extract_context_messages`` to build the LLM context.
|
||||
"""
|
||||
try:
|
||||
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Transcript download failed: %s", e)
|
||||
return False
|
||||
|
||||
if dl is None:
|
||||
logger.debug("[Baseline] No transcript available")
|
||||
return False
|
||||
|
||||
if not validate_transcript(dl.content):
|
||||
logger.warning("[Baseline] Downloaded transcript but invalid")
|
||||
return False
|
||||
|
||||
if is_transcript_stale(dl, session_msg_count):
|
||||
logger.warning(
|
||||
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
|
||||
dl.message_count,
|
||||
session_msg_count,
|
||||
restore = await download_transcript(
|
||||
user_id, session_id, log_prefix="[Baseline]"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Session restore failed: %s", e)
|
||||
# Unknown GCS state — be conservative, skip upload.
|
||||
return False, None
|
||||
|
||||
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
|
||||
if restore is None:
|
||||
logger.debug("[Baseline] No CLI session available — will upload fresh")
|
||||
# Nothing in GCS to protect; allow upload so the first baseline turn
|
||||
# writes the initial transcript snapshot.
|
||||
return True, None
|
||||
|
||||
content_bytes = restore.content
|
||||
try:
|
||||
raw_str = (
|
||||
content_bytes.decode("utf-8")
|
||||
if isinstance(content_bytes, bytes)
|
||||
else content_bytes
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("[Baseline] CLI session content is not valid UTF-8")
|
||||
# Corrupt file in GCS; overwriting with a valid one is better.
|
||||
return True, None
|
||||
|
||||
stripped = strip_for_upload(raw_str)
|
||||
if not validate_transcript(stripped):
|
||||
logger.warning("[Baseline] CLI session content invalid after strip")
|
||||
# Corrupt file in GCS; overwriting with a valid one is better.
|
||||
return True, None
|
||||
|
||||
transcript_builder.load_previous(stripped, log_prefix="[Baseline]")
|
||||
logger.info(
|
||||
"[Baseline] Loaded transcript: %dB, msg_count=%d",
|
||||
len(dl.content),
|
||||
dl.message_count,
|
||||
"[Baseline] Loaded CLI session: %dB, msg_count=%d",
|
||||
len(content_bytes) if isinstance(content_bytes, bytes) else len(raw_str),
|
||||
restore.message_count,
|
||||
)
|
||||
return True
|
||||
|
||||
gap = detect_gap(restore, session_messages)
|
||||
if gap:
|
||||
_append_gap_to_builder(gap, transcript_builder)
|
||||
logger.info(
|
||||
"[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB",
|
||||
restore.message_count,
|
||||
len(gap),
|
||||
)
|
||||
|
||||
# Return a str-content version so extract_context_messages receives a
|
||||
# pre-decoded, stripped transcript (avoids redundant decode + strip).
|
||||
# TranscriptDownload.content is typed as bytes | str; we pass str here
|
||||
# to avoid a redundant encode + decode round-trip.
|
||||
str_restore = TranscriptDownload(
|
||||
content=stripped,
|
||||
message_count=restore.message_count,
|
||||
mode=restore.mode,
|
||||
)
|
||||
return True, str_restore
|
||||
|
||||
|
||||
async def _upload_final_transcript(
|
||||
@@ -807,10 +877,10 @@ async def _upload_final_transcript(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
content=content.encode("utf-8"),
|
||||
message_count=session_msg_count,
|
||||
mode="baseline",
|
||||
log_prefix="[Baseline]",
|
||||
skip_strip=True,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(upload_task)
|
||||
@@ -897,7 +967,7 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
# --- Transcript support (feature parity with SDK path) ---
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_covers_prefix = True
|
||||
transcript_upload_safe = True
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
@@ -914,15 +984,16 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
# Run download + prompt build concurrently — both are independent I/O
|
||||
# on the request critical path.
|
||||
transcript_download: TranscriptDownload | None = None
|
||||
if user_id and len(session.messages) > 1:
|
||||
(
|
||||
transcript_covers_prefix,
|
||||
(transcript_upload_safe, transcript_download),
|
||||
(base_system_prompt, understanding),
|
||||
) = await asyncio.gather(
|
||||
_load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
session_msg_count=len(session.messages),
|
||||
session_messages=session.messages,
|
||||
transcript_builder=transcript_builder,
|
||||
),
|
||||
prompt_task,
|
||||
@@ -962,9 +1033,14 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
warm_ctx = await fetch_warm_context(user_id, message or "")
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
# Context path: transcript content (compacted, isCompactSummary preserved) +
|
||||
# gap (DB messages after watermark) + current user turn.
|
||||
# This avoids re-reading the full session history from DB on every turn.
|
||||
# See extract_context_messages() in transcript.py for the shared primitive.
|
||||
prior_context = extract_context_messages(transcript_download, session.messages)
|
||||
messages_for_context = await _compress_session_messages(
|
||||
session.messages, model=active_model
|
||||
prior_context + ([session.messages[-1]] if session.messages else []),
|
||||
model=active_model,
|
||||
)
|
||||
|
||||
# Build OpenAI message list from session history.
|
||||
@@ -1308,7 +1384,7 @@ async def stream_chat_completion_baseline(
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
if user_id and should_upload_transcript(user_id, transcript_covers_prefix):
|
||||
if user_id and should_upload_transcript(user_id, transcript_upload_safe):
|
||||
await _upload_final_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Integration tests for baseline transcript flow.
|
||||
|
||||
Exercises the real helpers in ``baseline/service.py`` that download,
|
||||
validate, load, append to, backfill, and upload the transcript.
|
||||
Exercises the real helpers in ``baseline/service.py`` that restore,
|
||||
validate, load, append to, backfill, and upload the CLI session.
|
||||
Storage is mocked via ``download_transcript`` / ``upload_transcript``
|
||||
patches; no network access is required.
|
||||
"""
|
||||
@@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_append_gap_to_builder,
|
||||
_load_prior_transcript,
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
is_transcript_stale,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
@@ -54,6 +55,13 @@ def _make_transcript_content(*roles: str) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _make_session_messages(*roles: str) -> list[ChatMessage]:
|
||||
"""Build a list of ChatMessage objects matching the given roles."""
|
||||
return [
|
||||
ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles)
|
||||
]
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
|
||||
@@ -73,87 +81,102 @@ class TestResolveBaselineModel:
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
|
||||
"""``_load_prior_transcript`` wraps the CLI session restore + validate + load flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=content.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert dl.message_count == 2
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
async def test_fills_gap_when_transcript_is_behind(self):
|
||||
"""When transcript covers fewer messages than session, gap is filled from DB."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
# transcript covers 2 messages, session has 4 (plus current user turn = 5)
|
||||
restore = TranscriptDownload(
|
||||
content=content.encode("utf-8"), message_count=2, mode="baseline"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
session_messages=_make_session_messages(
|
||||
"user", "assistant", "user", "assistant", "user"
|
||||
),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
# 2 from transcript + 2 gap messages (user+assistant at positions 2,3)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
async def test_missing_transcript_allows_upload(self):
|
||||
"""Nothing in GCS → upload is safe; the turn writes the first snapshot."""
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
async def test_invalid_transcript_allows_upload(self):
|
||||
"""Corrupt file in GCS → overwriting with a valid one is better."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content='{"type":"progress","uuid":"a"}\n',
|
||||
restore = TranscriptDownload(
|
||||
content=b'{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
mode="sdk",
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -163,36 +186,39 @@ class TestLoadPriorTranscript:
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
"""When msg_count is 0 (unknown), gap detection is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
restore = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant").encode("utf-8"),
|
||||
message_count=0,
|
||||
mode="sdk",
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
session_messages=_make_session_messages(*["user"] * 20),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert builder.entry_count == 2
|
||||
|
||||
|
||||
@@ -227,7 +253,7 @@ class TestUploadFinalTranscript:
|
||||
assert call_kwargs["user_id"] == "user-1"
|
||||
assert call_kwargs["session_id"] == "session-1"
|
||||
assert call_kwargs["message_count"] == 2
|
||||
assert "hello" in call_kwargs["content"]
|
||||
assert b"hello" in call_kwargs["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_upload_when_builder_empty(self):
|
||||
@@ -374,17 +400,19 @@ class TestRoundTrip:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=prior.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
@@ -424,11 +452,11 @@ class TestRoundTrip:
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "new question" in uploaded
|
||||
assert "new answer" in uploaded
|
||||
assert b"new question" in uploaded
|
||||
assert b"new answer" in uploaded
|
||||
# Original content preserved in the round trip.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
assert b"user message 0" in uploaded
|
||||
assert b"assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_append_guard(self):
|
||||
@@ -459,36 +487,6 @@ class TestRoundTrip:
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
|
||||
class TestIsTranscriptStale:
|
||||
"""``is_transcript_stale`` gates prior-transcript loading."""
|
||||
|
||||
def test_none_download_is_not_stale(self):
|
||||
assert is_transcript_stale(None, session_msg_count=5) is False
|
||||
|
||||
def test_zero_message_count_is_not_stale(self):
|
||||
"""Legacy transcripts without msg_count tracking must remain usable."""
|
||||
dl = TranscriptDownload(content="", message_count=0)
|
||||
assert is_transcript_stale(dl, session_msg_count=20) is False
|
||||
|
||||
def test_stale_when_covers_less_than_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=2)
|
||||
# session has 6 messages; transcript must cover at least 5 (6-1).
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is True
|
||||
|
||||
def test_fresh_when_covers_full_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_fresh_when_exceeds_prefix(self):
|
||||
"""Race: transcript ahead of session count is still acceptable."""
|
||||
dl = TranscriptDownload(content="", message_count=10)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_boundary_equal_to_prefix_minus_one(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
|
||||
class TestShouldUploadTranscript:
|
||||
"""``should_upload_transcript`` gates the final upload."""
|
||||
|
||||
@@ -510,7 +508,7 @@ class TestShouldUploadTranscript:
|
||||
|
||||
|
||||
class TestTranscriptLifecycle:
|
||||
"""End-to-end: download → validate → build → upload.
|
||||
"""End-to-end: restore → validate → build → upload.
|
||||
|
||||
Simulates the full transcript lifecycle inside
|
||||
``stream_chat_completion_baseline`` by mocking the storage layer and
|
||||
@@ -519,27 +517,29 @@ class TestTranscriptLifecycle:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_happy_path(self):
|
||||
"""Fresh download, append a turn, upload covers the session."""
|
||||
"""Fresh restore, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=prior.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
# --- 1. Download & load prior transcript ---
|
||||
covers = await _load_prior_transcript(
|
||||
# --- 1. Restore & load prior session ---
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
@@ -559,10 +559,7 @@ class TestTranscriptLifecycle:
|
||||
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is True
|
||||
should_upload_transcript(user_id="user-1", upload_safe=covers) is True
|
||||
)
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
@@ -574,20 +571,21 @@ class TestTranscriptLifecycle:
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "follow-up question" in uploaded
|
||||
assert "follow-up answer" in uploaded
|
||||
assert b"follow-up question" in uploaded
|
||||
assert b"follow-up answer" in uploaded
|
||||
# Original prior-turn content preserved.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
assert b"user message 0" in uploaded
|
||||
assert b"assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale download → covers=False → upload must be skipped."""
|
||||
async def test_lifecycle_stale_download_fills_gap(self):
|
||||
"""When transcript covers fewer messages, gap is filled rather than rejected."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored transcript only covers 2 → stale.
|
||||
# session has 5 msgs but stored transcript only covers 2 → gap filled.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
content=_make_transcript_content("user", "assistant").encode("utf-8"),
|
||||
message_count=2,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
@@ -601,20 +599,18 @@ class TestTranscriptLifecycle:
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
session_messages=_make_session_messages(
|
||||
"user", "assistant", "user", "assistant", "user"
|
||||
),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
assert covers is True
|
||||
# Gap was filled: 2 from transcript + 2 gap messages
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
@@ -627,15 +623,11 @@ class TestTranscriptLifecycle:
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
|
||||
is False
|
||||
)
|
||||
assert should_upload_transcript(user_id=None, upload_safe=True) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior transcript → covers defaults to True in the service,
|
||||
new turn should upload cleanly."""
|
||||
"""No prior session → upload is safe; the turn writes the first snapshot."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
@@ -648,20 +640,117 @@ class TestTranscriptLifecycle:
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
session_messages=_make_session_messages("user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No download: covers is False, so the production path would
|
||||
# skip upload. This protects against overwriting a future
|
||||
# more-complete transcript with a single-turn snapshot.
|
||||
assert covers is False
|
||||
# Nothing in GCS → upload is safe so the first baseline turn
|
||||
# can write the initial transcript snapshot.
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is False
|
||||
should_upload_transcript(user_id="user-1", upload_safe=upload_safe)
|
||||
is True
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _append_gap_to_builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppendGapToBuilder:
|
||||
"""``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries."""
|
||||
|
||||
def test_user_message_appended(self):
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="user", content="hello")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
assert builder.last_entry_type == "user"
|
||||
|
||||
def test_assistant_text_message_appended(self):
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="answer"),
|
||||
]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
assert "answer" in builder.to_jsonl()
|
||||
|
||||
def test_assistant_with_tool_calls_appended(self):
|
||||
"""Assistant tool_calls are recorded as tool_use blocks in the transcript."""
|
||||
builder = TranscriptBuilder()
|
||||
tool_call = {
|
||||
"id": "tc-1",
|
||||
"type": "function",
|
||||
"function": {"name": "my_tool", "arguments": '{"key":"val"}'},
|
||||
}
|
||||
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "tool_use" in jsonl
|
||||
assert "my_tool" in jsonl
|
||||
assert "tc-1" in jsonl
|
||||
|
||||
def test_assistant_invalid_json_args_uses_empty_dict(self):
|
||||
"""Malformed JSON in tool_call arguments falls back to {}."""
|
||||
builder = TranscriptBuilder()
|
||||
tool_call = {
|
||||
"id": "tc-bad",
|
||||
"type": "function",
|
||||
"function": {"name": "bad_tool", "arguments": "not-json"},
|
||||
}
|
||||
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert '"input":{}' in jsonl
|
||||
|
||||
def test_assistant_empty_content_and_no_tools_uses_fallback(self):
|
||||
"""Assistant with no content and no tool_calls gets a fallback empty text block."""
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="assistant", content=None)]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "text" in jsonl
|
||||
|
||||
def test_tool_role_with_tool_call_id_appended(self):
|
||||
"""Tool result messages are appended when tool_call_id is set."""
|
||||
builder = TranscriptBuilder()
|
||||
# Need a preceding assistant tool_use entry
|
||||
builder.append_user("use tool")
|
||||
builder.append_assistant(
|
||||
content_blocks=[
|
||||
{"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}}
|
||||
]
|
||||
)
|
||||
msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 3
|
||||
assert "tool_result" in builder.to_jsonl()
|
||||
|
||||
def test_tool_role_without_tool_call_id_skipped(self):
|
||||
"""Tool messages without tool_call_id are silently skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 0
|
||||
|
||||
def test_tool_call_missing_function_key_uses_unknown_name(self):
|
||||
"""A tool_call dict with no 'function' key uses 'unknown' as the tool name."""
|
||||
builder = TranscriptBuilder()
|
||||
# Tool call dict exists but 'function' sub-dict is missing entirely
|
||||
msgs = [
|
||||
ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}])
|
||||
]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "unknown" in jsonl
|
||||
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
# projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
|
||||
@@ -10,9 +10,11 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatMessageWhereInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
FindManyChatMessageArgsFromChatSession,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -30,6 +32,8 @@ from .model import get_chat_session as get_chat_session_cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BOUNDARY_SCAN_LIMIT = 10
|
||||
|
||||
|
||||
class PaginatedMessages(BaseModel):
|
||||
"""Result of a paginated message query."""
|
||||
@@ -37,6 +41,7 @@ class PaginatedMessages(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
has_more: bool
|
||||
oldest_sequence: int | None
|
||||
newest_sequence: int | None
|
||||
session: ChatSessionInfo
|
||||
|
||||
|
||||
@@ -61,32 +66,48 @@ async def get_chat_messages_paginated(
|
||||
session_id: str,
|
||||
limit: int = 50,
|
||||
before_sequence: int | None = None,
|
||||
after_sequence: int | None = None,
|
||||
from_start: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> PaginatedMessages | None:
|
||||
"""Get paginated messages for a session, newest first.
|
||||
"""Get paginated messages for a session.
|
||||
|
||||
Verifies session existence (and ownership when ``user_id`` is provided)
|
||||
in parallel with the message query. Returns ``None`` when the session
|
||||
is not found or does not belong to the user.
|
||||
Three modes:
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
limit: Max messages to return.
|
||||
before_sequence: Cursor — return messages with sequence < this value.
|
||||
user_id: If provided, filters via ``Session.userId`` so only the
|
||||
session owner's messages are returned (acts as an ownership guard).
|
||||
- ``before_sequence`` set: backward pagination (DESC), returns messages
|
||||
with sequence < ``before_sequence``. Used for active sessions or manual
|
||||
backward navigation.
|
||||
- ``from_start=True`` or ``after_sequence`` set: forward pagination (ASC).
|
||||
Returns messages from sequence 0 (``from_start``) or after
|
||||
``after_sequence``. Used on initial load of completed sessions and for
|
||||
loading subsequent forward pages.
|
||||
- Both cursors ``None`` and ``from_start=False``: newest-first (DESC
|
||||
without filter). Used for active sessions on initial load.
|
||||
|
||||
Verifies session existence (and ownership when ``user_id`` is provided).
|
||||
Returns ``None`` when the session is not found or does not belong to the
|
||||
user.
|
||||
"""
|
||||
# Build session-existence / ownership check
|
||||
session_where: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
session_where["userId"] = user_id
|
||||
|
||||
# Build message include — fetch paginated messages in the same query
|
||||
msg_include: dict[str, Any] = {
|
||||
"order_by": {"sequence": "desc"},
|
||||
forward = from_start or after_sequence is not None
|
||||
|
||||
# Build message include — fetch paginated messages in the same query.
|
||||
# Note: when both from_start=True and after_sequence is not None, the
|
||||
# after_sequence filter takes precedence (the elif branch below is skipped).
|
||||
# This combination is not reachable via the HTTP route (mutual exclusion is
|
||||
# enforced there), so we rely on the documented priority here without an
|
||||
# additional assertion.
|
||||
msg_include: FindManyChatMessageArgsFromChatSession = {
|
||||
"order_by": {"sequence": "asc" if forward else "desc"},
|
||||
"take": limit + 1,
|
||||
}
|
||||
if before_sequence is not None:
|
||||
if after_sequence is not None:
|
||||
msg_include["where"] = {"sequence": {"gt": after_sequence}}
|
||||
elif before_sequence is not None:
|
||||
msg_include["where"] = {"sequence": {"lt": before_sequence}}
|
||||
|
||||
# Single query: session existence/ownership + paginated messages
|
||||
@@ -104,57 +125,96 @@ async def get_chat_messages_paginated(
|
||||
has_more = len(results) > limit
|
||||
results = results[:limit]
|
||||
|
||||
# Reverse to ascending order
|
||||
results.reverse()
|
||||
if not forward:
|
||||
# Backward mode: DB returned DESC; reverse to ascending order.
|
||||
results.reverse()
|
||||
|
||||
# Tool-call boundary fix: if the oldest message is a tool message,
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
_BOUNDARY_SCAN_LIMIT = 10
|
||||
if results and results[0].role == "tool":
|
||||
boundary_where: dict[str, Any] = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
)
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
# Tool-call boundary fix: if the oldest message is a tool message,
|
||||
# expand backward to include the preceding assistant message that
|
||||
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
|
||||
# can pair them correctly.
|
||||
if results and results[0].role == "tool":
|
||||
boundary_where: ChatMessageWhereInput = {
|
||||
"sessionId": session_id,
|
||||
"sequence": {"lt": results[0].sequence},
|
||||
}
|
||||
if user_id is not None:
|
||||
boundary_where["Session"] = {"is": {"userId": user_id}}
|
||||
extra = await PrismaChatMessage.prisma().find_many(
|
||||
where=boundary_where,
|
||||
order={"sequence": "desc"},
|
||||
take=_BOUNDARY_SCAN_LIMIT,
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
# Only mark has_more if the expanded boundary isn't the
|
||||
# very start of the conversation (sequence 0).
|
||||
if boundary_msgs[0].sequence > 0:
|
||||
# Find the first non-tool message (should be the assistant)
|
||||
boundary_msgs = []
|
||||
found_owner = False
|
||||
for msg in extra:
|
||||
boundary_msgs.append(msg)
|
||||
if msg.role != "tool":
|
||||
found_owner = True
|
||||
break
|
||||
boundary_msgs.reverse()
|
||||
if not found_owner:
|
||||
logger.warning(
|
||||
"Boundary expansion did not find owning assistant message "
|
||||
"for session=%s before sequence=%s (%d msgs scanned)",
|
||||
session_id,
|
||||
results[0].sequence,
|
||||
len(extra),
|
||||
)
|
||||
if boundary_msgs:
|
||||
results = boundary_msgs + results
|
||||
# Only mark has_more if the expanded boundary isn't the
|
||||
# very start of the conversation (sequence 0).
|
||||
if boundary_msgs[0].sequence > 0:
|
||||
has_more = True
|
||||
else:
|
||||
# Forward mode: DB returned ASC.
|
||||
# Tool-call tail boundary fix: if the last message in this page is a
|
||||
# tool message, the NEXT forward page would start after it and begin
|
||||
# mid-tool-group — the owning assistant message is on this page but
|
||||
# the following tool results are on the next page.
|
||||
# Trim the current page so it ends on the owning assistant message,
|
||||
# which keeps tool groups intact across page boundaries.
|
||||
if results and results[-1].role == "tool":
|
||||
# Walk backward through results to find the last non-tool message.
|
||||
trim_idx = len(results) - 1
|
||||
while trim_idx >= 0 and results[trim_idx].role == "tool":
|
||||
trim_idx -= 1
|
||||
|
||||
if trim_idx >= 0:
|
||||
# Trim results so the page ends at the owning assistant.
|
||||
# Mark has_more=True so the client knows to fetch the rest.
|
||||
results = results[: trim_idx + 1]
|
||||
has_more = True
|
||||
else:
|
||||
# Entire page is tool messages with no visible owner — log and
|
||||
# keep as-is so the caller is not stuck with an empty page.
|
||||
logger.warning(
|
||||
"Forward tail boundary: entire page is tool messages "
|
||||
"for session=%s, no owning assistant found (%d msgs)",
|
||||
session_id,
|
||||
len(results),
|
||||
)
|
||||
|
||||
messages = [ChatMessage.from_db(m) for m in results]
|
||||
oldest_sequence = messages[0].sequence if messages else None
|
||||
# oldest_sequence is only meaningful in backward mode (used as backward
|
||||
# pagination cursor). In forward mode the page always starts near seq 0
|
||||
# and clients should use newest_sequence as the forward cursor instead.
|
||||
# Return None in forward mode so clients don't accidentally treat it as a
|
||||
# backward cursor on a forward-paginated session.
|
||||
oldest_sequence = messages[0].sequence if (messages and not forward) else None
|
||||
# newest_sequence is only meaningful in forward mode; in backward mode it
|
||||
# points to the last message of the page (not the session's newest message)
|
||||
# which is not a valid forward cursor. Return None in backward mode so
|
||||
# clients don't accidentally use it as one.
|
||||
newest_sequence = messages[-1].sequence if (messages and forward) else None
|
||||
|
||||
return PaginatedMessages(
|
||||
messages=messages,
|
||||
has_more=has_more,
|
||||
oldest_sequence=oldest_sequence,
|
||||
newest_sequence=newest_sequence,
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -175,6 +175,187 @@ async def test_no_where_on_messages_without_before_sequence(
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
# ---------- Forward pagination (from_start / after_sequence) ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_start_uses_asc_order_no_where(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""from_start=True queries messages in ASC order with no where filter."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert include["Messages"]["order_by"] == {"sequence": "asc"}
|
||||
assert "where" not in include["Messages"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_start_returns_messages_ascending(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""from_start=True returns messages in ascending sequence order."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [0, 1, 2]
|
||||
assert (
|
||||
page.oldest_sequence is None
|
||||
) # None in forward mode — not a valid backward cursor
|
||||
assert page.newest_sequence == 2
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_start_has_more_when_results_exceed_limit(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""from_start=True sets has_more when DB returns more than limit items."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0), _make_msg(1), _make_msg(2)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=2, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
assert page.has_more is True
|
||||
assert [m.sequence for m in page.messages] == [0, 1]
|
||||
assert page.newest_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_sequence_uses_gt_filter_asc_order(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""after_sequence adds a sequence > N where clause and uses ASC order."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(11), _make_msg(12)],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
|
||||
|
||||
call_kwargs = find_first.call_args
|
||||
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
|
||||
assert include["Messages"]["order_by"] == {"sequence": "asc"}
|
||||
assert include["Messages"]["where"] == {"sequence": {"gt": 10}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_sequence_returns_messages_in_order(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""after_sequence returns only messages with sequence > cursor, ascending."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(11), _make_msg(12), _make_msg(13)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [11, 12, 13]
|
||||
assert (
|
||||
page.oldest_sequence is None
|
||||
) # None in forward mode — not a valid backward cursor
|
||||
assert page.newest_sequence == 13
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_newest_sequence_none_for_backward_mode(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""newest_sequence is None in backward mode — it is not a valid forward cursor."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(5), _make_msg(4), _make_msg(3)],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
|
||||
|
||||
assert page is not None
|
||||
assert page.newest_sequence is None
|
||||
assert page.oldest_sequence == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_mode_no_boundary_expansion(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pagination never triggers backward boundary expansion."""
|
||||
find_first, find_many = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[_make_msg(0, role="tool"), _make_msg(1, role="tool")],
|
||||
)
|
||||
|
||||
await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True)
|
||||
|
||||
assert find_many.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_tail_boundary_trims_trailing_tool_messages(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pages that end with tool messages are trimmed to the owning
|
||||
assistant so the next after_sequence page doesn't start mid-tool-group."""
|
||||
find_first, _ = mock_db
|
||||
# DB returns 4 messages ASC: assistant at 0, tool at 1, tool at 2, tool at 3
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(0, role="assistant"),
|
||||
_make_msg(1, role="tool"),
|
||||
_make_msg(2, role="tool"),
|
||||
_make_msg(3, role="tool"),
|
||||
],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
# Page should be trimmed to end at the assistant message
|
||||
assert [m.sequence for m in page.messages] == [0]
|
||||
assert page.newest_sequence == 0
|
||||
# has_more must be True so the client fetches the tool messages on next page
|
||||
assert page.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_tail_boundary_no_trim_when_last_not_tool(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
):
|
||||
"""Forward pages that end with a non-tool message are not trimmed."""
|
||||
find_first, _ = mock_db
|
||||
find_first.return_value = _make_session(
|
||||
messages=[
|
||||
_make_msg(0, role="user"),
|
||||
_make_msg(1, role="assistant"),
|
||||
_make_msg(2, role="tool"),
|
||||
_make_msg(3, role="assistant"),
|
||||
],
|
||||
)
|
||||
|
||||
page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True)
|
||||
|
||||
assert page is not None
|
||||
assert [m.sequence for m in page.messages] == [0, 1, 2, 3]
|
||||
assert page.newest_sequence == 3
|
||||
assert page.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_id_filter_applied_to_session_where(
|
||||
mock_db: tuple[AsyncMock, AsyncMock],
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Per-request idempotency lock for the /stream endpoint.
|
||||
|
||||
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
|
||||
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
|
||||
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
|
||||
``None`` when the key already exists (duplicate request).
|
||||
2. ``release()`` — deletes the key. Must be called on turn completion or turn
|
||||
error so the next legitimate send is never blocked.
|
||||
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
|
||||
the backend turn is still running, and releasing would reopen the duplicate
|
||||
window for infra-level retries. The 30 s TTL is the safety net.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_KEY_PREFIX = "chat:msg_dedup"
|
||||
_TTL_SECONDS = 30
|
||||
|
||||
|
||||
class _DedupLock:
|
||||
def __init__(self, key: str, redis) -> None:
|
||||
self._key = key
|
||||
self._redis = redis
|
||||
|
||||
async def release(self) -> None:
|
||||
"""Best-effort key deletion. The TTL handles failures silently."""
|
||||
try:
|
||||
await self._redis.delete(self._key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def acquire_dedup_lock(
|
||||
session_id: str,
|
||||
message: str | None,
|
||||
file_ids: list[str] | None,
|
||||
) -> _DedupLock | None:
|
||||
"""Acquire the idempotency lock for this (session, message, files) tuple.
|
||||
|
||||
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
|
||||
Returns ``None`` when a duplicate is detected (lock already held).
|
||||
Returns ``None`` when there is nothing to deduplicate (no message, no files).
|
||||
"""
|
||||
if not message and not file_ids:
|
||||
return None
|
||||
|
||||
sorted_ids = ":".join(sorted(file_ids or []))
|
||||
content_hash = hashlib.sha256(
|
||||
f"{session_id}:{message or ''}:{sorted_ids}".encode()
|
||||
).hexdigest()[:16]
|
||||
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
|
||||
|
||||
redis = await get_redis_async()
|
||||
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"[STREAM] Duplicate user message blocked for session {session_id}, "
|
||||
f"hash={content_hash} — returning empty SSE",
|
||||
)
|
||||
return None
|
||||
|
||||
return _DedupLock(key, redis)
|
||||
@@ -1,94 +0,0 @@
|
||||
"""Unit tests for backend.copilot.message_dedup."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
|
||||
|
||||
|
||||
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
return mock_redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_when_no_message_no_files(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Nothing to deduplicate — no Redis call made, None returned."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
result = await acquire_dedup_lock("sess-1", None, None)
|
||||
assert result is None
|
||||
mock_redis.set.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_lock_on_first_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""First request acquires the lock and returns a _DedupLock."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
mock_redis.set.assert_called_once()
|
||||
key_arg = mock_redis.set.call_args.args[0]
|
||||
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_on_duplicate(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Duplicate request (NX fails) returns None to signal the caller."""
|
||||
_patch_redis(mocker, set_returns=None)
|
||||
result = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_key_stable_across_file_order(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""File IDs are sorted before hashing so order doesn't affect the key."""
|
||||
mock_redis_1 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
|
||||
key_ab = mock_redis_1.set.call_args.args[0]
|
||||
|
||||
mock_redis_2 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
|
||||
key_ba = mock_redis_2.set.call_args.args[0]
|
||||
|
||||
assert key_ab == key_ba
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_deletes_key(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() calls Redis delete exactly once."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release()
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_swallows_redis_error(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() must not raise even when Redis delete fails."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release() # must not raise
|
||||
mock_redis.delete.assert_called_once()
|
||||
@@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Callable, Self, cast
|
||||
from weakref import WeakValueDictionary
|
||||
from typing import Any, AsyncIterator, Callable, Self, cast
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
@@ -522,10 +521,7 @@ async def upsert_chat_session(
|
||||
callers are aware of the persistence failure.
|
||||
RedisError: If the cache write fails (after successful DB write).
|
||||
"""
|
||||
# Acquire session-specific lock to prevent concurrent upserts
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
async with _get_session_lock(session.session_id) as _:
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
|
||||
@@ -651,20 +647,50 @@ async def _save_session_to_db(
|
||||
msg.sequence = existing_message_count + i
|
||||
|
||||
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
async def append_and_save_message(
|
||||
session_id: str, message: ChatMessage
|
||||
) -> ChatSession | None:
|
||||
"""Atomically append a message to a session and persist it.
|
||||
|
||||
Acquires the session lock, re-fetches the latest session state,
|
||||
appends the message, and saves — preventing message loss when
|
||||
concurrent requests modify the same session.
|
||||
"""
|
||||
lock = await _get_session_lock(session_id)
|
||||
Returns the updated session, or None if the message was detected as a
|
||||
duplicate (idempotency guard). Callers must check for None and skip any
|
||||
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
|
||||
|
||||
async with lock:
|
||||
session = await get_chat_session(session_id)
|
||||
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
|
||||
The idempotency check below provides a last-resort guard when the lock degrades.
|
||||
"""
|
||||
async with _get_session_lock(session_id) as lock_acquired:
|
||||
# When the lock degraded (Redis down or 2s timeout), bypass cache for
|
||||
# the idempotency check. Stale cache could let two concurrent writers
|
||||
# both see the old state, pass the check, and write the same message.
|
||||
if lock_acquired:
|
||||
session = await get_chat_session(session_id)
|
||||
else:
|
||||
session = await _get_session_from_db(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# Idempotency: skip if the trailing block of same-role messages already
|
||||
# contains this content. Uses is_message_duplicate which checks all
|
||||
# consecutive trailing messages of the same role, not just [-1].
|
||||
#
|
||||
# This collapses infra/nginx retries whether they land on the same pod
|
||||
# (serialised by the Redis lock) or a different pod.
|
||||
#
|
||||
# Legit same-text messages are distinguished by the assistant turn
|
||||
# between them: if the user said "yes", got a response, and says
|
||||
# "yes" again, session.messages[-1] is the assistant reply, so the
|
||||
# role check fails and the second message goes through normally.
|
||||
#
|
||||
# Edge case: if a turn dies without writing any assistant message,
|
||||
# the user's next send of the same text is blocked here permanently.
|
||||
# The fix is to ensure failed turns always write an error/timeout
|
||||
# assistant message so the session always ends on an assistant turn.
|
||||
if message.content is not None and is_message_duplicate(
|
||||
session.messages, message.role, message.content
|
||||
):
|
||||
return None # duplicate — caller should skip enqueue
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
@@ -679,6 +705,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||
# Invalidate the stale entry so future reads fall back to DB,
|
||||
# preventing a retry from bypassing the idempotency check above.
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
return session
|
||||
|
||||
@@ -699,9 +728,7 @@ async def append_message_if(
|
||||
Returns the updated session on append, or ``None`` if the predicate
|
||||
rejected, the session no longer exists, or the append failed.
|
||||
"""
|
||||
lock = await _get_session_lock(session_id)
|
||||
|
||||
async with lock:
|
||||
async with _get_session_lock(session_id) as _lock_acquired:
|
||||
# Read from DB directly — the Redis cache can be stale because the
|
||||
# executor's upsert_chat_session overwrites it with in-memory copies
|
||||
# during streaming, which may not include messages appended by the
|
||||
@@ -815,10 +842,6 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
@@ -883,25 +906,38 @@ async def update_session_title(
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
|
||||
"""Distributed Redis lock for a session, usable as an async context manager.
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
Yields True if the lock was acquired, False if it timed out or Redis was
|
||||
unavailable. Callers should treat False as a degraded mode and prefer fresh
|
||||
DB reads over cache to avoid acting on stale state.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
|
||||
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
|
||||
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
_lock_key = f"copilot:session_lock:{session_id}"
|
||||
lock = None
|
||||
acquired = False
|
||||
try:
|
||||
_redis = await get_redis_async()
|
||||
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
|
||||
acquired = await lock.acquire(blocking=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
"Could not acquire session lock for %s within 2s", session_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
|
||||
|
||||
try:
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired and lock is not None:
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception:
|
||||
pass # TTL will expire the key
|
||||
|
||||
@@ -11,11 +11,13 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
append_and_save_message,
|
||||
get_chat_session,
|
||||
is_message_duplicate,
|
||||
maybe_append_user_message,
|
||||
@@ -574,3 +576,345 @@ def test_maybe_append_assistant_skips_duplicate():
|
||||
result = maybe_append_user_message(session, "dup", is_user_message=False)
|
||||
assert result is False
|
||||
assert len(session.messages) == 2
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# append_and_save_message #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
|
||||
s = ChatSession.new(user_id="u1", dry_run=False)
|
||||
s.messages = list(msgs)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_returns_none_for_duplicate(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message returns None when the trailing message is a duplicate."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="user", content="hello"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
|
||||
result = await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="hello")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_appends_new_message(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message appends a non-duplicate message and returns the session."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=2)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="second message")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
assert result is not None
|
||||
assert result.messages[-1].content == "second message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_raises_when_session_not_found(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""append_and_save_message raises ValueError when the session does not exist."""
|
||||
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await append_and_save_message(
|
||||
"missing-session-id", ChatMessage(role="user", content="hi")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_uses_db_when_lock_degraded(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mock_get_from_db = mocker.patch(
|
||||
"backend.copilot.model._get_session_from_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
# DB path was used (not cache-first)
|
||||
mock_get_from_db.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_raises_database_error_on_save_failure(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=RuntimeError("db down"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="new msg")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock()
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=RuntimeError("redis write failed"),
|
||||
)
|
||||
mock_invalidate = mocker.patch(
|
||||
"backend.copilot.model.invalidate_session_cache",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
result = await append_and_save_message(
|
||||
session.session_id, ChatMessage(role="user", content="new msg")
|
||||
)
|
||||
# DB write succeeded, cache invalidation was called
|
||||
mock_invalidate.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_uses_db_when_redis_unavailable(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
side_effect=ConnectionError("redis down"),
|
||||
)
|
||||
mock_get_from_db = mocker.patch(
|
||||
"backend.copilot.model._get_session_from_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
mock_get_from_db.assert_called_once_with(session.session_id)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_append_and_save_message_lock_release_failure_is_ignored(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
|
||||
|
||||
session = _make_session_with_messages(
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
)
|
||||
mock_redis_lock = mocker.AsyncMock()
|
||||
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
|
||||
mock_redis_lock.release = mocker.AsyncMock(
|
||||
side_effect=RuntimeError("release failed")
|
||||
)
|
||||
mock_redis_client = mocker.MagicMock()
|
||||
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_redis_async",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=mock_redis_client,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.get_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
return_value=session,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model._save_session_to_db",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.chat_db",
|
||||
return_value=mocker.MagicMock(
|
||||
get_next_sequence=mocker.AsyncMock(return_value=1)
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.model.cache_chat_session",
|
||||
new_callable=mocker.AsyncMock,
|
||||
)
|
||||
|
||||
new_msg = ChatMessage(role="user", content="new msg")
|
||||
result = await append_and_save_message(session.session_id, new_msg)
|
||||
assert result is not None
|
||||
|
||||
@@ -174,6 +174,7 @@ sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
|
||||
@@ -8,7 +8,7 @@ Cross-mode transcript flow
|
||||
==========================
|
||||
|
||||
Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking
|
||||
mode) read and write the same JSONL transcript store via
|
||||
mode) read and write the same CLI session store via
|
||||
``backend.copilot.transcript.upload_transcript`` /
|
||||
``download_transcript``.
|
||||
|
||||
@@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_s_baseline_loads_sdk_transcript(self):
|
||||
"""Scenario S: SDK-written transcript is accepted by baseline's load helper."""
|
||||
"""Scenario S: SDK-written CLI session is accepted by baseline's load helper."""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
@@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch:
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Baseline session now has those 2 SDK messages + 1 new baseline message.
|
||||
download = TranscriptDownload(content=sdk_transcript, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3, # 2 SDK + 1 new baseline
|
||||
session_messages=[
|
||||
ChatMessage(role="user", content="sdk-question"),
|
||||
ChatMessage(role="assistant", content="sdk-answer"),
|
||||
ChatMessage(role="user", content="baseline-question"),
|
||||
],
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Transcript is valid and covers the prefix.
|
||||
# CLI session is valid and covers the prefix.
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert baseline_builder.entry_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_s_stale_sdk_transcript_not_loaded(self):
|
||||
"""Scenario S (stale): SDK transcript is stale — baseline does not load it.
|
||||
"""Scenario S (stale): SDK CLI session is stale — baseline does not load it.
|
||||
|
||||
If SDK mode produced more turns than the transcript captured (e.g.
|
||||
upload failed on one turn), the baseline rejects the stale transcript
|
||||
If SDK mode produced more turns than the session captured (e.g.
|
||||
upload failed on one turn), the baseline rejects the stale session
|
||||
to avoid injecting an incomplete history.
|
||||
"""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
@@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch:
|
||||
)
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Transcript covers only 2 messages but session has 10 (many SDK turns).
|
||||
download = TranscriptDownload(content=sdk_transcript, message_count=2)
|
||||
# Session covers only 2 messages but session has 10 (many SDK turns).
|
||||
# With watermark=2 and 10 total messages, detect_gap will fill the gap
|
||||
# by appending messages 2..8 (positions 2 to total-2).
|
||||
restore = TranscriptDownload(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
# Build a session with 10 alternating user/assistant messages + current user
|
||||
session_messages = [
|
||||
ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
session_messages=session_messages,
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Stale transcript must be rejected.
|
||||
assert covers is False
|
||||
assert baseline_builder.is_empty
|
||||
# With gap filling, covers is True and gap messages are appended.
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
# 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn)
|
||||
assert baseline_builder.entry_count == 9
|
||||
|
||||
@@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.copilot.transcript import (
|
||||
TranscriptDownload,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
@@ -999,14 +1000,15 @@ def _make_sdk_patches(
|
||||
f"{_SVC}.download_transcript",
|
||||
dict(
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(content=original_transcript, message_count=2),
|
||||
return_value=TranscriptDownload(
|
||||
content=original_transcript.encode("utf-8"),
|
||||
message_count=2,
|
||||
mode="sdk",
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
dict(new_callable=AsyncMock, return_value=True),
|
||||
),
|
||||
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.validate_transcript", dict(return_value=True)),
|
||||
(
|
||||
f"{_SVC}.compact_transcript",
|
||||
@@ -1037,7 +1039,6 @@ def _make_sdk_patches(
|
||||
claude_agent_fallback_model=None,
|
||||
),
|
||||
),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
|
||||
]
|
||||
|
||||
@@ -1914,14 +1915,14 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
compacted_transcript=None,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
# Override restore_cli_session to return False (CLI native session unavailable)
|
||||
# Override download_transcript to return None (CLI native session unavailable)
|
||||
patches = [
|
||||
(
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
dict(new_callable=AsyncMock, return_value=False),
|
||||
f"{_SVC}.download_transcript",
|
||||
dict(new_callable=AsyncMock, return_value=None),
|
||||
)
|
||||
if p[0] == f"{_SVC}.restore_cli_session"
|
||||
if p[0] == f"{_SVC}.download_transcript"
|
||||
else p
|
||||
)
|
||||
for p in patches
|
||||
@@ -1944,7 +1945,7 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
# captured_options holds {"options": ClaudeAgentOptions}, so check
|
||||
# the attribute directly rather than dict keys.
|
||||
assert not getattr(captured_options.get("options"), "resume", None), (
|
||||
f"--resume was set even though restore_cli_session returned False: "
|
||||
f"--resume was set even though download_transcript returned None: "
|
||||
f"{captured_options}"
|
||||
)
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -365,7 +365,7 @@ def create_security_hooks(
|
||||
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# validates against projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = _sanitize(
|
||||
|
||||
@@ -16,6 +16,7 @@ import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field as dataclass_field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, NotRequired, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -92,12 +93,15 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from ..transcript import (
|
||||
_run_compression,
|
||||
TranscriptDownload,
|
||||
cleanup_stale_project_dirs,
|
||||
cli_session_path,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
extract_context_messages,
|
||||
projects_base,
|
||||
read_compacted_entries,
|
||||
restore_cli_session,
|
||||
upload_cli_session,
|
||||
strip_for_upload,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
@@ -121,7 +125,12 @@ config = ChatConfig()
|
||||
|
||||
|
||||
class _SystemPromptPreset(SystemPromptPreset, total=False):
|
||||
"""Extends SystemPromptPreset with fields added in claude-agent-sdk 0.1.59."""
|
||||
"""Extends :class:`SystemPromptPreset` with ``exclude_dynamic_sections``.
|
||||
|
||||
The field was added to the upstream TypedDict in claude-agent-sdk 0.1.59.
|
||||
Until the package is pinned to that version we declare it locally so Pyright
|
||||
accepts the kwarg without a ``# type: ignore`` comment.
|
||||
"""
|
||||
|
||||
exclude_dynamic_sections: NotRequired[bool]
|
||||
|
||||
@@ -849,6 +858,181 @@ def _make_sdk_cwd(session_id: str) -> str:
|
||||
return cwd
|
||||
|
||||
|
||||
def _write_cli_session_to_disk(
|
||||
content: bytes,
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
) -> bool:
|
||||
"""Write downloaded CLI session bytes to disk so the CLI can --resume.
|
||||
|
||||
Returns True on success, False if the path is invalid or the write fails.
|
||||
Path-traversal guard: rejects paths outside the CLI projects base.
|
||||
"""
|
||||
session_file = cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
_pbase = projects_base()
|
||||
if not real_path.startswith(_pbase + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session restore path outside projects base: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
)
|
||||
return False
|
||||
try:
|
||||
os.makedirs(os.path.dirname(real_path), exist_ok=True)
|
||||
Path(real_path).write_bytes(content)
|
||||
logger.info(
|
||||
"%s Wrote CLI session to disk (%dB) for --resume",
|
||||
log_prefix,
|
||||
len(content),
|
||||
)
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"%s Failed to write CLI session file %s: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
e.strerror or str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def read_cli_session_from_disk(
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
) -> bytes | None:
|
||||
"""Read the CLI session JSONL file from disk after the SDK turn.
|
||||
|
||||
Returns the file bytes, or None if the file is missing, outside the
|
||||
projects base, or unreadable.
|
||||
Path-traversal guard: rejects paths outside the CLI projects base.
|
||||
"""
|
||||
session_file = cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
_pbase = projects_base()
|
||||
if not real_path.startswith(_pbase + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session file outside projects base, skipping upload: %s",
|
||||
log_prefix,
|
||||
os.path.basename(real_path),
|
||||
)
|
||||
return None
|
||||
try:
|
||||
raw_bytes = Path(real_path).read_bytes()
|
||||
except FileNotFoundError:
|
||||
logger.debug(
|
||||
"%s CLI session file not found, skipping upload: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
)
|
||||
return None
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"%s Failed to read CLI session file %s: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
e.strerror or str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
# Strip stale thinking blocks and metadata entries before uploading.
|
||||
# Thinking blocks from non-last turns can be massive; keeping them causes
|
||||
# the CLI to auto-compact its session when the context window fills up,
|
||||
# silently losing conversation history.
|
||||
try:
|
||||
raw_text = raw_bytes.decode("utf-8")
|
||||
stripped_text = strip_for_upload(raw_text)
|
||||
stripped_bytes = stripped_text.encode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("%s CLI session is not valid UTF-8, uploading raw", log_prefix)
|
||||
return raw_bytes
|
||||
except (OSError, ValueError) as e:
|
||||
# OSError: encode/decode I/O failure; ValueError: malformed JSONL in strip.
|
||||
# Other unexpected exceptions are not silently swallowed here so they propagate
|
||||
# to the outer OSError handler and are logged with exc_info.
|
||||
logger.warning(
|
||||
"%s Failed to strip CLI session, uploading raw: %s", log_prefix, e
|
||||
)
|
||||
return raw_bytes
|
||||
|
||||
if len(stripped_bytes) < len(raw_bytes):
|
||||
# Write back locally so same-pod turns also benefit.
|
||||
try:
|
||||
Path(real_path).write_bytes(stripped_bytes)
|
||||
logger.info(
|
||||
"%s Stripped CLI session: %dB → %dB",
|
||||
log_prefix,
|
||||
len(raw_bytes),
|
||||
len(stripped_bytes),
|
||||
)
|
||||
except OSError as e:
|
||||
# write_bytes failed — stripped content is still valid for GCS upload even
|
||||
# though the local write-back failed (same-pod optimization silently skipped).
|
||||
logger.warning(
|
||||
"%s Failed to write back stripped CLI session: %s",
|
||||
log_prefix,
|
||||
e.strerror or str(e),
|
||||
)
|
||||
return stripped_bytes
|
||||
|
||||
|
||||
def process_cli_restore(
|
||||
cli_restore: TranscriptDownload,
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
) -> tuple[str, bool]:
|
||||
"""Validate and write a restored CLI session to disk.
|
||||
|
||||
Decodes bytes → UTF-8, strips progress entries and stale thinking blocks,
|
||||
validates the result, then writes the stripped content to disk so the CLI
|
||||
can ``--resume`` from it.
|
||||
|
||||
Returns ``(stripped_content, success)`` where ``success=False`` means the
|
||||
content was invalid or the disk write failed (caller should skip --resume).
|
||||
"""
|
||||
try:
|
||||
raw_bytes = cli_restore.content
|
||||
raw_str = (
|
||||
raw_bytes.decode("utf-8") if isinstance(raw_bytes, bytes) else raw_bytes
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
logger.warning(
|
||||
"%s CLI session content is not valid UTF-8, skipping", log_prefix
|
||||
)
|
||||
return "", False
|
||||
|
||||
stripped = strip_for_upload(raw_str)
|
||||
is_valid = validate_transcript(stripped)
|
||||
# Use len(raw_str) rather than len(cli_restore.content) so the unit is always
|
||||
# characters (raw_str is always str at this point regardless of input type).
|
||||
# lines_stripped = original lines minus remaining lines after stripping.
|
||||
_original_lines = len(raw_str.strip().split("\n")) if raw_str.strip() else 0
|
||||
_remaining_lines = len(stripped.strip().split("\n")) if stripped.strip() else 0
|
||||
logger.info(
|
||||
"%s Restored CLI session: %dB raw, %d lines stripped, msg_count=%d, valid=%s",
|
||||
log_prefix,
|
||||
len(raw_str),
|
||||
_original_lines - _remaining_lines,
|
||||
cli_restore.message_count,
|
||||
is_valid,
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"%s CLI session content invalid after strip — running without --resume",
|
||||
log_prefix,
|
||||
)
|
||||
return "", False
|
||||
|
||||
stripped_bytes = stripped.encode("utf-8")
|
||||
if not _write_cli_session_to_disk(stripped_bytes, sdk_cwd, session_id, log_prefix):
|
||||
return "", False
|
||||
|
||||
return stripped, True
|
||||
|
||||
|
||||
async def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""Remove SDK session artifacts for a specific working directory.
|
||||
|
||||
@@ -922,8 +1106,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
result.append(block)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Unknown content block type: {type(block).__name__}. "
|
||||
f"This may indicate a new SDK version with additional block types."
|
||||
"[SDK] Unknown content block type: %s."
|
||||
" This may indicate a new SDK version with additional block types.",
|
||||
type(block).__name__,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -978,10 +1163,11 @@ async def _compress_messages(
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
f"[SDK] Context compacted: {result.original_token_count} -> "
|
||||
f"{result.token_count} tokens "
|
||||
f"({result.messages_summarized} summarized, "
|
||||
f"{result.messages_dropped} dropped)"
|
||||
"[SDK] Context compacted: %d -> %d tokens (%d summarized, %d dropped)",
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
# Convert compressed dicts back to ChatMessages
|
||||
return [
|
||||
@@ -1048,11 +1234,17 @@ def _session_messages_to_transcript(messages: list[ChatMessage]) -> str:
|
||||
)
|
||||
if blocks:
|
||||
builder.append_assistant(blocks)
|
||||
elif msg.role == "tool" and msg.tool_call_id:
|
||||
builder.append_tool_result(
|
||||
tool_use_id=msg.tool_call_id,
|
||||
content=msg.content or "",
|
||||
)
|
||||
elif msg.role == "tool":
|
||||
if msg.tool_call_id:
|
||||
builder.append_tool_result(
|
||||
tool_use_id=msg.tool_call_id,
|
||||
content=msg.content or "",
|
||||
)
|
||||
else:
|
||||
# Malformed tool message — no tool_call_id to link to an
|
||||
# assistant tool_use block. Skip to avoid an unmatched
|
||||
# tool_result entry in the builder (which would confuse --resume).
|
||||
logger.warning("[SDK] Skipping tool gap message with no tool_call_id")
|
||||
return builder.to_jsonl()
|
||||
|
||||
|
||||
@@ -1098,6 +1290,7 @@ async def _build_query_message(
|
||||
transcript_msg_count: int,
|
||||
session_id: str,
|
||||
target_tokens: int | None = None,
|
||||
prior_messages: "list[ChatMessage] | None" = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Build the query message with appropriate context.
|
||||
|
||||
@@ -1203,15 +1396,16 @@ async def _build_query_message(
|
||||
)
|
||||
return current_message, False
|
||||
|
||||
source = prior_messages if prior_messages is not None else prior
|
||||
logger.warning(
|
||||
"[SDK] [%s] No --resume for %d-message session — compressing"
|
||||
" full session history (pod affinity issue or first turn after"
|
||||
" restore failure); target_tokens=%s",
|
||||
"[SDK] [%s] No --resume for %d-message session — compressing context "
|
||||
"(source=%s, target_tokens=%s)",
|
||||
session_id[:8],
|
||||
msg_count,
|
||||
"transcript+gap" if prior_messages is not None else "full-db",
|
||||
target_tokens,
|
||||
)
|
||||
compressed, was_compressed = await _compress_messages(prior, target_tokens)
|
||||
compressed, was_compressed = await _compress_messages(source, target_tokens)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
logger.info(
|
||||
@@ -1228,7 +1422,7 @@ async def _build_query_message(
|
||||
"[SDK] [%s] Fallback context empty after compression"
|
||||
" (%d messages) — sending message without history",
|
||||
session_id[:8],
|
||||
len(prior),
|
||||
len(source),
|
||||
)
|
||||
|
||||
return current_message, False
|
||||
@@ -2233,6 +2427,161 @@ async def _seed_transcript(
|
||||
return _seeded, True, len(_prior)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RestoreResult:
|
||||
"""Return value from ``_restore_cli_session_for_turn``."""
|
||||
|
||||
transcript_content: str = ""
|
||||
transcript_covers_prefix: bool = True
|
||||
use_resume: bool = False
|
||||
resume_file: str | None = None
|
||||
transcript_msg_count: int = 0
|
||||
baseline_download: "TranscriptDownload | None" = None
|
||||
context_messages: "list[ChatMessage] | None" = None
|
||||
|
||||
|
||||
async def _restore_cli_session_for_turn(
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: "ChatSession",
|
||||
sdk_cwd: str,
|
||||
transcript_builder: "TranscriptBuilder",
|
||||
log_prefix: str,
|
||||
) -> _RestoreResult:
|
||||
"""Download, validate and restore a CLI session for ``--resume`` on this turn.
|
||||
|
||||
Performs a single GCS round-trip to fetch the session bytes + message_count
|
||||
watermark. Falls back to DB-message reconstruction when GCS has no session
|
||||
(first turn or upload missed).
|
||||
|
||||
Returns a ``_RestoreResult`` with all transcript-related state ready for the
|
||||
caller to merge into its local variables.
|
||||
"""
|
||||
result = _RestoreResult()
|
||||
|
||||
if not (config.claude_agent_use_resume and user_id and len(session.messages) > 1):
|
||||
return result
|
||||
|
||||
try:
|
||||
cli_restore = await download_transcript(
|
||||
user_id, session_id, log_prefix=log_prefix
|
||||
)
|
||||
except Exception as restore_err:
|
||||
logger.warning(
|
||||
"%s CLI session restore failed, continuing without --resume: %s",
|
||||
log_prefix,
|
||||
restore_err,
|
||||
)
|
||||
cli_restore = None
|
||||
|
||||
# Only attempt --resume for SDK-written transcripts.
|
||||
# Baseline-written transcripts use TranscriptBuilder format (synthetic IDs,
|
||||
# stripped fields) that may not be valid for --resume.
|
||||
if cli_restore is not None and cli_restore.mode != "sdk":
|
||||
logger.info(
|
||||
"%s Transcript written by mode=%r — skipping --resume, "
|
||||
"will use transcript content + gap for context",
|
||||
log_prefix,
|
||||
cli_restore.mode,
|
||||
)
|
||||
result.baseline_download = cli_restore # keep for extract_context_messages
|
||||
cli_restore = None
|
||||
|
||||
# Validate, strip, and write to disk — delegate to helper to reduce
|
||||
# function complexity. Writing an invalid/corrupt file to disk then
|
||||
# falling back to "no --resume" would cause the CLI to fail with
|
||||
# "Session ID already in use" because the file exists at the expected
|
||||
# session path, so we validate BEFORE any disk write.
|
||||
stripped = ""
|
||||
if cli_restore is not None and sdk_cwd:
|
||||
stripped, ok = process_cli_restore(cli_restore, sdk_cwd, session_id, log_prefix)
|
||||
if not ok:
|
||||
result.transcript_covers_prefix = False
|
||||
cli_restore = None
|
||||
|
||||
if cli_restore is None and sdk_cwd:
|
||||
# Validation failed or GCS returned no session. Delete any
|
||||
# existing local session file so the CLI doesn't reject the
|
||||
# session_id with "Session ID already in use". T1 may have
|
||||
# left a valid file at this path; we clear it so the fallback
|
||||
# path (session_id= without --resume) can create a new session.
|
||||
_stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id))
|
||||
if Path(_stale_path).exists() and _stale_path.startswith(
|
||||
projects_base() + os.sep
|
||||
):
|
||||
try:
|
||||
Path(_stale_path).unlink()
|
||||
logger.debug(
|
||||
"%s Removed stale local CLI session file for clean fallback",
|
||||
log_prefix,
|
||||
)
|
||||
except OSError as _unlink_err:
|
||||
logger.debug(
|
||||
"%s Failed to remove stale local session file: %s",
|
||||
log_prefix,
|
||||
_unlink_err,
|
||||
)
|
||||
|
||||
if cli_restore is not None:
|
||||
result.transcript_content = stripped
|
||||
transcript_builder.load_previous(stripped, log_prefix=log_prefix)
|
||||
result.use_resume = True
|
||||
result.resume_file = session_id
|
||||
result.transcript_msg_count = cli_restore.message_count
|
||||
return result
|
||||
|
||||
# No valid --resume source (mode="baseline" or no GCS file).
|
||||
# Build context from transcript content + gap, falling back to full DB.
|
||||
# extract_context_messages handles both: non-None baseline_download uses
|
||||
# the compacted transcript + gap; None falls back to all prior DB messages.
|
||||
context_msgs = extract_context_messages(result.baseline_download, session.messages)
|
||||
result.context_messages = context_msgs
|
||||
result.transcript_msg_count = (
|
||||
result.baseline_download.message_count
|
||||
if result.baseline_download is not None
|
||||
and result.baseline_download.message_count > 0
|
||||
else len(session.messages) - 1
|
||||
)
|
||||
result.transcript_covers_prefix = True
|
||||
logger.info(
|
||||
"%s Context built from %s: %d messages (transcript watermark=%d, "
|
||||
"will inject as <conversation_history>)",
|
||||
log_prefix,
|
||||
(
|
||||
"baseline transcript + gap"
|
||||
if result.baseline_download is not None
|
||||
else "DB fallback"
|
||||
),
|
||||
len(context_msgs),
|
||||
result.transcript_msg_count,
|
||||
)
|
||||
|
||||
# Load baseline transcript content into builder so the upload path has accurate state.
|
||||
# Also sets result.transcript_content so the _seed_transcript guard in the caller
|
||||
# (``not transcript_content``) does not overwrite this builder state with a DB
|
||||
# reconstruction — which would duplicate entries since load_previous appends.
|
||||
if result.baseline_download is not None:
|
||||
try:
|
||||
raw_for_builder = result.baseline_download.content
|
||||
if isinstance(raw_for_builder, bytes):
|
||||
raw_for_builder = raw_for_builder.decode("utf-8")
|
||||
stripped = strip_for_upload(raw_for_builder)
|
||||
if validate_transcript(stripped):
|
||||
transcript_builder.load_previous(stripped, log_prefix=log_prefix)
|
||||
result.transcript_content = stripped
|
||||
except (UnicodeDecodeError, ValueError, OSError) as _load_err:
|
||||
# UnicodeDecodeError: non-UTF-8 content; ValueError: malformed JSONL in
|
||||
# strip_for_upload; OSError: encode/decode I/O failure. Unexpected
|
||||
# exceptions propagate so programming errors are not silently masked.
|
||||
logger.debug(
|
||||
"%s Could not load baseline transcript into builder: %s",
|
||||
log_prefix,
|
||||
_load_err,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -2427,28 +2776,9 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
return sandbox
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
if not (
|
||||
config.claude_agent_use_resume and user_id and len(session.messages) > 1
|
||||
):
|
||||
return None
|
||||
try:
|
||||
return await download_transcript(
|
||||
user_id, session_id, log_prefix=log_prefix
|
||||
)
|
||||
except Exception as transcript_err:
|
||||
logger.warning(
|
||||
"%s Transcript download failed, continuing without --resume: %s",
|
||||
log_prefix,
|
||||
transcript_err,
|
||||
)
|
||||
return None
|
||||
|
||||
e2b_sandbox, (base_system_prompt, understanding), dl = await asyncio.gather(
|
||||
e2b_sandbox, (base_system_prompt, understanding) = await asyncio.gather(
|
||||
_setup_e2b(),
|
||||
_build_system_prompt(user_id if not has_history else None),
|
||||
_fetch_transcript(),
|
||||
)
|
||||
|
||||
use_e2b = e2b_sandbox is not None
|
||||
@@ -2473,95 +2803,17 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
warm_ctx = await fetch_warm_context(user_id, message or "") or ""
|
||||
|
||||
# Process transcript download result and restore CLI native session.
|
||||
# The CLI native session file (uploaded after each turn) is the
|
||||
# source of truth for --resume. Our custom JSONL (TranscriptEntry)
|
||||
# is loaded into the builder for future upload_transcript calls.
|
||||
transcript_msg_count = 0
|
||||
if dl:
|
||||
is_valid = validate_transcript(dl.content)
|
||||
dl_lines = dl.content.strip().split("\n") if dl.content else []
|
||||
logger.info(
|
||||
"%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s",
|
||||
log_prefix,
|
||||
len(dl.content),
|
||||
len(dl_lines),
|
||||
dl.message_count,
|
||||
is_valid,
|
||||
)
|
||||
if is_valid:
|
||||
# Load previous FULL context into builder for state tracking.
|
||||
transcript_content = dl.content
|
||||
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
|
||||
# Restore CLI's native session file so --resume session_id works.
|
||||
# Falls back gracefully if not available (first turn or upload missed).
|
||||
# user_id is guaranteed non-None here: _fetch_transcript only sets dl
|
||||
# when `config.claude_agent_use_resume and user_id` is truthy.
|
||||
cli_restored = user_id is not None and await restore_cli_session(
|
||||
user_id, session_id, sdk_cwd, log_prefix=log_prefix
|
||||
)
|
||||
if cli_restored:
|
||||
use_resume = True
|
||||
resume_file = session_id # CLI --resume expects UUID, not file path
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.info(
|
||||
"%s Using --resume %s (%dB transcript, msg_count=%d)",
|
||||
log_prefix,
|
||||
session_id[:8],
|
||||
len(dl.content),
|
||||
transcript_msg_count,
|
||||
)
|
||||
else:
|
||||
# Builder loaded but CLI native session not available.
|
||||
# --resume will not be used this turn; upload after turn
|
||||
# will seed the native session for the next turn.
|
||||
#
|
||||
# Still record transcript_msg_count so _build_query_message
|
||||
# can use the transcript-aware gap path (inject only new
|
||||
# messages since the transcript end) instead of compressing
|
||||
# the full DB history. This avoids prompt-too-long on
|
||||
# large sessions where the CLI session is temporarily
|
||||
# unavailable (e.g. mixed-version rolling deployment).
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.info(
|
||||
"%s CLI session not restored — running without"
|
||||
" --resume this turn (transcript_msg_count=%d for"
|
||||
" gap-aware fallback)",
|
||||
log_prefix,
|
||||
transcript_msg_count,
|
||||
)
|
||||
else:
|
||||
logger.warning("%s Transcript downloaded but invalid", log_prefix)
|
||||
transcript_covers_prefix = False
|
||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
# No transcript in storage — reconstruct from DB messages as a
|
||||
# last-resort fallback (e.g., first turn after a crash or transition).
|
||||
# This path loses tool call IDs and structural fidelity but prevents
|
||||
# a completely context-free response for established sessions.
|
||||
prior = session.messages[:-1]
|
||||
reconstructed = _session_messages_to_transcript(prior)
|
||||
if reconstructed:
|
||||
# Populate builder only; no --resume since there is no CLI
|
||||
# native session to restore. The transcript builder state is
|
||||
# still useful for the upload that seeds future native sessions.
|
||||
transcript_content = reconstructed
|
||||
transcript_builder.load_previous(reconstructed, log_prefix=log_prefix)
|
||||
transcript_msg_count = len(prior)
|
||||
transcript_covers_prefix = True
|
||||
logger.info(
|
||||
"%s Reconstructed transcript from %d session messages "
|
||||
"(no CLI native session — running without --resume this turn)",
|
||||
log_prefix,
|
||||
len(prior),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s No transcript available and reconstruction produced empty"
|
||||
" output (%d messages in session)",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
transcript_covers_prefix = False
|
||||
# Restore CLI session — single GCS round-trip covers both --resume and builder state.
|
||||
# message_count watermark lives in the companion .meta.json alongside the session file.
|
||||
_restore = await _restore_cli_session_for_turn(
|
||||
user_id, session_id, session, sdk_cwd, transcript_builder, log_prefix
|
||||
)
|
||||
transcript_content = _restore.transcript_content
|
||||
transcript_covers_prefix = _restore.transcript_covers_prefix
|
||||
use_resume = _restore.use_resume
|
||||
resume_file = _restore.resume_file
|
||||
transcript_msg_count = _restore.transcript_msg_count
|
||||
restore_context_messages = _restore.context_messages
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
@@ -2680,14 +2932,14 @@ async def stream_chat_completion_sdk(
|
||||
else:
|
||||
# Set session_id whenever NOT resuming so the CLI writes the
|
||||
# native session file to a predictable path for
|
||||
# upload_cli_session() after the turn. This covers:
|
||||
# upload_transcript() after the turn. This covers:
|
||||
# • T1 fresh: no prior history, first SDK turn.
|
||||
# • Mode-switch T1: has_history=True (prior baseline turns in
|
||||
# DB) but no CLI session file was ever uploaded — the CLI has
|
||||
# never been invoked with this session_id before.
|
||||
# • T2+ without --resume (restore failed): no session file was
|
||||
# restored to local storage (restore_cli_session returned
|
||||
# False), so no conflict with an existing file.
|
||||
# restored to local storage (download_transcript returned
|
||||
# None), so no conflict with an existing file.
|
||||
# When --resume is active the session_id is already implied by
|
||||
# the resume file; passing it again would be rejected by the CLI.
|
||||
sdk_options_kwargs["session_id"] = session_id
|
||||
@@ -2780,6 +3032,7 @@ async def stream_chat_completion_sdk(
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
prior_messages=restore_context_messages,
|
||||
)
|
||||
# If files are attached, prepare them: images become vision
|
||||
# content blocks in the user message, other files go to sdk_cwd.
|
||||
@@ -2909,7 +3162,7 @@ async def stream_chat_completion_sdk(
|
||||
elif "session_id" in sdk_options_kwargs:
|
||||
# Initial invocation used session_id (T1 or mode-switch
|
||||
# T1): keep it so the CLI writes the session file to the
|
||||
# predictable path for upload_cli_session(). Storage is
|
||||
# predictable path for upload_transcript(). Storage is
|
||||
# ephemeral per invocation, so no "Session ID already in
|
||||
# use" conflict occurs — no prior file was restored.
|
||||
sdk_options_kwargs_retry.pop("resume", None)
|
||||
@@ -2932,6 +3185,10 @@ async def stream_chat_completion_sdk(
|
||||
system_prompt, cross_user_cache=_cross_user_retry
|
||||
)
|
||||
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
|
||||
# Retry intentionally omits prior_messages (transcript+gap context) and
|
||||
# falls back to full session.messages[:-1] from DB — the authoritative
|
||||
# source. transcript+gap is an optimisation for the first attempt only;
|
||||
# on retry the extra overhead of full-DB context is acceptable.
|
||||
state.query_message, state.was_compacted = await _build_query_message(
|
||||
current_message,
|
||||
session,
|
||||
@@ -3367,86 +3624,23 @@ async def stream_chat_completion_sdk(
|
||||
_background_tasks.add(_ingest_task)
|
||||
_ingest_task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# TranscriptBuilder is the single source of truth. It mirrors the
|
||||
# CLI's active context: on compaction, replace_entries() syncs it
|
||||
# with the compacted session file. No CLI file read needed here.
|
||||
if skip_transcript_upload:
|
||||
logger.warning(
|
||||
"%s Skipping transcript upload — transcript was dropped "
|
||||
"during prompt-too-long recovery",
|
||||
log_prefix,
|
||||
)
|
||||
elif (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
and session is not None
|
||||
and state is not None
|
||||
):
|
||||
try:
|
||||
transcript_upload_content = state.transcript_builder.to_jsonl()
|
||||
entry_count = state.transcript_builder.entry_count
|
||||
|
||||
if not transcript_upload_content:
|
||||
logger.warning(
|
||||
"%s No transcript to upload (builder empty)", log_prefix
|
||||
)
|
||||
elif not validate_transcript(transcript_upload_content):
|
||||
logger.warning(
|
||||
"%s Transcript invalid, skipping upload (entries=%d)",
|
||||
log_prefix,
|
||||
entry_count,
|
||||
)
|
||||
elif not transcript_covers_prefix:
|
||||
logger.warning(
|
||||
"%s Skipping transcript upload — builder does not "
|
||||
"cover full session prefix (entries=%d, session=%d)",
|
||||
log_prefix,
|
||||
entry_count,
|
||||
len(session.messages),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s Uploading transcript (entries=%d, bytes=%d)",
|
||||
log_prefix,
|
||||
entry_count,
|
||||
len(transcript_upload_content),
|
||||
)
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=transcript_upload_content,
|
||||
message_count=len(session.messages),
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
except Exception as upload_err:
|
||||
logger.error(
|
||||
"%s Transcript upload failed in finally: %s",
|
||||
log_prefix,
|
||||
upload_err,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# --- Upload CLI native session file for cross-pod --resume ---
|
||||
# The CLI writes its native session JSONL after each turn completes.
|
||||
# Uploading it here enables --resume on any pod (no pod affinity needed).
|
||||
# Runs after upload_transcript so both are available for the next turn.
|
||||
# asyncio.shield: same pattern as upload_transcript above — if the
|
||||
# outer finally-block coroutine is cancelled while awaiting shield,
|
||||
# the CancelledError propagates (BaseException, not caught by
|
||||
# `except Exception`) letting the caller handle cancellation, while
|
||||
# the shielded inner coroutine continues running to completion so the
|
||||
# upload is not lost. This is intentional and matches the pattern
|
||||
# used for upload_transcript immediately above.
|
||||
# The companion .meta.json carries the message_count watermark and mode
|
||||
# so the next turn can restore both --resume context and gap-fill state
|
||||
# in a single GCS round-trip via download_transcript().
|
||||
# asyncio.shield: if the outer finally-block coroutine is cancelled
|
||||
# while awaiting shield, the CancelledError propagates (BaseException,
|
||||
# not caught by `except Exception`) letting the caller handle
|
||||
# cancellation, while the shielded inner coroutine continues running
|
||||
# to completion so the upload is not lost.
|
||||
#
|
||||
# NOTE: upload is attempted regardless of state.use_resume — even when
|
||||
# this turn ran without --resume (restore failed or first T2+ on a new
|
||||
# pod), the T1 session file at the expected path may still be present
|
||||
# and should be re-uploaded so the next turn can resume from it.
|
||||
# upload_cli_session silently skips when the file is absent, so this is
|
||||
# always safe.
|
||||
# read_cli_session_from_disk returns None when the file is absent, so
|
||||
# this is always safe.
|
||||
#
|
||||
# Intentionally NOT gated on skip_transcript_upload: that flag is set
|
||||
# when our custom JSONL transcript is dropped (transcript_lost=True on
|
||||
@@ -3472,14 +3666,36 @@ async def stream_chat_completion_sdk(
|
||||
skip_transcript_upload,
|
||||
)
|
||||
try:
|
||||
await asyncio.shield(
|
||||
upload_cli_session(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
# Read the CLI's native session file from disk (written by the CLI
|
||||
# after the turn), then upload the bytes to GCS.
|
||||
_cli_content = read_cli_session_from_disk(
|
||||
sdk_cwd, session_id, log_prefix
|
||||
)
|
||||
if _cli_content:
|
||||
# Watermark = number of DB messages this transcript covers.
|
||||
# len(session.messages) is accurate: the CLI session file
|
||||
# was just written after the turn completed, so it covers
|
||||
# all messages through this turn. Any gap from a prior
|
||||
# missed upload was already detected by detect_gap and
|
||||
# injected as context, so the model has the full history.
|
||||
#
|
||||
# Previously this used _final_tmsg_count + 2, which
|
||||
# under-counted for tool-use turns (delta = 2 + 2*N_tool_calls),
|
||||
# causing persistent spurious gap-fills on every subsequent turn.
|
||||
# That concern was addressed by the inflated-watermark fix
|
||||
# (using the GCS watermark as the anchor for gap detection),
|
||||
# which makes len(session.messages) safe to use here.
|
||||
_jsonl_covered = len(session.messages)
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=_cli_content,
|
||||
message_count=_jsonl_covered,
|
||||
mode="sdk",
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
except Exception as cli_upload_err:
|
||||
logger.warning(
|
||||
"%s CLI session upload failed in finally: %s",
|
||||
|
||||
@@ -22,6 +22,7 @@ from .service import (
|
||||
_iter_sdk_messages,
|
||||
_normalize_model_name,
|
||||
_reduce_context,
|
||||
_restore_cli_session_for_turn,
|
||||
_TokenUsage,
|
||||
)
|
||||
|
||||
@@ -615,3 +616,340 @@ class TestSdkSessionIdSelection:
|
||||
)
|
||||
assert retry.get("resume") == self.SESSION_ID
|
||||
assert "session_id" not in retry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _restore_cli_session_for_turn — mode check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRestoreCliSessionModeCheck:
|
||||
"""SDK skips --resume when the transcript was written by the baseline mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path):
|
||||
"""A transcript with mode='baseline' must not be used as the --resume source.
|
||||
|
||||
The mode check discards the GCS baseline content and falls back to DB
|
||||
reconstruction from session.messages instead.
|
||||
"""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="hello-unique-marker"),
|
||||
ChatMessage(role="assistant", content="world-unique-marker"),
|
||||
ChatMessage(role="user", content="follow up"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
# Baseline content with a sentinel that must NOT appear in the final transcript
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n',
|
||||
message_count=1,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
download_mock = AsyncMock(return_value=baseline_restore)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=download_mock,
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
# download_transcript was called (attempted GCS restore)
|
||||
download_mock.assert_awaited_once()
|
||||
# use_resume must be False — baseline transcripts cannot be used with --resume
|
||||
assert result.use_resume is False
|
||||
# context_messages must be populated — new behaviour uses transcript content + gap
|
||||
# instead of full DB reconstruction.
|
||||
assert result.context_messages is not None
|
||||
# The baseline transcript has 1 user message (BASELINE_SENTINEL).
|
||||
# Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns [].
|
||||
# Result: 1 message from transcript, no gap.
|
||||
assert len(result.context_messages) == 1
|
||||
assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_mode_transcript_allows_resume(self, tmp_path):
|
||||
"""A valid SDK-written transcript is accepted for --resume."""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
ChatMessage(role="user", content="follow up"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
sdk_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2,
|
||||
mode="sdk",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=sdk_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_context_messages_from_transcript_content(
|
||||
self, tmp_path
|
||||
):
|
||||
"""mode='baseline' → context_messages populated from transcript content + gap.
|
||||
|
||||
When a baseline-mode transcript exists, extract_context_messages converts
|
||||
the JSONL content to ChatMessage objects and returns them in context_messages.
|
||||
use_resume must remain False.
|
||||
"""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
# Build a minimal valid JSONL transcript with 2 messages
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "TRANSCRIPT_USER"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="DB_USER"),
|
||||
ChatMessage(role="assistant", content="DB_ASSISTANT"),
|
||||
ChatMessage(role="user", content="current turn"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=baseline_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is False
|
||||
assert result.context_messages is not None
|
||||
# Transcript content has 2 messages, no gap (watermark=2, session prior=2)
|
||||
assert len(result.context_messages) == 2
|
||||
assert result.context_messages[0].role == "user"
|
||||
assert result.context_messages[1].role == "assistant"
|
||||
assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "")
|
||||
# transcript_content must be non-empty so the _seed_transcript guard in
|
||||
# stream_chat_completion_sdk skips DB reconstruction (which would duplicate
|
||||
# builder entries since load_previous appends).
|
||||
assert result.transcript_content != ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path):
|
||||
"""mode='baseline' + gap → context_messages includes transcript msgs and gap."""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
# Transcript covers only 2 messages; session has 4 prior + current turn
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "TRANSCRIPT_USER_0"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="DB_USER_0"),
|
||||
ChatMessage(role="assistant", content="DB_ASSISTANT_1"),
|
||||
ChatMessage(role="user", content="GAP_USER_2"),
|
||||
ChatMessage(role="assistant", content="GAP_ASSISTANT_3"),
|
||||
ChatMessage(role="user", content="current turn"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2, # watermark=2; session has 4 prior → gap of 2
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=baseline_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is False
|
||||
assert result.context_messages is not None
|
||||
# 2 from transcript + 2 gap messages = 4 total
|
||||
assert len(result.context_messages) == 4
|
||||
roles = [m.role for m in result.context_messages]
|
||||
assert roles == ["user", "assistant", "user", "assistant"]
|
||||
# Gap messages come from DB (ChatMessage objects)
|
||||
gap_user = result.context_messages[2]
|
||||
gap_asst = result.context_messages[3]
|
||||
assert gap_user.content == "GAP_USER_2"
|
||||
assert gap_asst.content == "GAP_ASSISTANT_3"
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk.
|
||||
|
||||
The fix is at the upload step: when use_resume=True and transcript_msg_count>0
|
||||
we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just
|
||||
recorded) instead of len(session.messages). This prevents the "inflated
|
||||
watermark" bug where a stale JSONL in GCS could hide missing context from
|
||||
future gap-fill checks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _compute_jsonl_covered(
|
||||
use_resume: bool,
|
||||
transcript_msg_count: int,
|
||||
session_msg_count: int,
|
||||
) -> int:
|
||||
"""Mirror the watermark computation from ``stream_chat_completion_sdk``.
|
||||
|
||||
Extracted here so we can unit-test it independently without invoking the
|
||||
full streaming stack.
|
||||
"""
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
return transcript_msg_count + 2
|
||||
return session_msg_count
|
||||
|
||||
|
||||
class TestWatermarkFix:
|
||||
"""Watermark computation logic — mirrors the finally-block in SDK service."""
|
||||
|
||||
def test_inflated_watermark_triggers_gap_fill(self):
|
||||
"""Stale JSONL (T12) with high watermark (46) → after fix, watermark=14.
|
||||
|
||||
Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1)
|
||||
never fires because 46 >= 47-1=46, so context loss is silent.
|
||||
After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and
|
||||
the model receives the missing turns.
|
||||
"""
|
||||
# Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47
|
||||
use_resume = True
|
||||
transcript_msg_count = 12
|
||||
session_msg_count = 47 # DB count (what old code used to set watermark)
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == 14 # 12 + 2, NOT 47
|
||||
# Verify: the gap check would fire on next turn
|
||||
# next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True
|
||||
assert watermark < session_msg_count - 1
|
||||
|
||||
def test_no_false_positive_when_transcript_current(self):
|
||||
"""Transcript current (watermark=46, DB=47) → gap stays 0.
|
||||
|
||||
When the JSONL actually covers T46 (the most recent assistant turn),
|
||||
uploading watermark=46+2=48 means next turn's gap check sees
|
||||
48 >= 48-1=47 → no gap. Correct.
|
||||
"""
|
||||
use_resume = True
|
||||
transcript_msg_count = 46
|
||||
session_msg_count = 47
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == 48 # 46 + 2
|
||||
# Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap
|
||||
next_turn_session = 48
|
||||
assert watermark >= next_turn_session - 1
|
||||
|
||||
def test_fresh_session_falls_back_to_db_count(self):
|
||||
"""use_resume=False → watermark = len(session.messages) (original behaviour)."""
|
||||
use_resume = False
|
||||
transcript_msg_count = 0
|
||||
session_msg_count = 3
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == session_msg_count
|
||||
|
||||
def test_old_format_meta_zero_count_falls_back_to_db(self):
|
||||
"""transcript_msg_count=0 (old-format meta with no count field) → DB fallback."""
|
||||
use_resume = True
|
||||
transcript_msg_count = 0 # old-format meta or not-yet-set
|
||||
session_msg_count = 10
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == session_msg_count
|
||||
@@ -12,18 +12,20 @@ from backend.copilot.transcript import (
|
||||
ENTRY_TYPE_MESSAGE,
|
||||
STOP_REASON_END_TURN,
|
||||
STRIPPABLE_TYPES,
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
TranscriptDownload,
|
||||
TranscriptMode,
|
||||
cleanup_stale_project_dirs,
|
||||
cli_session_path,
|
||||
compact_transcript,
|
||||
delete_transcript,
|
||||
detect_gap,
|
||||
download_transcript,
|
||||
extract_context_messages,
|
||||
projects_base,
|
||||
read_compacted_entries,
|
||||
restore_cli_session,
|
||||
strip_for_upload,
|
||||
strip_progress_entries,
|
||||
strip_stale_thinking_blocks,
|
||||
upload_cli_session,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -34,18 +36,20 @@ __all__ = [
|
||||
"ENTRY_TYPE_MESSAGE",
|
||||
"STOP_REASON_END_TURN",
|
||||
"STRIPPABLE_TYPES",
|
||||
"TRANSCRIPT_STORAGE_PREFIX",
|
||||
"TranscriptDownload",
|
||||
"TranscriptMode",
|
||||
"cleanup_stale_project_dirs",
|
||||
"cli_session_path",
|
||||
"compact_transcript",
|
||||
"delete_transcript",
|
||||
"detect_gap",
|
||||
"download_transcript",
|
||||
"extract_context_messages",
|
||||
"projects_base",
|
||||
"read_compacted_entries",
|
||||
"restore_cli_session",
|
||||
"strip_for_upload",
|
||||
"strip_progress_entries",
|
||||
"strip_stale_thinking_blocks",
|
||||
"upload_cli_session",
|
||||
"upload_transcript",
|
||||
"validate_transcript",
|
||||
"write_transcript_to_tempfile",
|
||||
|
||||
@@ -297,8 +297,8 @@ class TestStripProgressEntries:
|
||||
|
||||
class TestDeleteTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_both_jsonl_and_meta(self):
|
||||
"""delete_transcript removes both the .jsonl and .meta.json files."""
|
||||
async def test_deletes_cli_session_and_meta(self):
|
||||
"""delete_transcript removes the CLI session .jsonl and .meta.json."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
@@ -309,7 +309,7 @@ class TestDeleteTranscript:
|
||||
):
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
assert mock_storage.delete.call_count == 3
|
||||
assert mock_storage.delete.call_count == 2
|
||||
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert any(p.endswith(".jsonl") for p in paths)
|
||||
assert any(p.endswith(".meta.json") for p in paths)
|
||||
@@ -319,7 +319,7 @@ class TestDeleteTranscript:
|
||||
"""If .jsonl delete fails, .meta.json delete is still attempted."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock(
|
||||
side_effect=[Exception("jsonl delete failed"), None, None]
|
||||
side_effect=[Exception("jsonl delete failed"), None]
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -330,14 +330,14 @@ class TestDeleteTranscript:
|
||||
# Should not raise
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
assert mock_storage.delete.call_count == 3
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_meta_delete_failure(self):
|
||||
"""If .meta.json delete fails, no exception propagates."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock(
|
||||
side_effect=[None, Exception("meta delete failed"), None]
|
||||
side_effect=[None, Exception("meta delete failed")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks:
|
||||
# Both entries of last turn (msg_last) preserved
|
||||
assert lines[1]["message"]["content"][0]["type"] == "thinking"
|
||||
assert lines[2]["message"]["content"][0]["type"] == "text"
|
||||
|
||||
|
||||
class TestProcessCliRestore:
|
||||
"""``process_cli_restore`` validates, strips, and writes CLI session to disk."""
|
||||
|
||||
def test_writes_stripped_bytes_not_raw(self, tmp_path):
|
||||
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import process_cli_restore
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
|
||||
session_id = "12345678-0000-0000-0000-abcdef000001"
|
||||
sdk_cwd = str(tmp_path)
|
||||
projects_base_dir = str(tmp_path)
|
||||
|
||||
# Build raw content with a strippable progress entry + a valid user/assistant pair
|
||||
raw_content = (
|
||||
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
|
||||
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
|
||||
)
|
||||
raw_bytes = raw_content.encode("utf-8")
|
||||
restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
stripped_str, ok = process_cli_restore(
|
||||
restore, sdk_cwd, session_id, "[Test]"
|
||||
)
|
||||
|
||||
assert ok, "Expected successful restore"
|
||||
|
||||
# Find the written session file
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl"
|
||||
assert session_file.exists(), "Session file should have been written"
|
||||
|
||||
written_bytes = session_file.read_bytes()
|
||||
# The written bytes must be the stripped version (no progress entry)
|
||||
assert (
|
||||
b"progress" not in written_bytes
|
||||
), "Raw bytes with progress entry should not have been written"
|
||||
assert (
|
||||
b"hello" in written_bytes
|
||||
), "Stripped content should still contain assistant turn"
|
||||
|
||||
# Written bytes must equal the stripped string re-encoded
|
||||
assert written_bytes == stripped_str.encode(
|
||||
"utf-8"
|
||||
), "Written bytes must equal stripped content"
|
||||
|
||||
def test_invalid_content_returns_false(self):
|
||||
"""Content that fails validation after strip returns (empty, False)."""
|
||||
from backend.copilot.sdk.service import process_cli_restore
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
|
||||
# A single progress-only entry — stripped result will be empty/invalid
|
||||
raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
restore = TranscriptDownload(
|
||||
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
|
||||
)
|
||||
|
||||
stripped_str, ok = process_cli_restore(
|
||||
restore,
|
||||
"/tmp/nonexistent-sdk-cwd",
|
||||
"12345678-0000-0000-0000-000000000099",
|
||||
"[Test]",
|
||||
)
|
||||
|
||||
assert not ok
|
||||
assert stripped_str == ""
|
||||
|
||||
|
||||
class TestReadCliSessionFromDisk:
|
||||
"""``read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
|
||||
|
||||
def _build_session_file(self, tmp_path, session_id: str):
|
||||
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
sdk_cwd = str(tmp_path)
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
session_dir = Path(str(tmp_path)) / encoded_cwd
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
return sdk_cwd, session_dir / f"{session_id}.jsonl"
|
||||
|
||||
def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path):
|
||||
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import read_cli_session_from_disk
|
||||
|
||||
session_id = "12345678-0000-0000-0000-aabbccdd0001"
|
||||
projects_base_dir = str(tmp_path)
|
||||
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
|
||||
|
||||
# Write raw invalid UTF-8 bytes
|
||||
session_file.write_bytes(b"\xff\xfe invalid utf-8\n")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
|
||||
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
|
||||
assert result == b"\xff\xfe invalid utf-8\n"
|
||||
|
||||
def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path):
|
||||
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import read_cli_session_from_disk
|
||||
|
||||
session_id = "12345678-0000-0000-0000-aabbccdd0002"
|
||||
projects_base_dir = str(tmp_path)
|
||||
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
|
||||
|
||||
# Content with a strippable progress entry so stripped_bytes < raw_bytes
|
||||
raw_content = (
|
||||
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
|
||||
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
|
||||
)
|
||||
session_file.write_bytes(raw_content.encode("utf-8"))
|
||||
# Make the file read-only so write_bytes raises OSError on the write-back
|
||||
session_file.chmod(0o444)
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
finally:
|
||||
session_file.chmod(0o644)
|
||||
|
||||
# Must return stripped bytes (not raw, not None) so GCS gets the clean version
|
||||
assert result is not None
|
||||
assert (
|
||||
b"progress" not in result
|
||||
), "Stripped bytes must not contain progress entry"
|
||||
assert b"hello" in result, "Stripped bytes should contain assistant turn"
|
||||
|
||||
@@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
# (CLI version, platform). When that happens, multi-turn still works
|
||||
# via conversation compression (non-resume path), but we can't test
|
||||
# the --resume round-trip.
|
||||
transcript = None
|
||||
cli_session = None
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.5)
|
||||
transcript = await download_transcript(test_user_id, session.session_id)
|
||||
if transcript:
|
||||
cli_session = await download_transcript(test_user_id, session.session_id)
|
||||
# Wait until both the session bytes AND the message_count watermark are
|
||||
# present — a session with message_count=0 means the .meta.json hasn't
|
||||
# been uploaded yet, so --resume on the next turn would skip gap-fill.
|
||||
if cli_session and cli_session.message_count > 0:
|
||||
break
|
||||
if not transcript:
|
||||
if not cli_session:
|
||||
return pytest.skip(
|
||||
"CLI did not produce a usable transcript — "
|
||||
"cannot test --resume round-trip in this environment"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||
logger.info(
|
||||
f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}"
|
||||
)
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
|
||||
@@ -423,20 +423,33 @@ async def subscribe_to_session(
|
||||
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||
)
|
||||
|
||||
# RACE CONDITION FIX: If session not found, retry once after small delay
|
||||
# This handles the case where subscribe_to_session is called immediately
|
||||
# after create_session but before Redis propagates the write
|
||||
# RACE CONDITION FIX: If session not found, retry with backoff.
|
||||
# Duplicate requests skip create_session and subscribe immediately; the
|
||||
# original request's create_session (a Redis hset) may not have completed
|
||||
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
|
||||
# original request before the hset even starts.
|
||||
if not meta:
|
||||
logger.warning(
|
||||
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
await asyncio.sleep(0.05) # 50ms
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if not meta:
|
||||
_max_retries = 3
|
||||
_retry_delay = 0.1 # 100ms per attempt
|
||||
for attempt in range(_max_retries):
|
||||
logger.warning(
|
||||
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
|
||||
f"retrying after {int(_retry_delay * 1000)}ms",
|
||||
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
|
||||
)
|
||||
await asyncio.sleep(_retry_delay)
|
||||
meta = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if meta:
|
||||
logger.info(
|
||||
f"[TIMING] Session found after {attempt + 1} retries",
|
||||
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
|
||||
)
|
||||
break
|
||||
else:
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
|
||||
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
|
||||
f"({elapsed:.1f}ms total)",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -446,10 +459,6 @@ async def subscribe_to_session(
|
||||
},
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"[TIMING] Session found after retry",
|
||||
extra={"json_fields": {**log_meta}},
|
||||
)
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
session_status = meta.get("status", "")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""JSONL transcript management for stateless multi-turn resume.
|
||||
|
||||
The Claude Code CLI persists conversations as JSONL files (one JSON object per
|
||||
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
|
||||
(progress entries, metadata), and upload the result to bucket storage. On the
|
||||
next turn we download the transcript, write it to a temp file, and pass
|
||||
``--resume`` so the CLI can reconstruct the full conversation.
|
||||
line). When the SDK's ``Stop`` hook fires the caller reads this file, strips
|
||||
bloat (progress entries, metadata), and uploads the result to bucket storage.
|
||||
On the next turn the caller downloads the bytes and writes them to disk before
|
||||
passing ``--resume`` so the CLI can reconstruct the full conversation.
|
||||
|
||||
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
@@ -20,6 +20,7 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
@@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
@@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset(
|
||||
)
|
||||
|
||||
|
||||
TranscriptMode = Literal["sdk", "baseline"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
"""Result of downloading a transcript with its metadata."""
|
||||
|
||||
content: str
|
||||
message_count: int = 0 # session.messages length when uploaded
|
||||
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||
content: bytes | str
|
||||
message_count: int = 0
|
||||
# "sdk" = Claude CLI native, "baseline" = TranscriptBuilder
|
||||
mode: TranscriptMode = "sdk"
|
||||
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
# Storage prefix for the CLI's native session JSONL files (for cross-pod --resume).
|
||||
_CLI_SESSION_STORAGE_PREFIX = "cli-sessions"
|
||||
|
||||
@@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _projects_base() -> str:
|
||||
def projects_base() -> str:
|
||||
"""Return the resolved path to the CLI's projects directory."""
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
@@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
|
||||
Returns the number of directories removed.
|
||||
"""
|
||||
projects_base = _projects_base()
|
||||
if not os.path.isdir(projects_base):
|
||||
_pbase = projects_base()
|
||||
if not os.path.isdir(_pbase):
|
||||
return 0
|
||||
|
||||
now = time.time()
|
||||
@@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
|
||||
# Scoped mode: only clean up the one directory for the current session.
|
||||
if encoded_cwd:
|
||||
target = Path(projects_base) / encoded_cwd
|
||||
target = Path(_pbase) / encoded_cwd
|
||||
if not target.is_dir():
|
||||
return 0
|
||||
# Guard: only sweep copilot-generated dirs.
|
||||
@@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
# Only safe for single-tenant deployments; callers should prefer the
|
||||
# scoped variant by passing encoded_cwd.
|
||||
try:
|
||||
entries = Path(projects_base).iterdir()
|
||||
entries = Path(_pbase).iterdir()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to list projects dir: %s", e)
|
||||
return 0
|
||||
@@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
if not transcript_path:
|
||||
return None
|
||||
|
||||
projects_base = _projects_base()
|
||||
_pbase = projects_base()
|
||||
real_path = os.path.realpath(transcript_path)
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
if not real_path.startswith(_pbase + os.sep):
|
||||
logger.warning(
|
||||
"[Transcript] transcript_path outside projects base: %s", transcript_path
|
||||
)
|
||||
@@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript.
|
||||
|
||||
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
|
||||
IDs are sanitized to hex+hyphen to prevent path traversal.
|
||||
"""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.jsonl",
|
||||
)
|
||||
|
||||
|
||||
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
wid, fid, fname = parts
|
||||
@@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects."""
|
||||
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
|
||||
|
||||
|
||||
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path for the companion .meta.json file."""
|
||||
return _build_path_from_parts(
|
||||
_meta_storage_path_parts(user_id, session_id), backend
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI native session file — cross-pod --resume support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
def cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
"""Expected path of the CLI's native session JSONL file.
|
||||
|
||||
The CLI resolves the working directory via ``os.path.realpath``, then
|
||||
@@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
"""
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
safe_id = _sanitize_id(session_id)
|
||||
return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl")
|
||||
return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl")
|
||||
|
||||
|
||||
def _cli_session_storage_path_parts(
|
||||
@@ -689,235 +659,82 @@ def _cli_session_storage_path_parts(
|
||||
)
|
||||
|
||||
|
||||
async def upload_cli_session(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> None:
|
||||
"""Upload the CLI's native session JSONL file to remote storage.
|
||||
|
||||
Called after each turn so the next turn can restore the file on any pod
|
||||
(eliminating the pod-affinity requirement for --resume).
|
||||
|
||||
The CLI only writes the session file after the turn completes, so this
|
||||
must run in the finally block, AFTER the SDK stream has finished.
|
||||
"""
|
||||
session_file = _cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
projects_base = _projects_base()
|
||||
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session file outside projects base, skipping upload: %s",
|
||||
log_prefix,
|
||||
os.path.basename(real_path),
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
raw_bytes = Path(real_path).read_bytes()
|
||||
except FileNotFoundError:
|
||||
logger.debug(
|
||||
"%s CLI session file not found, skipping upload: %s",
|
||||
log_prefix,
|
||||
session_file,
|
||||
)
|
||||
return
|
||||
except OSError as e:
|
||||
logger.warning("%s Failed to read CLI session file: %s", log_prefix, e)
|
||||
return
|
||||
|
||||
# Strip stale thinking blocks and metadata entries (progress, file-history-snapshot,
|
||||
# queue-operation) from the CLI session before writing it back locally and uploading
|
||||
# to GCS. Thinking blocks from non-last assistant turns are not needed for --resume
|
||||
# but can be massive (tens of thousands of tokens each), causing the CLI to auto-compact
|
||||
# its session when the context window fills up. Stripping keeps the session well below
|
||||
# the ~200K-token compaction threshold and prevents silent context loss.
|
||||
try:
|
||||
raw_text = raw_bytes.decode("utf-8")
|
||||
stripped_text = strip_for_upload(raw_text)
|
||||
stripped_bytes = stripped_text.encode("utf-8")
|
||||
if len(stripped_bytes) < len(raw_bytes):
|
||||
# Write the stripped version back locally so same-pod turns also benefit.
|
||||
Path(real_path).write_bytes(stripped_bytes)
|
||||
logger.info(
|
||||
"%s Stripped CLI session file: %dB → %dB",
|
||||
log_prefix,
|
||||
len(raw_bytes),
|
||||
len(stripped_bytes),
|
||||
)
|
||||
content = stripped_bytes
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"%s Failed to strip CLI session file, uploading raw: %s", log_prefix, e
|
||||
)
|
||||
content = raw_bytes
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=wid, file_id=fid, filename=fname, content=content
|
||||
)
|
||||
logger.info(
|
||||
"%s Uploaded CLI session file (%dB) for cross-pod --resume",
|
||||
log_prefix,
|
||||
len(content),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e)
|
||||
|
||||
|
||||
async def restore_cli_session(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> bool:
|
||||
"""Download and restore the CLI's native session file for --resume.
|
||||
|
||||
Returns True if the file was successfully restored and --resume can be
|
||||
used with the session UUID. Returns False if not available (first turn
|
||||
or upload failed), in which case the caller should not set --resume.
|
||||
"""
|
||||
session_file = _cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
projects_base = _projects_base()
|
||||
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session restore path outside projects base: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
)
|
||||
return False
|
||||
|
||||
# If the session file already exists locally (same-pod reuse), use it directly.
|
||||
# Downloading from storage could overwrite a newer local version when a previous
|
||||
# turn's upload failed: stored content is stale while the local file already
|
||||
# contains extended history from that turn.
|
||||
if Path(real_path).exists():
|
||||
logger.debug(
|
||||
"%s CLI session file already exists locally — using it for --resume",
|
||||
log_prefix,
|
||||
)
|
||||
return True
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for the CLI session meta file."""
|
||||
return (
|
||||
_CLI_SESSION_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
try:
|
||||
content = await storage.retrieve(path)
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to download CLI session: %s", log_prefix, e)
|
||||
return False
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(real_path), exist_ok=True)
|
||||
Path(real_path).write_bytes(content)
|
||||
logger.info(
|
||||
"%s Restored CLI session file (%dB) for --resume",
|
||||
log_prefix,
|
||||
len(content),
|
||||
)
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warning("%s Failed to write CLI session file: %s", log_prefix, e)
|
||||
return False
|
||||
|
||||
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: str,
|
||||
content: bytes,
|
||||
message_count: int = 0,
|
||||
mode: TranscriptMode = "sdk",
|
||||
log_prefix: str = "[Transcript]",
|
||||
skip_strip: bool = False,
|
||||
) -> None:
|
||||
"""Strip progress entries and stale thinking blocks, then upload transcript.
|
||||
"""Upload CLI session content to GCS with companion meta.json.
|
||||
|
||||
The transcript represents the FULL active context (atomic).
|
||||
Each upload REPLACES the previous transcript entirely.
|
||||
Pure GCS operation — no disk I/O. The caller is responsible for reading
|
||||
the session file from disk before calling this function.
|
||||
|
||||
The executor holds a cluster lock per session, so concurrent uploads for
|
||||
the same session cannot happen.
|
||||
Also uploads a companion .meta.json with the message_count watermark so
|
||||
download_transcript can return it without a separate fetch.
|
||||
|
||||
Args:
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
skip_strip: When ``True``, skip the strip + re-validate pass.
|
||||
Safe for builder-generated content (baseline path) which
|
||||
never emits progress entries or stale thinking blocks.
|
||||
Called after each turn so the next turn can restore the file on any pod
|
||||
(eliminating the pod-affinity requirement for --resume).
|
||||
"""
|
||||
if skip_strip:
|
||||
# Caller guarantees the content is already clean and valid.
|
||||
stripped = content
|
||||
else:
|
||||
# Strip metadata entries and stale thinking blocks in a single parse.
|
||||
# SDK-built transcripts may have progress entries; strip for safety.
|
||||
stripped = strip_for_upload(content)
|
||||
if not skip_strip and not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
log_prefix,
|
||||
entry_types,
|
||||
len(stripped),
|
||||
len(content),
|
||||
)
|
||||
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
|
||||
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
|
||||
mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id)
|
||||
meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()}
|
||||
meta_encoded = json.dumps(meta).encode("utf-8")
|
||||
|
||||
# Transcript + metadata are independent objects at different keys, so
|
||||
# write them concurrently. ``return_exceptions`` keeps a metadata
|
||||
# failure from sinking the transcript write.
|
||||
transcript_result, metadata_result = await asyncio.gather(
|
||||
storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
),
|
||||
storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=meta_encoded,
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(transcript_result, BaseException):
|
||||
raise transcript_result
|
||||
if isinstance(metadata_result, BaseException):
|
||||
# Metadata is best-effort — the gap-fill logic in
|
||||
# _build_query_message tolerates a missing metadata file.
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result)
|
||||
# Write JSONL first, meta second — sequential so a crash between the two
|
||||
# leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong
|
||||
# watermark / mode paired with stale or absent content).
|
||||
# On any failure we roll back the other file so the pair is always absent
|
||||
# together; download_transcript returns None when either file is missing.
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=wid, file_id=fid, filename=fname, content=content
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.warning(
|
||||
"%s Failed to upload CLI session file: %s", log_prefix, session_err
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded
|
||||
)
|
||||
except Exception as meta_err:
|
||||
logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err)
|
||||
# Roll back the JSONL so neither file exists — avoids orphaned JSONL being
|
||||
# used with wrong mode/watermark defaults on the next restore.
|
||||
try:
|
||||
session_path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
)
|
||||
await storage.delete(session_path)
|
||||
except Exception as rollback_err:
|
||||
logger.debug(
|
||||
"%s Session rollback failed (harmless — download will return None): %s",
|
||||
log_prefix,
|
||||
rollback_err,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
"%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -926,83 +743,173 @@ async def download_transcript(
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
"""Download CLI session from GCS. Returns content + message_count + mode, or None if not found.
|
||||
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
Pure GCS operation — no disk I/O. The caller is responsible for writing
|
||||
content to disk if --resume is needed.
|
||||
|
||||
The content and metadata fetches run concurrently since they are
|
||||
independent objects in the bucket.
|
||||
Returns a TranscriptDownload with the raw content, message_count watermark,
|
||||
and mode on success, or None if not available (first turn or upload failed).
|
||||
"""
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
)
|
||||
meta_path = _build_path_from_parts(
|
||||
_cli_session_meta_path_parts(user_id, session_id), storage
|
||||
)
|
||||
|
||||
content_task = asyncio.create_task(storage.retrieve(path))
|
||||
meta_task = asyncio.create_task(storage.retrieve(meta_path))
|
||||
content_result, meta_result = await asyncio.gather(
|
||||
content_task, meta_task, return_exceptions=True
|
||||
storage.retrieve(path),
|
||||
storage.retrieve(meta_path),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if isinstance(content_result, FileNotFoundError):
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
|
||||
return None
|
||||
if isinstance(content_result, BaseException):
|
||||
logger.warning(
|
||||
"%s Failed to download transcript: %s", log_prefix, content_result
|
||||
"%s Failed to download CLI session: %s", log_prefix, content_result
|
||||
)
|
||||
return None
|
||||
|
||||
content = content_result.decode("utf-8")
|
||||
content: bytes = content_result
|
||||
|
||||
# Metadata is best-effort — old transcripts won't have it.
|
||||
# Parse message_count and mode from companion meta — best-effort, defaults.
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
mode: TranscriptMode = "sdk"
|
||||
if isinstance(meta_result, FileNotFoundError):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
pass # No meta — old upload; default to "sdk"
|
||||
elif isinstance(meta_result, BaseException):
|
||||
logger.debug(
|
||||
"%s Failed to load transcript metadata: %s", log_prefix, meta_result
|
||||
)
|
||||
logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result)
|
||||
else:
|
||||
meta = json.loads(meta_result.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
try:
|
||||
meta_str = meta_result.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix)
|
||||
meta_str = None
|
||||
if meta_str is not None:
|
||||
meta = json.loads(meta_str, fallback={})
|
||||
if isinstance(meta, dict):
|
||||
raw_count = meta.get("message_count", 0)
|
||||
message_count = (
|
||||
raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0
|
||||
)
|
||||
raw_mode = meta.get("mode", "sdk")
|
||||
mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk"
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
uploaded_at=uploaded_at,
|
||||
"%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)",
|
||||
log_prefix,
|
||||
len(content),
|
||||
message_count,
|
||||
mode,
|
||||
)
|
||||
return TranscriptDownload(content=content, message_count=message_count, mode=mode)
|
||||
|
||||
|
||||
def detect_gap(
|
||||
download: TranscriptDownload,
|
||||
session_messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Return chat-db messages after the transcript watermark (excluding current user turn).
|
||||
|
||||
Returns [] if transcript is current, watermark is zero, or the watermark
|
||||
position doesn't end on an assistant turn (misaligned watermark).
|
||||
"""
|
||||
if download.message_count == 0:
|
||||
return []
|
||||
wm = download.message_count
|
||||
total = len(session_messages)
|
||||
if wm >= total - 1:
|
||||
return []
|
||||
# Sanity: position wm-1 should be an assistant turn; misaligned watermark
|
||||
# means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context.
|
||||
# In normal operation ``message_count`` is always written after a complete
|
||||
# user→assistant exchange (never mid-turn), so the last covered position is
|
||||
# always assistant. This guard fires only on data corruption or message deletion.
|
||||
if session_messages[wm - 1].role != "assistant":
|
||||
return []
|
||||
return list(session_messages[wm : total - 1])
|
||||
|
||||
|
||||
def extract_context_messages(
|
||||
download: TranscriptDownload | None,
|
||||
session_messages: "list[ChatMessage]",
|
||||
) -> "list[ChatMessage]":
|
||||
"""Return context messages for the current turn: transcript content + gap.
|
||||
|
||||
This is the shared context primitive used by both the SDK path
|
||||
(``use_resume=False`` → ``<conversation_history>`` injection) and the
|
||||
baseline path (OpenAI messages array).
|
||||
|
||||
How it works:
|
||||
|
||||
- When a transcript exists, ``TranscriptBuilder.load_previous`` preserves
|
||||
``isCompactSummary=True`` compaction entries, so the returned messages
|
||||
mirror the compacted context the CLI would see via ``--resume``.
|
||||
- The gap (DB messages after the transcript watermark) is always small in
|
||||
normal operation; it only grows during mode switches or when an upload
|
||||
was missed.
|
||||
- Falls back to full DB messages when no transcript exists (first turn,
|
||||
upload failure, or GCS unavailable).
|
||||
- Returns *prior* messages only (excluding the current user turn at
|
||||
``session_messages[-1]``). Callers that need the current turn append
|
||||
``session_messages[-1]`` themselves.
|
||||
- **Tool calls from transcript entries are flattened to text**: assistant
|
||||
messages derived from the JSONL use ``_flatten_assistant_content``, which
|
||||
serialises ``tool_use`` blocks as human-readable text rather than
|
||||
structured ``tool_calls``. Gap messages (from DB) preserve their
|
||||
original ``tool_calls`` field. This is the same trade-off as the old
|
||||
``_compress_session_messages(session.messages)`` approach — no regression.
|
||||
|
||||
Args:
|
||||
download: The ``TranscriptDownload`` from GCS, or ``None`` when no
|
||||
transcript is available. ``content`` may be either ``bytes`` or
|
||||
``str`` (the baseline path decodes + strips before returning).
|
||||
session_messages: All messages in the session, with the current user
|
||||
turn as the last element.
|
||||
|
||||
Returns:
|
||||
A list of ``ChatMessage`` objects covering the prior conversation
|
||||
context, suitable for injection as conversation history.
|
||||
"""
|
||||
from .model import ChatMessage as _ChatMessage # runtime import
|
||||
|
||||
prior = session_messages[:-1]
|
||||
|
||||
if download is None:
|
||||
return prior
|
||||
|
||||
raw_content = download.content
|
||||
if not raw_content:
|
||||
return prior
|
||||
|
||||
# Handle both bytes (raw GCS download) and str (pre-decoded baseline path).
|
||||
if isinstance(raw_content, bytes):
|
||||
try:
|
||||
content_str: str = raw_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return prior
|
||||
else:
|
||||
content_str = raw_content
|
||||
|
||||
raw = _transcript_to_messages(content_str)
|
||||
if not raw:
|
||||
return prior
|
||||
|
||||
transcript_msgs = [
|
||||
_ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw
|
||||
]
|
||||
gap = detect_gap(download, session_messages)
|
||||
return transcript_msgs + gap
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript and its metadata from bucket storage.
|
||||
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
"""Delete CLI session JSONL and its companion .meta.json from bucket storage."""
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
try:
|
||||
await storage.delete(path)
|
||||
logger.info("[Transcript] Deleted transcript for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete transcript: %s", e)
|
||||
|
||||
# Also delete the companion .meta.json to avoid orphaned metadata.
|
||||
try:
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
await storage.delete(meta_path)
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
# Also delete the CLI native session file to prevent storage growth.
|
||||
try:
|
||||
cli_path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
@@ -1012,6 +919,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete CLI session: %s", e)
|
||||
|
||||
try:
|
||||
cli_meta_path = _build_path_from_parts(
|
||||
_cli_session_meta_path_parts(user_id, session_id), storage
|
||||
)
|
||||
await storage.delete(cli_meta_path)
|
||||
logger.info("[Transcript] Deleted CLI session meta for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete CLI session meta: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -143,6 +143,8 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
LlmModel.GROK_4_20: 5,
|
||||
LlmModel.GROK_4_20_MULTI_AGENT: 5,
|
||||
LlmModel.GROK_CODE_FAST_1: 1,
|
||||
LlmModel.KIMI_K2: 1,
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: 1,
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
@@ -31,6 +34,7 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.util.cache import cached
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
@@ -432,7 +436,7 @@ class UserCreditBase(ABC):
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
if current_balance >= ceiling_balance:
|
||||
raise ValueError(
|
||||
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
|
||||
)
|
||||
|
||||
# Single unified atomic operation for all transaction types using UserBalance
|
||||
@@ -571,7 +575,7 @@ class UserCreditBase(ABC):
|
||||
if amount < 0 and fail_insufficient_credits:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
|
||||
user_id=user_id,
|
||||
balance=current_balance,
|
||||
amount=amount,
|
||||
@@ -582,7 +586,6 @@ class UserCreditBase(ABC):
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
|
||||
async def _send_refund_notification(
|
||||
self,
|
||||
notification_request: RefundRequestData,
|
||||
@@ -734,7 +737,7 @@ class UserCredit(UserCreditBase):
|
||||
)
|
||||
if request.amount <= 0 or request.amount > transaction.amount:
|
||||
raise AssertionError(
|
||||
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
|
||||
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
|
||||
)
|
||||
|
||||
balance, _ = await self._add_transaction(
|
||||
@@ -788,12 +791,12 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
# If the user has enough balance, just let them win the dispute.
|
||||
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
|
||||
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
|
||||
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
|
||||
dispute.close()
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"Adding extra info for dispute from {user_id} for ${amount/100}"
|
||||
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
|
||||
)
|
||||
# Retrieve recent transaction history to support our evidence.
|
||||
# This provides a concise timeline that shows service usage and proper credit application.
|
||||
@@ -1237,14 +1240,23 @@ async def get_stripe_customer_id(user_id: str) -> str:
|
||||
if user.stripe_customer_id:
|
||||
return user.stripe_customer_id
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
# Race protection: two concurrent calls (e.g. user double-clicks "Upgrade",
|
||||
# or any retried request) would each pass the check above and create their
|
||||
# own Stripe Customer, leaving an orphaned billable customer in Stripe.
|
||||
# Pass an idempotency_key so Stripe collapses concurrent + retried calls
|
||||
# into the same Customer object server-side. The 24h Stripe idempotency
|
||||
# window comfortably covers any realistic in-flight retry scenario.
|
||||
customer = await run_in_threadpool(
|
||||
stripe.Customer.create,
|
||||
name=user.name or "",
|
||||
email=user.email,
|
||||
metadata={"user_id": user_id},
|
||||
idempotency_key=f"customer-create-{user_id}",
|
||||
)
|
||||
await User.prisma().update(
|
||||
where={"id": user_id}, data={"stripeCustomerId": customer.id}
|
||||
)
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
return customer.id
|
||||
|
||||
|
||||
@@ -1263,23 +1275,203 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
data={"subscriptionTier": tier},
|
||||
)
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
# Also invalidate the rate-limit tier cache so CoPilot picks up the new
|
||||
# tier immediately rather than waiting up to 5 minutes for the TTL to expire.
|
||||
from backend.copilot.rate_limit import get_user_tier # local import avoids circular
|
||||
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def cancel_stripe_subscription(user_id: str) -> None:
|
||||
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
subscriptions = stripe.Subscription.list(
|
||||
customer=customer_id, status="active", limit=10
|
||||
)
|
||||
for sub in subscriptions.auto_paging_iter():
|
||||
try:
|
||||
stripe.Subscription.cancel(sub["id"])
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"cancel_stripe_subscription: failed to cancel sub %s for user %s",
|
||||
sub["id"],
|
||||
user_id,
|
||||
async def _cancel_customer_subscriptions(
|
||||
customer_id: str,
|
||||
exclude_sub_id: str | None = None,
|
||||
at_period_end: bool = False,
|
||||
) -> int:
|
||||
"""Cancel all billable Stripe subscriptions for a customer, optionally excluding one.
|
||||
|
||||
Cancels both ``active`` and ``trialing`` subscriptions, since trialing subs will
|
||||
start billing once the trial ends and must be cleaned up on downgrade/upgrade to
|
||||
avoid double-charging or charging users who intended to cancel.
|
||||
|
||||
When ``at_period_end=True``, schedules cancellation at the end of the current
|
||||
billing period instead of cancelling immediately — the user keeps their tier
|
||||
until the period ends, then ``customer.subscription.deleted`` fires and the
|
||||
webhook downgrades them to FREE.
|
||||
|
||||
Wraps every synchronous Stripe SDK call with run_in_threadpool so the async event
|
||||
loop is never blocked. Raises stripe.StripeError on list/cancel failure so callers
|
||||
that need strict consistency can react; cleanup callers can catch and log instead.
|
||||
|
||||
Returns the number of subscriptions cancelled/scheduled for cancellation.
|
||||
"""
|
||||
# Query active and trialing separately; Stripe's list API accepts a single status
|
||||
# filter at a time (no OR), and we explicitly want to skip canceled/incomplete/
|
||||
# past_due subs rather than filter them out client-side via status="all".
|
||||
seen_ids: set[str] = set()
|
||||
for status in ("active", "trialing"):
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status=status, limit=10
|
||||
)
|
||||
# Iterate only the first page (up to 10); avoid auto_paging_iter which would
|
||||
# trigger additional sync HTTP calls inside the event loop.
|
||||
if subscriptions.has_more:
|
||||
logger.error(
|
||||
"_cancel_customer_subscriptions: customer %s has more than 10 %s"
|
||||
" subscriptions — only the first page was processed; remaining"
|
||||
" subscriptions were NOT cancelled",
|
||||
customer_id,
|
||||
status,
|
||||
)
|
||||
for sub in subscriptions.data:
|
||||
sub_id = sub["id"]
|
||||
if exclude_sub_id and sub_id == exclude_sub_id:
|
||||
continue
|
||||
if sub_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(sub_id)
|
||||
if at_period_end:
|
||||
await run_in_threadpool(
|
||||
stripe.Subscription.modify, sub_id, cancel_at_period_end=True
|
||||
)
|
||||
else:
|
||||
await run_in_threadpool(stripe.Subscription.cancel, sub_id)
|
||||
return len(seen_ids)
|
||||
|
||||
|
||||
async def cancel_stripe_subscription(user_id: str) -> bool:
|
||||
"""Schedule cancellation of all active/trialing Stripe subscriptions at period end.
|
||||
|
||||
The subscription stays active until the end of the billing period so the user
|
||||
keeps their tier for the time they already paid for. The ``customer.subscription.deleted``
|
||||
webhook fires at period end and downgrades the DB tier to FREE.
|
||||
|
||||
Returns True if at least one subscription was found and scheduled for cancellation,
|
||||
False if the customer had no active/trialing subscriptions (e.g., admin-granted tier
|
||||
with no associated Stripe subscription). When False, the caller should update the
|
||||
DB tier directly since no webhook will fire to do it.
|
||||
|
||||
Raises stripe.StripeError if any modification fails, so the caller can avoid
|
||||
updating the DB tier when Stripe is inconsistent.
|
||||
"""
|
||||
# Guard: only proceed if the user already has a Stripe customer ID. Calling
|
||||
# get_stripe_customer_id for a user who has never had a paid subscription would
|
||||
# create an orphaned, potentially-billable Stripe Customer object — we avoid that
|
||||
# by returning False early so the caller can downgrade the DB tier directly.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return False
|
||||
|
||||
customer_id = user.stripe_customer_id
|
||||
try:
|
||||
cancelled_count = await _cancel_customer_subscriptions(
|
||||
customer_id, at_period_end=True
|
||||
)
|
||||
return cancelled_count > 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"cancel_stripe_subscription: Stripe error while cancelling subs for user %s",
|
||||
user_id,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
|
||||
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
|
||||
|
||||
Fetches the user's active Stripe subscription to determine how many seconds
|
||||
remain in the current billing period, then calculates the unused portion of
|
||||
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
|
||||
is found.
|
||||
"""
|
||||
if monthly_cost_cents <= 0:
|
||||
return 0
|
||||
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
|
||||
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
|
||||
# orphaned customer on every billing-page load for those users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return 0
|
||||
try:
|
||||
customer_id = user.stripe_customer_id
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status="active", limit=1
|
||||
)
|
||||
if not subscriptions.data:
|
||||
return 0
|
||||
sub = subscriptions.data[0]
|
||||
period_start: int = sub["current_period_start"]
|
||||
period_end: int = sub["current_period_end"]
|
||||
now = int(time.time())
|
||||
total_seconds = period_end - period_start
|
||||
remaining_seconds = max(period_end - now, 0)
|
||||
if total_seconds <= 0:
|
||||
return 0
|
||||
return int(monthly_cost_cents * remaining_seconds / total_seconds)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"get_proration_credit_cents: failed to compute proration for user %s",
|
||||
user_id,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
async def modify_stripe_subscription_for_tier(
|
||||
user_id: str, tier: SubscriptionTier
|
||||
) -> bool:
|
||||
"""Modify an existing Stripe subscription to a new paid tier using proration.
|
||||
|
||||
For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing
|
||||
subscription is preferable to cancelling + creating a new one via Checkout:
|
||||
Stripe handles proration automatically, crediting unused time on the old plan
|
||||
and charging the pro-rated amount for the new plan in the same billing cycle.
|
||||
|
||||
Returns:
|
||||
True — a subscription was found and modified successfully.
|
||||
False — no active/trialing subscription exists (e.g. admin-granted tier or
|
||||
first-time paid signup); caller should fall back to Checkout.
|
||||
|
||||
Raises stripe.StripeError on API failures so callers can propagate a 502.
|
||||
Raises ValueError when no Stripe price ID is configured for the tier.
|
||||
"""
|
||||
price_id = await get_subscription_price_id(tier)
|
||||
if not price_id:
|
||||
raise ValueError(f"No Stripe price ID configured for tier {tier}")
|
||||
|
||||
# Guard: only proceed if the user already has a Stripe customer ID. Calling
|
||||
# get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier)
|
||||
# would create an orphaned customer object if the subsequent Subscription.list call
|
||||
# fails. Return False early so the API layer falls back to Checkout instead.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return False
|
||||
|
||||
customer_id = user.stripe_customer_id
|
||||
for status in ("active", "trialing"):
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status=status, limit=1
|
||||
)
|
||||
if not subscriptions.data:
|
||||
continue
|
||||
sub = subscriptions.data[0]
|
||||
sub_id = sub["id"]
|
||||
items = sub.get("items", {}).get("data", [])
|
||||
if not items:
|
||||
continue
|
||||
item_id = items[0]["id"]
|
||||
await run_in_threadpool(
|
||||
stripe.Subscription.modify,
|
||||
sub_id,
|
||||
items=[{"id": item_id, "price": price_id}],
|
||||
proration_behavior="create_prorations",
|
||||
)
|
||||
logger.info(
|
||||
"modify_stripe_subscription_for_tier: modified sub %s for user %s → %s",
|
||||
sub_id,
|
||||
user_id,
|
||||
tier,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
@@ -1291,8 +1483,19 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
return AutoTopUpConfig.model_validate(user.top_up_config)
|
||||
|
||||
|
||||
@cached(ttl_seconds=60, maxsize=8, cache_none=False)
|
||||
async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
|
||||
"""Return Stripe Price ID for a tier from LaunchDarkly. None = not configured."""
|
||||
"""Return Stripe Price ID for a tier from LaunchDarkly, cached for 60 seconds.
|
||||
|
||||
Price IDs are LaunchDarkly flag values that change only at deploy time.
|
||||
Caching for 60 seconds avoids hitting the LD SDK on every webhook delivery
|
||||
and every GET /credits/subscription page load (called 2x per request).
|
||||
|
||||
``cache_none=False`` prevents a transient LD failure from caching ``None``
|
||||
and blocking subscription upgrades for the full 60-second TTL window.
|
||||
A tier with no configured flag (FREE, ENTERPRISE) returns ``None`` from an
|
||||
O(1) dict lookup before hitting LD, so the extra LD call is never made.
|
||||
"""
|
||||
flag_map = {
|
||||
SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO,
|
||||
SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS,
|
||||
@@ -1300,7 +1503,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
|
||||
flag = flag_map.get(tier)
|
||||
if flag is None:
|
||||
return None
|
||||
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
|
||||
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
|
||||
return price_id if isinstance(price_id, str) and price_id else None
|
||||
|
||||
|
||||
@@ -1315,7 +1518,8 @@ async def create_subscription_checkout(
|
||||
if not price_id:
|
||||
raise ValueError(f"Subscription not available for tier {tier.value}")
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
session = stripe.checkout.Session.create(
|
||||
session = await run_in_threadpool(
|
||||
stripe.checkout.Session.create,
|
||||
customer=customer_id,
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
@@ -1323,26 +1527,111 @@ async def create_subscription_checkout(
|
||||
cancel_url=cancel_url,
|
||||
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
|
||||
)
|
||||
return session.url or ""
|
||||
if not session.url:
|
||||
# An empty checkout URL for a paid upgrade is always an error; surfacing it
|
||||
# as ValueError means the API handler returns 422 instead of silently
|
||||
# redirecting the client to an empty URL.
|
||||
raise ValueError("Stripe did not return a checkout session URL")
|
||||
return session.url
|
||||
|
||||
|
||||
async def _cleanup_stale_subscriptions(customer_id: str, new_sub_id: str) -> None:
|
||||
"""Best-effort cancel of any active subs for the customer other than new_sub_id.
|
||||
|
||||
Called from the webhook handler after a new subscription becomes active. Failures
|
||||
are logged but not raised so a transient Stripe error doesn't crash the webhook —
|
||||
a periodic reconciliation job is the intended backstop for persistent drift.
|
||||
|
||||
NOTE: until that reconcile job lands, a failure here means the user is silently
|
||||
billed for two simultaneous subscriptions. The error log below is intentionally
|
||||
`logger.exception` so it surfaces in Sentry with the customer/sub IDs needed to
|
||||
manually reconcile, and the metric `stripe_stale_subscription_cleanup_failed`
|
||||
is bumped so on-call can alert on persistent drift.
|
||||
TODO(#stripe-reconcile-job): replace this best-effort cleanup with a periodic
|
||||
reconciliation job that queries Stripe for customers with >1 active sub.
|
||||
"""
|
||||
try:
|
||||
await _cancel_customer_subscriptions(customer_id, exclude_sub_id=new_sub_id)
|
||||
except stripe.StripeError:
|
||||
# Use exception() (not warning) so this surfaces as an error in Sentry —
|
||||
# any failure here means a paid-to-paid upgrade may have left the user
|
||||
# with two simultaneous active subscriptions.
|
||||
logger.exception(
|
||||
"stripe_stale_subscription_cleanup_failed: customer=%s new_sub=%s —"
|
||||
" user may be billed for two simultaneous subscriptions; manual"
|
||||
" reconciliation required",
|
||||
customer_id,
|
||||
new_sub_id,
|
||||
)
|
||||
|
||||
|
||||
async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
"""Update User.subscriptionTier from a Stripe subscription object."""
|
||||
customer_id = stripe_subscription["customer"]
|
||||
"""Update User.subscriptionTier from a Stripe subscription object.
|
||||
|
||||
Expected shape of stripe_subscription (subset of Stripe's Subscription object):
|
||||
customer: str — Stripe customer ID
|
||||
status: str — "active" | "trialing" | "canceled" | ...
|
||||
id: str — Stripe subscription ID
|
||||
items.data[].price.id: str — Stripe price ID identifying the tier
|
||||
"""
|
||||
customer_id = stripe_subscription.get("customer")
|
||||
if not customer_id:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: missing 'customer' field in event, "
|
||||
"skipping (keys: %s)",
|
||||
list(stripe_subscription.keys()),
|
||||
)
|
||||
return
|
||||
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
|
||||
if not user:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: no user for customer %s", customer_id
|
||||
)
|
||||
return
|
||||
# Cross-check: if the subscription carries a metadata.user_id (set during
|
||||
# Checkout Session creation), verify it matches the user we found via
|
||||
# stripeCustomerId. A mismatch indicates a customer↔user mapping
|
||||
# inconsistency — updating the wrong user's tier would be a data-corruption
|
||||
# bug, so we log loudly and bail out. Absence of metadata.user_id (e.g.
|
||||
# subscriptions created outside the Checkout flow) is not an error — we
|
||||
# simply skip the check and proceed with the customer-ID-based lookup.
|
||||
metadata = stripe_subscription.get("metadata") or {}
|
||||
metadata_user_id = metadata.get("user_id") if isinstance(metadata, dict) else None
|
||||
if metadata_user_id and metadata_user_id != user.id:
|
||||
logger.error(
|
||||
"sync_subscription_from_stripe: metadata.user_id=%s does not match"
|
||||
" user.id=%s found via stripeCustomerId=%s — refusing to update tier"
|
||||
" to avoid corrupting the wrong user's subscription state",
|
||||
metadata_user_id,
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
# ENTERPRISE tiers are admin-managed. Never let a Stripe webhook flip an
|
||||
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
|
||||
# a self-service Stripe sub, it's a data-consistency issue for an operator,
|
||||
# not something the webhook should automatically "fix".
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.FREE
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
|
||||
" for user %s (customer %s); event status=%s",
|
||||
user.id,
|
||||
customer_id,
|
||||
stripe_subscription.get("status", ""),
|
||||
)
|
||||
return
|
||||
status = stripe_subscription.get("status", "")
|
||||
new_sub_id = stripe_subscription.get("id", "")
|
||||
if status in ("active", "trialing"):
|
||||
price_id = ""
|
||||
items = stripe_subscription.get("items", {}).get("data", [])
|
||||
if items:
|
||||
price_id = items[0].get("price", {}).get("id", "")
|
||||
pro_price = await get_subscription_price_id(SubscriptionTier.PRO)
|
||||
biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS)
|
||||
pro_price, biz_price = await asyncio.gather(
|
||||
get_subscription_price_id(SubscriptionTier.PRO),
|
||||
get_subscription_price_id(SubscriptionTier.BUSINESS),
|
||||
)
|
||||
if price_id and pro_price and price_id == pro_price:
|
||||
tier = SubscriptionTier.PRO
|
||||
elif price_id and biz_price and price_id == biz_price:
|
||||
@@ -1359,10 +1648,206 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
)
|
||||
return
|
||||
else:
|
||||
# A subscription was cancelled or ended. DO NOT unconditionally downgrade
|
||||
# to FREE — Stripe does not guarantee webhook delivery order, so a
|
||||
# `customer.subscription.deleted` for the OLD sub can arrive after we've
|
||||
# already processed `customer.subscription.created` for a new paid sub.
|
||||
# Ask Stripe whether any OTHER active/trialing subs exist for this
|
||||
# customer; if they do, keep the user's current tier (the other sub's
|
||||
# own event will/has already set the correct tier).
|
||||
try:
|
||||
other_subs_active, other_subs_trialing = await asyncio.gather(
|
||||
run_in_threadpool(
|
||||
stripe.Subscription.list,
|
||||
customer=customer_id,
|
||||
status="active",
|
||||
limit=10,
|
||||
),
|
||||
run_in_threadpool(
|
||||
stripe.Subscription.list,
|
||||
customer=customer_id,
|
||||
status="trialing",
|
||||
limit=10,
|
||||
),
|
||||
)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: could not verify other active"
|
||||
" subs for customer %s on cancel event %s; preserving current"
|
||||
" tier to avoid an unsafe downgrade",
|
||||
customer_id,
|
||||
new_sub_id,
|
||||
)
|
||||
return
|
||||
# Filter out the cancelled subscription to check if other active subs
|
||||
# exist. When new_sub_id is empty (malformed event with no 'id' field),
|
||||
# we cannot safely exclude any sub — preserve current tier to avoid
|
||||
# an unsafe downgrade on a malformed webhook payload.
|
||||
if not new_sub_id:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: cancel event missing 'id' field"
|
||||
" for customer %s; preserving current tier",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
other_active_ids = {sub["id"] for sub in other_subs_active.data} - {new_sub_id}
|
||||
other_trialing_ids = {sub["id"] for sub in other_subs_trialing.data} - {
|
||||
new_sub_id
|
||||
}
|
||||
still_has_active_sub = bool(other_active_ids or other_trialing_ids)
|
||||
if still_has_active_sub:
|
||||
logger.info(
|
||||
"sync_subscription_from_stripe: sub %s cancelled but customer %s"
|
||||
" still has another active sub; keeping tier %s",
|
||||
new_sub_id,
|
||||
customer_id,
|
||||
current_tier.value,
|
||||
)
|
||||
return
|
||||
tier = SubscriptionTier.FREE
|
||||
# Idempotency: Stripe retries webhooks on delivery failure, and several event
|
||||
# types map to the same final tier. Skip the DB write + cache invalidation
|
||||
# when the tier is already correct to avoid redundant writes on replay.
|
||||
if current_tier == tier:
|
||||
return
|
||||
# When a new subscription becomes active (e.g. paid-to-paid tier upgrade
|
||||
# via a fresh Checkout Session), cancel any OTHER active subscriptions for
|
||||
# the same customer so the user isn't billed twice. We do this in the
|
||||
# webhook rather than the API handler so that abandoning the checkout
|
||||
# doesn't leave the user without a subscription.
|
||||
# IMPORTANT: this runs AFTER the idempotency check above so that webhook
|
||||
# replays for an already-applied event do NOT trigger another cleanup round
|
||||
# (which could otherwise cancel a legitimately new subscription the user
|
||||
# signed up for between the original event and its replay).
|
||||
if status in ("active", "trialing") and new_sub_id:
|
||||
# NOTE: paid-to-paid upgrade race (e.g. PRO → BUSINESS):
|
||||
# _cleanup_stale_subscriptions cancels the old PRO sub before
|
||||
# set_subscription_tier writes BUSINESS to the DB. If Stripe delivers
|
||||
# the PRO `customer.subscription.deleted` event concurrently and it
|
||||
# processes after the PRO cancel but before set_subscription_tier
|
||||
# commits, the user could momentarily appear as FREE in the DB.
|
||||
# This window is very short in practice (two sequential awaits),
|
||||
# but is a known limitation of the current webhook-driven approach.
|
||||
# A future improvement would be to write the new tier first, then
|
||||
# cancel the old sub.
|
||||
await _cleanup_stale_subscriptions(customer_id, new_sub_id)
|
||||
await set_subscription_tier(user.id, tier)
|
||||
|
||||
|
||||
async def handle_subscription_payment_failure(invoice: dict) -> None:
|
||||
"""Handle a failed Stripe subscription payment.
|
||||
|
||||
Tries to cover the invoice amount from the user's credit balance.
|
||||
|
||||
- Balance sufficient → deduct from balance, then pay the Stripe invoice so
|
||||
Stripe stops retrying it. The sub stays intact and the user keeps their tier.
|
||||
- Balance insufficient → cancel Stripe sub immediately, downgrade to FREE.
|
||||
Cancelling here avoids further Stripe retries on an invoice we cannot cover.
|
||||
"""
|
||||
customer_id = invoice.get("customer")
|
||||
if not customer_id:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: missing customer in invoice; skipping"
|
||||
)
|
||||
return
|
||||
|
||||
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
|
||||
if not user:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: no user found for customer %s",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.FREE
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
|
||||
" (customer %s) — tier is admin-managed",
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
amount_due: int = invoice.get("amount_due", 0)
|
||||
sub_id: str = invoice.get("subscription", "")
|
||||
invoice_id: str = invoice.get("id", "")
|
||||
|
||||
if amount_due <= 0:
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: amount_due=%d for user %s;"
|
||||
" nothing to deduct",
|
||||
amount_due,
|
||||
user.id,
|
||||
)
|
||||
return
|
||||
|
||||
credit_model = UserCredit()
|
||||
try:
|
||||
await credit_model._add_transaction(
|
||||
user_id=user.id,
|
||||
amount=-amount_due,
|
||||
transaction_type=CreditTransactionType.SUBSCRIPTION,
|
||||
fail_insufficient_credits=True,
|
||||
# Use invoice_id as the idempotency key so that Stripe webhook retries
|
||||
# (e.g. on a transient stripe.Invoice.pay failure) do not double-charge.
|
||||
transaction_key=invoice_id or None,
|
||||
metadata=SafeJson(
|
||||
{
|
||||
"stripe_customer_id": customer_id,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"reason": "subscription_payment_failure_covered_by_balance",
|
||||
}
|
||||
),
|
||||
)
|
||||
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
|
||||
# system stops retrying it — without this call Stripe would retry automatically
|
||||
# and re-trigger this webhook, causing double-deductions each retry cycle.
|
||||
if invoice_id:
|
||||
try:
|
||||
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: balance deducted for user"
|
||||
" %s but failed to mark invoice %s as paid; Stripe may retry",
|
||||
user.id,
|
||||
invoice_id,
|
||||
)
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: deducted %d cents from balance"
|
||||
" for user %s; Stripe invoice %s paid, sub %s intact, tier preserved",
|
||||
amount_due,
|
||||
user.id,
|
||||
invoice_id,
|
||||
sub_id,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# Balance insufficient — cancel Stripe subscription first, then downgrade DB.
|
||||
# Order matters: if we downgrade the DB first and the Stripe cancel fails, the
|
||||
# user is permanently stuck on FREE while Stripe continues billing them.
|
||||
# Cancelling Stripe first is safe: if the DB write then fails, the webhook
|
||||
# customer.subscription.deleted will fire and correct the tier eventually.
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: insufficient balance for user %s;"
|
||||
" cancelling Stripe sub %s then downgrading to FREE",
|
||||
user.id,
|
||||
sub_id,
|
||||
)
|
||||
try:
|
||||
await _cancel_customer_subscriptions(customer_id)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
|
||||
" for user %s (customer %s); skipping tier downgrade to avoid"
|
||||
" inconsistency — Stripe may continue retrying the invoice",
|
||||
sub_id,
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
await set_subscription_tier(user.id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -73,6 +73,31 @@ def _get_redis() -> Redis:
|
||||
return r
|
||||
|
||||
|
||||
class _MissingType:
|
||||
"""Singleton sentinel type — distinct from ``None`` (a valid cached value).
|
||||
|
||||
Using a dedicated class (instead of ``Any = object()``) lets mypy prove
|
||||
that comparisons ``result is _MISSING`` narrow the type correctly and
|
||||
prevents accidental use of the sentinel where a real value is expected.
|
||||
"""
|
||||
|
||||
_instance: "_MissingType | None" = None
|
||||
|
||||
def __new__(cls) -> "_MissingType":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<MISSING>"
|
||||
|
||||
|
||||
# Sentinel returned by ``_get_from_memory`` / ``_get_from_redis`` to mean
|
||||
# "no entry exists" — distinct from a cached ``None`` value, which is a
|
||||
# valid result for callers that opt into caching it.
|
||||
_MISSING = _MissingType()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedValue:
|
||||
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
|
||||
@@ -160,6 +185,7 @@ def cached(
|
||||
ttl_seconds: int,
|
||||
shared_cache: bool = False,
|
||||
refresh_ttl_on_get: bool = False,
|
||||
cache_none: bool = True,
|
||||
) -> Callable[[Callable[P, R]], CachedFunction[P, R]]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
@@ -172,6 +198,10 @@ def cached(
|
||||
ttl_seconds: Time to live in seconds. Required - entries must expire.
|
||||
shared_cache: If True, use Redis for cross-process caching
|
||||
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
|
||||
cache_none: If True (default) ``None`` is cached like any other value.
|
||||
Set to ``False`` for functions that return ``None`` to signal a
|
||||
transient error and should be re-tried on the next call without
|
||||
poisoning the cache (e.g. external API calls that may fail).
|
||||
|
||||
Returns:
|
||||
Decorated function with caching capabilities
|
||||
@@ -184,6 +214,12 @@ def cached(
|
||||
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
async def fetch_external(id: str) -> dict | None:
|
||||
# Returns None on transient error — won't be stored,
|
||||
# next call retries instead of returning the stale None.
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
|
||||
@@ -191,9 +227,14 @@ def cached(
|
||||
cache_storage: dict[tuple, CachedValue] = {}
|
||||
_event_loop_locks: dict[Any, asyncio.Lock] = {}
|
||||
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
def _get_from_redis(redis_key: str) -> Any:
|
||||
"""Get value from Redis, optionally refreshing TTL.
|
||||
|
||||
Returns the cached value (which may be ``None``) on a hit, or the
|
||||
module-level ``_MISSING`` sentinel on a miss / corrupt entry.
|
||||
Callers must compare with ``is _MISSING`` so cached ``None`` values
|
||||
are not mistaken for misses.
|
||||
|
||||
Values are expected to carry an HMAC-SHA256 prefix for integrity
|
||||
verification. Unsigned (legacy) or tampered entries are silently
|
||||
discarded and treated as cache misses, so the caller recomputes and
|
||||
@@ -213,11 +254,11 @@ def cached(
|
||||
f"for {func_name}, discarding entry: "
|
||||
"possible tampering or legacy unsigned value"
|
||||
)
|
||||
return None
|
||||
return _MISSING
|
||||
return pickle.loads(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis error during cache check for {func_name}: {e}")
|
||||
return None
|
||||
return _MISSING
|
||||
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set HMAC-signed pickled value in Redis with TTL."""
|
||||
@@ -227,8 +268,13 @@ def cached(
|
||||
except Exception as e:
|
||||
logger.error(f"Redis error storing cache for {func_name}: {e}")
|
||||
|
||||
def _get_from_memory(key: tuple) -> Any | None:
|
||||
"""Get value from in-memory cache, checking TTL."""
|
||||
def _get_from_memory(key: tuple) -> Any:
|
||||
"""Get value from in-memory cache, checking TTL.
|
||||
|
||||
Returns the cached value (which may be ``None``) on a hit, or the
|
||||
``_MISSING`` sentinel on a miss / TTL expiry. See
|
||||
``_get_from_redis`` for the rationale.
|
||||
"""
|
||||
if key in cache_storage:
|
||||
cached_data = cache_storage[key]
|
||||
if time.time() - cached_data.timestamp < ttl_seconds:
|
||||
@@ -236,7 +282,7 @@ def cached(
|
||||
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
|
||||
)
|
||||
return cached_data.result
|
||||
return None
|
||||
return _MISSING
|
||||
|
||||
def _set_to_memory(key: tuple, value: Any) -> None:
|
||||
"""Set value in in-memory cache with timestamp."""
|
||||
@@ -270,11 +316,11 @@ def cached(
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
@@ -282,22 +328,24 @@ def cached(
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
# Store result (skip ``None`` if the caller opted out of
|
||||
# caching it — used for transient-error sentinels).
|
||||
if cache_none or result is not None:
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
@@ -315,11 +363,11 @@ def cached(
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
@@ -327,22 +375,24 @@ def cached(
|
||||
# Double-check: another thread might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
# Store result (skip ``None`` if the caller opted out of
|
||||
# caching it — used for transient-error sentinels).
|
||||
if cache_none or result is not None:
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1223,3 +1223,123 @@ class TestCacheHMAC:
|
||||
assert call_count == 2
|
||||
|
||||
legacy_test_fn.cache_clear()
|
||||
|
||||
|
||||
class TestCacheNoneHandling:
|
||||
"""Tests for the ``cache_none`` parameter on the @cached decorator.
|
||||
|
||||
Sentry bug PRRT_kwDOJKSTjM56RTEu (HIGH): the cache previously could not
|
||||
distinguish "no entry" from "entry is None", so any function returning
|
||||
``None`` was effectively re-executed on every call. The fix is a
|
||||
sentinel-based check inside the wrappers, plus an opt-out
|
||||
``cache_none=False`` flag for callers that *want* errors to retry.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_none_is_cached_by_default(self):
|
||||
"""With ``cache_none=True`` (default), cached ``None`` is returned
|
||||
from the cache instead of triggering re-execution."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
async def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Second call should hit the cache, not re-execute.
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Different argument is a different cache key — re-executes.
|
||||
assert await maybe_none(2) is None
|
||||
assert call_count == 2
|
||||
|
||||
def test_sync_none_is_cached_by_default(self):
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
assert maybe_none(1) is None
|
||||
assert maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_none_false_skips_storing_none(self):
|
||||
"""``cache_none=False`` skips storing ``None`` so transient errors
|
||||
are retried on the next call instead of poisoning the cache."""
|
||||
call_count = 0
|
||||
results: list[int | None] = [None, None, 42]
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
async def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
result = results[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
# First call: returns None, NOT stored.
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same key: re-executes (None wasn't cached).
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 2
|
||||
|
||||
# Third call: returns 42, this time it IS stored.
|
||||
assert await maybe_none(1) == 42
|
||||
assert call_count == 3
|
||||
|
||||
# Fourth call: cache hit on the stored 42.
|
||||
assert await maybe_none(1) == 42
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_cache_none_false_skips_storing_none(self):
|
||||
call_count = 0
|
||||
results: list[int | None] = [None, 99]
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
result = results[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
assert maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# None was not stored — re-executes.
|
||||
assert maybe_none(1) == 99
|
||||
assert call_count == 2
|
||||
|
||||
# 99 IS stored — no re-execution.
|
||||
assert maybe_none(1) == 99
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_none_is_cached_by_default(self):
|
||||
"""Shared (Redis) cache also properly returns cached ``None`` values."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def maybe_none_redis(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
maybe_none_redis.cache_clear()
|
||||
|
||||
assert await maybe_none_redis(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
assert await maybe_none_redis(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
maybe_none_redis.cache_clear()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
@@ -101,6 +102,12 @@ async def _fetch_user_context_data(user_id: str) -> Context:
|
||||
"""
|
||||
builder = Context.builder(user_id).kind("user").anonymous(True)
|
||||
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
|
||||
return builder.build()
|
||||
|
||||
try:
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
|
||||
@@ -88,17 +88,19 @@ async def cmd_download(session_ids: list[str]) -> None:
|
||||
print(f"[{sid[:12]}] Not found in GCS")
|
||||
continue
|
||||
|
||||
content_str = (
|
||||
dl.content.decode("utf-8") if isinstance(dl.content, bytes) else dl.content
|
||||
)
|
||||
out = _transcript_path(sid)
|
||||
with open(out, "w") as f:
|
||||
f.write(dl.content)
|
||||
f.write(content_str)
|
||||
|
||||
lines = len(dl.content.strip().split("\n"))
|
||||
lines = len(content_str.strip().split("\n"))
|
||||
meta = {
|
||||
"session_id": sid,
|
||||
"user_id": user_id,
|
||||
"message_count": dl.message_count,
|
||||
"uploaded_at": dl.uploaded_at,
|
||||
"transcript_bytes": len(dl.content),
|
||||
"transcript_bytes": len(content_str),
|
||||
"transcript_lines": lines,
|
||||
}
|
||||
with open(_meta_path(sid), "w") as f:
|
||||
@@ -106,7 +108,7 @@ async def cmd_download(session_ids: list[str]) -> None:
|
||||
|
||||
print(
|
||||
f"[{sid[:12]}] Saved: {lines} entries, "
|
||||
f"{len(dl.content)} bytes, msg_count={dl.message_count}"
|
||||
f"{len(content_str)} bytes, msg_count={dl.message_count}"
|
||||
)
|
||||
print("\nDone. Run 'load' command to import into local dev environment.")
|
||||
|
||||
@@ -227,7 +229,7 @@ async def cmd_load(session_ids: list[str]) -> None:
|
||||
await upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=sid,
|
||||
content=content,
|
||||
content=content.encode("utf-8"),
|
||||
message_count=msg_count,
|
||||
)
|
||||
print(f"[{sid[:12]}] Stored transcript in local workspace storage")
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Unit tests for the transcript watermark (message_count) fix.
|
||||
|
||||
The bug: upload used message_count=len(session.messages) (DB count). When a
|
||||
prior turn's GCS upload failed silently, the JSONL on GCS was stale (e.g.
|
||||
covered only T1-T12) but the meta.json watermark matched the full DB count
|
||||
(e.g. 46). The next turn's gap-fill check (transcript_msg_count < msg_count-1)
|
||||
never triggered, so the model silently lost context for the skipped turns.
|
||||
|
||||
The fix: watermark = previous_coverage + 2 (current user+asst pair) when
|
||||
use_resume=True and transcript_msg_count > 0. This ensures the watermark
|
||||
reflects the JSONL content, not the DB count.
|
||||
|
||||
These tests exercise _build_query_message directly to verify that gap-fill
|
||||
triggers with the corrected watermark but NOT with the inflated (buggy) one.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.service import _build_query_message
|
||||
|
||||
|
||||
def _make_messages(n_pairs: int, *, current_user: str = "current") -> list[MagicMock]:
|
||||
"""Build a flat list of n_pairs*2 alternating user/asst messages, plus
|
||||
one trailing user message for the *current* turn."""
|
||||
msgs: list[MagicMock] = []
|
||||
for i in range(n_pairs):
|
||||
u = MagicMock()
|
||||
u.role = "user"
|
||||
u.content = f"user message {i}"
|
||||
a = MagicMock()
|
||||
a.role = "assistant"
|
||||
a.content = f"assistant response {i}"
|
||||
msgs.extend([u, a])
|
||||
# Current turn's user message
|
||||
cur = MagicMock()
|
||||
cur.role = "user"
|
||||
cur.content = current_user
|
||||
msgs.append(cur)
|
||||
return msgs
|
||||
|
||||
|
||||
def _make_session(messages: list[MagicMock]) -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.messages = messages
|
||||
return session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gap_fill_triggers_for_stale_jsonl():
|
||||
"""Scenario: T1-T12 in JSONL (watermark=24), DB has T1-T22+Test (46 msgs).
|
||||
|
||||
With the FIX: 'Test' uploaded watermark=26 (T12's 24 + 2 for 'Test').
|
||||
Next turn (T24) downloads watermark=26, DB has 47.
|
||||
Gap check: 26 < 47-1=46 → TRUE → gap fills T14-T23.
|
||||
"""
|
||||
# T23 turns in DB (46 messages) + T24 user = 47
|
||||
msgs = _make_messages(23, current_user="memory test - recall all")
|
||||
assert len(msgs) == 47
|
||||
|
||||
session = _make_session(msgs)
|
||||
|
||||
# Watermark as uploaded by the FIX: T12 covered 24, 'Test' +2 = 26
|
||||
result_msg, _ = await _build_query_message(
|
||||
current_message="memory test - recall all",
|
||||
session=session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=26,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
assert "<conversation_history>" in result_msg, (
|
||||
"Expected gap-fill to inject <conversation_history> when "
|
||||
"watermark=26 < msg_count-1=46"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_gap_fill_when_watermark_is_current():
|
||||
"""When the JSONL is fully current (watermark = DB-1), no gap injected."""
|
||||
# T23 turns in DB (46 messages) + T24 user = 47
|
||||
msgs = _make_messages(23, current_user="next message")
|
||||
session = _make_session(msgs)
|
||||
|
||||
result_msg, _ = await _build_query_message(
|
||||
current_message="next message",
|
||||
session=session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=46, # current — no gap
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
assert (
|
||||
"<conversation_history>" not in result_msg
|
||||
), "No gap-fill expected when watermark is current"
|
||||
assert result_msg == "next message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inflated_watermark_suppresses_gap_fill():
|
||||
"""Documents the original bug: inflated watermark suppresses gap-fill.
|
||||
|
||||
'Test' uploaded watermark=len(session.messages)=46 even though only 26
|
||||
messages are in the JSONL. Next turn: 46 < 47-1=46 → FALSE → no gap fill.
|
||||
"""
|
||||
msgs = _make_messages(23, current_user="memory test")
|
||||
session = _make_session(msgs)
|
||||
|
||||
# Buggy watermark: inflated to DB count
|
||||
result_msg, _ = await _build_query_message(
|
||||
current_message="memory test",
|
||||
session=session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=46, # inflated — suppresses gap fill
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
assert (
|
||||
"<conversation_history>" not in result_msg
|
||||
), "With inflated watermark, gap-fill is suppressed — this documents the bug"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixed_watermark_fills_same_gap():
|
||||
"""Same scenario but with the FIXED watermark triggers gap-fill."""
|
||||
msgs = _make_messages(23, current_user="memory test")
|
||||
session = _make_session(msgs)
|
||||
|
||||
result_msg, _ = await _build_query_message(
|
||||
current_message="memory test",
|
||||
session=session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=26, # fixed watermark
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
assert (
|
||||
"<conversation_history>" in result_msg
|
||||
), "With fixed watermark=26, gap-fill triggers and injects missing turns"
|
||||
@@ -155,6 +155,7 @@
|
||||
"@types/twemoji": "13.1.2",
|
||||
"@vitejs/plugin-react": "5.1.2",
|
||||
"@vitest/coverage-v8": "4.0.17",
|
||||
"agentation": "3.0.2",
|
||||
"axe-playwright": "2.2.2",
|
||||
"chromatic": "13.3.3",
|
||||
"concurrently": "9.2.1",
|
||||
|
||||
19
autogpt_platform/frontend/pnpm-lock.yaml
generated
19
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -376,6 +376,9 @@ importers:
|
||||
'@vitest/coverage-v8':
|
||||
specifier: 4.0.17
|
||||
version: 4.0.17(vitest@4.0.17(@opentelemetry/api@1.9.0)(@types/node@24.10.0)(happy-dom@20.3.4)(jiti@2.6.1)(jsdom@27.4.0)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(terser@5.44.1)(yaml@2.8.2))
|
||||
agentation:
|
||||
specifier: 3.0.2
|
||||
version: 3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
axe-playwright:
|
||||
specifier: 2.2.2
|
||||
version: 2.2.2(playwright@1.56.1)
|
||||
@@ -4119,6 +4122,17 @@ packages:
|
||||
resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==}
|
||||
engines: {node: '>= 14'}
|
||||
|
||||
agentation@3.0.2:
|
||||
resolution: {integrity: sha512-iGzBxFVTuZEIKzLY6AExSLAQH6i6SwxV4pAu7v7m3X6bInZ7qlZXAwrEqyc4+EfP4gM7z2RXBF6SF4DeH0f2lA==}
|
||||
peerDependencies:
|
||||
react: '>=18.0.0'
|
||||
react-dom: '>=18.0.0'
|
||||
peerDependenciesMeta:
|
||||
react:
|
||||
optional: true
|
||||
react-dom:
|
||||
optional: true
|
||||
|
||||
ai@6.0.134:
|
||||
resolution: {integrity: sha512-YalNEaavld/kE444gOcsMKXdVVRGEe0SK77fAFcWYcqLg+a7xKnEet8bdfrEAJTfnMjj01rhgrIL10903w1a5Q==}
|
||||
engines: {node: '>=18'}
|
||||
@@ -13119,6 +13133,11 @@ snapshots:
|
||||
agent-base@7.1.4:
|
||||
optional: true
|
||||
|
||||
agentation@3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
optionalDependencies:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
ai@6.0.134(zod@3.25.76):
|
||||
dependencies:
|
||||
'@ai-sdk/gateway': 3.0.77(zod@3.25.76)
|
||||
|
||||
@@ -110,7 +110,7 @@ export const Flow = () => {
|
||||
event.preventDefault();
|
||||
}}
|
||||
maxZoom={2}
|
||||
minZoom={0.1}
|
||||
minZoom={0.05}
|
||||
onDragOver={onDragOver}
|
||||
onDrop={onDrop}
|
||||
nodesDraggable={!isLocked}
|
||||
|
||||
@@ -93,6 +93,7 @@ export function CopilotPage() {
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
forwardPaginated,
|
||||
// Mobile drawer
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
@@ -217,6 +218,7 @@ export function CopilotPage() {
|
||||
hasMoreMessages={hasMoreMessages}
|
||||
isLoadingMore={isLoadingMore}
|
||||
onLoadMore={loadMore}
|
||||
forwardPaginated={forwardPaginated}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
historicalDurations={historicalDurations}
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useChatSession } from "../useChatSession";
|
||||
|
||||
const mockUseGetV2GetSession = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
useGetV2GetSession: (...args: unknown[]) => mockUseGetV2GetSession(...args),
|
||||
usePostV2CreateSession: () => ({ mutateAsync: vi.fn(), isPending: false }),
|
||||
getGetV2GetSessionQueryKey: (id: string) => ["session", id],
|
||||
getGetV2ListSessionsQueryKey: () => ["sessions"],
|
||||
}));
|
||||
|
||||
vi.mock("@tanstack/react-query", () => ({
|
||||
useQueryClient: () => ({
|
||||
invalidateQueries: vi.fn(),
|
||||
setQueryData: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("nuqs", () => ({
|
||||
parseAsString: { withDefault: (v: unknown) => v },
|
||||
useQueryState: () => ["sess-1", vi.fn()],
|
||||
}));
|
||||
|
||||
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
|
||||
convertChatSessionMessagesToUiMessages: vi.fn(() => ({
|
||||
messages: [],
|
||||
historicalDurations: new Map(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("../helpers", () => ({
|
||||
resolveSessionDryRun: vi.fn(() => false),
|
||||
}));
|
||||
|
||||
vi.mock("@sentry/nextjs", () => ({
|
||||
captureException: vi.fn(),
|
||||
}));
|
||||
|
||||
function makeQueryResult(data: object | null) {
|
||||
return {
|
||||
data: data ? { status: 200, data } : undefined,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
isFetching: false,
|
||||
refetch: vi.fn(),
|
||||
};
|
||||
}
|
||||
|
||||
describe("useChatSession — newestSequence and forwardPaginated", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("returns null / false when no session data", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(makeQueryResult(null));
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.newestSequence).toBeNull();
|
||||
expect(result.current.forwardPaginated).toBe(false);
|
||||
});
|
||||
|
||||
it("returns newestSequence from session data", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: 99,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.newestSequence).toBe(99);
|
||||
});
|
||||
|
||||
it("returns null for newestSequence when field is missing", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: false,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: null,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.newestSequence).toBeNull();
|
||||
});
|
||||
|
||||
it("returns forwardPaginated=true when session is forward-paginated", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
newest_sequence: 49,
|
||||
forward_paginated: true,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.forwardPaginated).toBe(true);
|
||||
});
|
||||
|
||||
it("returns forwardPaginated=false when session is backward-paginated", () => {
|
||||
mockUseGetV2GetSession.mockReturnValue(
|
||||
makeQueryResult({
|
||||
messages: [],
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 50,
|
||||
newest_sequence: 99,
|
||||
forward_paginated: false,
|
||||
active_stream: null,
|
||||
}),
|
||||
);
|
||||
const { result } = renderHook(() => useChatSession());
|
||||
expect(result.current.forwardPaginated).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,202 @@
|
||||
import { act, renderHook, waitFor } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useCopilotPage } from "../useCopilotPage";
|
||||
|
||||
const mockUseChatSession = vi.fn();
|
||||
const mockUseCopilotStream = vi.fn();
|
||||
const mockUseLoadMoreMessages = vi.fn();
|
||||
|
||||
vi.mock("../useChatSession", () => ({
|
||||
useChatSession: (...args: unknown[]) => mockUseChatSession(...args),
|
||||
}));
|
||||
vi.mock("../useCopilotStream", () => ({
|
||||
useCopilotStream: (...args: unknown[]) => mockUseCopilotStream(...args),
|
||||
}));
|
||||
vi.mock("../useLoadMoreMessages", () => ({
|
||||
useLoadMoreMessages: (...args: unknown[]) => mockUseLoadMoreMessages(...args),
|
||||
}));
|
||||
vi.mock("../useCopilotNotifications", () => ({
|
||||
useCopilotNotifications: () => undefined,
|
||||
}));
|
||||
vi.mock("../useWorkflowImportAutoSubmit", () => ({
|
||||
useWorkflowImportAutoSubmit: () => undefined,
|
||||
}));
|
||||
vi.mock("../store", () => ({
|
||||
useCopilotUIStore: () => ({
|
||||
sessionToDelete: null,
|
||||
setSessionToDelete: vi.fn(),
|
||||
isDrawerOpen: false,
|
||||
setDrawerOpen: vi.fn(),
|
||||
copilotChatMode: "chat",
|
||||
copilotLlmModel: null,
|
||||
isDryRun: false,
|
||||
}),
|
||||
}));
|
||||
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
|
||||
concatWithAssistantMerge: (a: unknown[], b: unknown[]) => [...a, ...b],
|
||||
}));
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
useDeleteV2DeleteSession: () => ({ mutate: vi.fn(), isPending: false }),
|
||||
useGetV2ListSessions: () => ({ data: undefined, isLoading: false }),
|
||||
getGetV2ListSessionsQueryKey: () => ["sessions"],
|
||||
}));
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
toast: vi.fn(),
|
||||
}));
|
||||
vi.mock("@/lib/direct-upload", () => ({
|
||||
uploadFileDirect: vi.fn(),
|
||||
}));
|
||||
vi.mock("@/lib/hooks/useBreakpoint", () => ({
|
||||
useBreakpoint: () => "lg",
|
||||
}));
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: () => ({ isUserLoading: false, isLoggedIn: true }),
|
||||
}));
|
||||
vi.mock("@tanstack/react-query", () => ({
|
||||
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
|
||||
}));
|
||||
vi.mock("@/services/feature-flags/use-get-flag", () => ({
|
||||
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
|
||||
useGetFlag: () => false,
|
||||
}));
|
||||
|
||||
function makeBaseChatSession(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
sessionId: "sess-1",
|
||||
setSessionId: vi.fn(),
|
||||
hydratedMessages: [],
|
||||
rawSessionMessages: [],
|
||||
historicalDurations: new Map(),
|
||||
hasActiveStream: false,
|
||||
hasMoreMessages: false,
|
||||
oldestSequence: null,
|
||||
newestSequence: null,
|
||||
forwardPaginated: false,
|
||||
isLoadingSession: false,
|
||||
isSessionError: false,
|
||||
createSession: vi.fn(),
|
||||
isCreatingSession: false,
|
||||
refetchSession: vi.fn(),
|
||||
sessionDryRun: false,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeBaseCopilotStream(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
messages: [],
|
||||
sendMessage: vi.fn(),
|
||||
stop: vi.fn(),
|
||||
status: "ready",
|
||||
error: undefined,
|
||||
isReconnecting: false,
|
||||
isSyncing: false,
|
||||
isUserStoppingRef: { current: false },
|
||||
rateLimitMessage: null,
|
||||
dismissRateLimit: vi.fn(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeBaseLoadMore(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
pagedMessages: [],
|
||||
hasMore: false,
|
||||
isLoadingMore: false,
|
||||
loadMore: vi.fn(),
|
||||
resetPaged: vi.fn(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("useCopilotPage — forwardPaginated message ordering", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("prepends pagedMessages before currentMessages when forwardPaginated=false", () => {
|
||||
const pagedMsg = { id: "paged", role: "user" };
|
||||
const currentMsg = { id: "current", role: "assistant" };
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: false }),
|
||||
);
|
||||
mockUseCopilotStream.mockReturnValue(
|
||||
makeBaseCopilotStream({ messages: [currentMsg] }),
|
||||
);
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useCopilotPage());
|
||||
|
||||
// Backward: pagedMessages (older) come first
|
||||
expect(result.current.messages[0]).toEqual(pagedMsg);
|
||||
expect(result.current.messages[1]).toEqual(currentMsg);
|
||||
});
|
||||
|
||||
it("appends pagedMessages after currentMessages when forwardPaginated=true", () => {
|
||||
const pagedMsg = { id: "paged", role: "assistant" };
|
||||
const currentMsg = { id: "current", role: "user" };
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: true }),
|
||||
);
|
||||
mockUseCopilotStream.mockReturnValue(
|
||||
makeBaseCopilotStream({ messages: [currentMsg] }),
|
||||
);
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useCopilotPage());
|
||||
|
||||
// Forward: currentMessages (beginning of session) come first
|
||||
expect(result.current.messages[0]).toEqual(currentMsg);
|
||||
expect(result.current.messages[1]).toEqual(pagedMsg);
|
||||
});
|
||||
|
||||
it("calls resetPaged when forwardPaginated transitions false→true with paged messages", async () => {
|
||||
const mockResetPaged = vi.fn();
|
||||
const pagedMsg = { id: "paged", role: "user" };
|
||||
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: false }),
|
||||
);
|
||||
mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream());
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({
|
||||
pagedMessages: [pagedMsg],
|
||||
resetPaged: mockResetPaged,
|
||||
}),
|
||||
);
|
||||
|
||||
const { rerender } = renderHook(() => useCopilotPage());
|
||||
|
||||
// Simulate session completing — forwardPaginated flips to true
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: true }),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
rerender();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockResetPaged).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it("does not call resetPaged when forwardPaginated is already true on mount", () => {
|
||||
const mockResetPaged = vi.fn();
|
||||
mockUseChatSession.mockReturnValue(
|
||||
makeBaseChatSession({ forwardPaginated: true }),
|
||||
);
|
||||
mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream());
|
||||
mockUseLoadMoreMessages.mockReturnValue(
|
||||
makeBaseLoadMore({ pagedMessages: [], resetPaged: mockResetPaged }),
|
||||
);
|
||||
|
||||
renderHook(() => useCopilotPage());
|
||||
|
||||
expect(mockResetPaged).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,568 @@
|
||||
import { act, renderHook, waitFor } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useLoadMoreMessages } from "../useLoadMoreMessages";
|
||||
|
||||
const mockGetV2GetSession = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
getV2GetSession: (...args: unknown[]) => mockGetV2GetSession(...args),
|
||||
}));
|
||||
|
||||
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
|
||||
convertChatSessionMessagesToUiMessages: vi.fn(() => ({ messages: [] })),
|
||||
extractToolOutputsFromRaw: vi.fn(() => []),
|
||||
}));
|
||||
|
||||
const BASE_ARGS = {
|
||||
sessionId: "sess-1",
|
||||
initialOldestSequence: 0,
|
||||
initialNewestSequence: 49,
|
||||
initialHasMore: true,
|
||||
forwardPaginated: true,
|
||||
initialPageRawMessages: [],
|
||||
};
|
||||
|
||||
function makeSuccessResponse(overrides: {
|
||||
messages?: unknown[];
|
||||
has_more_messages?: boolean;
|
||||
oldest_sequence?: number;
|
||||
newest_sequence?: number;
|
||||
}) {
|
||||
return {
|
||||
status: 200,
|
||||
data: {
|
||||
messages: overrides.messages ?? [],
|
||||
has_more_messages: overrides.has_more_messages ?? false,
|
||||
oldest_sequence: overrides.oldest_sequence ?? 0,
|
||||
newest_sequence: overrides.newest_sequence ?? 49,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe("useLoadMoreMessages", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("initialises with empty pagedMessages and correct cursors", () => {
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("resetPaged clears paged state and sets hasMore=false during transition", () => {
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
// hasMore must be false during transition to prevent forward loadMore
|
||||
// from firing on the now-active session before forwardPaginated updates.
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("resetPaged exposes a fresh loadMore via incremented epoch", () => {
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
// Just verify resetPaged is callable and doesn't throw.
|
||||
expect(() => {
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
it("resets all state on sessionId change", () => {
|
||||
const { result, rerender } = renderHook(
|
||||
(props) => useLoadMoreMessages(props),
|
||||
{ initialProps: BASE_ARGS },
|
||||
);
|
||||
|
||||
rerender({
|
||||
...BASE_ARGS,
|
||||
sessionId: "sess-2",
|
||||
initialOldestSequence: 10,
|
||||
initialNewestSequence: 59,
|
||||
initialHasMore: false,
|
||||
});
|
||||
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
describe("loadMore — forward pagination", () => {
|
||||
it("calls getV2GetSession with after_sequence and updates newestSequence", async () => {
|
||||
const rawMsg = { role: "user", content: "hi", sequence: 50 };
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [rawMsg],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 49 }),
|
||||
);
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("sets hasMore=false when response has no more messages", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false }),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("is a no-op when hasMore is false", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
initialHasMore: false,
|
||||
forwardPaginated: true,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — backward pagination", () => {
|
||||
it("calls getV2GetSession with before_sequence", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [{ role: "user", content: "old", sequence: 0 }],
|
||||
has_more_messages: false,
|
||||
oldest_sequence: 0,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: 50,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ before_sequence: 50 }),
|
||||
);
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — error handling", () => {
|
||||
it("does not set hasMore=false on first error", async () => {
|
||||
mockGetV2GetSession.mockRejectedValueOnce(new Error("network error"));
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// First error — hasMore still true
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) errors", async () => {
|
||||
mockGetV2GetSession.mockRejectedValue(new Error("network error"));
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
// Reset the in-flight guard between calls
|
||||
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
|
||||
}
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("ignores non-200 response and increments error count", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce({ status: 500, data: {} });
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// One error, not yet at threshold — hasMore still true
|
||||
expect(result.current.hasMore).toBe(true);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
|
||||
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) non-200 responses", async () => {
|
||||
mockGetV2GetSession.mockResolvedValue({ status: 503, data: {} });
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
|
||||
}
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("discards in-flight error when epoch changes mid-flight (resetPaged called)", async () => {
|
||||
let rejectRequest!: (e: Error) => void;
|
||||
mockGetV2GetSession.mockReturnValueOnce(
|
||||
new Promise((_, rej) => {
|
||||
rejectRequest = rej;
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
act(() => {
|
||||
result.current.loadMore();
|
||||
});
|
||||
|
||||
// Reset epoch mid-flight
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
|
||||
// Reject the in-flight request — stale error should be discarded
|
||||
await act(async () => {
|
||||
rejectRequest(new Error("network error"));
|
||||
});
|
||||
|
||||
// State unchanged: no hasMore=false, no errorCount, isLoadingMore cleared
|
||||
expect(result.current.hasMore).toBe(false); // false from resetPaged
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — forward pagination cursor advancement", () => {
|
||||
it("advances newestSequence after a successful forward load", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [{ role: "user", content: "hi", sequence: 50 }],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// A second loadMore should use after_sequence: 99 (advanced cursor)
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 99 }),
|
||||
);
|
||||
});
|
||||
|
||||
it("does not regress newestSequence when parent refetches after pages loaded", async () => {
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [{ role: "user", content: "msg", sequence: 50 }],
|
||||
has_more_messages: true,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
(props) => useLoadMoreMessages(props),
|
||||
{ initialProps: { ...BASE_ARGS, forwardPaginated: true } },
|
||||
);
|
||||
|
||||
// Load one page — newestSequence advances to 99
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Parent refetches with a lower newest_sequence (49) — should NOT regress cursor
|
||||
rerender({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: 49,
|
||||
});
|
||||
|
||||
// Next loadMore should still use the advanced cursor (99)
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).toHaveBeenLastCalledWith(
|
||||
"sess-1",
|
||||
expect.objectContaining({ after_sequence: 99 }),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
|
||||
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
|
||||
// Single load of 2001 messages exceeds the limit in one shot.
|
||||
// This avoids relying on cross-render closure staleness: estimatedTotal =
|
||||
// pagedRawMessages.length (0, fresh) + 2001 = 2001 >= 2000 → hasMore=false.
|
||||
const args = { ...BASE_ARGS, forwardPaginated: false };
|
||||
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 2001 }, (_, i) => ({
|
||||
role: "user",
|
||||
content: `msg ${i}`,
|
||||
sequence: i,
|
||||
})),
|
||||
has_more_messages: true,
|
||||
oldest_sequence: 0,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(args));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(result.current.hasMore).toBe(false);
|
||||
});
|
||||
|
||||
it("forward truncation keeps first MAX_OLDER_MESSAGES items (not last)", async () => {
|
||||
// 1990 messages already paged; load 20 more forward — total 2010 > 2000.
|
||||
// Forward truncation must keep slice(0, 2000), not slice(-2000),
|
||||
// to preserve the beginning of the conversation.
|
||||
const forwardNearLimitArgs = {
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: 49,
|
||||
initialOldestSequence: 0,
|
||||
initialHasMore: true,
|
||||
};
|
||||
|
||||
const { result } = renderHook((props) => useLoadMoreMessages(props), {
|
||||
initialProps: forwardNearLimitArgs,
|
||||
});
|
||||
|
||||
// First load: 1990 messages — advances newestSequence to 2039
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 1990 }, (_, i) => ({
|
||||
role: "assistant",
|
||||
content: `msg ${i + 50}`,
|
||||
sequence: i + 50,
|
||||
})),
|
||||
has_more_messages: true,
|
||||
newest_sequence: 2039,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Second load: 20 more messages pushes total to 2010 > 2000.
|
||||
// Truncation keeps seq 50..2049 (2000 items); discards seq 2050..2059 (10 items).
|
||||
// Even though the server says has_more_messages=false, hasMore stays true
|
||||
// because there are discarded items that need to be re-fetched.
|
||||
// The cursor (newestSequence) advances to 2049 — the last kept item's sequence.
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: Array.from({ length: 20 }, (_, i) => ({
|
||||
role: "assistant",
|
||||
content: `msg ${i + 2040}`,
|
||||
sequence: i + 2040,
|
||||
})),
|
||||
has_more_messages: false,
|
||||
newest_sequence: 2059,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
// Truncation occurred (2010 > 2000): hasMore=true so discarded items can be fetched.
|
||||
// Cursor advances to last kept item (seq 2049), not the server's newest (2059).
|
||||
await waitFor(() => expect(result.current.hasMore).toBe(true));
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — null cursor guard", () => {
|
||||
it("is a no-op when newestSequence is null (forwardPaginated=true)", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialNewestSequence: null,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("is a no-op when oldestSequence is null (forwardPaginated=false)", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: null,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(mockGetV2GetSession).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("pagedMessages — initialPageRawMessages extraToolOutputs", () => {
|
||||
it("calls extractToolOutputsFromRaw for backward pagination with non-empty initialPageRawMessages", async () => {
|
||||
const { extractToolOutputsFromRaw } = await import(
|
||||
"../helpers/convertChatSessionToUiMessages"
|
||||
);
|
||||
|
||||
const rawMsg = { role: "user", content: "old", sequence: 0 };
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [rawMsg],
|
||||
has_more_messages: false,
|
||||
oldest_sequence: 0,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: false,
|
||||
initialOldestSequence: 50,
|
||||
initialPageRawMessages: [{ role: "assistant", content: "response" }],
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(extractToolOutputsFromRaw).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("does NOT call extractToolOutputsFromRaw for forward pagination", async () => {
|
||||
const { extractToolOutputsFromRaw } = await import(
|
||||
"../helpers/convertChatSessionToUiMessages"
|
||||
);
|
||||
|
||||
const rawMsg = { role: "assistant", content: "hi", sequence: 50 };
|
||||
mockGetV2GetSession.mockResolvedValueOnce(
|
||||
makeSuccessResponse({
|
||||
messages: [rawMsg],
|
||||
has_more_messages: false,
|
||||
newest_sequence: 99,
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useLoadMoreMessages({
|
||||
...BASE_ARGS,
|
||||
forwardPaginated: true,
|
||||
initialPageRawMessages: [{ role: "user", content: "hello" }],
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.loadMore();
|
||||
});
|
||||
|
||||
expect(extractToolOutputsFromRaw).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadMore — epoch / stale-response guard", () => {
|
||||
it("discards response when epoch changes during flight (resetPaged called)", async () => {
|
||||
let resolveRequest!: (v: unknown) => void;
|
||||
mockGetV2GetSession.mockReturnValueOnce(
|
||||
new Promise((res) => {
|
||||
resolveRequest = res;
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
|
||||
|
||||
// Start the request without awaiting
|
||||
act(() => {
|
||||
result.current.loadMore();
|
||||
});
|
||||
|
||||
// Reset epoch mid-flight
|
||||
act(() => {
|
||||
result.current.resetPaged();
|
||||
});
|
||||
|
||||
// Now resolve the in-flight request
|
||||
await act(async () => {
|
||||
resolveRequest(
|
||||
makeSuccessResponse({ messages: [{ role: "user", content: "hi" }] }),
|
||||
);
|
||||
});
|
||||
|
||||
// Response discarded — pagedMessages stays empty, isLoadingMore stays false
|
||||
expect(result.current.pagedMessages).toHaveLength(0);
|
||||
expect(result.current.isLoadingMore).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -30,6 +30,7 @@ export interface ChatContainerProps {
|
||||
hasMoreMessages?: boolean;
|
||||
isLoadingMore?: boolean;
|
||||
onLoadMore?: () => void;
|
||||
forwardPaginated?: boolean;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
@@ -54,6 +55,7 @@ export const ChatContainer = ({
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
onLoadMore,
|
||||
forwardPaginated,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
historicalDurations,
|
||||
@@ -108,6 +110,7 @@ export const ChatContainer = ({
|
||||
hasMoreMessages={hasMoreMessages}
|
||||
isLoadingMore={isLoadingMore}
|
||||
onLoadMore={onLoadMore}
|
||||
forwardPaginated={forwardPaginated}
|
||||
onRetry={handleRetry}
|
||||
historicalDurations={historicalDurations}
|
||||
/>
|
||||
|
||||
@@ -86,11 +86,11 @@ export function ChatInput({
|
||||
title:
|
||||
next === "advanced"
|
||||
? "Switched to Advanced model"
|
||||
: "Switched to Standard model",
|
||||
: "Switched to Balanced model",
|
||||
description:
|
||||
next === "advanced"
|
||||
? "Using the highest-capability model."
|
||||
: "Using the balanced standard model.",
|
||||
: "Using the balanced default model.",
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -162,10 +162,15 @@ describe("ChatInput mode toggle", () => {
|
||||
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
|
||||
});
|
||||
|
||||
it("hides toggle button when streaming", () => {
|
||||
it("hides toggle buttons when streaming", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} isStreaming />);
|
||||
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (fast|extended thinking) mode/i),
|
||||
).toBeNull();
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (advanced|balanced|standard) model/i),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("shows mode toggle when hasSession is true and not streaming", () => {
|
||||
@@ -234,7 +239,7 @@ describe("ChatInput model toggle", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "advanced";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
|
||||
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
|
||||
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
|
||||
});
|
||||
|
||||
@@ -288,10 +293,10 @@ describe("ChatInput model toggle", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "advanced";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
|
||||
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to standard model/i),
|
||||
title: expect.stringMatching(/switched to balanced model/i),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Flask } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
|
||||
// This button is only rendered on NEW chats (no active session).
|
||||
// Once a session exists, it is hidden — the session's dry_run flag is
|
||||
@@ -14,27 +19,31 @@ interface Props {
|
||||
|
||||
export function DryRunToggleButton({ isDryRun, onToggle }: Props) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isDryRun}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isDryRun
|
||||
? "bg-amber-100 text-amber-900 hover:bg-amber-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isDryRun ? "Test mode active — click to disable" : "Enable Test mode"
|
||||
}
|
||||
title={
|
||||
isDryRun
|
||||
? "Test mode ON — new chats run agents as simulation (click to disable)"
|
||||
: "Enable Test mode — new chats will run agents as simulation"
|
||||
}
|
||||
>
|
||||
<Flask size={14} />
|
||||
{isDryRun && "Test"}
|
||||
</button>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isDryRun}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
|
||||
isDryRun
|
||||
? "text-amber-900"
|
||||
: "text-neutral-500 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={isDryRun ? "Test mode active" : "Enable Test mode"}
|
||||
>
|
||||
<Flask size={14} />
|
||||
<span className="hidden sm:inline">
|
||||
{isDryRun ? "Test mode enabled" : "Enable test mode"}
|
||||
</span>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{isDryRun
|
||||
? "Test mode on — new sessions run without performing real actions (click to turn off)."
|
||||
: "Turn on test mode to try prompts without performing real actions."}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Brain, Lightning } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import type { CopilotMode } from "../../../store";
|
||||
|
||||
interface Props {
|
||||
@@ -11,37 +16,42 @@ interface Props {
|
||||
|
||||
export function ModeToggleButton({ mode, onToggle }: Props) {
|
||||
const isExtended = mode === "extended_thinking";
|
||||
|
||||
const tooltipText = isExtended
|
||||
? "Extended Thinking — deeper reasoning (click to switch to Fast)"
|
||||
: "Fast mode — quicker responses (click to switch to Thinking)";
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isExtended}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isExtended
|
||||
? "bg-purple-100 text-purple-900 hover:bg-purple-200"
|
||||
: "bg-amber-100 text-amber-900 hover:bg-amber-200",
|
||||
)}
|
||||
aria-label={
|
||||
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
|
||||
}
|
||||
title={
|
||||
isExtended
|
||||
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
|
||||
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
|
||||
}
|
||||
>
|
||||
{isExtended ? (
|
||||
<>
|
||||
<Brain size={14} />
|
||||
Thinking
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Lightning size={14} />
|
||||
Fast
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isExtended}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"ml-2 inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
|
||||
isExtended ? "text-purple-900" : "text-amber-900",
|
||||
)}
|
||||
aria-label={
|
||||
isExtended
|
||||
? "Switch to Fast mode"
|
||||
: "Switch to Extended Thinking mode"
|
||||
}
|
||||
>
|
||||
{isExtended ? (
|
||||
<>
|
||||
<Brain size={14} />
|
||||
Thinking
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Lightning size={14} />
|
||||
Fast
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{tooltipText}</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Cpu } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import type { CopilotLlmModel } from "../../../store";
|
||||
|
||||
interface Props {
|
||||
@@ -12,27 +17,33 @@ interface Props {
|
||||
export function ModelToggleButton({ model, onToggle }: Props) {
|
||||
const isAdvanced = model === "advanced";
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isAdvanced}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isAdvanced
|
||||
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
|
||||
}
|
||||
title={
|
||||
isAdvanced
|
||||
? "Advanced model — highest capability (click to switch to Standard)"
|
||||
: "Standard model — click to switch to Advanced"
|
||||
}
|
||||
>
|
||||
<Cpu size={14} />
|
||||
{isAdvanced && "Advanced"}
|
||||
</button>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isAdvanced}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
|
||||
isAdvanced
|
||||
? "text-sky-900"
|
||||
: "text-neutral-500 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isAdvanced ? "Switch to Balanced model" : "Switch to Advanced model"
|
||||
}
|
||||
>
|
||||
<Cpu size={14} />
|
||||
<span className="hidden sm:inline">
|
||||
{isAdvanced ? "Advanced" : "Balanced"}
|
||||
</span>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{isAdvanced
|
||||
? "Using the highest-capability model (click to switch to Balanced)."
|
||||
: "Using the balanced default model (click to switch to Advanced)."}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,32 @@
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import {
|
||||
render as rtlRender,
|
||||
screen,
|
||||
fireEvent,
|
||||
cleanup,
|
||||
} from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import type { ReactElement } from "react";
|
||||
import { TooltipProvider } from "@/components/ui/tooltip";
|
||||
import { DryRunToggleButton } from "../DryRunToggleButton";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
function render(ui: ReactElement) {
|
||||
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
|
||||
}
|
||||
|
||||
// DryRunToggleButton only appears on new chats (no active session).
|
||||
// It has no readOnly/isStreaming props — those scenarios are handled by hiding
|
||||
// the button entirely at the ChatInput level when hasSession is true.
|
||||
describe("DryRunToggleButton", () => {
|
||||
it("shows Test label when isDryRun is true", () => {
|
||||
it("shows enabled label when isDryRun is true", () => {
|
||||
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
|
||||
expect(screen.getByText("Test")).toBeTruthy();
|
||||
expect(screen.getByText("Test mode enabled")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("shows no text label when isDryRun is false", () => {
|
||||
it("shows enable label when isDryRun is false", () => {
|
||||
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
|
||||
expect(screen.queryByText("Test")).toBeNull();
|
||||
expect(screen.getByText("Enable test mode")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("calls onToggle when clicked", () => {
|
||||
|
||||
@@ -1,9 +1,20 @@
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import {
|
||||
render as rtlRender,
|
||||
screen,
|
||||
fireEvent,
|
||||
cleanup,
|
||||
} from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import type { ReactElement } from "react";
|
||||
import { TooltipProvider } from "@/components/ui/tooltip";
|
||||
import { ModelToggleButton } from "../ModelToggleButton";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
function render(ui: ReactElement) {
|
||||
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
|
||||
}
|
||||
|
||||
describe("ModelToggleButton", () => {
|
||||
it("shows no text label when model is standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
@@ -31,7 +42,7 @@ describe("ModelToggleButton", () => {
|
||||
|
||||
it("sets aria-pressed=true for advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
|
||||
const btn = screen.getByLabelText("Switch to Standard model");
|
||||
const btn = screen.getByLabelText("Switch to Balanced model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -43,6 +43,10 @@ interface Props {
|
||||
hasMoreMessages?: boolean;
|
||||
isLoadingMore?: boolean;
|
||||
onLoadMore?: () => void;
|
||||
/** When true the load-more sentinel is placed at the bottom (forward
|
||||
* pagination for completed sessions). When false it is at the top
|
||||
* (backward pagination for active sessions). */
|
||||
forwardPaginated?: boolean;
|
||||
onRetry?: () => void;
|
||||
historicalDurations?: Map<string, number>;
|
||||
}
|
||||
@@ -140,11 +144,25 @@ export function LoadMoreSentinel({
|
||||
isLoading,
|
||||
messageCount,
|
||||
onLoadMore,
|
||||
rootMargin = "200px 0px 0px 0px",
|
||||
adjustScroll = true,
|
||||
forwardPaginated = false,
|
||||
}: {
|
||||
hasMore: boolean;
|
||||
isLoading: boolean;
|
||||
messageCount: number;
|
||||
onLoadMore: () => void;
|
||||
/** IntersectionObserver rootMargin. Top sentinel uses "200px 0px 0px 0px"
|
||||
* (pre-trigger when approaching from above); bottom sentinel should use
|
||||
* "0px 0px 200px 0px" (pre-trigger when approaching from below). */
|
||||
rootMargin?: string;
|
||||
/** Whether to adjust scrollTop after load to preserve visual position.
|
||||
* True for backward pagination (prepend above); false for forward
|
||||
* pagination (append below) where no adjustment is needed. */
|
||||
adjustScroll?: boolean;
|
||||
/** When true the button reads "Load newer messages" (forward pagination).
|
||||
* When false (default) it reads "Load older messages". */
|
||||
forwardPaginated?: boolean;
|
||||
}) {
|
||||
const sentinelRef = useRef<HTMLDivElement>(null);
|
||||
const onLoadMoreRef = useRef(onLoadMore);
|
||||
@@ -189,11 +207,11 @@ export function LoadMoreSentinel({
|
||||
if (autoFillRoundsRef.current >= MAX_AUTO_FILL_ROUNDS) return;
|
||||
captureAndLoad(true);
|
||||
},
|
||||
{ rootMargin: "200px 0px 0px 0px" },
|
||||
{ rootMargin },
|
||||
);
|
||||
observer.observe(sentinelRef.current);
|
||||
return () => observer.disconnect();
|
||||
}, [hasMore, isLoading, scrollRef]);
|
||||
}, [hasMore, isLoading, rootMargin, scrollRef]);
|
||||
|
||||
// After React commits new DOM nodes (prepended messages), adjust
|
||||
// scrollTop so the user stays at the same visual position.
|
||||
@@ -206,7 +224,9 @@ export function LoadMoreSentinel({
|
||||
scrollSnapshotRef.current;
|
||||
if (!el || prevHeight === 0) return;
|
||||
const delta = el.scrollHeight - prevHeight;
|
||||
if (delta > 0) {
|
||||
// Only restore scroll position for backward pagination (content prepended
|
||||
// above). Forward pagination appends below — no adjustment needed.
|
||||
if (adjustScroll && delta > 0) {
|
||||
el.scrollTop = prevTop + delta;
|
||||
}
|
||||
// Reset the auto-fill backoff whenever the container becomes
|
||||
@@ -220,7 +240,7 @@ export function LoadMoreSentinel({
|
||||
}
|
||||
scrollSnapshotRef.current = { scrollHeight: 0, scrollTop: 0 };
|
||||
autoTriggeredRef.current = false;
|
||||
}, [messageCount, scrollRef]);
|
||||
}, [adjustScroll, messageCount, scrollRef]);
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -239,7 +259,7 @@ export function LoadMoreSentinel({
|
||||
size="small"
|
||||
onClick={() => captureAndLoad(false)}
|
||||
>
|
||||
Load older messages
|
||||
{forwardPaginated ? "Load newer messages" : "Load older messages"}
|
||||
</Button>
|
||||
)
|
||||
)}
|
||||
@@ -256,6 +276,7 @@ export function ChatMessagesContainer({
|
||||
hasMoreMessages,
|
||||
isLoadingMore,
|
||||
onLoadMore,
|
||||
forwardPaginated,
|
||||
onRetry,
|
||||
historicalDurations,
|
||||
}: Props) {
|
||||
@@ -334,7 +355,7 @@ export function ChatMessagesContainer({
|
||||
}
|
||||
>
|
||||
<ConversationContent className="flex min-h-full flex-1 flex-col gap-6 px-3 py-6">
|
||||
{hasMoreMessages && onLoadMore && (
|
||||
{hasMoreMessages && onLoadMore && !forwardPaginated && (
|
||||
<LoadMoreSentinel
|
||||
hasMore={hasMoreMessages}
|
||||
isLoading={!!isLoadingMore}
|
||||
@@ -497,6 +518,17 @@ export function ChatMessagesContainer({
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
{hasMoreMessages && onLoadMore && forwardPaginated && (
|
||||
<LoadMoreSentinel
|
||||
hasMore={hasMoreMessages}
|
||||
isLoading={!!isLoadingMore}
|
||||
messageCount={messages.length}
|
||||
onLoadMore={onLoadMore}
|
||||
rootMargin="0px 0px 200px 0px"
|
||||
adjustScroll={false}
|
||||
forwardPaginated
|
||||
/>
|
||||
)}
|
||||
</ConversationContent>
|
||||
<ConversationScrollButton />
|
||||
</Conversation>
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { ChatMessagesContainer } from "../ChatMessagesContainer";
|
||||
|
||||
const mockScrollEl = {
|
||||
scrollHeight: 100,
|
||||
scrollTop: 0,
|
||||
clientHeight: 500,
|
||||
};
|
||||
|
||||
vi.mock("use-stick-to-bottom", () => ({
|
||||
useStickToBottomContext: () => ({ scrollRef: { current: mockScrollEl } }),
|
||||
Conversation: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
ConversationContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
ConversationScrollButton: () => null,
|
||||
}));
|
||||
|
||||
vi.mock("@/components/ai-elements/conversation", () => ({
|
||||
Conversation: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
ConversationContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
ConversationScrollButton: () => null,
|
||||
}));
|
||||
|
||||
vi.mock("@/components/ai-elements/message", () => ({
|
||||
Message: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
MessageContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
MessageActions: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("../components/AssistantMessageActions", () => ({
|
||||
AssistantMessageActions: () => null,
|
||||
}));
|
||||
vi.mock("../components/CopyButton", () => ({ CopyButton: () => null }));
|
||||
vi.mock("../components/CollapsedToolGroup", () => ({
|
||||
CollapsedToolGroup: () => null,
|
||||
}));
|
||||
vi.mock("../components/MessageAttachments", () => ({
|
||||
MessageAttachments: () => null,
|
||||
}));
|
||||
vi.mock("../components/MessagePartRenderer", () => ({
|
||||
MessagePartRenderer: () => null,
|
||||
}));
|
||||
vi.mock("../components/ReasoningCollapse", () => ({
|
||||
ReasoningCollapse: () => null,
|
||||
}));
|
||||
vi.mock("../components/ThinkingIndicator", () => ({
|
||||
ThinkingIndicator: () => null,
|
||||
}));
|
||||
vi.mock("../../JobStatsBar/TurnStatsBar", () => ({
|
||||
TurnStatsBar: () => null,
|
||||
}));
|
||||
vi.mock("../../JobStatsBar/useElapsedTimer", () => ({
|
||||
useElapsedTimer: () => ({ elapsedSeconds: 0 }),
|
||||
}));
|
||||
vi.mock("../../CopilotPendingReviews/CopilotPendingReviews", () => ({
|
||||
CopilotPendingReviews: () => null,
|
||||
}));
|
||||
vi.mock("../helpers", () => ({
|
||||
buildRenderSegments: () => [],
|
||||
getTurnMessages: () => [],
|
||||
parseSpecialMarkers: () => ({ markerType: null }),
|
||||
splitReasoningAndResponse: (parts: unknown[]) => ({
|
||||
reasoningParts: [],
|
||||
responseParts: parts,
|
||||
}),
|
||||
}));
|
||||
|
||||
type ObserverCallback = (entries: { isIntersecting: boolean }[]) => void;
|
||||
class MockIntersectionObserver {
|
||||
static lastCallback: ObserverCallback | null = null;
|
||||
private callback: ObserverCallback;
|
||||
constructor(cb: ObserverCallback) {
|
||||
this.callback = cb;
|
||||
MockIntersectionObserver.lastCallback = cb;
|
||||
}
|
||||
observe() {}
|
||||
disconnect() {}
|
||||
unobserve() {}
|
||||
takeRecords() {
|
||||
return [];
|
||||
}
|
||||
root = null;
|
||||
rootMargin = "";
|
||||
thresholds = [];
|
||||
}
|
||||
|
||||
const BASE_PROPS = {
|
||||
messages: [],
|
||||
status: "ready" as const,
|
||||
error: undefined,
|
||||
isLoading: false,
|
||||
sessionID: "sess-1",
|
||||
hasMoreMessages: true,
|
||||
isLoadingMore: false,
|
||||
onLoadMore: vi.fn(),
|
||||
onRetry: vi.fn(),
|
||||
};
|
||||
|
||||
describe("ChatMessagesContainer", () => {
|
||||
beforeEach(() => {
|
||||
mockScrollEl.scrollHeight = 100;
|
||||
mockScrollEl.scrollTop = 0;
|
||||
mockScrollEl.clientHeight = 500;
|
||||
MockIntersectionObserver.lastCallback = null;
|
||||
vi.stubGlobal("IntersectionObserver", MockIntersectionObserver);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("renders top sentinel when forwardPaginated is false (backward pagination)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={false} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load older messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders top sentinel when forwardPaginated is undefined (default, backward)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load older messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders bottom sentinel when forwardPaginated is true (forward pagination)", () => {
|
||||
render(<ChatMessagesContainer {...BASE_PROPS} forwardPaginated={true} />);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /load newer messages/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("hides sentinel when hasMoreMessages is false", () => {
|
||||
render(
|
||||
<ChatMessagesContainer
|
||||
{...BASE_PROPS}
|
||||
hasMoreMessages={false}
|
||||
forwardPaginated={true}
|
||||
/>,
|
||||
);
|
||||
expect(
|
||||
screen.queryByRole("button", { name: /load older messages/i }),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("hides sentinel when onLoadMore is not provided", () => {
|
||||
render(
|
||||
<ChatMessagesContainer
|
||||
{...BASE_PROPS}
|
||||
onLoadMore={undefined}
|
||||
forwardPaginated={true}
|
||||
/>,
|
||||
);
|
||||
expect(
|
||||
screen.queryByRole("button", { name: /load older messages/i }),
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -172,6 +172,36 @@ describe("LoadMoreSentinel", () => {
|
||||
expect(mockScrollEl.scrollTop).toBe(200);
|
||||
});
|
||||
|
||||
it("does NOT adjust scroll when adjustScroll=false (forward pagination)", () => {
|
||||
mockScrollEl.scrollHeight = 100;
|
||||
mockScrollEl.scrollTop = 50;
|
||||
const onLoadMore = vi.fn();
|
||||
const { rerender } = render(
|
||||
<LoadMoreSentinel
|
||||
hasMore={true}
|
||||
isLoading={false}
|
||||
messageCount={5}
|
||||
onLoadMore={onLoadMore}
|
||||
adjustScroll={false}
|
||||
/>,
|
||||
);
|
||||
// Fire observer to capture snapshot.
|
||||
MockIntersectionObserver.lastCallback?.([{ isIntersecting: true }]);
|
||||
// Simulate DOM growing from appended newer messages (forward load-more).
|
||||
mockScrollEl.scrollHeight = 300;
|
||||
rerender(
|
||||
<LoadMoreSentinel
|
||||
hasMore={true}
|
||||
isLoading={false}
|
||||
messageCount={10}
|
||||
onLoadMore={onLoadMore}
|
||||
adjustScroll={false}
|
||||
/>,
|
||||
);
|
||||
// scrollTop should remain unchanged — no jump for forward pagination.
|
||||
expect(mockScrollEl.scrollTop).toBe(50);
|
||||
});
|
||||
|
||||
it("ignores same-frame duplicate triggers until isLoading transitions", () => {
|
||||
const onLoadMore = vi.fn();
|
||||
const { rerender } = render(
|
||||
|
||||
@@ -13,6 +13,10 @@ import {
|
||||
getSuggestionThemes,
|
||||
} from "./helpers";
|
||||
import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes";
|
||||
import { PulseChips } from "../PulseChips/PulseChips";
|
||||
import { usePulseChips } from "../PulseChips/usePulseChips";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { EditNameDialog } from "./components/EditNameDialog/EditNameDialog";
|
||||
|
||||
interface Props {
|
||||
inputLayoutId: string;
|
||||
@@ -34,6 +38,8 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
|
||||
const pulseChips = usePulseChips();
|
||||
|
||||
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
|
||||
useGetV2GetSuggestedPrompts({
|
||||
@@ -75,11 +81,16 @@ export function EmptySession({
|
||||
<div className="mx-auto max-w-[52rem]">
|
||||
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
|
||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||
<EditNameDialog currentName={greetingName} />
|
||||
</Text>
|
||||
<Text variant="h3" className="mb-8 !font-normal">
|
||||
Tell me about your work — I'll find what to automate.
|
||||
</Text>
|
||||
|
||||
{isAgentBriefingEnabled && (
|
||||
<PulseChips chips={pulseChips} onChipClick={onSend} />
|
||||
)}
|
||||
|
||||
<div className="mb-6">
|
||||
<motion.div
|
||||
layoutId={inputLayoutId}
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { PencilSimpleIcon } from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Props {
|
||||
currentName: string;
|
||||
}
|
||||
|
||||
export function EditNameDialog({ currentName }: Props) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [name, setName] = useState(currentName);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const { refreshSession } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
|
||||
function handleOpenChange(open: boolean) {
|
||||
if (open) setName(currentName);
|
||||
setIsOpen(open);
|
||||
}
|
||||
|
||||
async function handleSave() {
|
||||
const trimmed = name.trim();
|
||||
if (!trimmed) return;
|
||||
|
||||
setIsSaving(true);
|
||||
try {
|
||||
const res = await fetch("/api/auth/user", {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ full_name: trimmed }),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const body = await res.json();
|
||||
toast({
|
||||
title: "Failed to update name",
|
||||
description: body.error ?? "Unknown error",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const session = await refreshSession();
|
||||
if (session?.error) {
|
||||
toast({
|
||||
title: "Name saved, but session refresh failed",
|
||||
description: session.error,
|
||||
variant: "destructive",
|
||||
});
|
||||
setIsOpen(false);
|
||||
return;
|
||||
}
|
||||
setIsOpen(false);
|
||||
toast({ title: "Name updated" });
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit display name"
|
||||
styling={{ maxWidth: "24rem" }}
|
||||
controlled={{ isOpen, set: handleOpenChange }}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<button
|
||||
type="button"
|
||||
className="ml-1 inline-flex items-center text-violet-500 transition-colors hover:text-violet-700"
|
||||
>
|
||||
<PencilSimpleIcon size={16} />
|
||||
</button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col gap-4 px-1">
|
||||
<Input
|
||||
id="display-name"
|
||||
label="Display name"
|
||||
placeholder="Your name"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
handleSave();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
variant="primary"
|
||||
onClick={handleSave}
|
||||
disabled={!name.trim() || isSaving}
|
||||
loading={isSaving}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
import {
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { EditNameDialog } from "../EditNameDialog";
|
||||
|
||||
const mockToast = vi.hoisted(() => vi.fn());
|
||||
const mockRefreshSession = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
useToast: () => ({ toast: mockToast }),
|
||||
}));
|
||||
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: () => ({
|
||||
refreshSession: mockRefreshSession,
|
||||
}),
|
||||
}));
|
||||
|
||||
function mockUpdateNameSuccess() {
|
||||
server.use(
|
||||
http.put("/api/auth/user", () => {
|
||||
return HttpResponse.json({ user: { id: "u1" } });
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
function mockUpdateNameError(message = "Network error") {
|
||||
server.use(
|
||||
http.put("/api/auth/user", () => {
|
||||
return HttpResponse.json({ error: message }, { status: 400 });
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
async function openDialogAndGetInput() {
|
||||
const trigger = screen.getByRole("button");
|
||||
fireEvent.click(trigger);
|
||||
await screen.findAllByLabelText(/display name/i);
|
||||
const inputs =
|
||||
document.querySelectorAll<HTMLInputElement>("input#display-name");
|
||||
return inputs[0];
|
||||
}
|
||||
|
||||
function getSaveButton() {
|
||||
const saves = screen.getAllByRole("button", { name: /save/i });
|
||||
return saves[0] as HTMLButtonElement;
|
||||
}
|
||||
|
||||
describe("EditNameDialog", () => {
|
||||
beforeEach(() => {
|
||||
mockToast.mockReset();
|
||||
mockRefreshSession.mockReset();
|
||||
mockRefreshSession.mockResolvedValue({ user: { id: "u1" } });
|
||||
});
|
||||
|
||||
test("opens dialog with current name prefilled", async () => {
|
||||
mockUpdateNameSuccess();
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
expect(input.value).toBe("Alice");
|
||||
});
|
||||
|
||||
test("saves name via API route and closes dialog", async () => {
|
||||
mockUpdateNameSuccess();
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
fireEvent.change(input, { target: { value: "Bob" } });
|
||||
fireEvent.click(getSaveButton());
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRefreshSession).toHaveBeenCalled();
|
||||
});
|
||||
expect(mockToast).toHaveBeenCalledWith({ title: "Name updated" });
|
||||
});
|
||||
|
||||
test("shows error toast when API returns error", async () => {
|
||||
mockUpdateNameError("Network error");
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
fireEvent.change(input, { target: { value: "Bob" } });
|
||||
fireEvent.click(getSaveButton());
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: "Failed to update name",
|
||||
description: "Network error",
|
||||
variant: "destructive",
|
||||
}),
|
||||
);
|
||||
});
|
||||
expect(mockRefreshSession).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("shows warning toast when refreshSession returns an error", async () => {
|
||||
mockUpdateNameSuccess();
|
||||
mockRefreshSession.mockResolvedValue({ error: "refresh failed" });
|
||||
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
fireEvent.change(input, { target: { value: "Bob" } });
|
||||
fireEvent.click(getSaveButton());
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: "Name saved, but session refresh failed",
|
||||
description: "refresh failed",
|
||||
variant: "destructive",
|
||||
}),
|
||||
);
|
||||
});
|
||||
expect(mockToast).not.toHaveBeenCalledWith({ title: "Name updated" });
|
||||
});
|
||||
|
||||
test("disables Save button while input is empty", async () => {
|
||||
mockUpdateNameSuccess();
|
||||
render(<EditNameDialog currentName="Alice" />);
|
||||
|
||||
const input = await openDialogAndGetInput();
|
||||
fireEvent.change(input, { target: { value: " " } });
|
||||
|
||||
expect(getSaveButton().disabled).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,93 @@
|
||||
.glassPanel {
|
||||
position: relative;
|
||||
isolation: isolate;
|
||||
}
|
||||
|
||||
.glassPanel::before {
|
||||
content: "";
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
border-radius: inherit;
|
||||
padding: 1px;
|
||||
background: conic-gradient(
|
||||
from var(--border-angle, 0deg),
|
||||
rgba(129, 120, 228, 0.08),
|
||||
rgba(129, 120, 228, 0.28),
|
||||
rgba(168, 130, 255, 0.18),
|
||||
rgba(129, 120, 228, 0.08),
|
||||
rgba(99, 102, 241, 0.24),
|
||||
rgba(129, 120, 228, 0.08)
|
||||
);
|
||||
-webkit-mask:
|
||||
linear-gradient(#000 0 0) content-box,
|
||||
linear-gradient(#000 0 0);
|
||||
mask:
|
||||
linear-gradient(#000 0 0) content-box,
|
||||
linear-gradient(#000 0 0);
|
||||
-webkit-mask-composite: xor;
|
||||
mask-composite: exclude;
|
||||
animation: rotate-border 6s linear infinite;
|
||||
pointer-events: none;
|
||||
z-index: -1;
|
||||
}
|
||||
|
||||
@property --border-angle {
|
||||
syntax: "<angle>";
|
||||
initial-value: 0deg;
|
||||
inherits: false;
|
||||
}
|
||||
|
||||
@keyframes rotate-border {
|
||||
to {
|
||||
--border-angle: 360deg;
|
||||
}
|
||||
}
|
||||
|
||||
.chip {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
@media (hover: hover) {
|
||||
.chip {
|
||||
padding-bottom: 0.9rem;
|
||||
}
|
||||
}
|
||||
|
||||
@media (hover: none) {
|
||||
.chip {
|
||||
padding-bottom: 2.25rem;
|
||||
}
|
||||
}
|
||||
|
||||
.chipActions {
|
||||
position: absolute;
|
||||
inset-inline: 0;
|
||||
bottom: 0;
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
backdrop-filter: blur(4px);
|
||||
-webkit-backdrop-filter: blur(4px);
|
||||
}
|
||||
|
||||
@media (hover: hover) {
|
||||
.chipActions {
|
||||
opacity: 0;
|
||||
transform: translateY(100%);
|
||||
transition:
|
||||
opacity 0.2s ease-out,
|
||||
transform 0.2s ease-out;
|
||||
}
|
||||
|
||||
.chip:hover .chipActions {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.chipContent {
|
||||
transition: filter 0.2s ease-out;
|
||||
}
|
||||
|
||||
.chip:hover .chipContent {
|
||||
filter: blur(2px);
|
||||
opacity: 0.5;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
ArrowRightIcon,
|
||||
EyeIcon,
|
||||
ChatCircleDotsIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
import { StatusBadge } from "@/app/(platform)/library/components/StatusBadge/StatusBadge";
|
||||
import styles from "./PulseChips.module.css";
|
||||
import type { PulseChipData } from "./types";
|
||||
|
||||
interface Props {
|
||||
chips: PulseChipData[];
|
||||
onChipClick?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
export function PulseChips({ chips, onChipClick }: Props) {
|
||||
if (chips.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${styles.glassPanel} mx-[0.6875rem] mb-5 rounded-large p-5`}
|
||||
>
|
||||
<div className="mb-3 flex items-center gap-3">
|
||||
<Text variant="body-medium" className="text-zinc-600">
|
||||
What's happening with your agents
|
||||
</Text>
|
||||
<NextLink
|
||||
href="/library"
|
||||
className="flex items-center gap-1 text-xs text-zinc-500 hover:text-zinc-700"
|
||||
>
|
||||
View all <ArrowRightIcon size={12} />
|
||||
</NextLink>
|
||||
</div>
|
||||
<div className="flex gap-2 overflow-x-auto pb-1 scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300">
|
||||
{chips.map((chip) => (
|
||||
<PulseChip key={chip.id} chip={chip} onAsk={onChipClick} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ChipProps {
|
||||
chip: PulseChipData;
|
||||
onAsk?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
function PulseChip({ chip, onAsk }: ChipProps) {
|
||||
function handleAsk() {
|
||||
const prompt = buildChipPrompt(chip);
|
||||
onAsk?.(prompt);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${styles.chip} relative flex w-[15rem] shrink-0 flex-col items-start gap-2 rounded-medium border border-zinc-100 bg-white px-3 py-2`}
|
||||
>
|
||||
<div className={`${styles.chipContent} w-full text-left`}>
|
||||
{chip.priority === "success" ? (
|
||||
<span className="inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium text-emerald-600">
|
||||
<span className="h-1.5 w-1.5 rounded-full bg-emerald-500" />
|
||||
Completed
|
||||
</span>
|
||||
) : (
|
||||
<StatusBadge status={chip.status} />
|
||||
)}
|
||||
<div className="mt-2 min-w-0">
|
||||
<Text variant="small-medium" className="truncate text-zinc-900">
|
||||
{chip.name}
|
||||
</Text>
|
||||
<Text variant="small" className="truncate text-zinc-500">
|
||||
{chip.shortMessage}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.chipActions} flex items-center justify-center gap-1.5 rounded-b-medium px-3 py-1.5`}
|
||||
>
|
||||
<NextLink
|
||||
href={`/library/agents/${chip.agentID}`}
|
||||
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
|
||||
>
|
||||
<EyeIcon size={14} />
|
||||
See
|
||||
</NextLink>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleAsk}
|
||||
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
|
||||
>
|
||||
<ChatCircleDotsIcon size={14} />
|
||||
Ask
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function buildChipPrompt(chip: PulseChipData): string {
|
||||
if (chip.priority === "success") {
|
||||
return `${chip.name} just finished a run — can you summarize what it did?`;
|
||||
}
|
||||
switch (chip.status) {
|
||||
case "error":
|
||||
return `What happened with ${chip.name}? It has an error — can you check?`;
|
||||
case "running":
|
||||
return `Give me a status update on ${chip.name} — what has it done so far?`;
|
||||
case "idle":
|
||||
return `${chip.name} hasn't run recently. Should I keep it or update and re-run it?`;
|
||||
default:
|
||||
return `Tell me about ${chip.name} — what's its current status?`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
import { describe, expect, test, vi } from "vitest";
|
||||
import { render, screen, fireEvent } from "@/tests/integrations/test-utils";
|
||||
import { PulseChips } from "../PulseChips";
|
||||
import type { PulseChipData } from "../types";
|
||||
|
||||
function makeChip(overrides: Partial<PulseChipData> = {}): PulseChipData {
|
||||
return {
|
||||
id: "chip-1",
|
||||
agentID: "agent-1",
|
||||
name: "Test Agent",
|
||||
status: "running",
|
||||
priority: "running",
|
||||
shortMessage: "Doing work…",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("PulseChips", () => {
|
||||
test("renders nothing when chips array is empty", () => {
|
||||
const { container } = render(<PulseChips chips={[]} />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
test("renders chip names and messages", () => {
|
||||
const chips = [
|
||||
makeChip({ id: "1", name: "Alpha Bot", shortMessage: "Running task A" }),
|
||||
makeChip({ id: "2", name: "Beta Bot", shortMessage: "Running task B" }),
|
||||
];
|
||||
|
||||
render(<PulseChips chips={chips} />);
|
||||
|
||||
expect(screen.getByText("Alpha Bot")).toBeDefined();
|
||||
expect(screen.getByText("Running task A")).toBeDefined();
|
||||
expect(screen.getByText("Beta Bot")).toBeDefined();
|
||||
expect(screen.getByText("Running task B")).toBeDefined();
|
||||
});
|
||||
|
||||
test("renders section heading and View all link", () => {
|
||||
render(<PulseChips chips={[makeChip()]} />);
|
||||
|
||||
expect(screen.getByText("What's happening with your agents")).toBeDefined();
|
||||
expect(screen.getByText("View all")).toBeDefined();
|
||||
});
|
||||
|
||||
test("shows Completed badge for success priority chips", () => {
|
||||
render(
|
||||
<PulseChips
|
||||
chips={[makeChip({ priority: "success", status: "idle" })]}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByText("Completed")).toBeDefined();
|
||||
});
|
||||
|
||||
test("calls onChipClick with generated prompt when Ask is clicked", () => {
|
||||
const onChipClick = vi.fn();
|
||||
render(
|
||||
<PulseChips
|
||||
chips={[
|
||||
makeChip({
|
||||
name: "Error Agent",
|
||||
status: "error",
|
||||
priority: "error",
|
||||
}),
|
||||
]}
|
||||
onChipClick={onChipClick}
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByText("Ask"));
|
||||
|
||||
expect(onChipClick).toHaveBeenCalledWith(
|
||||
"What happened with Error Agent? It has an error — can you check?",
|
||||
);
|
||||
});
|
||||
|
||||
test("generates success prompt for completed chips", () => {
|
||||
const onChipClick = vi.fn();
|
||||
render(
|
||||
<PulseChips
|
||||
chips={[
|
||||
makeChip({
|
||||
name: "Done Agent",
|
||||
priority: "success",
|
||||
status: "idle",
|
||||
}),
|
||||
]}
|
||||
onChipClick={onChipClick}
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByText("Ask"));
|
||||
|
||||
expect(onChipClick).toHaveBeenCalledWith(
|
||||
"Done Agent just finished a run — can you summarize what it did?",
|
||||
);
|
||||
});
|
||||
|
||||
test("renders See link pointing to agent detail page", () => {
|
||||
render(<PulseChips chips={[makeChip({ agentID: "agent-xyz" })]} />);
|
||||
|
||||
const seeLink = screen.getByText("See").closest("a");
|
||||
expect(seeLink?.getAttribute("href")).toBe("/library/agents/agent-xyz");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,13 @@
|
||||
import type {
|
||||
AgentStatus,
|
||||
SitrepPriority,
|
||||
} from "@/app/(platform)/library/types";
|
||||
|
||||
export interface PulseChipData {
|
||||
id: string;
|
||||
agentID: string;
|
||||
name: string;
|
||||
status: AgentStatus;
|
||||
priority: SitrepPriority;
|
||||
shortMessage: string;
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
|
||||
import { useSitrepItems } from "@/app/(platform)/library/components/SitrepItem/useSitrepItems";
|
||||
import type { PulseChipData } from "./types";
|
||||
import { useMemo } from "react";
|
||||
|
||||
export function usePulseChips(): PulseChipData[] {
|
||||
const { agents } = useLibraryAgents();
|
||||
|
||||
const sitrepItems = useSitrepItems(agents, 5);
|
||||
|
||||
return useMemo(() => {
|
||||
return sitrepItems.map((item) => ({
|
||||
id: item.id,
|
||||
agentID: item.agentID,
|
||||
name: item.agentName,
|
||||
status: item.status,
|
||||
priority: item.priority,
|
||||
shortMessage: item.message,
|
||||
}));
|
||||
}, [sitrepItems]);
|
||||
}
|
||||
@@ -6,6 +6,9 @@ import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
|
||||
import { formatCents } from "../usageHelpers";
|
||||
|
||||
export { formatCents };
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
@@ -18,10 +21,6 @@ interface Props {
|
||||
onCreditChange?: () => void;
|
||||
}
|
||||
|
||||
export function formatCents(cents: number): string {
|
||||
return `$${(cents / 100).toFixed(2)}`;
|
||||
}
|
||||
|
||||
export function RateLimitResetDialog({
|
||||
isOpen,
|
||||
onClose,
|
||||
|
||||
@@ -1,35 +1,10 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import Link from "next/link";
|
||||
import { formatCents } from "../RateLimitResetDialog/RateLimitResetDialog";
|
||||
import { formatCents, formatResetTime } from "../usageHelpers";
|
||||
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
|
||||
|
||||
export function formatResetTime(
|
||||
resetsAt: Date | string,
|
||||
now: Date = new Date(),
|
||||
): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / (1000 * 60 * 60));
|
||||
|
||||
// Under 24h: show relative time ("in 4h 23m")
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
export { formatResetTime };
|
||||
|
||||
function UsageBar({
|
||||
label,
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
export function formatCents(cents: number): string {
|
||||
return `$${(cents / 100).toFixed(2)}`;
|
||||
}
|
||||
|
||||
export function formatResetTime(
|
||||
resetsAt: Date | string,
|
||||
now: Date = new Date(),
|
||||
): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / (1000 * 60 * 60));
|
||||
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { convertChatSessionMessagesToUiMessages } from "../convertChatSessionToUiMessages";
|
||||
|
||||
const SESSION_ID = "sess-test";
|
||||
|
||||
describe("convertChatSessionMessagesToUiMessages", () => {
|
||||
it("does not drop user messages with null content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "user", content: null, sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(1);
|
||||
expect(result.messages[0].role).toBe("user");
|
||||
});
|
||||
|
||||
it("does not drop user messages with empty string content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "user", content: "", sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(1);
|
||||
expect(result.messages[0].role).toBe("user");
|
||||
});
|
||||
|
||||
it("still drops non-user messages with null content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "assistant", content: null, sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("still drops non-user messages with empty string content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "assistant", content: "", sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("includes user message with normal content", () => {
|
||||
const result = convertChatSessionMessagesToUiMessages(
|
||||
SESSION_ID,
|
||||
[{ role: "user", content: "hello", sequence: 0 }],
|
||||
{ isComplete: true },
|
||||
);
|
||||
|
||||
expect(result.messages).toHaveLength(1);
|
||||
expect(result.messages[0].role).toBe("user");
|
||||
});
|
||||
});
|
||||
@@ -253,6 +253,11 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
}
|
||||
}
|
||||
|
||||
// User messages must always be rendered, even with empty content, so the
|
||||
// initial prompt is visible when reloading a session.
|
||||
if (parts.length === 0 && msg.role === "user") {
|
||||
parts.push({ type: "text", text: "", state: "done" });
|
||||
}
|
||||
if (parts.length === 0) return;
|
||||
|
||||
// Merge consecutive assistant messages into a single UIMessage
|
||||
|
||||
@@ -86,6 +86,16 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
|
||||
return sessionQuery.data.data.oldest_sequence ?? null;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const newestSequence = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200) return null;
|
||||
return sessionQuery.data.data.newest_sequence ?? null;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
const forwardPaginated = useMemo(() => {
|
||||
if (sessionQuery.data?.status !== 200) return false;
|
||||
return !!sessionQuery.data.data.forward_paginated;
|
||||
}, [sessionQuery.data]);
|
||||
|
||||
// Memoize so the effect in useCopilotPage doesn't infinite-loop on a new
|
||||
// array reference every render. Re-derives only when query data changes.
|
||||
// When the session is complete (no active stream), mark dangling tool
|
||||
@@ -185,6 +195,8 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
|
||||
hasActiveStream,
|
||||
hasMoreMessages,
|
||||
oldestSequence,
|
||||
newestSequence,
|
||||
forwardPaginated,
|
||||
isLoadingSession: sessionQuery.isLoading,
|
||||
isSessionError: sessionQuery.isError,
|
||||
createSession,
|
||||
|
||||
@@ -56,6 +56,8 @@ export function useCopilotPage() {
|
||||
hasActiveStream,
|
||||
hasMoreMessages,
|
||||
oldestSequence,
|
||||
newestSequence,
|
||||
forwardPaginated,
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
createSession,
|
||||
@@ -84,18 +86,26 @@ export function useCopilotPage() {
|
||||
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
|
||||
});
|
||||
|
||||
const { olderMessages, hasMore, isLoadingMore, loadMore } =
|
||||
const { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged } =
|
||||
useLoadMoreMessages({
|
||||
sessionId,
|
||||
initialOldestSequence: oldestSequence,
|
||||
initialNewestSequence: newestSequence,
|
||||
initialHasMore: hasMoreMessages,
|
||||
forwardPaginated,
|
||||
initialPageRawMessages: rawSessionMessages,
|
||||
});
|
||||
|
||||
// Combine older (paginated) messages with current page messages,
|
||||
// merging consecutive assistant UIMessages at the page boundary so
|
||||
// reasoning + response parts stay in a single bubble.
|
||||
const messages = concatWithAssistantMerge(olderMessages, currentMessages);
|
||||
// Combine paginated messages with current page messages, merging consecutive
|
||||
// assistant UIMessages at the page boundary so reasoning + response parts
|
||||
// stay in a single bubble.
|
||||
// Forward pagination (completed sessions): current page is the beginning,
|
||||
// paged messages are newer pages appended after.
|
||||
// Backward pagination (active sessions): paged messages are older history
|
||||
// prepended before the current page.
|
||||
const messages = forwardPaginated
|
||||
? concatWithAssistantMerge(currentMessages, pagedMessages)
|
||||
: concatWithAssistantMerge(pagedMessages, currentMessages);
|
||||
|
||||
useCopilotNotifications(sessionId);
|
||||
|
||||
@@ -170,6 +180,23 @@ export function useCopilotPage() {
|
||||
}
|
||||
}, [sessionId, pendingMessage, sendMessage]);
|
||||
|
||||
// --- Clear backward-paginated messages when session completes ---
|
||||
// When a session transitions from active (forwardPaginated=false) to complete
|
||||
// (forwardPaginated=true), any backward-paginated older messages would be
|
||||
// appended after currentMessages instead of before, causing chronological
|
||||
// disorder. Reset paged state so the completed session renders cleanly.
|
||||
const prevForwardPaginatedRef = useRef(forwardPaginated);
|
||||
useEffect(() => {
|
||||
if (
|
||||
!prevForwardPaginatedRef.current &&
|
||||
forwardPaginated &&
|
||||
pagedMessages.length > 0
|
||||
) {
|
||||
resetPaged();
|
||||
}
|
||||
prevForwardPaginatedRef.current = forwardPaginated;
|
||||
}, [forwardPaginated, pagedMessages.length, resetPaged]);
|
||||
|
||||
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
|
||||
useWorkflowImportAutoSubmit({
|
||||
createSession,
|
||||
@@ -251,6 +278,15 @@ export function useCopilotPage() {
|
||||
isUserStoppingRef.current = false;
|
||||
|
||||
if (sessionId) {
|
||||
// When continuing a completed session that had forward-paginated history
|
||||
// loaded, the paged messages would appear in wrong position relative to
|
||||
// the new streaming turn (pagedMessages are newer pages, so they'd end
|
||||
// up after the streaming turn). Reset paged state so ordering is correct
|
||||
// during streaming; the user can reload history afterward if needed.
|
||||
if (forwardPaginated && pagedMessages.length > 0) {
|
||||
resetPaged();
|
||||
}
|
||||
|
||||
if (files && files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
try {
|
||||
@@ -397,6 +433,7 @@ export function useCopilotPage() {
|
||||
hasMoreMessages: hasMore,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
forwardPaginated,
|
||||
// Mobile drawer
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -9,7 +9,11 @@ import {
|
||||
interface UseLoadMoreMessagesArgs {
|
||||
sessionId: string | null;
|
||||
initialOldestSequence: number | null;
|
||||
initialNewestSequence: number | null;
|
||||
initialHasMore: boolean;
|
||||
/** True when the initial page was loaded from sequence 0 forward (completed
|
||||
* sessions). False when loaded newest-first (active sessions). */
|
||||
forwardPaginated: boolean;
|
||||
/** Raw messages from the initial page, used for cross-page tool output matching. */
|
||||
initialPageRawMessages: unknown[];
|
||||
}
|
||||
@@ -20,16 +24,21 @@ const MAX_OLDER_MESSAGES = 2000;
|
||||
export function useLoadMoreMessages({
|
||||
sessionId,
|
||||
initialOldestSequence,
|
||||
initialNewestSequence,
|
||||
initialHasMore,
|
||||
forwardPaginated,
|
||||
initialPageRawMessages,
|
||||
}: UseLoadMoreMessagesArgs) {
|
||||
// Store accumulated raw messages from all older pages (in ascending order).
|
||||
// Accumulated raw messages from all extra pages (ascending order).
|
||||
// Re-converting them all together ensures tool outputs are matched across
|
||||
// inter-page boundaries.
|
||||
const [olderRawMessages, setOlderRawMessages] = useState<unknown[]>([]);
|
||||
const [pagedRawMessages, setPagedRawMessages] = useState<unknown[]>([]);
|
||||
const [oldestSequence, setOldestSequence] = useState<number | null>(
|
||||
initialOldestSequence,
|
||||
);
|
||||
const [newestSequence, setNewestSequence] = useState<number | null>(
|
||||
initialNewestSequence,
|
||||
);
|
||||
const [hasMore, setHasMore] = useState(initialHasMore);
|
||||
const [isLoadingMore, setIsLoadingMore] = useState(false);
|
||||
const isLoadingMoreRef = useRef(false);
|
||||
@@ -46,7 +55,7 @@ export function useLoadMoreMessages({
|
||||
// The parent's `initialOldestSequence` drifts forward every time the
|
||||
// session query refetches (e.g. after a stream completes — see
|
||||
// `useCopilotStream` invalidation on `streaming → ready`). If we
|
||||
// wiped `olderRawMessages` every time that happened, users who had
|
||||
// wiped `pagedRawMessages` every time that happened, users who had
|
||||
// scrolled back would lose their loaded history on each new turn and
|
||||
// subsequent `loadMore` calls would fetch messages that overlap with
|
||||
// the AI SDK's retained state in `currentMessages`, producing visible
|
||||
@@ -63,8 +72,9 @@ export function useLoadMoreMessages({
|
||||
// Session changed — full reset
|
||||
prevSessionIdRef.current = sessionId;
|
||||
prevInitialOldestRef.current = initialOldestSequence;
|
||||
setOlderRawMessages([]);
|
||||
setPagedRawMessages([]);
|
||||
setOldestSequence(initialOldestSequence);
|
||||
setNewestSequence(initialNewestSequence);
|
||||
setHasMore(initialHasMore);
|
||||
setIsLoadingMore(false);
|
||||
isLoadingMoreRef.current = false;
|
||||
@@ -75,49 +85,64 @@ export function useLoadMoreMessages({
|
||||
|
||||
prevInitialOldestRef.current = initialOldestSequence;
|
||||
|
||||
// If we haven't paged back yet, mirror the parent so the first
|
||||
// If we haven't paged yet, mirror the parent so the first
|
||||
// `loadMore` starts from the correct cursor.
|
||||
if (olderRawMessages.length === 0) {
|
||||
//
|
||||
// When paged messages exist (pagedRawMessages.length > 0) we intentionally
|
||||
// do NOT update `hasMore` or `newestSequence` from the parent. A parent
|
||||
// refetch (e.g. after a new turn completes) may carry a fresh
|
||||
// `initialHasMore=true` or a larger `initialNewestSequence`, but those
|
||||
// reflect the *initial* page window, not the forward-paged window we have
|
||||
// already advanced into. Overwriting the local cursor here would cause the
|
||||
// next `loadMore` to re-fetch pages we already have. The local cursor is
|
||||
// advanced correctly inside `loadMore` itself via `setNewestSequence`.
|
||||
if (pagedRawMessages.length === 0) {
|
||||
setOldestSequence(initialOldestSequence);
|
||||
// Only regress the forward cursor if we haven't paged ahead yet —
|
||||
// otherwise a parent refetch would reset a cursor we already advanced.
|
||||
setNewestSequence((prev) =>
|
||||
prev !== null && prev > (initialNewestSequence ?? -1)
|
||||
? prev
|
||||
: initialNewestSequence,
|
||||
);
|
||||
setHasMore(initialHasMore);
|
||||
}
|
||||
}, [sessionId, initialOldestSequence, initialHasMore]);
|
||||
}, [sessionId, initialOldestSequence, initialNewestSequence, initialHasMore]);
|
||||
|
||||
// Convert all accumulated raw messages in one pass so tool outputs
|
||||
// are matched across inter-page boundaries. Initial page tool outputs
|
||||
// are included via extraToolOutputs to handle the boundary between
|
||||
// the last older page and the initial/streaming page.
|
||||
const olderMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
|
||||
// are matched across inter-page boundaries.
|
||||
// For backward pagination only: include initial page tool outputs so older
|
||||
// paged pages can match tool calls whose outputs landed in the initial page.
|
||||
// For forward pagination this is unnecessary — tool calls in newer paged
|
||||
// pages cannot have their outputs in the older initial page.
|
||||
const pagedMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
|
||||
useMemo(() => {
|
||||
if (!sessionId || olderRawMessages.length === 0) return [];
|
||||
if (!sessionId || pagedRawMessages.length === 0) return [];
|
||||
const extraToolOutputs =
|
||||
initialPageRawMessages.length > 0
|
||||
!forwardPaginated && initialPageRawMessages.length > 0
|
||||
? extractToolOutputsFromRaw(initialPageRawMessages)
|
||||
: undefined;
|
||||
return convertChatSessionMessagesToUiMessages(
|
||||
sessionId,
|
||||
olderRawMessages,
|
||||
pagedRawMessages,
|
||||
{ isComplete: true, extraToolOutputs },
|
||||
).messages;
|
||||
}, [sessionId, olderRawMessages, initialPageRawMessages]);
|
||||
}, [sessionId, pagedRawMessages, initialPageRawMessages, forwardPaginated]);
|
||||
|
||||
async function loadMore() {
|
||||
if (
|
||||
!sessionId ||
|
||||
!hasMore ||
|
||||
isLoadingMoreRef.current ||
|
||||
oldestSequence === null
|
||||
)
|
||||
return;
|
||||
if (!sessionId || !hasMore || isLoadingMoreRef.current) return;
|
||||
|
||||
const cursor = forwardPaginated ? newestSequence : oldestSequence;
|
||||
if (cursor === null) return;
|
||||
|
||||
const requestEpoch = epochRef.current;
|
||||
isLoadingMoreRef.current = true;
|
||||
setIsLoadingMore(true);
|
||||
try {
|
||||
const response = await getV2GetSession(sessionId, {
|
||||
limit: 50,
|
||||
before_sequence: oldestSequence,
|
||||
});
|
||||
const params = forwardPaginated
|
||||
? { limit: 50, after_sequence: cursor }
|
||||
: { limit: 50, before_sequence: cursor };
|
||||
const response = await getV2GetSession(sessionId, params);
|
||||
|
||||
// Discard response if session/pagination was reset while awaiting
|
||||
if (epochRef.current !== requestEpoch) return;
|
||||
@@ -136,18 +161,66 @@ export function useLoadMoreMessages({
|
||||
consecutiveErrorsRef.current = 0;
|
||||
|
||||
const newRaw = (response.data.messages ?? []) as unknown[];
|
||||
setOlderRawMessages((prev) => {
|
||||
const merged = [...newRaw, ...prev];
|
||||
// Estimate total after merge using the closure-captured pagedRawMessages.length.
|
||||
// This is a safe approximation: worst case it's one page stale (one extra load
|
||||
// allowed), but it avoids the React-18-batching pitfall where a functional
|
||||
// updater's mutations are not visible until the next render.
|
||||
const estimatedTotal = pagedRawMessages.length + newRaw.length;
|
||||
setPagedRawMessages((prev) => {
|
||||
// Forward: append to end. Backward: prepend to start.
|
||||
const merged = forwardPaginated
|
||||
? [...prev, ...newRaw]
|
||||
: [...newRaw, ...prev];
|
||||
if (merged.length > MAX_OLDER_MESSAGES) {
|
||||
return merged.slice(merged.length - MAX_OLDER_MESSAGES);
|
||||
// Backward: discard the oldest (front) items — user has scrolled far
|
||||
// back and we shed the furthest history.
|
||||
// Forward: discard the newest (tail) items — we only ever fetch
|
||||
// forward, so the tail is the most recently appended page; shedding
|
||||
// it means the sentinel stalls, which is safer than discarding the
|
||||
// beginning of the conversation the user is here to read.
|
||||
return forwardPaginated
|
||||
? merged.slice(0, MAX_OLDER_MESSAGES)
|
||||
: merged.slice(merged.length - MAX_OLDER_MESSAGES);
|
||||
}
|
||||
return merged;
|
||||
});
|
||||
setOldestSequence(response.data.oldest_sequence ?? null);
|
||||
if (newRaw.length + olderRawMessages.length >= MAX_OLDER_MESSAGES) {
|
||||
setHasMore(false);
|
||||
|
||||
if (forwardPaginated) {
|
||||
const willTruncateForward = estimatedTotal > MAX_OLDER_MESSAGES;
|
||||
if (willTruncateForward) {
|
||||
// Truncation shed the newest tail. Advance the cursor to the last KEPT
|
||||
// item's sequence so the sentinel re-fetches the discarded items next
|
||||
// time rather than jumping past them.
|
||||
// lastKeptIdx: index within newRaw of the last item that survives.
|
||||
// prev contributes pagedRawMessages.length items; total kept = MAX.
|
||||
const lastKeptIdx = MAX_OLDER_MESSAGES - 1 - pagedRawMessages.length;
|
||||
if (lastKeptIdx >= 0 && lastKeptIdx < newRaw.length) {
|
||||
const lastKeptMsg = newRaw[lastKeptIdx] as { sequence?: number };
|
||||
if (typeof lastKeptMsg?.sequence === "number") {
|
||||
setNewestSequence(lastKeptMsg.sequence);
|
||||
setHasMore(true); // Discarded items still exist — keep sentinel active
|
||||
} else {
|
||||
// Sequence unavailable — fall back; truncated items will be lost
|
||||
setNewestSequence(response.data.newest_sequence ?? null);
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
}
|
||||
} else {
|
||||
// All of newRaw was dropped (already at MAX_OLDER_MESSAGES cap).
|
||||
// Stop to avoid an infinite re-fetch loop at the display cap.
|
||||
setHasMore(false);
|
||||
}
|
||||
} else {
|
||||
setNewestSequence(response.data.newest_sequence ?? null);
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
}
|
||||
} else {
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
setOldestSequence(response.data.oldest_sequence ?? null);
|
||||
if (estimatedTotal >= MAX_OLDER_MESSAGES) {
|
||||
// Backward: accumulated MAX_OLDER_MESSAGES — stop to avoid unbounded memory.
|
||||
setHasMore(false);
|
||||
} else {
|
||||
setHasMore(!!response.data.has_more_messages);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (epochRef.current !== requestEpoch) return;
|
||||
@@ -164,5 +237,22 @@ export function useLoadMoreMessages({
|
||||
}
|
||||
}
|
||||
|
||||
return { olderMessages, hasMore, isLoadingMore, loadMore };
|
||||
function resetPaged() {
|
||||
setPagedRawMessages([]);
|
||||
setOldestSequence(initialOldestSequence);
|
||||
setNewestSequence(initialNewestSequence);
|
||||
// Set hasMore=false during the session-transition window so no loadMore
|
||||
// fires with forward pagination (after_sequence) on the now-active session.
|
||||
// The useEffect will restore hasMore from the parent after the refetch
|
||||
// completes and forwardPaginated switches to false.
|
||||
setHasMore(false);
|
||||
// Clear the loading state so the spinner doesn't stay stuck if a loadMore
|
||||
// was in flight when resetPaged was called.
|
||||
setIsLoadingMore(false);
|
||||
isLoadingMoreRef.current = false;
|
||||
consecutiveErrorsRef.current = 0;
|
||||
epochRef.current += 1;
|
||||
}
|
||||
|
||||
return { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged };
|
||||
}
|
||||
|
||||
@@ -2,14 +2,17 @@ import { Navbar } from "@/components/layout/Navbar/Navbar";
|
||||
import { NetworkStatusMonitor } from "@/services/network-status/NetworkStatusMonitor";
|
||||
import { ReactNode } from "react";
|
||||
import { AdminImpersonationBanner } from "./admin/components/AdminImpersonationBanner";
|
||||
import { AutoPilotBridgeProvider } from "@/contexts/AutoPilotBridgeContext";
|
||||
|
||||
export default function PlatformLayout({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<main className="flex h-screen w-full flex-col">
|
||||
<NetworkStatusMonitor />
|
||||
<Navbar />
|
||||
<AdminImpersonationBanner />
|
||||
<section className="flex-1">{children}</section>
|
||||
</main>
|
||||
<AutoPilotBridgeProvider>
|
||||
<main className="flex h-screen w-full flex-col">
|
||||
<NetworkStatusMonitor />
|
||||
<Navbar />
|
||||
<AdminImpersonationBanner />
|
||||
<section className="flex-1">{children}</section>
|
||||
</main>
|
||||
</AutoPilotBridgeProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -137,8 +137,10 @@ describe("LibraryPage", () => {
|
||||
user_id: "test-user",
|
||||
name: "Work Agents",
|
||||
agent_count: 3,
|
||||
subfolder_count: 0,
|
||||
color: null,
|
||||
icon: null,
|
||||
parent_id: null,
|
||||
created_at: new Date(),
|
||||
updated_at: new Date(),
|
||||
},
|
||||
@@ -147,8 +149,10 @@ describe("LibraryPage", () => {
|
||||
user_id: "test-user",
|
||||
name: "Personal",
|
||||
agent_count: 1,
|
||||
subfolder_count: 0,
|
||||
color: null,
|
||||
icon: null,
|
||||
parent_id: null,
|
||||
created_at: new Date(),
|
||||
updated_at: new Date(),
|
||||
},
|
||||
@@ -158,12 +162,14 @@ describe("LibraryPage", () => {
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
await waitForAgentsToLoad();
|
||||
|
||||
expect(await screen.findByText("Work Agents")).toBeDefined();
|
||||
expect(screen.getByText("Personal")).toBeDefined();
|
||||
expect(screen.getAllByTestId("library-folder")).toHaveLength(2);
|
||||
});
|
||||
|
||||
test("shows See runs link on agent card", async () => {
|
||||
test("shows See tasks link on agent card", async () => {
|
||||
setupHandlers({
|
||||
agents: [makeAgent({ name: "Linked Agent", can_access_graph: true })],
|
||||
});
|
||||
@@ -172,7 +178,7 @@ describe("LibraryPage", () => {
|
||||
|
||||
await screen.findByText("Linked Agent");
|
||||
|
||||
const runLinks = screen.getAllByText("See runs");
|
||||
const runLinks = screen.getAllByText("See tasks");
|
||||
expect(runLinks.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
@@ -190,7 +196,7 @@ describe("LibraryPage", () => {
|
||||
expect(importButtons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders Jump Back In when there is an active execution", async () => {
|
||||
test("renders running agent card when execution is active", async () => {
|
||||
const agent = makeAgent({
|
||||
id: "lib-1",
|
||||
graph_id: "g-1",
|
||||
@@ -218,6 +224,6 @@ describe("LibraryPage", () => {
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText("Jump Back In")).toBeDefined();
|
||||
expect(await screen.findByText("Running Agent")).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
.glassPanel {
|
||||
position: relative;
|
||||
isolation: isolate;
|
||||
}
|
||||
|
||||
.glassPanel::before {
|
||||
content: "";
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
border-radius: inherit;
|
||||
padding: 1px;
|
||||
background: conic-gradient(
|
||||
from var(--border-angle, 0deg),
|
||||
rgba(129, 120, 228, 0.04),
|
||||
rgba(129, 120, 228, 0.14),
|
||||
rgba(168, 130, 255, 0.09),
|
||||
rgba(129, 120, 228, 0.04),
|
||||
rgba(99, 102, 241, 0.12),
|
||||
rgba(129, 120, 228, 0.04)
|
||||
);
|
||||
-webkit-mask:
|
||||
linear-gradient(#000 0 0) content-box,
|
||||
linear-gradient(#000 0 0);
|
||||
mask:
|
||||
linear-gradient(#000 0 0) content-box,
|
||||
linear-gradient(#000 0 0);
|
||||
-webkit-mask-composite: xor;
|
||||
mask-composite: exclude;
|
||||
animation: rotate-border 6s linear infinite;
|
||||
pointer-events: none;
|
||||
z-index: -1;
|
||||
}
|
||||
|
||||
@property --border-angle {
|
||||
syntax: "<angle>";
|
||||
initial-value: 0deg;
|
||||
inherits: false;
|
||||
}
|
||||
|
||||
@keyframes rotate-border {
|
||||
to {
|
||||
--border-angle: 360deg;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useState } from "react";
|
||||
import type { FleetSummary, AgentStatusFilter } from "../../types";
|
||||
import { BriefingTabContent } from "./BriefingTabContent";
|
||||
import { StatsGrid } from "./StatsGrid";
|
||||
import styles from "./AgentBriefingPanel.module.css";
|
||||
|
||||
interface Props {
|
||||
summary: FleetSummary;
|
||||
agents: LibraryAgent[];
|
||||
}
|
||||
|
||||
export function AgentBriefingPanel({ summary, agents }: Props) {
|
||||
const [userTab, setUserTab] = useState<AgentStatusFilter | null>(null);
|
||||
const activeTab: AgentStatusFilter =
|
||||
userTab ?? (summary.running > 0 ? "running" : "all");
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${styles.glassPanel} min-h-[14.75rem] rounded-large bg-gradient-to-br from-indigo-50/30 via-white/90 to-purple-50/25 px-5 pb-5 pt-[1.125rem] shadow-sm backdrop-blur-md`}
|
||||
>
|
||||
<Text variant="h5">Agent Briefing</Text>
|
||||
<div className="mt-4 space-y-5">
|
||||
<StatsGrid
|
||||
summary={summary}
|
||||
activeTab={activeTab}
|
||||
onTabChange={setUserTab}
|
||||
/>
|
||||
<BriefingTabContent activeTab={activeTab} agents={agents} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,347 @@
|
||||
"use client";
|
||||
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import {
|
||||
formatResetTime,
|
||||
formatCents,
|
||||
} from "@/app/(platform)/copilot/components/usageHelpers";
|
||||
import { useResetRateLimit } from "@/app/(platform)/copilot/hooks/useResetRateLimit";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Badge } from "@/components/atoms/Badge/Badge";
|
||||
import useCredits from "@/hooks/useCredits";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useSitrepItems } from "../SitrepItem/useSitrepItems";
|
||||
import { SitrepItem } from "../SitrepItem/SitrepItem";
|
||||
import { useAgentStatusMap } from "../../hooks/useAgentStatus";
|
||||
import type { AgentStatusFilter } from "../../types";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Props {
|
||||
activeTab: AgentStatusFilter;
|
||||
agents: LibraryAgent[];
|
||||
}
|
||||
|
||||
export function BriefingTabContent({ activeTab, agents }: Props) {
|
||||
if (activeTab === "all") {
|
||||
return <UsageSection />;
|
||||
}
|
||||
|
||||
if (
|
||||
activeTab === "running" ||
|
||||
activeTab === "attention" ||
|
||||
activeTab === "completed"
|
||||
) {
|
||||
return <ExecutionListSection activeTab={activeTab} agents={agents} />;
|
||||
}
|
||||
|
||||
return <AgentListSection activeTab={activeTab} agents={agents} />;
|
||||
}
|
||||
|
||||
function UsageSection() {
|
||||
const { data: usage } = useGetV2GetCopilotUsage({
|
||||
query: {
|
||||
select: (res) => res.data as CoPilotUsageStatus,
|
||||
refetchInterval: 30000,
|
||||
staleTime: 10000,
|
||||
},
|
||||
});
|
||||
|
||||
const isBillingEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
|
||||
const { credits, fetchCredits } = useCredits({ fetchInitialCredits: true });
|
||||
const resetCost = usage?.reset_cost;
|
||||
const hasInsufficientCredits =
|
||||
credits !== null && resetCost != null && credits < resetCost;
|
||||
|
||||
if (!usage?.daily || !usage?.weekly) return null;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text variant="h5" className="text-neutral-800">
|
||||
Usage limits
|
||||
</Text>
|
||||
{usage.tier && (
|
||||
<Badge variant="info" size="small" className="bg-[rgb(224,237,255)]">
|
||||
{usage.tier.charAt(0) + usage.tier.slice(1).toLowerCase()} plan
|
||||
</Badge>
|
||||
)}
|
||||
<div className="flex-1" />
|
||||
{isBillingEnabled && (
|
||||
<Link
|
||||
href="/profile/credits"
|
||||
className="text-sm text-blue-600 hover:underline"
|
||||
>
|
||||
Manage billing
|
||||
</Link>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-4 grid grid-cols-1 gap-6 sm:grid-cols-2">
|
||||
{usage.daily.limit > 0 && (
|
||||
<UsageMeter
|
||||
label="Today"
|
||||
used={usage.daily.used}
|
||||
limit={usage.daily.limit}
|
||||
resetsAt={usage.daily.resets_at}
|
||||
/>
|
||||
)}
|
||||
{usage.weekly.limit > 0 && (
|
||||
<UsageMeter
|
||||
label="This week"
|
||||
used={usage.weekly.used}
|
||||
limit={usage.weekly.limit}
|
||||
resetsAt={usage.weekly.resets_at}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<UsageFooter
|
||||
usage={usage}
|
||||
hasInsufficientCredits={hasInsufficientCredits}
|
||||
onCreditChange={fetchCredits}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const MAX_VISIBLE = 6;
|
||||
|
||||
function ExecutionListSection({
|
||||
activeTab,
|
||||
agents,
|
||||
}: {
|
||||
activeTab: AgentStatusFilter;
|
||||
agents: LibraryAgent[];
|
||||
}) {
|
||||
const allItems = useSitrepItems(agents, 50);
|
||||
const [showAll, setShowAll] = useState(false);
|
||||
|
||||
const filtered = allItems.filter((item) => {
|
||||
if (activeTab === "running") return item.priority === "running";
|
||||
if (activeTab === "attention") return item.priority === "error";
|
||||
if (activeTab === "completed") return item.priority === "success";
|
||||
return false;
|
||||
});
|
||||
|
||||
if (filtered.length === 0) {
|
||||
return <EmptyMessage tab={activeTab} />;
|
||||
}
|
||||
|
||||
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
|
||||
const hasMore = filtered.length > MAX_VISIBLE;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
|
||||
{visible.map((item) => (
|
||||
<SitrepItem key={item.id} item={item} />
|
||||
))}
|
||||
</div>
|
||||
{hasMore && (
|
||||
<div className="mt-3 flex justify-center">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => setShowAll(!showAll)}
|
||||
>
|
||||
{showAll ? "Collapse" : `Show all (${filtered.length})`}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const TAB_STATUS_LABEL: Record<string, string> = {
|
||||
listening: "Waiting for trigger event",
|
||||
scheduled: "Has a scheduled run",
|
||||
idle: "No recent activity",
|
||||
};
|
||||
|
||||
function AgentListSection({
|
||||
activeTab,
|
||||
agents,
|
||||
}: {
|
||||
activeTab: AgentStatusFilter;
|
||||
agents: LibraryAgent[];
|
||||
}) {
|
||||
const [showAll, setShowAll] = useState(false);
|
||||
const statusMap = useAgentStatusMap(agents);
|
||||
|
||||
const filtered = agents.filter((agent) => {
|
||||
const status = statusMap.get(agent.graph_id)?.status;
|
||||
if (activeTab === "listening") return status === "listening";
|
||||
if (activeTab === "scheduled") return status === "scheduled";
|
||||
if (activeTab === "idle") return status === "idle";
|
||||
return false;
|
||||
});
|
||||
|
||||
if (filtered.length === 0) {
|
||||
return <EmptyMessage tab={activeTab} />;
|
||||
}
|
||||
|
||||
const status =
|
||||
activeTab === "listening"
|
||||
? ("listening" as const)
|
||||
: activeTab === "scheduled"
|
||||
? ("scheduled" as const)
|
||||
: ("idle" as const);
|
||||
|
||||
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
|
||||
const hasMore = filtered.length > MAX_VISIBLE;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
|
||||
{visible.map((agent) => (
|
||||
<SitrepItem
|
||||
key={agent.id}
|
||||
item={{
|
||||
id: agent.id,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
agentImageUrl: agent.image_url,
|
||||
priority: status,
|
||||
message: TAB_STATUS_LABEL[activeTab] ?? "",
|
||||
status,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
{hasMore && (
|
||||
<div className="mt-3 flex justify-center">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => setShowAll(!showAll)}
|
||||
>
|
||||
{showAll ? "Collapse" : `Show all (${filtered.length})`}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function UsageFooter({
|
||||
usage,
|
||||
hasInsufficientCredits,
|
||||
onCreditChange,
|
||||
}: {
|
||||
usage: CoPilotUsageStatus;
|
||||
hasInsufficientCredits: boolean;
|
||||
onCreditChange?: () => void;
|
||||
}) {
|
||||
const isDailyExhausted =
|
||||
usage.daily.limit > 0 && usage.daily.used >= usage.daily.limit;
|
||||
const isWeeklyExhausted =
|
||||
usage.weekly.limit > 0 && usage.weekly.used >= usage.weekly.limit;
|
||||
const resetCost = usage.reset_cost ?? 0;
|
||||
const { resetUsage, isPending } = useResetRateLimit({ onCreditChange });
|
||||
|
||||
const showReset =
|
||||
isDailyExhausted &&
|
||||
!isWeeklyExhausted &&
|
||||
resetCost > 0 &&
|
||||
!hasInsufficientCredits;
|
||||
|
||||
const showAddCredits =
|
||||
isDailyExhausted && !isWeeklyExhausted && hasInsufficientCredits;
|
||||
|
||||
if (!showReset && !showAddCredits) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-4 flex items-center gap-3">
|
||||
{showReset && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={() => resetUsage()}
|
||||
loading={isPending}
|
||||
>
|
||||
{isPending
|
||||
? "Resetting..."
|
||||
: `Reset daily limit for ${formatCents(resetCost)}`}
|
||||
</Button>
|
||||
)}
|
||||
{showAddCredits && (
|
||||
<Link
|
||||
href="/profile/credits"
|
||||
className="inline-flex items-center justify-center rounded-md bg-primary px-3 py-1.5 text-sm font-medium text-primary-foreground hover:bg-primary/90"
|
||||
>
|
||||
Add credits to reset
|
||||
</Link>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function UsageMeter({
|
||||
label,
|
||||
used,
|
||||
limit,
|
||||
resetsAt,
|
||||
}: {
|
||||
label: string;
|
||||
used: number;
|
||||
limit: number;
|
||||
resetsAt: Date | string;
|
||||
}) {
|
||||
if (limit <= 0) return null;
|
||||
|
||||
const rawPercent = (used / limit) * 100;
|
||||
const percent = Math.min(100, Math.round(rawPercent));
|
||||
const isHigh = percent >= 80;
|
||||
const percentLabel =
|
||||
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex items-baseline justify-between">
|
||||
<Text variant="body-medium" className="text-neutral-700">
|
||||
{label}
|
||||
</Text>
|
||||
<Text variant="body" className="tabular-nums text-neutral-500">
|
||||
{percentLabel}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||
<div
|
||||
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||
}`}
|
||||
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-baseline justify-between">
|
||||
<Text variant="small" className="tabular-nums text-neutral-500">
|
||||
{used.toLocaleString()} / {limit.toLocaleString()}
|
||||
</Text>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
Resets {formatResetTime(resetsAt)}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const EMPTY_MESSAGES: Record<string, string> = {
|
||||
running: "No agents running right now",
|
||||
attention: "No agents that need attention",
|
||||
completed: "No recently completed runs",
|
||||
listening: "No agents listening for events",
|
||||
scheduled: "No agents with scheduled runs",
|
||||
idle: "No idle agents",
|
||||
};
|
||||
|
||||
function EmptyMessage({ tab }: { tab: AgentStatusFilter }) {
|
||||
return (
|
||||
<div className="flex items-center justify-center pt-4">
|
||||
<Text variant="body-medium" className="text-zinc-600">
|
||||
{EMPTY_MESSAGES[tab] ?? "No agents in this category"}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { OverflowText } from "@/components/atoms/OverflowText/OverflowText";
|
||||
import { Emoji } from "@/components/atoms/Emoji/Emoji";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { FleetSummary, AgentStatusFilter } from "../../types";
|
||||
|
||||
interface Props {
|
||||
summary: FleetSummary;
|
||||
activeTab: AgentStatusFilter;
|
||||
onTabChange: (tab: AgentStatusFilter) => void;
|
||||
}
|
||||
|
||||
const TILES: {
|
||||
label: string;
|
||||
key: keyof FleetSummary;
|
||||
format?: (v: number) => string;
|
||||
filter: AgentStatusFilter;
|
||||
emoji: string;
|
||||
color: string;
|
||||
}[] = [
|
||||
{
|
||||
label: "Spent this month",
|
||||
key: "monthlySpend",
|
||||
format: (v) => `$${v.toLocaleString()}`,
|
||||
filter: "all",
|
||||
emoji: "💵",
|
||||
color: "text-zinc-700",
|
||||
},
|
||||
{
|
||||
label: "Running now",
|
||||
key: "running",
|
||||
filter: "running",
|
||||
emoji: "🚩",
|
||||
color: "text-blue-600",
|
||||
},
|
||||
{
|
||||
label: "Recently completed",
|
||||
key: "completed",
|
||||
filter: "completed",
|
||||
emoji: "🗃️",
|
||||
color: "text-green-600",
|
||||
},
|
||||
{
|
||||
label: "Needs attention",
|
||||
key: "error",
|
||||
filter: "attention",
|
||||
emoji: "⚠️",
|
||||
color: "text-red-500",
|
||||
},
|
||||
{
|
||||
label: "Scheduled",
|
||||
key: "scheduled",
|
||||
filter: "scheduled",
|
||||
emoji: "📅",
|
||||
color: "text-yellow-600",
|
||||
},
|
||||
{
|
||||
label: "Idle",
|
||||
key: "idle",
|
||||
filter: "idle",
|
||||
emoji: "💤",
|
||||
color: "text-zinc-400",
|
||||
},
|
||||
];
|
||||
|
||||
export function StatsGrid({ summary, activeTab, onTabChange }: Props) {
|
||||
return (
|
||||
<div className="grid grid-cols-1 gap-3 min-[450px]:grid-cols-2 sm:grid-cols-3 lg:grid-cols-6">
|
||||
{TILES.map((tile) => {
|
||||
const rawValue = summary[tile.key];
|
||||
const value = tile.format ? tile.format(rawValue) : rawValue;
|
||||
const isActive = activeTab === tile.filter;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={tile.label}
|
||||
type="button"
|
||||
onClick={() => onTabChange(tile.filter)}
|
||||
className={cn(
|
||||
"flex min-w-0 flex-col gap-1 rounded-medium border p-3 text-left shadow-md transition-all hover:shadow-lg",
|
||||
isActive
|
||||
? "border-zinc-900 bg-zinc-50"
|
||||
: "border-zinc-100 bg-white",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 items-center gap-1.5">
|
||||
<Emoji text={tile.emoji} size={18} />
|
||||
<OverflowText
|
||||
value={tile.label}
|
||||
variant="body"
|
||||
className="text-zinc-800"
|
||||
/>
|
||||
</div>
|
||||
<Text variant="h4">{value}</Text>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
"use client";
|
||||
|
||||
import type { SelectOption } from "@/components/atoms/Select/Select";
|
||||
import { Select } from "@/components/atoms/Select/Select";
|
||||
import { FunnelIcon } from "@phosphor-icons/react";
|
||||
import type { AgentStatusFilter, FleetSummary } from "../../types";
|
||||
|
||||
interface Props {
|
||||
value: AgentStatusFilter;
|
||||
onChange: (value: AgentStatusFilter) => void;
|
||||
summary: FleetSummary;
|
||||
}
|
||||
|
||||
function buildOptions(summary: FleetSummary): SelectOption[] {
|
||||
return [
|
||||
{ value: "all", label: "All Agents" },
|
||||
{ value: "running", label: `Running (${summary.running})` },
|
||||
{ value: "attention", label: `Needs Attention (${summary.error})` },
|
||||
{ value: "listening", label: `Listening (${summary.listening})` },
|
||||
{ value: "scheduled", label: `Scheduled (${summary.scheduled})` },
|
||||
{ value: "idle", label: `Idle / Stale (${summary.idle})` },
|
||||
{ value: "healthy", label: "Healthy" },
|
||||
];
|
||||
}
|
||||
|
||||
export function AgentFilterMenu({ value, onChange, summary }: Props) {
|
||||
function handleChange(val: string) {
|
||||
onChange(val as AgentStatusFilter);
|
||||
}
|
||||
|
||||
const options = buildOptions(summary);
|
||||
|
||||
return (
|
||||
<div className="flex items-center" data-testid="agent-filter-dropdown">
|
||||
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
|
||||
filter
|
||||
</span>
|
||||
<FunnelIcon className="ml-1 h-4 w-4 sm:hidden" />
|
||||
<Select
|
||||
id="agent-status-filter"
|
||||
label="Filter agents"
|
||||
hideLabel
|
||||
value={value}
|
||||
onValueChange={handleChange}
|
||||
options={options}
|
||||
size="small"
|
||||
className="ml-1 w-fit border-none !bg-transparent text-sm underline underline-offset-4 shadow-none"
|
||||
wrapperClassName="mb-0"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
EyeIcon,
|
||||
ArrowsClockwiseIcon,
|
||||
MonitorPlayIcon,
|
||||
PlayIcon,
|
||||
ArrowCounterClockwiseIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { AgentStatus } from "../../types";
|
||||
|
||||
interface Props {
|
||||
status: AgentStatus;
|
||||
agentID: string;
|
||||
executionID?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ContextualActionButton({
|
||||
status,
|
||||
agentID,
|
||||
executionID,
|
||||
className,
|
||||
}: Props) {
|
||||
const router = useRouter();
|
||||
|
||||
const config = ACTION_CONFIG[status];
|
||||
if (!config) return null;
|
||||
|
||||
const Icon = config.icon;
|
||||
|
||||
function handleClick(e: React.MouseEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
|
||||
const params = new URLSearchParams();
|
||||
if (executionID) params.set("activeItem", executionID);
|
||||
const query = params.toString();
|
||||
router.push(`/library/agents/${agentID}${query ? `?${query}` : ""}`);
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleClick}
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<Icon size={12} className="shrink-0" />
|
||||
{config.label}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
const ACTION_CONFIG: Record<
|
||||
AgentStatus,
|
||||
{ label: string; icon: typeof EyeIcon }
|
||||
> = {
|
||||
error: { label: "View error", icon: EyeIcon },
|
||||
listening: { label: "Reconnect", icon: ArrowsClockwiseIcon },
|
||||
running: { label: "Watch live", icon: MonitorPlayIcon },
|
||||
idle: { label: "Start", icon: PlayIcon },
|
||||
scheduled: { label: "Start", icon: ArrowCounterClockwiseIcon },
|
||||
};
|
||||
@@ -1,46 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ArrowRight, Lightning } from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useJumpBackIn } from "./useJumpBackIn";
|
||||
|
||||
export function JumpBackIn() {
|
||||
const { execution, isLoading } = useJumpBackIn();
|
||||
|
||||
if (isLoading || !execution) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const href = execution.libraryAgentId
|
||||
? `/library/agents/${execution.libraryAgentId}?activeTab=runs&activeItem=${execution.id}`
|
||||
: "#";
|
||||
|
||||
return (
|
||||
<div className="rounded-large bg-gradient-to-r from-zinc-200 via-zinc-200/60 to-indigo-200/50 p-[1px]">
|
||||
<div className="flex items-center justify-between rounded-large bg-[#F6F7F8] px-5 py-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
|
||||
<Lightning size={18} weight="fill" className="text-white" />
|
||||
</div>
|
||||
<div className="flex flex-col">
|
||||
<Text variant="small" className="text-zinc-500">
|
||||
{execution.statusLabel} · {execution.duration}
|
||||
</Text>
|
||||
<Text variant="body-medium" className="text-zinc-900">
|
||||
{execution.agentName}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<NextLink href={href}>
|
||||
<Button variant="secondary" size="small" className="gap-1.5">
|
||||
Jump Back In
|
||||
<ArrowRight size={16} />
|
||||
</Button>
|
||||
</NextLink>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
|
||||
import { useMemo } from "react";
|
||||
|
||||
function isActive(status: AgentExecutionStatus) {
|
||||
return (
|
||||
status === AgentExecutionStatus.RUNNING ||
|
||||
status === AgentExecutionStatus.QUEUED ||
|
||||
status === AgentExecutionStatus.REVIEW
|
||||
);
|
||||
}
|
||||
|
||||
function formatDuration(startedAt: Date | string | null | undefined): string {
|
||||
if (!startedAt) return "";
|
||||
|
||||
const start = new Date(startedAt);
|
||||
if (isNaN(start.getTime())) return "";
|
||||
|
||||
const ms = Date.now() - start.getTime();
|
||||
if (ms < 0) return "";
|
||||
|
||||
const sec = Math.floor(ms / 1000);
|
||||
if (sec < 5) return "a few seconds";
|
||||
if (sec < 60) return `${sec}s`;
|
||||
const min = Math.floor(sec / 60);
|
||||
if (min < 60) return `${min}m ${sec % 60}s`;
|
||||
const hr = Math.floor(min / 60);
|
||||
return `${hr}h ${min % 60}m`;
|
||||
}
|
||||
|
||||
function getStatusLabel(status: AgentExecutionStatus) {
|
||||
if (status === AgentExecutionStatus.RUNNING) return "Running";
|
||||
if (status === AgentExecutionStatus.QUEUED) return "Queued";
|
||||
if (status === AgentExecutionStatus.REVIEW) return "Awaiting approval";
|
||||
return "";
|
||||
}
|
||||
|
||||
export function useJumpBackIn() {
|
||||
const { data: executions, isLoading: executionsLoading } =
|
||||
useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const { agentInfoMap, isRefreshing: agentsLoading } = useLibraryAgents();
|
||||
|
||||
const activeExecution = useMemo(() => {
|
||||
if (!executions) return null;
|
||||
|
||||
const active = executions
|
||||
.filter((e) => isActive(e.status))
|
||||
.sort((a, b) => {
|
||||
const aTime = a.started_at ? new Date(a.started_at).getTime() : 0;
|
||||
const bTime = b.started_at ? new Date(b.started_at).getTime() : 0;
|
||||
return bTime - aTime;
|
||||
});
|
||||
|
||||
return active[0] ?? null;
|
||||
}, [executions]);
|
||||
|
||||
const enriched = useMemo(() => {
|
||||
if (!activeExecution) return null;
|
||||
|
||||
const info = agentInfoMap.get(activeExecution.graph_id);
|
||||
return {
|
||||
id: activeExecution.id,
|
||||
agentName: info?.name ?? "Unknown Agent",
|
||||
libraryAgentId: info?.library_agent_id,
|
||||
status: activeExecution.status,
|
||||
statusLabel: getStatusLabel(activeExecution.status),
|
||||
duration: formatDuration(activeExecution.started_at),
|
||||
};
|
||||
}, [activeExecution, agentInfoMap]);
|
||||
|
||||
return {
|
||||
execution: enriched,
|
||||
isLoading: executionsLoading || agentsLoading,
|
||||
};
|
||||
}
|
||||
@@ -8,7 +8,7 @@ interface Props {
|
||||
export function LibraryActionHeader({ setSearchTerm }: Props) {
|
||||
return (
|
||||
<>
|
||||
<div className="mb-[32px] hidden items-center justify-center gap-4 md:flex">
|
||||
<div className="mb-7 hidden items-center justify-center gap-4 md:flex">
|
||||
<LibrarySearchBar setSearchTerm={setSearchTerm} />
|
||||
<LibraryImportDialog />
|
||||
</div>
|
||||
|
||||
@@ -1,29 +1,40 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CaretCircleRightIcon } from "@phosphor-icons/react";
|
||||
import { EyeIcon, ChatCircleDotsIcon } from "@phosphor-icons/react";
|
||||
import Image from "next/image";
|
||||
import NextLink from "next/link";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { motion } from "framer-motion";
|
||||
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import Avatar, {
|
||||
AvatarFallback,
|
||||
AvatarImage,
|
||||
} from "@/components/atoms/Avatar/Avatar";
|
||||
import { Link } from "@/components/atoms/Link/Link";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { AgentCardMenu } from "./components/AgentCardMenu";
|
||||
import { FavoriteButton } from "./components/FavoriteButton";
|
||||
import { useLibraryAgentCard } from "./useLibraryAgentCard";
|
||||
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
|
||||
import { StatusBadge } from "../StatusBadge/StatusBadge";
|
||||
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
|
||||
import type { AgentStatusInfo } from "../../types";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
|
||||
interface Props {
|
||||
agent: LibraryAgent;
|
||||
statusInfo: AgentStatusInfo;
|
||||
draggable?: boolean;
|
||||
}
|
||||
|
||||
export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
const { id, name, graph_id, can_access_graph, image_url } = agent;
|
||||
export function LibraryAgentCard({
|
||||
agent,
|
||||
statusInfo,
|
||||
draggable = true,
|
||||
}: Props) {
|
||||
const { id, name, image_url } = agent;
|
||||
const router = useRouter();
|
||||
const { triggerFavoriteAnimation } = useFavoriteAnimation();
|
||||
|
||||
function handleDragStart(e: React.DragEvent<HTMLDivElement>) {
|
||||
@@ -31,18 +42,14 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
e.dataTransfer.effectAllowed = "move";
|
||||
}
|
||||
|
||||
const {
|
||||
isFromMarketplace,
|
||||
isFavorite,
|
||||
profile,
|
||||
creator_image_url,
|
||||
handleToggleFavorite,
|
||||
} = useLibraryAgentCard({
|
||||
const { isFavorite, handleToggleFavorite } = useLibraryAgentCard({
|
||||
agent,
|
||||
onFavoriteAdd: triggerFavoriteAnimation,
|
||||
});
|
||||
|
||||
return (
|
||||
const hasError = statusInfo.status === "error";
|
||||
|
||||
const card = (
|
||||
<div
|
||||
draggable={draggable}
|
||||
onDragStart={handleDragStart}
|
||||
@@ -52,7 +59,10 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
layoutId={`agent-card-${id}`}
|
||||
data-testid="library-agent-card"
|
||||
data-agent-id={id}
|
||||
className="group relative inline-flex h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border border-zinc-100 bg-white hover:shadow-md"
|
||||
className={cn(
|
||||
"group relative inline-flex h-auto min-h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border bg-white hover:shadow-md",
|
||||
hasError ? "border-red-400" : "border-zinc-100",
|
||||
)}
|
||||
transition={{
|
||||
type: "spring",
|
||||
damping: 25,
|
||||
@@ -61,23 +71,10 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
style={{ willChange: "transform" }}
|
||||
>
|
||||
<NextLink href={`/library/agents/${id}`} className="flex-shrink-0">
|
||||
<div className="relative flex items-center gap-2 px-4 pt-3">
|
||||
<Avatar className="h-4 w-4 rounded-full">
|
||||
<AvatarImage
|
||||
src={
|
||||
isFromMarketplace
|
||||
? creator_image_url || "/avatar-placeholder.png"
|
||||
: profile?.avatar_url || "/avatar-placeholder.png"
|
||||
}
|
||||
alt={`${name} creator avatar`}
|
||||
/>
|
||||
<AvatarFallback size={48}>{name.charAt(0)}</AvatarFallback>
|
||||
</Avatar>
|
||||
<Text
|
||||
variant="small-medium"
|
||||
className="uppercase tracking-wide text-zinc-400"
|
||||
>
|
||||
{isFromMarketplace ? "FROM MARKETPLACE" : "Built by you"}
|
||||
<div className="relative flex items-center gap-3 pl-2 pr-4 pt-3">
|
||||
<StatusBadge status={statusInfo.status} />
|
||||
<Text variant="small" className="text-zinc-400">
|
||||
{statusInfo.totalRuns} tasks
|
||||
</Text>
|
||||
</div>
|
||||
</NextLink>
|
||||
@@ -89,7 +86,7 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
<AgentCardMenu agent={agent} />
|
||||
|
||||
<div className="flex w-full flex-1 flex-col px-4 pb-2">
|
||||
<Link
|
||||
<NextLink
|
||||
href={`/library/agents/${id}`}
|
||||
className="flex w-full items-start justify-between gap-2 no-underline hover:no-underline focus:ring-0"
|
||||
>
|
||||
@@ -126,30 +123,52 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
className="flex-shrink-0 rounded-small object-cover"
|
||||
/>
|
||||
)}
|
||||
</Link>
|
||||
</NextLink>
|
||||
|
||||
<div className="mt-auto flex w-full justify-start gap-6 border-t border-zinc-100 pb-1 pt-3">
|
||||
<Link
|
||||
href={`/library/agents/${id}`}
|
||||
<div className="mt-4 flex w-full items-center justify-end gap-1 border-t border-zinc-100 pb-0 pt-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => router.push(`/library/agents/${id}`)}
|
||||
data-testid="library-agent-card-see-runs-link"
|
||||
className="flex items-center gap-1 text-[13px]"
|
||||
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
|
||||
>
|
||||
See runs <CaretCircleRightIcon size={20} />
|
||||
</Link>
|
||||
|
||||
{can_access_graph && (
|
||||
<Link
|
||||
href={`/build?flowID=${graph_id}`}
|
||||
data-testid="library-agent-card-open-in-builder-link"
|
||||
className="flex items-center gap-1 text-[13px]"
|
||||
isExternal
|
||||
>
|
||||
Open in builder <CaretCircleRightIcon size={20} />
|
||||
</Link>
|
||||
)}
|
||||
<EyeIcon size={14} className="shrink-0" />
|
||||
See tasks
|
||||
</button>
|
||||
<ContextualActionButton
|
||||
status={statusInfo.status}
|
||||
agentID={id}
|
||||
executionID={statusInfo.activeExecutionID ?? undefined}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
const prompt = encodeURIComponent(
|
||||
`Tell me about ${name}, its current status, recent runs and how can I get the most out of it`,
|
||||
);
|
||||
router.push(`/copilot?autosubmit=true#prompt=${prompt}`);
|
||||
}}
|
||||
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
|
||||
>
|
||||
<ChatCircleDotsIcon size={14} className="shrink-0" />
|
||||
Chat
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
|
||||
if (hasError && statusInfo.lastError) {
|
||||
return (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>{card}</TooltipTrigger>
|
||||
<TooltipContent className="max-w-xs text-red-600">
|
||||
{statusInfo.lastError}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
return card;
|
||||
}
|
||||
|
||||
@@ -169,6 +169,7 @@ export function AgentCardMenu({ agent }: AgentCardMenuProps) {
|
||||
href={`/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`}
|
||||
target="_blank"
|
||||
className="flex items-center gap-2"
|
||||
data-testid="library-agent-card-open-in-builder-link"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
Edit agent
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
|
||||
import { LibraryAgentCard } from "../LibraryAgentCard/LibraryAgentCard";
|
||||
@@ -16,8 +17,11 @@ import {
|
||||
} from "framer-motion";
|
||||
import { LibraryFolderEditDialog } from "../LibraryFolderEditDialog/LibraryFolderEditDialog";
|
||||
import { LibraryFolderDeleteDialog } from "../LibraryFolderDeleteDialog/LibraryFolderDeleteDialog";
|
||||
import { LibraryTab } from "../../types";
|
||||
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
|
||||
import { useLibraryAgentList } from "./useLibraryAgentList";
|
||||
import { AgentBriefingPanel } from "../AgentBriefingPanel/AgentBriefingPanel";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useAgentStatusMap, getAgentStatus } from "../../hooks/useAgentStatus";
|
||||
|
||||
// cancels the current spring and starts a new one from current state.
|
||||
const containerVariants = {
|
||||
@@ -70,6 +74,10 @@ interface Props {
|
||||
tabs: LibraryTab[];
|
||||
activeTab: string;
|
||||
onTabChange: (tabId: string) => void;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
|
||||
fleetSummary?: FleetSummary;
|
||||
briefingAgents?: LibraryAgent[];
|
||||
}
|
||||
|
||||
export function LibraryAgentList({
|
||||
@@ -81,7 +89,12 @@ export function LibraryAgentList({
|
||||
tabs,
|
||||
activeTab,
|
||||
onTabChange,
|
||||
statusFilter = "all",
|
||||
onStatusFilterChange,
|
||||
fleetSummary,
|
||||
briefingAgents,
|
||||
}: Props) {
|
||||
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
|
||||
const shouldReduceMotion = useReducedMotion();
|
||||
const activeContainerVariants = shouldReduceMotion
|
||||
? reducedContainerVariants
|
||||
@@ -95,7 +108,7 @@ export function LibraryAgentList({
|
||||
const {
|
||||
isFavoritesTab,
|
||||
agentLoading,
|
||||
allAgentsCount,
|
||||
displayedCount,
|
||||
favoritesCount,
|
||||
agents,
|
||||
hasNextPage,
|
||||
@@ -116,18 +129,37 @@ export function LibraryAgentList({
|
||||
selectedFolderId,
|
||||
onFolderSelect,
|
||||
activeTab,
|
||||
statusFilter,
|
||||
});
|
||||
|
||||
const agentStatusMap = useAgentStatusMap(agents);
|
||||
|
||||
return (
|
||||
<>
|
||||
{isAgentBriefingEnabled &&
|
||||
!selectedFolderId &&
|
||||
fleetSummary &&
|
||||
briefingAgents &&
|
||||
briefingAgents.length > 0 && (
|
||||
<div className="mb-4">
|
||||
<AgentBriefingPanel
|
||||
summary={fleetSummary}
|
||||
agents={briefingAgents}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!selectedFolderId && (
|
||||
<LibrarySubSection
|
||||
tabs={tabs}
|
||||
activeTab={activeTab}
|
||||
onTabChange={onTabChange}
|
||||
allCount={allAgentsCount}
|
||||
allCount={displayedCount}
|
||||
favoritesCount={favoritesCount}
|
||||
setLibrarySort={setLibrarySort}
|
||||
statusFilter={statusFilter}
|
||||
onStatusFilterChange={onStatusFilterChange}
|
||||
fleetSummary={fleetSummary}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -219,7 +251,13 @@ export function LibraryAgentList({
|
||||
0.04,
|
||||
}}
|
||||
>
|
||||
<LibraryAgentCard agent={agent} />
|
||||
<LibraryAgentCard
|
||||
agent={agent}
|
||||
statusInfo={getAgentStatus(
|
||||
agentStatusMap,
|
||||
agent.graph_id,
|
||||
)}
|
||||
/>
|
||||
</motion.div>
|
||||
))}
|
||||
</motion.div>
|
||||
|
||||
@@ -21,7 +21,12 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useFavoriteAgents } from "../../hooks/useFavoriteAgents";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import type { AgentStatusFilter } from "../../types";
|
||||
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
|
||||
const FILTER_EXHAUST_THRESHOLD = 3;
|
||||
|
||||
interface Props {
|
||||
searchTerm: string;
|
||||
@@ -29,6 +34,7 @@ interface Props {
|
||||
selectedFolderId: string | null;
|
||||
onFolderSelect: (folderId: string | null) => void;
|
||||
activeTab: string;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
}
|
||||
|
||||
export function useLibraryAgentList({
|
||||
@@ -37,12 +43,16 @@ export function useLibraryAgentList({
|
||||
selectedFolderId,
|
||||
onFolderSelect,
|
||||
activeTab,
|
||||
statusFilter = "all",
|
||||
}: Props) {
|
||||
const isFavoritesTab = activeTab === "favorites";
|
||||
const { toast } = useToast();
|
||||
const stableQueryClient = getQueryClient();
|
||||
const queryClient = useQueryClient();
|
||||
const prevSortRef = useRef<LibraryAgentSort | null>(null);
|
||||
const [consecutiveEmptyPages, setConsecutiveEmptyPages] = useState(0);
|
||||
const prevFilteredLengthRef = useRef(0);
|
||||
const prevAgentsLengthRef = useRef(0);
|
||||
|
||||
const [editingFolder, setEditingFolder] = useState<LibraryFolder | null>(
|
||||
null,
|
||||
@@ -199,6 +209,90 @@ export function useLibraryAgentList({
|
||||
|
||||
const showFolders = !isFavoritesTab;
|
||||
|
||||
const { data: executions } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const { activeGraphIds, errorGraphIds, completedGraphIds } = useMemo(() => {
|
||||
const active = new Set<string>();
|
||||
const errors = new Set<string>();
|
||||
const completed = new Set<string>();
|
||||
const cutoff = Date.now() - 72 * 60 * 60 * 1000;
|
||||
for (const exec of executions ?? []) {
|
||||
if (
|
||||
exec.status === AgentExecutionStatus.RUNNING ||
|
||||
exec.status === AgentExecutionStatus.QUEUED ||
|
||||
exec.status === AgentExecutionStatus.REVIEW
|
||||
) {
|
||||
active.add(exec.graph_id);
|
||||
}
|
||||
const endedTs = exec.ended_at
|
||||
? exec.ended_at instanceof Date
|
||||
? exec.ended_at.getTime()
|
||||
: new Date(String(exec.ended_at)).getTime()
|
||||
: 0;
|
||||
if (
|
||||
(exec.status === AgentExecutionStatus.FAILED ||
|
||||
exec.status === AgentExecutionStatus.TERMINATED) &&
|
||||
endedTs > cutoff
|
||||
) {
|
||||
errors.add(exec.graph_id);
|
||||
}
|
||||
if (exec.status === AgentExecutionStatus.COMPLETED && endedTs > cutoff) {
|
||||
completed.add(exec.graph_id);
|
||||
}
|
||||
}
|
||||
return {
|
||||
activeGraphIds: active,
|
||||
errorGraphIds: errors,
|
||||
completedGraphIds: completed,
|
||||
};
|
||||
}, [executions]);
|
||||
|
||||
const filteredAgents = filterAgentsByStatus(
|
||||
agents,
|
||||
statusFilter,
|
||||
activeGraphIds,
|
||||
errorGraphIds,
|
||||
completedGraphIds,
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (statusFilter === "all") {
|
||||
setConsecutiveEmptyPages(0);
|
||||
prevFilteredLengthRef.current = filteredAgents.length;
|
||||
prevAgentsLengthRef.current = agents.length;
|
||||
return;
|
||||
}
|
||||
|
||||
if (agents.length > prevAgentsLengthRef.current) {
|
||||
const newFilteredCount = filteredAgents.length;
|
||||
const previousCount = prevFilteredLengthRef.current;
|
||||
|
||||
if (newFilteredCount > previousCount) {
|
||||
setConsecutiveEmptyPages(0);
|
||||
} else {
|
||||
setConsecutiveEmptyPages((prev) => prev + 1);
|
||||
}
|
||||
}
|
||||
|
||||
prevAgentsLengthRef.current = agents.length;
|
||||
prevFilteredLengthRef.current = filteredAgents.length;
|
||||
}, [agents.length, filteredAgents.length, statusFilter]);
|
||||
|
||||
useEffect(() => {
|
||||
setConsecutiveEmptyPages(0);
|
||||
prevFilteredLengthRef.current = 0;
|
||||
prevAgentsLengthRef.current = 0;
|
||||
}, [statusFilter]);
|
||||
|
||||
const filteredExhausted =
|
||||
statusFilter !== "all" && consecutiveEmptyPages >= FILTER_EXHAUST_THRESHOLD;
|
||||
|
||||
// When a filter is active, show the filtered count instead of the API total.
|
||||
const displayedCount =
|
||||
statusFilter === "all" ? allAgentsCount : filteredAgents.length;
|
||||
|
||||
function handleFolderDeleted() {
|
||||
if (selectedFolderId === deletingFolder?.id) {
|
||||
onFolderSelect(null);
|
||||
@@ -210,9 +304,10 @@ export function useLibraryAgentList({
|
||||
agentLoading,
|
||||
agentCount,
|
||||
allAgentsCount,
|
||||
displayedCount,
|
||||
favoritesCount: favoriteAgentsData.agentCount,
|
||||
agents,
|
||||
hasNextPage: agentsHasNextPage,
|
||||
agents: filteredAgents,
|
||||
hasNextPage: agentsHasNextPage && !filteredExhausted,
|
||||
isFetchingNextPage: agentsIsFetchingNextPage,
|
||||
fetchNextPage: agentsFetchNextPage,
|
||||
foldersData,
|
||||
@@ -226,3 +321,46 @@ export function useLibraryAgentList({
|
||||
handleFolderDeleted,
|
||||
};
|
||||
}
|
||||
|
||||
function filterAgentsByStatus<
|
||||
T extends {
|
||||
graph_id: string;
|
||||
has_external_trigger: boolean;
|
||||
recommended_schedule_cron?: string | null;
|
||||
},
|
||||
>(
|
||||
agents: T[],
|
||||
statusFilter: AgentStatusFilter,
|
||||
activeGraphIds: Set<string>,
|
||||
errorGraphIds: Set<string>,
|
||||
completedGraphIds: Set<string>,
|
||||
): T[] {
|
||||
if (statusFilter === "all") return agents;
|
||||
return agents.filter((agent) => {
|
||||
const isRunning = activeGraphIds.has(agent.graph_id);
|
||||
const hasError = errorGraphIds.has(agent.graph_id);
|
||||
|
||||
if (statusFilter === "running") return isRunning;
|
||||
if (statusFilter === "attention") return hasError && !isRunning;
|
||||
if (statusFilter === "completed")
|
||||
return completedGraphIds.has(agent.graph_id);
|
||||
if (statusFilter === "listening")
|
||||
return !isRunning && !hasError && agent.has_external_trigger;
|
||||
if (statusFilter === "scheduled")
|
||||
return (
|
||||
!isRunning &&
|
||||
!hasError &&
|
||||
!agent.has_external_trigger &&
|
||||
!!agent.recommended_schedule_cron
|
||||
);
|
||||
if (statusFilter === "idle")
|
||||
return (
|
||||
!isRunning &&
|
||||
!hasError &&
|
||||
!agent.has_external_trigger &&
|
||||
!agent.recommended_schedule_cron
|
||||
);
|
||||
if (statusFilter === "healthy") return !hasError;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
FolderIcon,
|
||||
FolderColor,
|
||||
folderCardStyles,
|
||||
resolveColor,
|
||||
} from "./FolderIcon";
|
||||
import { FolderIcon, FolderColor } from "./FolderIcon";
|
||||
import { useState } from "react";
|
||||
import { PencilSimpleIcon, TrashIcon } from "@phosphor-icons/react";
|
||||
import type { AgentStatus } from "../../types";
|
||||
import { StatusBadge } from "../StatusBadge/StatusBadge";
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
@@ -21,6 +18,8 @@ interface Props {
|
||||
onDelete?: () => void;
|
||||
onAgentDrop?: (agentId: string, folderId: string) => void;
|
||||
onClick?: () => void;
|
||||
/** Worst status among child agents (optional, for status aggregation). */
|
||||
worstStatus?: AgentStatus;
|
||||
}
|
||||
|
||||
export function LibraryFolder({
|
||||
@@ -33,11 +32,10 @@ export function LibraryFolder({
|
||||
onDelete,
|
||||
onAgentDrop,
|
||||
onClick,
|
||||
worstStatus,
|
||||
}: Props) {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const [isDragOver, setIsDragOver] = useState(false);
|
||||
const resolvedColor = resolveColor(color);
|
||||
const cardStyle = folderCardStyles[resolvedColor];
|
||||
|
||||
function handleDragOver(e: React.DragEvent<HTMLDivElement>) {
|
||||
if (e.dataTransfer.types.includes("application/agent-id")) {
|
||||
@@ -64,10 +62,10 @@ export function LibraryFolder({
|
||||
<div
|
||||
data-testid="library-folder"
|
||||
data-folder-id={id}
|
||||
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 transition-all duration-200 hover:shadow-md ${
|
||||
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 shadow-sm backdrop-blur-md transition-all duration-200 hover:shadow-md ${
|
||||
isDragOver
|
||||
? "border-blue-400 bg-blue-50 ring-2 ring-blue-200"
|
||||
: `${cardStyle.border} ${cardStyle.bg}`
|
||||
: "border-indigo-200/40 bg-gradient-to-br from-indigo-50/40 via-white/70 to-purple-50/30"
|
||||
}`}
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
@@ -76,7 +74,7 @@ export function LibraryFolder({
|
||||
onDrop={handleDrop}
|
||||
onClick={onClick}
|
||||
>
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
<div className="flex w-full items-center justify-between gap-4">
|
||||
{/* Left side - Folder name and agent count */}
|
||||
<div className="flex flex-1 flex-col gap-2">
|
||||
<Text
|
||||
@@ -86,17 +84,22 @@ export function LibraryFolder({
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-zinc-500"
|
||||
data-testid="library-folder-agent-count"
|
||||
>
|
||||
{agentCount} {agentCount === 1 ? "agent" : "agents"}
|
||||
</Text>
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-zinc-500"
|
||||
data-testid="library-folder-agent-count"
|
||||
>
|
||||
{agentCount} {agentCount === 1 ? "agent" : "agents"}
|
||||
</Text>
|
||||
{worstStatus && worstStatus !== "idle" && (
|
||||
<StatusBadge status={worstStatus} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right side - Custom folder icon */}
|
||||
<div className="flex-shrink-0">
|
||||
<div className="relative top-5 flex flex-shrink-0 items-center">
|
||||
<FolderIcon isOpen={isHovered} color={color} icon={icon} />
|
||||
</div>
|
||||
</div>
|
||||
@@ -114,7 +117,7 @@ export function LibraryFolder({
|
||||
e.stopPropagation();
|
||||
onEdit?.();
|
||||
}}
|
||||
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
|
||||
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
|
||||
>
|
||||
<PencilSimpleIcon className="h-4 w-4" />
|
||||
</Button>
|
||||
@@ -126,7 +129,7 @@ export function LibraryFolder({
|
||||
e.stopPropagation();
|
||||
onDelete?.();
|
||||
}}
|
||||
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
|
||||
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
|
||||
>
|
||||
<TrashIcon className="h-4 w-4" />
|
||||
</Button>
|
||||
|
||||
@@ -19,11 +19,11 @@ export function LibrarySortMenu({ setLibrarySort }: Props) {
|
||||
const { handleSortChange } = useLibrarySortMenu({ setLibrarySort });
|
||||
return (
|
||||
<div className="flex items-center" data-testid="sort-by-dropdown">
|
||||
<span className="hidden whitespace-nowrap text-sm sm:inline">
|
||||
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
|
||||
sort by
|
||||
</span>
|
||||
<Select onValueChange={handleSortChange}>
|
||||
<SelectTrigger className="ml-1 w-fit space-x-1 border-none px-0 text-sm underline underline-offset-4 shadow-none">
|
||||
<SelectTrigger className="!m-0 ml-1 w-fit space-x-1 border-none !bg-transparent px-[1rem] text-sm underline underline-offset-4 !shadow-none !ring-offset-transparent">
|
||||
<ArrowDownNarrowWideIcon className="h-4 w-4 sm:hidden" />
|
||||
<SelectValue placeholder="Last Modified" />
|
||||
</SelectTrigger>
|
||||
|
||||
@@ -6,9 +6,10 @@ import {
|
||||
} from "@/components/molecules/TabsLine/TabsLine";
|
||||
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
|
||||
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
|
||||
import { LibraryTab } from "../../types";
|
||||
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
|
||||
import LibraryFolderCreationDialog from "../LibraryFolderCreationDialog/LibraryFolderCreationDialog";
|
||||
import { LibrarySortMenu } from "../LibrarySortMenu/LibrarySortMenu";
|
||||
import { AgentFilterMenu } from "../AgentFilterMenu/AgentFilterMenu";
|
||||
|
||||
interface Props {
|
||||
tabs: LibraryTab[];
|
||||
@@ -17,6 +18,9 @@ interface Props {
|
||||
allCount: number;
|
||||
favoritesCount: number;
|
||||
setLibrarySort: (value: LibraryAgentSort) => void;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
|
||||
fleetSummary?: FleetSummary;
|
||||
}
|
||||
|
||||
export function LibrarySubSection({
|
||||
@@ -26,6 +30,9 @@ export function LibrarySubSection({
|
||||
allCount,
|
||||
favoritesCount,
|
||||
setLibrarySort,
|
||||
statusFilter = "all",
|
||||
onStatusFilterChange,
|
||||
fleetSummary,
|
||||
}: Props) {
|
||||
const { registerFavoritesTabRef } = useFavoriteAnimation();
|
||||
const favoritesRef = useRef<HTMLButtonElement>(null);
|
||||
@@ -68,8 +75,15 @@ export function LibrarySubSection({
|
||||
))}
|
||||
</TabsLineList>
|
||||
</TabsLine>
|
||||
<div className="hidden items-center gap-6 md:flex">
|
||||
<div className="relative top-1.5 hidden items-center gap-6 md:flex">
|
||||
<LibraryFolderCreationDialog />
|
||||
{fleetSummary && onStatusFilterChange && (
|
||||
<AgentFilterMenu
|
||||
value={statusFilter}
|
||||
onChange={onStatusFilterChange}
|
||||
summary={fleetSummary}
|
||||
/>
|
||||
)}
|
||||
<LibrarySortMenu setLibrarySort={setLibrarySort} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
.spinner {
|
||||
aspect-ratio: 1;
|
||||
border-radius: 50%;
|
||||
background:
|
||||
radial-gradient(farthest-side, currentColor 94%, #0000) top/3px 3px
|
||||
no-repeat,
|
||||
conic-gradient(#0000 30%, currentColor);
|
||||
-webkit-mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
|
||||
mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
|
||||
animation: spin 1s infinite linear;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
100% {
|
||||
transform: rotate(1turn);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
WarningCircleIcon,
|
||||
ClockCountdownIcon,
|
||||
CheckCircleIcon,
|
||||
ChatCircleDotsIcon,
|
||||
EarIcon,
|
||||
CalendarDotsIcon,
|
||||
MoonIcon,
|
||||
EyeIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { SitrepItemData, SitrepPriority } from "../../types";
|
||||
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
|
||||
import styles from "./SitrepItem.module.css";
|
||||
|
||||
interface Props {
|
||||
item: SitrepItemData;
|
||||
}
|
||||
|
||||
const PRIORITY_CONFIG: Record<
|
||||
SitrepPriority,
|
||||
{
|
||||
icon?: typeof WarningCircleIcon;
|
||||
color: string;
|
||||
bg: string;
|
||||
cssSpinner?: boolean;
|
||||
}
|
||||
> = {
|
||||
error: {
|
||||
icon: WarningCircleIcon,
|
||||
color: "text-red-500",
|
||||
bg: "bg-red-50",
|
||||
},
|
||||
running: {
|
||||
color: "text-zinc-800",
|
||||
bg: "",
|
||||
cssSpinner: true,
|
||||
},
|
||||
stale: {
|
||||
icon: ClockCountdownIcon,
|
||||
color: "text-yellow-600",
|
||||
bg: "bg-yellow-50",
|
||||
},
|
||||
success: {
|
||||
icon: CheckCircleIcon,
|
||||
color: "text-green-600",
|
||||
bg: "bg-green-50",
|
||||
},
|
||||
listening: {
|
||||
icon: EarIcon,
|
||||
color: "text-purple-500",
|
||||
bg: "bg-purple-50",
|
||||
},
|
||||
scheduled: {
|
||||
icon: CalendarDotsIcon,
|
||||
color: "text-yellow-600",
|
||||
bg: "bg-yellow-50",
|
||||
},
|
||||
idle: {
|
||||
icon: MoonIcon,
|
||||
color: "text-zinc-400",
|
||||
bg: "bg-zinc-100",
|
||||
},
|
||||
};
|
||||
|
||||
export function SitrepItem({ item }: Props) {
|
||||
const config = PRIORITY_CONFIG[item.priority];
|
||||
const router = useRouter();
|
||||
|
||||
function handleAskAutoPilot() {
|
||||
const prompt = buildAutoPilotPrompt(item);
|
||||
const encoded = encodeURIComponent(prompt);
|
||||
router.push(`/copilot?autosubmit=true#prompt=${encoded}`);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-2 rounded-medium border border-zinc-200/50 bg-transparent p-2 sm:flex-row sm:items-center sm:gap-3",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 flex-1 items-center gap-3">
|
||||
{item.agentImageUrl ? (
|
||||
<img
|
||||
src={item.agentImageUrl}
|
||||
alt={item.agentName}
|
||||
className="h-6 w-6 flex-shrink-0 rounded-full object-cover"
|
||||
/>
|
||||
) : (
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-6 w-6 flex-shrink-0 items-center justify-center rounded-full",
|
||||
config.bg,
|
||||
)}
|
||||
>
|
||||
{config.cssSpinner ? (
|
||||
<div
|
||||
className={cn(
|
||||
styles.spinner,
|
||||
"h-[21px] w-[21px] text-zinc-800",
|
||||
)}
|
||||
/>
|
||||
) : (
|
||||
config.icon && (
|
||||
<config.icon size={14} className={config.color} weight="fill" />
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="min-w-0 flex-1">
|
||||
<Text variant="body-medium" className="leading-tight text-zinc-900">
|
||||
{item.agentName}
|
||||
</Text>
|
||||
<Text variant="small" className="leading-tight text-zinc-500">
|
||||
{item.message}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-shrink-0 flex-wrap items-center justify-center gap-1.5 sm:flex-nowrap sm:justify-end">
|
||||
{item.priority === "success" ? (
|
||||
<NextLink
|
||||
href={`/library/agents/${item.agentID}${item.executionID ? `?activeItem=${item.executionID}` : ""}`}
|
||||
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
|
||||
>
|
||||
<EyeIcon size={14} className="shrink-0" />
|
||||
See task
|
||||
</NextLink>
|
||||
) : (
|
||||
<ContextualActionButton
|
||||
status={item.status}
|
||||
agentID={item.agentID}
|
||||
executionID={item.executionID}
|
||||
/>
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleAskAutoPilot}
|
||||
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
|
||||
>
|
||||
<ChatCircleDotsIcon size={14} className="shrink-0" />
|
||||
Ask AutoPilot
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function buildAutoPilotPrompt(item: SitrepItemData): string {
|
||||
switch (item.priority) {
|
||||
case "error":
|
||||
return `What happened with ${item.agentName}? It says "${item.message}" — can you check the logs and tell me what to fix?`;
|
||||
case "running":
|
||||
return `Give me a status update on the ${item.agentName} run — what has it found so far?`;
|
||||
case "stale":
|
||||
return `${item.agentName} hasn't run recently. Should I keep it or update and re-run it?`;
|
||||
case "success":
|
||||
return `Show me what ${item.agentName} found in its last run — summarize the results and any key takeaways.`;
|
||||
case "listening":
|
||||
return `What is ${item.agentName} listening for? Give me a summary of its trigger configuration.`;
|
||||
case "scheduled":
|
||||
return `When is ${item.agentName} scheduled to run next?`;
|
||||
case "idle":
|
||||
return `${item.agentName} has been idle. Should I keep it or update and re-run it?`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ClockCounterClockwise } from "@phosphor-icons/react";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { useSitrepItems } from "./useSitrepItems";
|
||||
import { SitrepItem } from "./SitrepItem";
|
||||
|
||||
interface Props {
|
||||
agents: LibraryAgent[];
|
||||
maxItems?: number;
|
||||
}
|
||||
|
||||
export function SitrepList({ agents, maxItems = 10 }: Props) {
|
||||
const items = useSitrepItems(agents, maxItems);
|
||||
|
||||
if (items.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="mb-2 flex items-center gap-1.5">
|
||||
<ClockCounterClockwise size={16} className="text-zinc-700" />
|
||||
<Text variant="body-medium" className="text-zinc-700">
|
||||
Recent tasks
|
||||
</Text>
|
||||
</div>
|
||||
<div className="grid grid-cols-1 gap-1 lg:grid-cols-2">
|
||||
{items.map((item) => (
|
||||
<SitrepItem key={item.id} item={item} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useMemo } from "react";
|
||||
import type { SitrepItemData, SitrepPriority } from "../../types";
|
||||
import {
|
||||
isActive,
|
||||
isFailed,
|
||||
toEndTime,
|
||||
endedAfter,
|
||||
runningMessage,
|
||||
SEVENTY_TWO_HOURS_MS,
|
||||
} from "../../hooks/executionHelpers";
|
||||
|
||||
export function useSitrepItems(
|
||||
agents: LibraryAgent[],
|
||||
maxItems: number,
|
||||
): SitrepItemData[] {
|
||||
const { data: executions } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
return useMemo(() => {
|
||||
if (!executions || agents.length === 0) return [];
|
||||
|
||||
const graphIdToAgent = new Map(agents.map((a) => [a.graph_id, a]));
|
||||
const agentExecutions = groupByAgent(executions, graphIdToAgent);
|
||||
const items: SitrepItemData[] = [];
|
||||
|
||||
for (const [agent, execs] of agentExecutions) {
|
||||
const item = buildSitrepFromExecutions(agent, execs);
|
||||
if (item) items.push(item);
|
||||
}
|
||||
|
||||
const order: Record<SitrepPriority, number> = {
|
||||
error: 0,
|
||||
running: 1,
|
||||
stale: 2,
|
||||
success: 3,
|
||||
listening: 4,
|
||||
scheduled: 5,
|
||||
idle: 6,
|
||||
};
|
||||
items.sort((a, b) => order[a.priority] - order[b.priority]);
|
||||
|
||||
return items.slice(0, maxItems);
|
||||
}, [agents, executions, maxItems]);
|
||||
}
|
||||
|
||||
function groupByAgent(
|
||||
executions: GraphExecutionMeta[],
|
||||
graphIdToAgent: Map<string, LibraryAgent>,
|
||||
): Map<LibraryAgent, GraphExecutionMeta[]> {
|
||||
const map = new Map<LibraryAgent, GraphExecutionMeta[]>();
|
||||
|
||||
for (const exec of executions) {
|
||||
const agent = graphIdToAgent.get(exec.graph_id);
|
||||
if (!agent) continue;
|
||||
const list = map.get(agent);
|
||||
if (list) {
|
||||
list.push(exec);
|
||||
} else {
|
||||
map.set(agent, [exec]);
|
||||
}
|
||||
}
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
function buildSitrepFromExecutions(
|
||||
agent: LibraryAgent,
|
||||
executions: GraphExecutionMeta[],
|
||||
): SitrepItemData | null {
|
||||
const active = executions.find((e) => isActive(e.status));
|
||||
if (active) {
|
||||
return {
|
||||
id: `${agent.id}-${active.id}`,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
executionID: active.id,
|
||||
priority: "running",
|
||||
message:
|
||||
active.stats?.activity_status ??
|
||||
runningMessage(active.status, active.started_at),
|
||||
status: "running",
|
||||
};
|
||||
}
|
||||
|
||||
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
|
||||
const recent = executions
|
||||
.filter((e) => endedAfter(e, cutoff))
|
||||
.sort((a, b) => toEndTime(b) - toEndTime(a));
|
||||
|
||||
const lastFailed = recent.find((e) => isFailed(e.status));
|
||||
if (lastFailed) {
|
||||
const errorMsg =
|
||||
lastFailed.stats?.error ??
|
||||
lastFailed.stats?.activity_status ??
|
||||
"Execution failed";
|
||||
return {
|
||||
id: `${agent.id}-${lastFailed.id}`,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
executionID: lastFailed.id,
|
||||
priority: "error",
|
||||
message: typeof errorMsg === "string" ? errorMsg : "Execution failed",
|
||||
status: "error",
|
||||
};
|
||||
}
|
||||
|
||||
const lastCompleted = recent.find(
|
||||
(e) => e.status === AgentExecutionStatus.COMPLETED,
|
||||
);
|
||||
if (lastCompleted) {
|
||||
const summary =
|
||||
lastCompleted.stats?.activity_status ?? "Completed successfully";
|
||||
return {
|
||||
id: `${agent.id}-${lastCompleted.id}`,
|
||||
agentID: agent.id,
|
||||
agentName: agent.name,
|
||||
executionID: lastCompleted.id,
|
||||
priority: "success",
|
||||
message: typeof summary === "string" ? summary : "Completed successfully",
|
||||
status: "idle",
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { AgentStatus } from "../../types";
|
||||
|
||||
const STATUS_CONFIG: Record<
|
||||
AgentStatus,
|
||||
{ label: string; bg: string; text: string; pulse: boolean }
|
||||
> = {
|
||||
running: {
|
||||
label: "Running",
|
||||
bg: "",
|
||||
text: "text-blue-600",
|
||||
pulse: true,
|
||||
},
|
||||
error: {
|
||||
label: "Error",
|
||||
bg: "",
|
||||
text: "text-red-500",
|
||||
pulse: false,
|
||||
},
|
||||
listening: {
|
||||
label: "Listening",
|
||||
bg: "",
|
||||
text: "text-purple-500",
|
||||
pulse: true,
|
||||
},
|
||||
scheduled: {
|
||||
label: "Scheduled",
|
||||
bg: "",
|
||||
text: "text-yellow-600",
|
||||
pulse: false,
|
||||
},
|
||||
idle: {
|
||||
label: "Idle",
|
||||
bg: "",
|
||||
text: "text-zinc-500",
|
||||
pulse: false,
|
||||
},
|
||||
};
|
||||
|
||||
interface Props {
|
||||
status: AgentStatus;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function StatusBadge({ status, className }: Props) {
|
||||
const config = STATUS_CONFIG[status];
|
||||
|
||||
return (
|
||||
<span
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium",
|
||||
config.bg,
|
||||
config.text,
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
"inline-block h-1.5 w-1.5 rounded-full",
|
||||
config.pulse && "animate-pulse",
|
||||
statusDotColor(status),
|
||||
)}
|
||||
/>
|
||||
{config.label}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
function statusDotColor(status: AgentStatus): string {
|
||||
switch (status) {
|
||||
case "running":
|
||||
return "bg-blue-500";
|
||||
case "error":
|
||||
return "bg-red-500";
|
||||
case "listening":
|
||||
return "bg-purple-500";
|
||||
case "scheduled":
|
||||
return "bg-yellow-500";
|
||||
case "idle":
|
||||
return "bg-zinc-400";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||
|
||||
export const SEVENTY_TWO_HOURS_MS = 72 * 60 * 60 * 1000;
|
||||
|
||||
export function isActive(status: string): boolean {
|
||||
return (
|
||||
status === AgentExecutionStatus.RUNNING ||
|
||||
status === AgentExecutionStatus.QUEUED ||
|
||||
status === AgentExecutionStatus.REVIEW
|
||||
);
|
||||
}
|
||||
|
||||
export function isFailed(status: string): boolean {
|
||||
return (
|
||||
status === AgentExecutionStatus.FAILED ||
|
||||
status === AgentExecutionStatus.TERMINATED
|
||||
);
|
||||
}
|
||||
|
||||
export function toEndTime(exec: GraphExecutionMeta): number {
|
||||
if (!exec.ended_at) return 0;
|
||||
return exec.ended_at instanceof Date
|
||||
? exec.ended_at.getTime()
|
||||
: new Date(exec.ended_at).getTime();
|
||||
}
|
||||
|
||||
export function endedAfter(exec: GraphExecutionMeta, cutoff: number): boolean {
|
||||
if (!exec.ended_at) return false;
|
||||
return toEndTime(exec) > cutoff;
|
||||
}
|
||||
|
||||
export function runningMessage(
|
||||
status: string,
|
||||
startedAt?: string | Date | null,
|
||||
): string {
|
||||
if (status === AgentExecutionStatus.QUEUED) return "Queued for execution";
|
||||
if (status === AgentExecutionStatus.REVIEW) return "Awaiting review";
|
||||
if (!startedAt) return "Currently executing";
|
||||
const ms =
|
||||
Date.now() -
|
||||
(startedAt instanceof Date
|
||||
? startedAt.getTime()
|
||||
: new Date(startedAt).getTime());
|
||||
return `Running for ${formatRelativeDuration(ms)}`;
|
||||
}
|
||||
|
||||
export function formatRelativeDuration(ms: number): string {
|
||||
const seconds = Math.floor(ms / 1000);
|
||||
if (seconds < 60) return "a few seconds";
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
if (minutes < 60) return `${minutes}m`;
|
||||
const hours = Math.floor(minutes / 60);
|
||||
const remainingMin = minutes % 60;
|
||||
if (hours < 24)
|
||||
return remainingMin > 0 ? `${hours}h ${remainingMin}m` : `${hours}h`;
|
||||
const days = Math.floor(hours / 24);
|
||||
return `${days}d ${hours % 24}h`;
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import type {
|
||||
AgentStatus,
|
||||
AgentHealth,
|
||||
AgentStatusInfo,
|
||||
FleetSummary,
|
||||
} from "../types";
|
||||
import {
|
||||
isActive,
|
||||
isFailed,
|
||||
toEndTime,
|
||||
SEVENTY_TWO_HOURS_MS,
|
||||
} from "./executionHelpers";
|
||||
|
||||
function deriveHealth(
|
||||
status: AgentStatus,
|
||||
lastRunAt: string | null,
|
||||
): AgentHealth {
|
||||
if (status === "error") return "attention";
|
||||
if (status === "idle" && lastRunAt) {
|
||||
const daysSince =
|
||||
(Date.now() - new Date(lastRunAt).getTime()) / (1000 * 60 * 60 * 24);
|
||||
if (daysSince > 14) return "stale";
|
||||
}
|
||||
return "good";
|
||||
}
|
||||
|
||||
function computeAgentStatus(
|
||||
agent: LibraryAgent,
|
||||
agentExecutions: GraphExecutionMeta[],
|
||||
): AgentStatusInfo {
|
||||
const activeExec = agentExecutions.find((e) => isActive(e.status));
|
||||
|
||||
let status: AgentStatus;
|
||||
let lastError: string | null = null;
|
||||
let lastRunAt: string | null = null;
|
||||
const activeExecutionID = activeExec?.id ?? null;
|
||||
|
||||
if (activeExec) {
|
||||
status = "running";
|
||||
} else {
|
||||
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
|
||||
const recentFailed = agentExecutions.find(
|
||||
(e) =>
|
||||
isFailed(e.status) &&
|
||||
e.ended_at &&
|
||||
new Date(
|
||||
e.ended_at instanceof Date ? e.ended_at.getTime() : e.ended_at,
|
||||
).getTime() > cutoff,
|
||||
);
|
||||
|
||||
if (recentFailed) {
|
||||
status = "error";
|
||||
lastError =
|
||||
(recentFailed.stats?.error as string) ??
|
||||
(recentFailed.stats?.activity_status as string) ??
|
||||
"Execution failed";
|
||||
} else if (agent.has_external_trigger) {
|
||||
status = "listening";
|
||||
} else if (agent.recommended_schedule_cron) {
|
||||
status = "scheduled";
|
||||
} else {
|
||||
status = "idle";
|
||||
}
|
||||
}
|
||||
|
||||
const completedExecs = agentExecutions.filter((e) => e.ended_at);
|
||||
if (completedExecs.length > 0) {
|
||||
const sorted = completedExecs.sort((a, b) => toEndTime(b) - toEndTime(a));
|
||||
const endedAt = sorted[0].ended_at;
|
||||
lastRunAt =
|
||||
endedAt instanceof Date ? endedAt.toISOString() : String(endedAt);
|
||||
}
|
||||
|
||||
const totalRuns = agent.execution_count ?? agentExecutions.length;
|
||||
|
||||
return {
|
||||
status,
|
||||
health: deriveHealth(status, lastRunAt),
|
||||
progress: null,
|
||||
totalRuns,
|
||||
lastRunAt,
|
||||
lastError,
|
||||
activeExecutionID,
|
||||
monthlySpend: 0,
|
||||
nextScheduledRun: null,
|
||||
triggerType: agent.has_external_trigger ? "webhook" : null,
|
||||
};
|
||||
}
|
||||
|
||||
export function useAgentStatusMap(
|
||||
agents: LibraryAgent[],
|
||||
): Map<string, AgentStatusInfo> {
|
||||
const { data: executions } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
return useMemo(() => {
|
||||
const map = new Map<string, AgentStatusInfo>();
|
||||
const execsByGraph = new Map<string, GraphExecutionMeta[]>();
|
||||
|
||||
for (const exec of executions ?? []) {
|
||||
const list = execsByGraph.get(exec.graph_id);
|
||||
if (list) {
|
||||
list.push(exec);
|
||||
} else {
|
||||
execsByGraph.set(exec.graph_id, [exec]);
|
||||
}
|
||||
}
|
||||
|
||||
for (const agent of agents) {
|
||||
const agentExecs = execsByGraph.get(agent.graph_id) ?? [];
|
||||
map.set(agent.graph_id, computeAgentStatus(agent, agentExecs));
|
||||
}
|
||||
|
||||
return map;
|
||||
}, [agents, executions]);
|
||||
}
|
||||
|
||||
const DEFAULT_STATUS: AgentStatusInfo = {
|
||||
status: "idle",
|
||||
health: "good",
|
||||
progress: null,
|
||||
totalRuns: 0,
|
||||
lastRunAt: null,
|
||||
lastError: null,
|
||||
activeExecutionID: null,
|
||||
monthlySpend: 0,
|
||||
nextScheduledRun: null,
|
||||
triggerType: null,
|
||||
};
|
||||
|
||||
export function getAgentStatus(
|
||||
statusMap: Map<string, AgentStatusInfo>,
|
||||
graphID: string,
|
||||
): AgentStatusInfo {
|
||||
return statusMap.get(graphID) ?? DEFAULT_STATUS;
|
||||
}
|
||||
|
||||
export function useFleetSummary(agents: LibraryAgent[]): FleetSummary {
|
||||
const { data: executions } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
return useMemo(() => {
|
||||
const counts: FleetSummary = {
|
||||
running: 0,
|
||||
error: 0,
|
||||
completed: 0,
|
||||
listening: 0,
|
||||
scheduled: 0,
|
||||
idle: 0,
|
||||
monthlySpend: 0,
|
||||
};
|
||||
|
||||
const activeGraphIds = new Set<string>();
|
||||
const errorGraphIds = new Set<string>();
|
||||
const completedGraphIds = new Set<string>();
|
||||
|
||||
if (executions) {
|
||||
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
|
||||
for (const exec of executions) {
|
||||
if (isActive(exec.status)) {
|
||||
activeGraphIds.add(exec.graph_id);
|
||||
}
|
||||
const endedTs = exec.ended_at
|
||||
? new Date(
|
||||
exec.ended_at instanceof Date
|
||||
? exec.ended_at.getTime()
|
||||
: exec.ended_at,
|
||||
).getTime()
|
||||
: 0;
|
||||
if (isFailed(exec.status) && endedTs > cutoff) {
|
||||
errorGraphIds.add(exec.graph_id);
|
||||
}
|
||||
if (
|
||||
exec.status === AgentExecutionStatus.COMPLETED &&
|
||||
endedTs > cutoff
|
||||
) {
|
||||
completedGraphIds.add(exec.graph_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const agent of agents) {
|
||||
if (activeGraphIds.has(agent.graph_id)) {
|
||||
counts.running += 1;
|
||||
} else if (errorGraphIds.has(agent.graph_id)) {
|
||||
counts.error += 1;
|
||||
} else if (agent.has_external_trigger) {
|
||||
counts.listening += 1;
|
||||
} else if (agent.recommended_schedule_cron) {
|
||||
counts.scheduled += 1;
|
||||
} else {
|
||||
counts.idle += 1;
|
||||
}
|
||||
if (completedGraphIds.has(agent.graph_id)) {
|
||||
counts.completed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return counts;
|
||||
}, [agents, executions]);
|
||||
}
|
||||
|
||||
export { deriveHealth };
|
||||
@@ -0,0 +1,116 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
getGetV1ListAllExecutionsQueryKey,
|
||||
useGetV1ListAllExecutions,
|
||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useExecutionEvents } from "@/hooks/useExecutionEvents";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import type { FleetSummary } from "../types";
|
||||
import { isActive, isFailed, SEVENTY_TWO_HOURS_MS } from "./executionHelpers";
|
||||
|
||||
function isRecentFailure(
|
||||
status: string,
|
||||
endedAt?: string | Date | null,
|
||||
): boolean {
|
||||
if (!isFailed(status)) return false;
|
||||
if (!endedAt) return false;
|
||||
const ts =
|
||||
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
|
||||
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
|
||||
}
|
||||
|
||||
function isRecentCompletion(
|
||||
status: string,
|
||||
endedAt?: string | Date | null,
|
||||
): boolean {
|
||||
if (status !== AgentExecutionStatus.COMPLETED) return false;
|
||||
if (!endedAt) return false;
|
||||
const ts =
|
||||
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
|
||||
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
|
||||
}
|
||||
|
||||
export function useLibraryFleetSummary(
|
||||
agents: LibraryAgent[],
|
||||
): FleetSummary | undefined {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { data: executions, isSuccess } = useGetV1ListAllExecutions({
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const graphIDs = useMemo(() => agents.map((a) => a.graph_id), [agents]);
|
||||
|
||||
const handleExecutionUpdate = useCallback(() => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV1ListAllExecutionsQueryKey(),
|
||||
});
|
||||
}, [queryClient]);
|
||||
|
||||
useExecutionEvents({
|
||||
graphIds: graphIDs.length > 0 ? graphIDs : undefined,
|
||||
enabled: graphIDs.length > 0,
|
||||
onExecutionUpdate: handleExecutionUpdate,
|
||||
});
|
||||
|
||||
return useMemo(() => {
|
||||
if (!isSuccess || !executions) return undefined;
|
||||
|
||||
const agentsWithActiveExecution = new Set<string>();
|
||||
const agentsWithRecentFailure = new Set<string>();
|
||||
const agentsWithRecentCompletion = new Set<string>();
|
||||
|
||||
for (const exec of executions) {
|
||||
if (isActive(exec.status)) {
|
||||
agentsWithActiveExecution.add(exec.graph_id);
|
||||
}
|
||||
if (isRecentFailure(exec.status, exec.ended_at)) {
|
||||
agentsWithRecentFailure.add(exec.graph_id);
|
||||
}
|
||||
if (isRecentCompletion(exec.status, exec.ended_at)) {
|
||||
agentsWithRecentCompletion.add(exec.graph_id);
|
||||
}
|
||||
}
|
||||
|
||||
const summary: FleetSummary = {
|
||||
running: 0,
|
||||
error: 0,
|
||||
completed: 0,
|
||||
listening: 0,
|
||||
scheduled: 0,
|
||||
idle: 0,
|
||||
monthlySpend: 0,
|
||||
};
|
||||
|
||||
for (const agent of agents) {
|
||||
if (agentsWithActiveExecution.has(agent.graph_id)) {
|
||||
summary.running += 1;
|
||||
} else if (agentsWithRecentFailure.has(agent.graph_id)) {
|
||||
summary.error += 1;
|
||||
} else if (agent.has_external_trigger) {
|
||||
summary.listening += 1;
|
||||
} else if (agent.recommended_schedule_cron) {
|
||||
summary.scheduled += 1;
|
||||
} else {
|
||||
summary.idle += 1;
|
||||
}
|
||||
// Parallel counter: mutually exclusive with running/error (which match
|
||||
// the sitrep priority order used by the "Recently completed" tab list)
|
||||
// but orthogonal to listening/scheduled/idle.
|
||||
if (
|
||||
!agentsWithActiveExecution.has(agent.graph_id) &&
|
||||
!agentsWithRecentFailure.has(agent.graph_id) &&
|
||||
agentsWithRecentCompletion.has(agent.graph_id)
|
||||
) {
|
||||
summary.completed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return summary;
|
||||
}, [agents, executions, isSuccess]);
|
||||
}
|
||||
@@ -2,12 +2,14 @@
|
||||
|
||||
import { useEffect, useState, useCallback } from "react";
|
||||
import { HeartIcon, ListIcon } from "@phosphor-icons/react";
|
||||
import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn";
|
||||
import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader";
|
||||
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
|
||||
import { useLibraryListPage } from "./components/useLibraryListPage";
|
||||
import { FavoriteAnimationProvider } from "./context/FavoriteAnimationContext";
|
||||
import { LibraryTab } from "./types";
|
||||
import type { LibraryTab, AgentStatusFilter } from "./types";
|
||||
import { useLibraryFleetSummary } from "./hooks/useLibraryFleetSummary";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
|
||||
|
||||
const LIBRARY_TABS: LibraryTab[] = [
|
||||
{ id: "all", title: "All", icon: ListIcon },
|
||||
@@ -19,6 +21,10 @@ export default function LibraryPage() {
|
||||
useLibraryListPage();
|
||||
const [selectedFolderId, setSelectedFolderId] = useState<string | null>(null);
|
||||
const [activeTab, setActiveTab] = useState(LIBRARY_TABS[0].id);
|
||||
const [statusFilter, setStatusFilter] = useState<AgentStatusFilter>("all");
|
||||
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
|
||||
const { agents } = useLibraryAgents();
|
||||
const fleetSummary = useLibraryFleetSummary(agents);
|
||||
|
||||
useEffect(() => {
|
||||
document.title = "Library – AutoGPT Platform";
|
||||
@@ -40,7 +46,6 @@ export default function LibraryPage() {
|
||||
>
|
||||
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
|
||||
<LibraryActionHeader setSearchTerm={setSearchTerm} />
|
||||
<JumpBackIn />
|
||||
<LibraryAgentList
|
||||
searchTerm={searchTerm}
|
||||
librarySort={librarySort}
|
||||
@@ -50,6 +55,10 @@ export default function LibraryPage() {
|
||||
tabs={LIBRARY_TABS}
|
||||
activeTab={activeTab}
|
||||
onTabChange={handleTabChange}
|
||||
statusFilter={statusFilter}
|
||||
onStatusFilterChange={setStatusFilter}
|
||||
fleetSummary={isAgentBriefingEnabled ? fleetSummary : undefined}
|
||||
briefingAgents={isAgentBriefingEnabled ? agents : undefined}
|
||||
/>
|
||||
</main>
|
||||
</FavoriteAnimationProvider>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user