mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
26 Commits
ntindle/wa
...
hotfix/tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9123178556 | ||
|
|
b8c65e3d2b | ||
|
|
0eddb6f1bb | ||
|
|
042ed42c0b | ||
|
|
eadd67c70c | ||
|
|
ad7044995e | ||
|
|
203bf2ca32 | ||
|
|
a1e0caa983 | ||
|
|
440a06ad9e | ||
|
|
8ec706c125 | ||
|
|
6c83a91ae3 | ||
|
|
f19a423cae | ||
|
|
87258441f2 | ||
|
|
494978319e | ||
|
|
9135969c34 | ||
|
|
8625a82495 | ||
|
|
c17f19317b | ||
|
|
fc48944b56 | ||
|
|
2f57c1499c | ||
|
|
7042fcecdf | ||
|
|
3e45a28307 | ||
|
|
81aea5dc52 | ||
|
|
08c49a78f8 | ||
|
|
60f950c719 | ||
|
|
0b9e0665dd | ||
|
|
f6f268a1f0 |
@@ -11,7 +11,7 @@ from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -25,6 +25,7 @@ from backend.copilot.model import (
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -141,6 +142,20 @@ class CancelSessionResponse(BaseModel):
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UpdateSessionTitleRequest(BaseModel):
|
||||
"""Request model for updating a session's title."""
|
||||
|
||||
title: str
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def title_must_not_be_blank(cls, v: str) -> str:
|
||||
stripped = v.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Title must not be blank")
|
||||
return stripped
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -264,6 +279,43 @@ async def delete_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
responses={404: {"description": "Session not found or access denied"}},
|
||||
)
|
||||
async def update_session_title_route(
|
||||
session_id: str,
|
||||
request: UpdateSessionTitleRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Update the title of a chat session.
|
||||
|
||||
Allows the user to rename their chat session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
request: Request body containing the new title.
|
||||
user_id: The authenticated user's ID.
|
||||
|
||||
Returns:
|
||||
dict: Status of the update.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if session not found or not owned by user.
|
||||
"""
|
||||
success = await update_session_title(session_id, user_id, request.title)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Tests for chat route file_ids validation and enrichment."""
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -17,6 +19,7 @@ TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
@@ -24,7 +27,95 @@ def setup_app_auth(mock_jwt_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---- file_ids Pydantic validation (B1) ----
|
||||
def _mock_update_session_title(
|
||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
||||
):
|
||||
"""Mock update_session_title."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.update_session_title",
|
||||
new_callable=AsyncMock,
|
||||
return_value=success,
|
||||
)
|
||||
|
||||
|
||||
# ─── Update title: success ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_update_title_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "My project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
||||
|
||||
|
||||
def test_update_title_trims_whitespace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " trimmed "},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
||||
|
||||
|
||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
||||
|
||||
|
||||
def test_update_title_blank_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " "},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_title_empty_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": ""},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
||||
|
||||
|
||||
def test_update_title_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
_mock_update_session_title(mocker, success=False)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "New name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
@@ -92,7 +183,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- UUID format filtering ----
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
@@ -131,7 +222,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ---- Cross-workspace file_ids ----
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
|
||||
@@ -62,8 +62,8 @@ async def _update_title_async(
|
||||
"""Generate and persist a session title in the background."""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Failed to update session title: %s", e)
|
||||
|
||||
|
||||
@@ -81,6 +81,35 @@ async def update_chat_session(
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def update_chat_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
Always filters by (session_id, user_id) so callers cannot mutate another
|
||||
user's session even when they know the session_id.
|
||||
|
||||
Args:
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
guard so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns True if a row was updated, False otherwise (session not found,
|
||||
wrong user, or — when only_if_empty — title was already set).
|
||||
"""
|
||||
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
|
||||
if only_if_empty:
|
||||
where["title"] = None
|
||||
result = await PrismaChatSession.prisma().update_many(
|
||||
where=where,
|
||||
data={"title": title, "updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
return result > 0
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
|
||||
@@ -469,8 +469,16 @@ async def upsert_chat_session(
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
# Save to cache (best-effort, even if DB failed).
|
||||
# Title updates (update_session_title) run *outside* this lock because
|
||||
# they only touch the title field, not messages. So a concurrent rename
|
||||
# or auto-title may have written a newer title to Redis while this
|
||||
# upsert was in progress. Always prefer the cached title to avoid
|
||||
# overwriting it with the stale in-memory copy.
|
||||
try:
|
||||
existing_cached = await _get_session_from_cache(session.session_id)
|
||||
if existing_cached and existing_cached.title:
|
||||
session = session.model_copy(update={"title": existing_cached.title})
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
@@ -685,30 +693,48 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
async def update_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
Lightweight operation that doesn't touch messages, avoiding race conditions
|
||||
with concurrent message updates.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
user_id: Owning user — the DB query filters on this.
|
||||
title: The new title to set.
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise.
|
||||
True if updated successfully, False otherwise (not found, wrong user,
|
||||
or — when only_if_empty — title was already set).
|
||||
"""
|
||||
try:
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
updated = await chat_db().update_chat_session_title(
|
||||
session_id, user_id, title, only_if_empty=only_if_empty
|
||||
)
|
||||
if not updated:
|
||||
return False
|
||||
|
||||
# Invalidate the cache so the next access reloads from DB with the
|
||||
# updated title. This avoids a read-modify-write on the full session
|
||||
# blob, which could overwrite concurrent message updates.
|
||||
await invalidate_session_cache(session_id)
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
try:
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Cache title update failed for session {session_id} (non-critical): {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -127,7 +127,6 @@ def create_security_hooks(
|
||||
sdk_cwd: str | None = None,
|
||||
max_subtasks: int = 3,
|
||||
on_compact: Callable[[], None] | None = None,
|
||||
on_stop: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
@@ -136,15 +135,12 @@ def create_security_hooks(
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
||||
the SDK finishes processing — used to read the JSONL transcript
|
||||
before the CLI process exits.
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
@@ -311,30 +307,6 @@ def create_security_hooks(
|
||||
on_compact()
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
# --- Stop hook: capture transcript path for stateless resume ---
|
||||
async def stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Capture transcript path when SDK finishes processing.
|
||||
|
||||
The Stop hook fires while the CLI process is still alive, giving us
|
||||
a reliable window to read the JSONL transcript before SIGTERM.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
||||
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
||||
|
||||
if transcript_path and on_stop:
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
||||
f"sdk_session_id={sdk_session_id[:12]}..."
|
||||
)
|
||||
on_stop(transcript_path, sdk_session_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
@@ -344,9 +316,6 @@ def create_security_hooks(
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
|
||||
if on_stop is not None:
|
||||
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
||||
|
||||
return hooks
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
|
||||
@@ -12,7 +12,6 @@ import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import openai
|
||||
@@ -21,6 +20,9 @@ from claude_agent_sdk import (
|
||||
ClaudeAgentOptions,
|
||||
ClaudeSDKClient,
|
||||
ResultMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
)
|
||||
from langfuse import propagate_attributes
|
||||
@@ -74,11 +76,11 @@ from .tool_adapter import (
|
||||
from .transcript import (
|
||||
cleanup_cli_project_dir,
|
||||
download_transcript,
|
||||
read_transcript_file,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
@@ -137,19 +139,6 @@ _setup_langfuse_otel()
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapturedTranscript:
|
||||
"""Info captured by the SDK Stop hook for stateless --resume."""
|
||||
|
||||
path: str = ""
|
||||
sdk_session_id: str = ""
|
||||
raw_content: str = ""
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return bool(self.path)
|
||||
|
||||
|
||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
|
||||
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
|
||||
@@ -451,6 +440,43 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
"""Convert SDK content blocks to transcript format.
|
||||
|
||||
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for block in blocks or []:
|
||||
if isinstance(block, TextBlock):
|
||||
result.append({"type": "text", "text": block.text})
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
result.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"input": block.input,
|
||||
}
|
||||
)
|
||||
elif isinstance(block, ToolResultBlock):
|
||||
result.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.tool_use_id,
|
||||
"content": block.content,
|
||||
}
|
||||
)
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
result.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": block.thinking,
|
||||
"signature": block.signature,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _compress_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> tuple[list[ChatMessage], bool]:
|
||||
@@ -806,6 +832,11 @@ async def stream_chat_completion_sdk(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
)
|
||||
|
||||
# Structured log prefix: [SDK][<session>][T<turn>]
|
||||
# Turn = number of user messages (1-based), computed AFTER appending the new message.
|
||||
turn = sum(1 for m in session.messages if m.role == "user")
|
||||
log_prefix = f"[SDK][{session_id[:12]}][T{turn}]"
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
@@ -823,10 +854,11 @@ async def stream_chat_completion_sdk(
|
||||
message_id = str(uuid.uuid4())
|
||||
stream_id = str(uuid.uuid4())
|
||||
stream_completed = False
|
||||
ended_with_stream_error = False
|
||||
e2b_sandbox = None
|
||||
use_resume = False
|
||||
resume_file: str | None = None
|
||||
captured_transcript = CapturedTranscript()
|
||||
transcript_builder = TranscriptBuilder()
|
||||
sdk_cwd = ""
|
||||
|
||||
# Acquire stream lock to prevent concurrent streams to the same session
|
||||
@@ -841,7 +873,7 @@ async def stream_chat_completion_sdk(
|
||||
if lock_owner != stream_id:
|
||||
# Another stream is active
|
||||
logger.warning(
|
||||
f"[SDK] Session {session_id} already has an active stream: {lock_owner}"
|
||||
f"{log_prefix} Session already has an active stream: {lock_owner}"
|
||||
)
|
||||
yield StreamError(
|
||||
errorText="Another stream is already active for this session. "
|
||||
@@ -865,7 +897,7 @@ async def stream_chat_completion_sdk(
|
||||
sdk_cwd = _make_sdk_cwd(session_id)
|
||||
os.makedirs(sdk_cwd, exist_ok=True)
|
||||
except (ValueError, OSError) as e:
|
||||
logger.error("[SDK] [%s] Invalid SDK cwd: %s", session_id[:12], e)
|
||||
logger.error("%s Invalid SDK cwd: %s", log_prefix, e)
|
||||
yield StreamError(
|
||||
errorText="Unable to initialize working directory.",
|
||||
code="sdk_cwd_error",
|
||||
@@ -909,12 +941,13 @@ async def stream_chat_completion_sdk(
|
||||
):
|
||||
return None
|
||||
try:
|
||||
return await download_transcript(user_id, session_id)
|
||||
return await download_transcript(
|
||||
user_id, session_id, log_prefix=log_prefix
|
||||
)
|
||||
except Exception as transcript_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Transcript download failed, continuing without "
|
||||
"--resume: %s",
|
||||
session_id[:12],
|
||||
"%s Transcript download failed, continuing without " "--resume: %s",
|
||||
log_prefix,
|
||||
transcript_err,
|
||||
)
|
||||
return None
|
||||
@@ -936,11 +969,18 @@ async def stream_chat_completion_sdk(
|
||||
transcript_msg_count = 0
|
||||
if dl:
|
||||
is_valid = validate_transcript(dl.content)
|
||||
dl_lines = dl.content.strip().split("\n") if dl.content else []
|
||||
logger.info(
|
||||
"%s Downloaded transcript: %dB, %d lines, " "msg_count=%d, valid=%s",
|
||||
log_prefix,
|
||||
len(dl.content),
|
||||
len(dl_lines),
|
||||
dl.message_count,
|
||||
is_valid,
|
||||
)
|
||||
if is_valid:
|
||||
logger.info(
|
||||
f"[SDK] Transcript available for session {session_id}: "
|
||||
f"{len(dl.content)}B, msg_count={dl.message_count}"
|
||||
)
|
||||
# Load previous FULL context into builder
|
||||
transcript_builder.load_previous(dl.content)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
@@ -948,16 +988,14 @@ async def stream_chat_completion_sdk(
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"{log_prefix} Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Transcript downloaded but invalid for {session_id}"
|
||||
)
|
||||
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
|
||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"[SDK] No transcript available for {session_id} "
|
||||
f"{log_prefix} No transcript available "
|
||||
f"({len(session.messages)} messages in session)"
|
||||
)
|
||||
|
||||
@@ -979,25 +1017,6 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
# --- Transcript capture via Stop hook ---
|
||||
# Read the file content immediately — the SDK may clean up
|
||||
# the file before our finally block runs.
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
content = read_transcript_file(transcript_path)
|
||||
if content:
|
||||
captured_transcript.raw_content = content
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: captured {len(content)}B from "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Stop hook: transcript file empty/missing at "
|
||||
f"{transcript_path}"
|
||||
)
|
||||
|
||||
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
|
||||
compaction = CompactionTracker()
|
||||
|
||||
@@ -1005,7 +1024,6 @@ async def stream_chat_completion_sdk(
|
||||
user_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
max_subtasks=config.claude_agent_max_subtasks,
|
||||
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
@@ -1040,7 +1058,10 @@ async def stream_chat_completion_sdk(
|
||||
session_id=session_id,
|
||||
trace_name="copilot-sdk",
|
||||
tags=["sdk"],
|
||||
metadata={"resume": str(use_resume)},
|
||||
metadata={
|
||||
"resume": str(use_resume),
|
||||
"conversation_turn": str(turn),
|
||||
},
|
||||
)
|
||||
_otel_ctx.__enter__()
|
||||
|
||||
@@ -1074,9 +1095,9 @@ async def stream_chat_completion_sdk(
|
||||
query_message = f"{query_message}\n\n{attachments.hint}"
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, "
|
||||
"%s Sending query — resume=%s, total_msgs=%d, "
|
||||
"query_len=%d, attached_files=%d, image_blocks=%d",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
use_resume,
|
||||
len(session.messages),
|
||||
len(query_message),
|
||||
@@ -1105,8 +1126,12 @@ async def stream_chat_completion_sdk(
|
||||
await client._transport.write( # noqa: SLF001
|
||||
json.dumps(user_msg) + "\n"
|
||||
)
|
||||
# Capture user message in transcript (multimodal)
|
||||
transcript_builder.add_user_message(content=content_blocks)
|
||||
else:
|
||||
await client.query(query_message, session_id=session_id)
|
||||
# Capture user message in transcript (text only)
|
||||
transcript_builder.add_user_message(content=query_message)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
@@ -1150,8 +1175,8 @@ async def stream_chat_completion_sdk(
|
||||
sdk_msg = done.pop().result()
|
||||
except StopAsyncIteration:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
|
||||
session_id[:12],
|
||||
"%s Stream ended normally (StopAsyncIteration)",
|
||||
log_prefix,
|
||||
)
|
||||
break
|
||||
except Exception as stream_err:
|
||||
@@ -1160,8 +1185,8 @@ async def stream_chat_completion_sdk(
|
||||
# so the session can still be saved and the
|
||||
# frontend gets a clean finish.
|
||||
logger.error(
|
||||
"[SDK] [%s] Stream error from SDK: %s",
|
||||
session_id[:12],
|
||||
"%s Stream error from SDK: %s",
|
||||
log_prefix,
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -1173,9 +1198,9 @@ async def stream_chat_completion_sdk(
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: %s %s "
|
||||
"%s Received: %s %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
type(sdk_msg).__name__,
|
||||
getattr(sdk_msg, "subtype", ""),
|
||||
len(adapter.current_tool_calls)
|
||||
@@ -1184,6 +1209,15 @@ async def stream_chat_completion_sdk(
|
||||
len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Capture AssistantMessage in transcript
|
||||
if isinstance(sdk_msg, AssistantMessage):
|
||||
content_blocks = _format_sdk_content_blocks(sdk_msg.content)
|
||||
model_name = getattr(sdk_msg, "model", "")
|
||||
transcript_builder.add_assistant_message(
|
||||
content_blocks=content_blocks,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
@@ -1210,10 +1244,10 @@ async def stream_chat_completion_sdk(
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Timed out waiting for "
|
||||
"%s Timed out waiting for "
|
||||
"PostToolUse hook stash "
|
||||
"(%d unresolved tool calls)",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
@@ -1221,9 +1255,9 @@ async def stream_chat_completion_sdk(
|
||||
# Log ResultMessage details for debugging
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"[SDK] [%s] Received: ResultMessage %s "
|
||||
"%s Received: ResultMessage %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
sdk_msg.subtype,
|
||||
len(adapter.current_tool_calls)
|
||||
- len(adapter.resolved_tool_calls),
|
||||
@@ -1232,8 +1266,8 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
if sdk_msg.subtype in ("error", "error_during_execution"):
|
||||
logger.error(
|
||||
"[SDK] [%s] SDK execution failed with error: %s",
|
||||
session_id[:12],
|
||||
"%s SDK execution failed with error: %s",
|
||||
log_prefix,
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
@@ -1258,8 +1292,8 @@ async def stream_chat_completion_sdk(
|
||||
out_len = len(str(response.output))
|
||||
extra = f", output_len={out_len}"
|
||||
logger.info(
|
||||
"[SDK] [%s] Tool event: %s, tool=%s%s",
|
||||
session_id[:12],
|
||||
"%s Tool event: %s, tool=%s%s",
|
||||
log_prefix,
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
extra,
|
||||
@@ -1268,8 +1302,8 @@ async def stream_chat_completion_sdk(
|
||||
# Log errors being sent to frontend
|
||||
if isinstance(response, StreamError):
|
||||
logger.error(
|
||||
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
|
||||
session_id[:12],
|
||||
"%s Sending error to frontend: %s (code=%s)",
|
||||
log_prefix,
|
||||
response.errorText,
|
||||
response.code,
|
||||
)
|
||||
@@ -1335,8 +1369,8 @@ async def stream_chat_completion_sdk(
|
||||
# server shutdown). Log and let the safety-net / finally
|
||||
# blocks handle cleanup.
|
||||
logger.warning(
|
||||
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
|
||||
session_id[:12],
|
||||
"%s Streaming loop cancelled (asyncio.CancelledError)",
|
||||
log_prefix,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -1350,7 +1384,8 @@ async def stream_chat_completion_sdk(
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
# Expected: task was cancelled or exhausted during cleanup
|
||||
logger.info(
|
||||
"[SDK] Pending __anext__ task completed during cleanup"
|
||||
"%s Pending __anext__ task completed during cleanup",
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
# Safety net: if tools are still unresolved after the
|
||||
@@ -1359,9 +1394,9 @@ async def stream_chat_completion_sdk(
|
||||
# them now so the frontend stops showing spinners.
|
||||
if adapter.has_unresolved_tool_calls:
|
||||
logger.warning(
|
||||
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
|
||||
"%s %d unresolved tool(s) after stream loop — "
|
||||
"flushing as safety net",
|
||||
session_id[:12],
|
||||
log_prefix,
|
||||
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
|
||||
)
|
||||
safety_responses: list[StreamBaseResponse] = []
|
||||
@@ -1372,8 +1407,8 @@ async def stream_chat_completion_sdk(
|
||||
(StreamToolInputAvailable, StreamToolOutputAvailable),
|
||||
):
|
||||
logger.info(
|
||||
"[SDK] [%s] Safety flush: %s, tool=%s",
|
||||
session_id[:12],
|
||||
"%s Safety flush: %s, tool=%s",
|
||||
log_prefix,
|
||||
type(response).__name__,
|
||||
getattr(response, "toolName", "N/A"),
|
||||
)
|
||||
@@ -1386,8 +1421,8 @@ async def stream_chat_completion_sdk(
|
||||
# StreamFinish is published by mark_session_completed in the processor.
|
||||
if not stream_completed and not ended_with_stream_error:
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
|
||||
session_id[:12],
|
||||
"%s Stream ended without ResultMessage (stopped by user)",
|
||||
log_prefix,
|
||||
)
|
||||
closing_responses: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(closing_responses)
|
||||
@@ -1408,69 +1443,36 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# After async with the SDK task group has exited, so the Stop
|
||||
# hook has already fired and the CLI has been SIGTERMed. The
|
||||
# CLI uses appendFileSync, so all writes are safely on disk.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
# With --resume the CLI appends to the resume file (most
|
||||
# complete). Otherwise use the Stop hook path.
|
||||
if use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
logger.debug("[SDK] Transcript source: resume file")
|
||||
elif captured_transcript.path:
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
logger.debug(
|
||||
"[SDK] Transcript source: stop hook (%s), read result: %s",
|
||||
captured_transcript.path,
|
||||
f"{len(raw_transcript)}B" if raw_transcript else "None",
|
||||
)
|
||||
else:
|
||||
raw_transcript = None
|
||||
# Transcript upload is handled exclusively in the finally block
|
||||
# to avoid double-uploads (the success path used to upload the
|
||||
# old resume file, then the finally block overwrote it with the
|
||||
# stop hook content — which could be smaller after compaction).
|
||||
|
||||
if not raw_transcript:
|
||||
logger.debug(
|
||||
"[SDK] No usable transcript — CLI file had no "
|
||||
"conversation entries (expected for first turn "
|
||||
"without --resume)"
|
||||
)
|
||||
|
||||
if raw_transcript:
|
||||
# Shield the upload from generator cancellation so a
|
||||
# client disconnect / page refresh doesn't lose the
|
||||
# transcript. The upload must finish even if the SSE
|
||||
# connection is torn down.
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Stream completed successfully with %d messages",
|
||||
session_id[:12],
|
||||
len(session.messages),
|
||||
)
|
||||
if ended_with_stream_error:
|
||||
logger.warning(
|
||||
"%s Stream ended with SDK error after %d messages",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s Stream completed successfully with %d messages",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
except BaseException as e:
|
||||
# Catch BaseException to handle both Exception and CancelledError
|
||||
# (CancelledError inherits from BaseException in Python 3.8+)
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
logger.warning("[SDK] [%s] Session cancelled", session_id[:12])
|
||||
logger.warning("%s Session cancelled", log_prefix)
|
||||
error_msg = "Operation cancelled"
|
||||
else:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
# SDK cleanup RuntimeError is expected during cancellation, log as warning
|
||||
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
|
||||
logger.warning(
|
||||
"[SDK] [%s] SDK cleanup error: %s", session_id[:12], error_msg
|
||||
)
|
||||
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
|
||||
else:
|
||||
logger.error(
|
||||
f"[SDK] [%s] Error: {error_msg}", session_id[:12], exc_info=True
|
||||
)
|
||||
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
|
||||
|
||||
# Append error marker to session (non-invasive text parsing approach)
|
||||
# The finally block will persist the session with this error marker
|
||||
@@ -1481,8 +1483,8 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
"[SDK] [%s] Appended error marker, will be persisted in finally",
|
||||
session_id[:12],
|
||||
"%s Appended error marker, will be persisted in finally",
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
# Yield StreamError for immediate feedback (only for non-cancellation errors)
|
||||
@@ -1514,47 +1516,72 @@ async def stream_chat_completion_sdk(
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
logger.info(
|
||||
"[SDK] [%s] Session persisted in finally with %d messages",
|
||||
session_id[:12],
|
||||
"%s Session persisted in finally with %d messages",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
except Exception as persist_err:
|
||||
logger.error(
|
||||
"[SDK] [%s] Failed to persist session in finally: %s",
|
||||
session_id[:12],
|
||||
"%s Failed to persist session in finally: %s",
|
||||
log_prefix,
|
||||
persist_err,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# This MUST run in finally so the transcript is uploaded even when
|
||||
# the streaming loop raises an exception. The CLI uses
|
||||
# appendFileSync, so whatever was written before the error/SIGTERM
|
||||
# is safely on disk and still useful for the next turn.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
# the streaming loop raises an exception.
|
||||
# The transcript represents the COMPLETE active context (atomic).
|
||||
if config.claude_agent_use_resume and user_id and session is not None:
|
||||
try:
|
||||
# Prefer content captured in the Stop hook (read before
|
||||
# cleanup removes the file). Fall back to the resume
|
||||
# file when the stop hook didn't fire (e.g. error before
|
||||
# completion) so we don't lose the prior transcript.
|
||||
raw_transcript = captured_transcript.raw_content or None
|
||||
if not raw_transcript and use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
# Build complete transcript from captured SDK messages
|
||||
transcript_content = transcript_builder.to_jsonl()
|
||||
|
||||
if raw_transcript and session is not None:
|
||||
await asyncio.shield(
|
||||
_try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
if not transcript_content:
|
||||
logger.warning(
|
||||
"%s No transcript to upload (builder empty)", log_prefix
|
||||
)
|
||||
elif not validate_transcript(transcript_content):
|
||||
logger.warning(
|
||||
"%s Transcript invalid, skipping upload (entries=%d)",
|
||||
log_prefix,
|
||||
transcript_builder.entry_count,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[SDK] No transcript to upload for {session_id}")
|
||||
logger.info(
|
||||
"%s Uploading complete transcript (entries=%d, bytes=%d)",
|
||||
log_prefix,
|
||||
transcript_builder.entry_count,
|
||||
len(transcript_content),
|
||||
)
|
||||
# Create task first so we have a reference if timeout occurs
|
||||
upload_task = asyncio.create_task(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=transcript_content,
|
||||
message_count=len(session.messages),
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
try:
|
||||
async with asyncio.timeout(30):
|
||||
await asyncio.shield(upload_task)
|
||||
except TimeoutError:
|
||||
# Timeout fired but shield keeps upload running - track the
|
||||
# SAME task to prevent garbage collection (no double upload)
|
||||
logger.warning(
|
||||
"%s Transcript upload exceeded 30s timeout, "
|
||||
"continuing in background",
|
||||
log_prefix,
|
||||
)
|
||||
_background_tasks.add(upload_task)
|
||||
upload_task.add_done_callback(_background_tasks.discard)
|
||||
except Exception as upload_err:
|
||||
logger.error(
|
||||
f"[SDK] Transcript upload failed in finally: {upload_err}",
|
||||
"%s Transcript upload failed in finally: %s",
|
||||
log_prefix,
|
||||
upload_err,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -1565,33 +1592,6 @@ async def stream_chat_completion_sdk(
|
||||
await lock.release()
|
||||
|
||||
|
||||
async def _try_upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
raw_content: str,
|
||||
message_count: int = 0,
|
||||
) -> bool:
|
||||
"""Strip progress entries and upload transcript (with timeout).
|
||||
|
||||
Returns True if the upload completed without error.
|
||||
"""
|
||||
try:
|
||||
async with asyncio.timeout(30):
|
||||
await upload_transcript(
|
||||
user_id, session_id, raw_content, message_count=message_count
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SDK] Failed to upload transcript for {session_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
@@ -1600,8 +1600,8 @@ async def _update_title_async(
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
|
||||
@@ -58,41 +58,40 @@ def strip_progress_entries(content: str) -> str:
|
||||
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
||||
any remaining child entries so the ``parentUuid`` chain stays intact.
|
||||
Typically reduces transcript size by ~30%.
|
||||
|
||||
Entries that are not stripped or reparented are kept as their original
|
||||
raw JSON line to avoid unnecessary re-serialization that changes
|
||||
whitespace or key ordering.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
entries: list[dict] = []
|
||||
# Parse entries, keeping the original line alongside the parsed dict.
|
||||
parsed: list[tuple[str, dict | None]] = []
|
||||
for line in lines:
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
parsed.append((line, json.loads(line)))
|
||||
except json.JSONDecodeError:
|
||||
# Keep unparseable lines as-is (safety)
|
||||
entries.append({"_raw": line})
|
||||
parsed.append((line, None))
|
||||
|
||||
# First pass: identify stripped UUIDs and build parent map.
|
||||
stripped_uuids: set[str] = set()
|
||||
uuid_to_parent: dict[str, str] = {}
|
||||
kept: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
if "_raw" in entry:
|
||||
kept.append(entry)
|
||||
for _line, entry in parsed:
|
||||
if entry is None:
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
parent = entry.get("parentUuid", "")
|
||||
entry_type = entry.get("type", "")
|
||||
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
||||
stripped_uuids.add(uid)
|
||||
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
if uid:
|
||||
stripped_uuids.add(uid)
|
||||
else:
|
||||
kept.append(entry)
|
||||
|
||||
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||
# Preserve original line when no reparenting is required.
|
||||
reparented: set[str] = set()
|
||||
for _line, entry in parsed:
|
||||
if entry is None:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
@@ -100,63 +99,32 @@ def strip_progress_entries(content: str) -> str:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
uid = entry.get("uuid", "")
|
||||
if uid:
|
||||
reparented.add(uid)
|
||||
|
||||
result_lines: list[str] = []
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
result_lines.append(entry["_raw"])
|
||||
else:
|
||||
for line, entry in parsed:
|
||||
if entry is None:
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
if uid in reparented:
|
||||
# Re-serialize only entries whose parentUuid was changed.
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
else:
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
||||
# Local file I/O (write temp file for --resume)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_transcript_file(transcript_path: str) -> str | None:
|
||||
"""Read a JSONL transcript file from disk.
|
||||
|
||||
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
||||
or only contains metadata (≤2 lines with no conversation messages).
|
||||
"""
|
||||
if not transcript_path or not os.path.isfile(transcript_path):
|
||||
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(transcript_path) as f:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug("[Transcript] File is empty: %s", transcript_path)
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
# Validate that the transcript has real conversation content
|
||||
# (not just metadata like queue-operation entries).
|
||||
if not validate_transcript(content):
|
||||
logger.debug(
|
||||
"[Transcript] No conversation content (%d lines) in %s",
|
||||
len(lines),
|
||||
transcript_path,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
)
|
||||
return content
|
||||
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
"""Sanitize an ID for safe use in file paths.
|
||||
|
||||
@@ -171,14 +139,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
@@ -188,7 +148,8 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""
|
||||
import shutil
|
||||
|
||||
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
@@ -248,32 +209,30 @@ def write_transcript_to_tempfile(
|
||||
def validate_transcript(content: str | None) -> bool:
|
||||
"""Check that a transcript has actual conversation messages.
|
||||
|
||||
A valid transcript for resume needs at least one user message and one
|
||||
assistant message (not just queue-operation / file-history-snapshot
|
||||
metadata).
|
||||
A valid transcript needs at least one assistant message (not just
|
||||
queue-operation / file-history-snapshot metadata). We do NOT require
|
||||
a ``type: "user"`` entry because with ``--resume`` the user's message
|
||||
is passed as a CLI query parameter and does not appear in the
|
||||
transcript file.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return False
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
has_user = False
|
||||
has_assistant = False
|
||||
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
msg_type = entry.get("type")
|
||||
if msg_type == "user":
|
||||
has_user = True
|
||||
elif msg_type == "assistant":
|
||||
if entry.get("type") == "assistant":
|
||||
has_assistant = True
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
return has_user and has_assistant
|
||||
return has_assistant
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -328,26 +287,43 @@ async def upload_transcript(
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
"""Strip progress entries and upload complete transcript.
|
||||
|
||||
The transcript represents the FULL active context (atomic).
|
||||
Each upload REPLACES the previous transcript entirely.
|
||||
|
||||
The executor holds a cluster lock per session, so concurrent uploads for
|
||||
the same session cannot happen. We always overwrite — with ``--resume``
|
||||
the CLI may compact old tool results, so neither byte size nor line count
|
||||
is a reliable proxy for "newer".
|
||||
the same session cannot happen.
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
try:
|
||||
entry_types.append(json.loads(line).get("type", "?"))
|
||||
except json.JSONDecodeError:
|
||||
entry_types.append("INVALID_JSON")
|
||||
logger.warning(
|
||||
f"[Transcript] Skipping upload — stripped content not valid "
|
||||
f"for session {session_id}"
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
log_prefix,
|
||||
entry_types,
|
||||
len(stripped),
|
||||
len(content),
|
||||
)
|
||||
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
|
||||
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
@@ -373,17 +349,18 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(
|
||||
user_id: str, session_id: str
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
|
||||
@@ -399,10 +376,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -425,10 +402,7 @@ async def download_transcript(
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def load_previous(self, content: str) -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse transcript line: %s", line[:100])
|
||||
continue
|
||||
|
||||
# Only load conversation messages (user/assistant)
|
||||
# Skip metadata entries
|
||||
if data.get("type") not in ("user", "assistant"):
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
type=data["type"],
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def add_user_message(
|
||||
self, content: str | list[dict], uuid: str | None = None
|
||||
) -> None:
|
||||
"""Add user message to the complete context."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def add_assistant_message(
|
||||
self, content_blocks: list[dict], model: str = ""
|
||||
) -> None:
|
||||
"""Add assistant message to the complete context."""
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content_blocks,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
@@ -5,7 +5,6 @@ import os
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
read_transcript_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -38,49 +37,6 @@ PROGRESS_ENTRY = {
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
# --- read_transcript_file ---
|
||||
|
||||
|
||||
class TestReadTranscriptFile:
|
||||
def test_returns_content_for_valid_file(self, tmp_path):
|
||||
path = tmp_path / "session.jsonl"
|
||||
path.write_text(VALID_TRANSCRIPT)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
assert "user" in result
|
||||
|
||||
def test_returns_none_for_missing_file(self):
|
||||
assert read_transcript_file("/nonexistent/path.jsonl") is None
|
||||
|
||||
def test_returns_none_for_empty_path(self):
|
||||
assert read_transcript_file("") is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path):
|
||||
path = tmp_path / "empty.jsonl"
|
||||
path.write_text("")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_metadata_only(self, tmp_path):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
||||
path = tmp_path / "meta.jsonl"
|
||||
path.write_text(content)
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_invalid_json(self, tmp_path):
|
||||
path = tmp_path / "bad.jsonl"
|
||||
path.write_text("not json\n{}\n{}\n")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_no_size_limit(self, tmp_path):
|
||||
"""Large files are accepted — bucket storage has no size limit."""
|
||||
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
|
||||
path = tmp_path / "big.jsonl"
|
||||
path.write_text(content)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
|
||||
|
||||
# --- write_transcript_to_tempfile ---
|
||||
|
||||
|
||||
@@ -155,12 +111,56 @@ class TestValidateTranscript:
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_assistant_only_no_user(self):
|
||||
"""With --resume the user message is a CLI query param, not a transcript entry.
|
||||
A transcript with only assistant entries is valid."""
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
|
||||
assert validate_transcript(content) is False
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_resume_transcript_without_user_entry(self):
|
||||
"""Simulates a real --resume stop hook transcript: the CLI session file
|
||||
has summary + assistant entries but no user entry."""
|
||||
summary = {"type": "summary", "uuid": "s1", "text": "context..."}
|
||||
asst1 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {"role": "assistant", "content": "Hello!"},
|
||||
}
|
||||
asst2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "a1",
|
||||
"message": {"role": "assistant", "content": "Sure, let me help."},
|
||||
}
|
||||
content = _make_jsonl(summary, asst1, asst2)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_single_assistant_entry(self):
|
||||
"""A transcript with just one assistant line is valid — the CLI may
|
||||
produce short transcripts for simple responses with no tool use."""
|
||||
content = json.dumps(ASST_MSG) + "\n"
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_invalid_json_returns_false(self):
|
||||
assert validate_transcript("not json\n{}\n{}\n") is False
|
||||
|
||||
def test_malformed_json_after_valid_assistant_returns_false(self):
|
||||
"""Validation must scan all lines - malformed JSON anywhere should fail."""
|
||||
valid_asst = json.dumps(ASST_MSG)
|
||||
malformed = "not valid json"
|
||||
content = valid_asst + "\n" + malformed + "\n"
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_blank_lines_are_skipped(self):
|
||||
"""Transcripts with blank lines should be valid if they contain assistant entries."""
|
||||
content = (
|
||||
json.dumps(USER_MSG)
|
||||
+ "\n\n" # blank line
|
||||
+ json.dumps(ASST_MSG)
|
||||
+ "\n"
|
||||
+ "\n" # another blank line
|
||||
)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
|
||||
# --- strip_progress_entries ---
|
||||
|
||||
@@ -253,3 +253,32 @@ class TestStripProgressEntries:
|
||||
assert "queue-operation" not in result_types
|
||||
assert "user" in result_types
|
||||
assert "assistant" in result_types
|
||||
|
||||
def test_preserves_original_line_formatting(self):
|
||||
"""Non-reparented entries keep their original JSON formatting."""
|
||||
# Use pretty-printed JSON with spaces (as the CLI produces)
|
||||
original_line = json.dumps(USER_MSG) # default formatting with spaces
|
||||
compact_line = json.dumps(USER_MSG, separators=(",", ":"))
|
||||
assert original_line != compact_line # precondition
|
||||
|
||||
content = original_line + "\n" + json.dumps(ASST_MSG) + "\n"
|
||||
result = strip_progress_entries(content)
|
||||
result_lines = result.strip().split("\n")
|
||||
|
||||
# Original line should be byte-identical (not re-serialized)
|
||||
assert result_lines[0] == original_line
|
||||
|
||||
def test_reparented_entries_are_reserialized(self):
|
||||
"""Entries whose parentUuid changes must be re-serialized."""
|
||||
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p1",
|
||||
"message": {"role": "assistant", "content": "done"},
|
||||
}
|
||||
content = _make_jsonl(USER_MSG, progress, asst)
|
||||
result = strip_progress_entries(content)
|
||||
lines = result.strip().split("\n")
|
||||
asst_entry = json.loads(lines[-1])
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
@@ -305,6 +305,7 @@ class DatabaseManager(AppService):
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
update_chat_session_title = _(chat_db.update_chat_session_title)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
@@ -475,3 +476,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
update_chat_session_title = d.update_chat_session_title
|
||||
|
||||
@@ -184,17 +184,17 @@ async def find_webhook_by_credentials_and_props(
|
||||
credentials_id: str,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: Optional[list[str]],
|
||||
events: list[str] | None = None,
|
||||
) -> Webhook | None:
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
**({"events": {"has_every": events}} if events else {}),
|
||||
},
|
||||
)
|
||||
where: IntegrationWebhookWhereInput = {
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
}
|
||||
if events is not None:
|
||||
where["events"] = {"has_every": events}
|
||||
webhook = await IntegrationWebhook.prisma().find_first(where=where)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
|
||||
@@ -76,7 +76,6 @@ class TelegramWebhooksManager(BaseWebhooksManager):
|
||||
credentials_id=credentials.id,
|
||||
webhook_type=webhook_type,
|
||||
resource=resource,
|
||||
events=None, # Ignore events for this lookup
|
||||
):
|
||||
# Re-register with Telegram using the same URL but new allowed_updates
|
||||
ingress_url = webhook_ingress_url(self.PROVIDER_NAME, existing.id)
|
||||
@@ -143,10 +142,6 @@ class TelegramWebhooksManager(BaseWebhooksManager):
|
||||
elif "video" in message:
|
||||
event_type = "message.video"
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown Telegram webhook payload type; "
|
||||
f"message.keys() = {message.keys()}"
|
||||
)
|
||||
event_type = "message.other"
|
||||
elif "edited_message" in payload:
|
||||
event_type = "message.edited_message"
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useDeleteV2DeleteSession,
|
||||
useGetV2ListSessions,
|
||||
usePatchV2UpdateSessionTitle,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
@@ -17,7 +18,6 @@ import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import {
|
||||
Sidebar,
|
||||
SidebarContent,
|
||||
SidebarFooter,
|
||||
SidebarHeader,
|
||||
SidebarTrigger,
|
||||
useSidebar,
|
||||
@@ -25,8 +25,9 @@ import {
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { motion } from "framer-motion";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
|
||||
@@ -65,6 +66,39 @@ export function ChatSidebar() {
|
||||
},
|
||||
});
|
||||
|
||||
const [editingSessionId, setEditingSessionId] = useState<string | null>(null);
|
||||
const [editingTitle, setEditingTitle] = useState("");
|
||||
const renameInputRef = useRef<HTMLInputElement>(null);
|
||||
const renameCancelledRef = useRef(false);
|
||||
|
||||
const { mutate: renameSession } = usePatchV2UpdateSessionTitle({
|
||||
mutation: {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
setEditingSessionId(null);
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: "Failed to rename chat",
|
||||
description:
|
||||
error instanceof Error ? error.message : "An error occurred",
|
||||
variant: "destructive",
|
||||
});
|
||||
setEditingSessionId(null);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Auto-focus the rename input when editing starts
|
||||
useEffect(() => {
|
||||
if (editingSessionId && renameInputRef.current) {
|
||||
renameInputRef.current.focus();
|
||||
renameInputRef.current.select();
|
||||
}
|
||||
}, [editingSessionId]);
|
||||
|
||||
const sessions =
|
||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||
|
||||
@@ -76,6 +110,26 @@ export function ChatSidebar() {
|
||||
setSessionId(id);
|
||||
}
|
||||
|
||||
function handleRenameClick(
|
||||
e: React.MouseEvent,
|
||||
id: string,
|
||||
title: string | null | undefined,
|
||||
) {
|
||||
e.stopPropagation();
|
||||
renameCancelledRef.current = false;
|
||||
setEditingSessionId(id);
|
||||
setEditingTitle(title || "");
|
||||
}
|
||||
|
||||
function handleRenameSubmit(id: string) {
|
||||
const trimmed = editingTitle.trim();
|
||||
if (trimmed) {
|
||||
renameSession({ sessionId: id, data: { title: trimmed } });
|
||||
} else {
|
||||
setEditingSessionId(null);
|
||||
}
|
||||
}
|
||||
|
||||
function handleDeleteClick(
|
||||
e: React.MouseEvent,
|
||||
id: string,
|
||||
@@ -160,29 +214,42 @@ export function ChatSidebar() {
|
||||
</motion.div>
|
||||
</SidebarHeader>
|
||||
)}
|
||||
{!isCollapsed && (
|
||||
<SidebarHeader className="shrink-0 px-4 pb-4 pt-4 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.1 }}
|
||||
className="flex flex-col gap-3 px-3"
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-6">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</motion.div>
|
||||
</SidebarHeader>
|
||||
)}
|
||||
|
||||
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{!isCollapsed && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.1 }}
|
||||
className="flex items-center justify-between px-3"
|
||||
>
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-6">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
|
||||
{!isCollapsed && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.15 }}
|
||||
className="mt-4 flex flex-col gap-1"
|
||||
className="flex flex-col gap-1"
|
||||
>
|
||||
{isLoadingSessions ? (
|
||||
<div className="flex min-h-[30rem] items-center justify-center py-4">
|
||||
@@ -203,76 +270,105 @@ export function ChatSidebar() {
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<button
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === sessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{session.title || `Untitled chat`}
|
||||
{editingSessionId === session.id ? (
|
||||
<div className="px-3 py-2.5">
|
||||
<input
|
||||
ref={renameInputRef}
|
||||
type="text"
|
||||
aria-label="Rename chat"
|
||||
value={editingTitle}
|
||||
onChange={(e) => setEditingTitle(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.currentTarget.blur();
|
||||
} else if (e.key === "Escape") {
|
||||
renameCancelledRef.current = true;
|
||||
setEditingSessionId(null);
|
||||
}
|
||||
}}
|
||||
onBlur={() => {
|
||||
if (renameCancelledRef.current) {
|
||||
renameCancelledRef.current = false;
|
||||
return;
|
||||
}
|
||||
handleRenameSubmit(session.id);
|
||||
}}
|
||||
className="w-full rounded border border-zinc-300 bg-white px-2 py-1 text-sm text-zinc-800 outline-none focus:border-purple-500 focus:ring-1 focus:ring-purple-500"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === sessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
<AnimatePresence mode="wait" initial={false}>
|
||||
<motion.span
|
||||
key={session.title || "untitled"}
|
||||
initial={{ opacity: 0, y: 4 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -4 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="block truncate"
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</button>
|
||||
)}
|
||||
{editingSessionId !== session.id && (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleRenameClick(e, session.id, session.title)
|
||||
}
|
||||
>
|
||||
Rename
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)}
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</motion.div>
|
||||
)}
|
||||
</SidebarContent>
|
||||
{!isCollapsed && sessionId && (
|
||||
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.2 }}
|
||||
>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</motion.div>
|
||||
</SidebarFooter>
|
||||
)}
|
||||
</Sidebar>
|
||||
|
||||
<DeleteChatDialog
|
||||
|
||||
@@ -29,7 +29,6 @@ export function DeleteChatDialog({
|
||||
}
|
||||
},
|
||||
}}
|
||||
onClose={isDeleting ? undefined : onCancel}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<Text variant="body">
|
||||
|
||||
@@ -71,6 +71,17 @@ export function MobileDrawer({
|
||||
<X width="1rem" height="1rem" />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
@@ -120,19 +131,6 @@ export function MobileDrawer({
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
{currentSessionId && (
|
||||
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</Drawer.Content>
|
||||
</Drawer.Portal>
|
||||
</Drawer.Root>
|
||||
|
||||
@@ -2,6 +2,7 @@ import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useDeleteV2DeleteSession,
|
||||
useGetV2ListSessions,
|
||||
type getV2ListSessionsResponse,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
@@ -15,6 +16,9 @@ import { useCopilotUIStore } from "./store";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useCopilotStream } from "./useCopilotStream";
|
||||
|
||||
const TITLE_POLL_INTERVAL_MS = 2_000;
|
||||
const TITLE_POLL_MAX_ATTEMPTS = 5;
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
@@ -258,6 +262,52 @@ export function useCopilotPage() {
|
||||
const sessions =
|
||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||
|
||||
// Start title polling when stream ends cleanly — sidebar title animates in
|
||||
const titlePollRef = useRef<ReturnType<typeof setInterval>>();
|
||||
const prevStatusRef = useRef(status);
|
||||
|
||||
useEffect(() => {
|
||||
const prev = prevStatusRef.current;
|
||||
prevStatusRef.current = status;
|
||||
|
||||
const wasActive = prev === "streaming" || prev === "submitted";
|
||||
const isNowReady = status === "ready";
|
||||
|
||||
if (!wasActive || !isNowReady || !sessionId || isReconnecting) return;
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
});
|
||||
const sid = sessionId;
|
||||
let attempts = 0;
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = setInterval(() => {
|
||||
const data = queryClient.getQueryData<getV2ListSessionsResponse>(
|
||||
getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
);
|
||||
const hasTitle =
|
||||
data?.status === 200 &&
|
||||
data.data.sessions.some((s) => s.id === sid && s.title);
|
||||
if (hasTitle || attempts >= TITLE_POLL_MAX_ATTEMPTS) {
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = undefined;
|
||||
return;
|
||||
}
|
||||
attempts += 1;
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey({ limit: 50 }),
|
||||
});
|
||||
}, TITLE_POLL_INTERVAL_MS);
|
||||
}, [status, sessionId, isReconnecting, queryClient]);
|
||||
|
||||
// Clean up polling on session change or unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
clearInterval(titlePollRef.current);
|
||||
titlePollRef.current = undefined;
|
||||
};
|
||||
}, [sessionId]);
|
||||
|
||||
// --- Mobile drawer handlers ---
|
||||
function handleOpenDrawer() {
|
||||
setDrawerOpen(true);
|
||||
|
||||
@@ -1305,6 +1305,59 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/title": {
|
||||
"patch": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Update session title",
|
||||
"description": "Update the title of a chat session.\n\nAllows the user to rename their chat session.\n\nArgs:\n session_id: The session ID to update.\n request: Request body containing the new title.\n user_id: The authenticated user's ID.\n\nReturns:\n dict: Status of the update.\n\nRaises:\n HTTPException: 404 if session not found or not owned by user.",
|
||||
"operationId": "patchV2Update session title",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/UpdateSessionTitleRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Patchv2Update Session Title"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"404": { "description": "Session not found or access denied" },
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -13291,6 +13344,13 @@
|
||||
"required": ["permissions"],
|
||||
"title": "UpdatePermissionsRequest"
|
||||
},
|
||||
"UpdateSessionTitleRequest": {
|
||||
"properties": { "title": { "type": "string", "title": "Title" } },
|
||||
"type": "object",
|
||||
"required": ["title"],
|
||||
"title": "UpdateSessionTitleRequest",
|
||||
"description": "Request model for updating a session's title."
|
||||
},
|
||||
"UpdateTimezoneRequest": {
|
||||
"properties": {
|
||||
"timezone": {
|
||||
|
||||
@@ -7,6 +7,9 @@ import {
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { Button as AtomButton } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { cjk } from "@streamdown/cjk";
|
||||
import { code } from "@/lib/streamdown-code-plugin";
|
||||
@@ -16,6 +19,7 @@ import type { UIMessage } from "ai";
|
||||
import { ChevronLeftIcon, ChevronRightIcon } from "lucide-react";
|
||||
import type { ComponentProps, HTMLAttributes, ReactElement } from "react";
|
||||
import { createContext, memo, useContext, useEffect, useState } from "react";
|
||||
import type { LinkSafetyModalProps } from "streamdown";
|
||||
import { Streamdown } from "streamdown";
|
||||
|
||||
export type MessageProps = HTMLAttributes<HTMLDivElement> & {
|
||||
@@ -307,6 +311,46 @@ function isSameOriginLink(url: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
function ExternalLinkModal({
|
||||
url,
|
||||
isOpen,
|
||||
onClose,
|
||||
onConfirm,
|
||||
}: LinkSafetyModalProps) {
|
||||
return (
|
||||
<Dialog
|
||||
title="Open external link"
|
||||
styling={{ maxWidth: "30rem", minWidth: "auto" }}
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: async (open) => {
|
||||
if (!open) onClose();
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<Text variant="body">
|
||||
You're about to visit an external website:
|
||||
</Text>
|
||||
<Text
|
||||
variant="small"
|
||||
className="mt-2 break-all rounded-md bg-neutral-100 p-3 font-mono"
|
||||
>
|
||||
{url}
|
||||
</Text>
|
||||
<Dialog.Footer>
|
||||
<AtomButton variant="secondary" onClick={onClose}>
|
||||
Cancel
|
||||
</AtomButton>
|
||||
<AtomButton variant="primary" onClick={onConfirm}>
|
||||
Open link
|
||||
</AtomButton>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
export const MessageResponse = memo(
|
||||
({ className, ...props }: MessageResponseProps) => (
|
||||
<Streamdown
|
||||
@@ -318,6 +362,7 @@ export const MessageResponse = memo(
|
||||
linkSafety={{
|
||||
enabled: true,
|
||||
onLinkCheck: isSameOriginLink,
|
||||
renderModal: (modalProps) => <ExternalLinkModal {...modalProps} />,
|
||||
}}
|
||||
{...props}
|
||||
/>
|
||||
|
||||
Reference in New Issue
Block a user