mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-19 02:54:28 -05:00
feat(copilot): stream resume, transcript staleness detection, WebSearch display
- Enable `resume: true` on `useChat` with `prepareReconnectToStreamRequest` so page refresh reconnects to active backend streams via Redis replay - Add `message_count` watermark + timestamp metadata to transcript uploads; on download, detect staleness and compress only the gap instead of the full history (hybrid: transcript via --resume + compressed missed turns) - Fix WebSearch accordion showing empty by extracting text from MCP-style content blocks (`extractMcpText`) with raw JSON fallback - Revert over-blocking: only `AskUserQuestion` added to SDK_DISALLOWED_TOOLS (removed EnterPlanMode, ExitPlanMode, Skill, NotebookEdit) - Add defensive TodoItem filter per coderabbit review - Fix service_test for TranscriptDownload return type change
This commit is contained in:
@@ -515,17 +515,20 @@ async def stream_chat_completion_sdk(
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
resume_file: str | None = None
|
||||
use_resume = False
|
||||
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
transcript_content = await download_transcript(user_id, session_id)
|
||||
if transcript_content and validate_transcript(transcript_content):
|
||||
dl = await download_transcript(user_id, session_id)
|
||||
if dl and validate_transcript(dl.content):
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
transcript_content, session_id, sdk_cwd
|
||||
dl.content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(transcript_content)}B)"
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
@@ -566,11 +569,35 @@ async def stream_chat_completion_sdk(
|
||||
# Build query: with --resume the CLI already has full
|
||||
# context, so we only send the new message. Without
|
||||
# resume, compress history into a context prefix.
|
||||
#
|
||||
# Hybrid mode: if the transcript is stale (upload missed
|
||||
# some turns), compress only the gap and prepend it so
|
||||
# the agent has transcript context + missed turns.
|
||||
query_message = current_message
|
||||
if not use_resume and len(session.messages) > 1:
|
||||
current_msg_count = len(session.messages)
|
||||
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
# Transcript covers messages[0..M-1]. Current session
|
||||
# has N messages (last one is the new user msg).
|
||||
# Gap = messages[M .. N-2] (everything between upload
|
||||
# and the current turn).
|
||||
if transcript_msg_count < current_msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
gap_context = _format_conversation_context(gap)
|
||||
if gap_context:
|
||||
logger.info(
|
||||
f"[SDK] Transcript stale: covers {transcript_msg_count} "
|
||||
f"of {current_msg_count} messages, compressing "
|
||||
f"{len(gap)} missed messages"
|
||||
)
|
||||
query_message = (
|
||||
f"{gap_context}\n\n"
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
elif not use_resume and current_msg_count > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({len(session.messages)} messages) — "
|
||||
f"{session_id} ({current_msg_count} messages) — "
|
||||
f"no transcript available for --resume"
|
||||
)
|
||||
compressed = await _compress_conversation_history(session)
|
||||
@@ -680,7 +707,12 @@ async def stream_chat_completion_sdk(
|
||||
raw_transcript = None
|
||||
|
||||
if raw_transcript:
|
||||
await _try_upload_transcript(user_id, session_id, raw_transcript)
|
||||
await _try_upload_transcript(
|
||||
user_id,
|
||||
session_id,
|
||||
raw_transcript,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
@@ -713,7 +745,10 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
|
||||
async def _try_upload_transcript(
|
||||
user_id: str, session_id: str, raw_content: str
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
raw_content: str,
|
||||
message_count: int = 0,
|
||||
) -> bool:
|
||||
"""Strip progress entries and upload transcript (with timeout).
|
||||
|
||||
@@ -721,7 +756,9 @@ async def _try_upload_transcript(
|
||||
"""
|
||||
try:
|
||||
async with asyncio.timeout(30):
|
||||
await upload_transcript(user_id, session_id, raw_content)
|
||||
await upload_transcript(
|
||||
user_id, session_id, raw_content, message_count=message_count
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
|
||||
|
||||
@@ -395,7 +395,12 @@ _SDK_BUILTIN_TOOLS = [
|
||||
# network isolation (unshare --net) instead.
|
||||
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
|
||||
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
|
||||
SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"]
|
||||
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
|
||||
SDK_DISALLOWED_TOOLS = [
|
||||
"Bash",
|
||||
"WebFetch",
|
||||
"AskUserQuestion",
|
||||
]
|
||||
|
||||
# Tools that are blocked entirely in security hooks (defence-in-depth).
|
||||
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.
|
||||
|
||||
@@ -14,6 +14,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,6 +33,16 @@ STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
"""Result of downloading a transcript with its metadata."""
|
||||
|
||||
content: str
|
||||
message_count: int = 0 # session.messages length when uploaded
|
||||
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
|
||||
@@ -277,6 +289,15 @@ def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
)
|
||||
|
||||
|
||||
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects.
|
||||
|
||||
@@ -297,13 +318,22 @@ def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
@@ -339,16 +369,32 @@ async def upload_transcript(user_id: str, session_id: str, content: str) -> None
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Store metadata alongside the transcript so the next turn can detect
|
||||
# staleness and only compress the gap instead of the full history.
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
await storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {new_size}B "
|
||||
f"(stripped from {len(content)}B) for session {session_id}"
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
"""Download transcript from bucket storage.
|
||||
async def download_transcript(
|
||||
user_id: str, session_id: str
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
|
||||
Returns the JSONL content string, or ``None`` if not found.
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
@@ -358,10 +404,6 @@ async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
try:
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
logger.debug(
|
||||
f"[Transcript] Downloaded {len(content)}B for session {session_id}"
|
||||
)
|
||||
return content
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
return None
|
||||
@@ -369,6 +411,36 @@ async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
try:
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
if isinstance(storage, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
|
||||
meta_path = f"gcs://{storage.bucket_name}/{blob}"
|
||||
else:
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"))
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.debug(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
uploaded_at=uploaded_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
|
||||
@@ -143,7 +143,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
"Transcript was not uploaded to bucket after turn 1 — "
|
||||
"Stop hook may not have fired or transcript was too small"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
|
||||
@@ -342,6 +342,20 @@ function parseOutput(output: unknown): Record<string, unknown> | null {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract text from MCP-style content blocks.
|
||||
* SDK built-in tools (WebSearch, etc.) may return `{content: [{type:"text", text:"..."}]}`.
|
||||
*/
|
||||
function extractMcpText(output: Record<string, unknown>): string | null {
|
||||
if (Array.isArray(output.content)) {
|
||||
const texts = (output.content as Array<Record<string, unknown>>)
|
||||
.filter((b) => b.type === "text" && typeof b.text === "string")
|
||||
.map((b) => b.text as string);
|
||||
if (texts.length > 0) return texts.join("\n");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function getExitCode(output: unknown): number | null {
|
||||
const parsed = parseOutput(output);
|
||||
if (!parsed) return null;
|
||||
@@ -429,7 +443,20 @@ function getWebAccordionData(
|
||||
const url =
|
||||
getStringField(inp as Record<string, unknown>, "url", "query") ??
|
||||
"Web content";
|
||||
const content = getStringField(output, "content", "text", "_raw");
|
||||
|
||||
// Try direct string fields first, then MCP content blocks, then raw JSON
|
||||
let content = getStringField(output, "content", "text", "_raw");
|
||||
if (!content) content = extractMcpText(output);
|
||||
if (!content) {
|
||||
// Fallback: render the raw JSON so the accordion isn't empty
|
||||
try {
|
||||
const raw = JSON.stringify(output, null, 2);
|
||||
if (raw !== "{}") content = raw;
|
||||
} catch {
|
||||
/* empty */
|
||||
}
|
||||
}
|
||||
|
||||
const statusCode =
|
||||
typeof output.status_code === "number" ? output.status_code : null;
|
||||
const message = getStringField(output, "message");
|
||||
@@ -495,7 +522,14 @@ function getTodoAccordionData(input: unknown): AccordionData {
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const todos: TodoItem[] = Array.isArray(inp.todos) ? inp.todos : [];
|
||||
const todos: TodoItem[] = Array.isArray(inp.todos)
|
||||
? inp.todos.filter(
|
||||
(t: unknown): t is TodoItem =>
|
||||
typeof t === "object" &&
|
||||
t !== null &&
|
||||
typeof (t as TodoItem).content === "string",
|
||||
)
|
||||
: [];
|
||||
|
||||
const completed = todos.filter((t) => t.status === "completed").length;
|
||||
const total = todos.length;
|
||||
@@ -551,10 +585,13 @@ function getDefaultAccordionData(
|
||||
): AccordionData {
|
||||
const message = getStringField(output, "message");
|
||||
const raw = output._raw;
|
||||
const mcpText = extractMcpText(output);
|
||||
|
||||
let displayContent: string;
|
||||
if (typeof raw === "string") {
|
||||
displayContent = raw;
|
||||
} else if (mcpText) {
|
||||
displayContent = mcpText;
|
||||
} else if (message) {
|
||||
displayContent = message;
|
||||
} else {
|
||||
|
||||
@@ -80,6 +80,12 @@ export function useCopilotPage() {
|
||||
},
|
||||
};
|
||||
},
|
||||
// Resume: GET goes to the same URL as POST (backend uses
|
||||
// method to distinguish). Override the default formula which
|
||||
// would append /{chatId}/stream to the existing path.
|
||||
prepareReconnectToStreamRequest: () => ({
|
||||
api: `/api/chat/sessions/${sessionId}/stream`,
|
||||
}),
|
||||
})
|
||||
: null,
|
||||
[sessionId],
|
||||
@@ -88,6 +94,7 @@ export function useCopilotPage() {
|
||||
const { messages, sendMessage, stop, status, error, setMessages } = useChat({
|
||||
id: sessionId ?? undefined,
|
||||
transport: transport ?? undefined,
|
||||
resume: true,
|
||||
});
|
||||
|
||||
// Abort the stream if the backend doesn't start sending data within 12s.
|
||||
|
||||
Reference in New Issue
Block a user