mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into fix/messed-up-copilot
Resolve service.py conflicts: take dev's file-based transcript approach (CapturedTranscript.path + read_transcript_file) and public client API, layer our fixes on top (wait_for_stash race-condition fix, session_id logging, approach logging).
This commit is contained in:
@@ -18,7 +18,7 @@ from backend.copilot.completion_handler import (
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -132,6 +132,14 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class CancelTaskResponse(BaseModel):
|
||||
"""Response model for the cancel task endpoint."""
|
||||
|
||||
cancelled: bool
|
||||
task_id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
@@ -314,6 +322,57 @@ async def get_session(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/cancel",
|
||||
status_code=200,
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CancelTaskResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
||||
polls Redis until the task status flips from ``running`` or a timeout
|
||||
(5 s) is reached. Returns only after the cancellation is confirmed.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
active_task, _ = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if not active_task:
|
||||
return CancelTaskResponse(cancelled=False, reason="no_active_task")
|
||||
|
||||
task_id = active_task.task_id
|
||||
await enqueue_cancel_task(task_id)
|
||||
logger.info(
|
||||
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
|
||||
f"session ...{session_id[-8:]}"
|
||||
)
|
||||
|
||||
# Poll until the executor confirms the task is no longer running.
|
||||
# Keep max_wait below typical reverse-proxy read timeouts.
|
||||
poll_interval = 0.5
|
||||
max_wait = 5.0
|
||||
waited = 0.0
|
||||
while waited < max_wait:
|
||||
await asyncio.sleep(poll_interval)
|
||||
waited += poll_interval
|
||||
task = await stream_registry.get_task(task_id)
|
||||
if task is None or task.status != "running":
|
||||
logger.info(
|
||||
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
|
||||
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
|
||||
)
|
||||
return CancelTaskResponse(cancelled=True, task_id=task_id)
|
||||
|
||||
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
|
||||
return CancelTaskResponse(
|
||||
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
|
||||
@@ -205,3 +205,20 @@ async def enqueue_copilot_task(
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
|
||||
async def enqueue_cancel_task(task_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot task.
|
||||
|
||||
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
|
||||
pods receive the cancellation signal.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
event = CancelCoPilotEvent(task_id=task_id)
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key="", # FANOUT ignores routing key
|
||||
message=event.model_dump_json(),
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field as dataclass_field
|
||||
from typing import Any
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -52,6 +51,7 @@ from .tool_adapter import (
|
||||
from .transcript import (
|
||||
cleanup_cli_project_dir,
|
||||
download_transcript,
|
||||
read_transcript_file,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -66,20 +66,14 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
@dataclass
|
||||
class CapturedTranscript:
|
||||
"""Transcript built from raw SDK output for stateless --resume.
|
||||
|
||||
The CLI does not write JSONL files in SDK mode, so we capture the raw
|
||||
JSON messages from the CLI's stdout and build the transcript ourselves.
|
||||
"""
|
||||
|
||||
raw_entries: list[str] = dataclass_field(default_factory=list)
|
||||
"""Raw JSON lines captured from the SDK output (non-ephemeral only)."""
|
||||
"""Info captured by the SDK Stop hook for stateless --resume."""
|
||||
|
||||
path: str = ""
|
||||
sdk_session_id: str = ""
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return bool(self.raw_entries)
|
||||
return bool(self.path)
|
||||
|
||||
|
||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
@@ -512,7 +506,6 @@ async def stream_chat_completion_sdk(
|
||||
# even if _make_sdk_cwd raises (in that case it stays as "").
|
||||
sdk_cwd = ""
|
||||
use_resume = False
|
||||
current_message = message or ""
|
||||
|
||||
try:
|
||||
# Use a session-specific temp dir to avoid cleanup race conditions
|
||||
@@ -541,14 +534,13 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
# --- Transcript capture from SDK output ---
|
||||
# The CLI does not write JSONL files in SDK mode. Instead
|
||||
# we capture the raw JSON from the CLI stdout and build
|
||||
# the transcript for --resume ourselves.
|
||||
# --- Transcript capture via Stop hook ---
|
||||
captured_transcript = CapturedTranscript()
|
||||
|
||||
def _on_stop(_transcript_path: str, sdk_session_id: str) -> None:
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
logger.debug(f"[SDK] Stop hook: path={transcript_path!r}")
|
||||
|
||||
security_hooks = create_security_hooks(
|
||||
user_id,
|
||||
@@ -561,7 +553,6 @@ async def stream_chat_completion_sdk(
|
||||
resume_file: str | None = None
|
||||
use_resume = False
|
||||
transcript_msg_count = 0 # watermark: session.messages length at upload
|
||||
downloaded_transcript_content: str | None = None
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
dl = await download_transcript(user_id, session_id)
|
||||
@@ -572,7 +563,6 @@ async def stream_chat_completion_sdk(
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
transcript_msg_count = dl.message_count
|
||||
downloaded_transcript_content = dl.content
|
||||
logger.debug(
|
||||
f"[SDK] Using --resume ({len(dl.content)}B, "
|
||||
f"msg_count={transcript_msg_count})"
|
||||
@@ -691,47 +681,17 @@ async def stream_chat_completion_sdk(
|
||||
# asyncio.timeout() is preferred over asyncio.wait_for()
|
||||
# because wait_for wraps in a separate Task whose cancellation
|
||||
# can leave the async generator in a broken state.
|
||||
#
|
||||
# TECH DEBT: We use two private SDK internals here:
|
||||
# 1. client._query.receive_messages() — raw dict iterator
|
||||
# 2. _internal.message_parser.parse_message — dict→Message
|
||||
# This is necessary because the public receive_messages()
|
||||
# only yields parsed Messages, but we need the raw dicts
|
||||
# for transcript capture (CLI doesn't write JSONL in SDK
|
||||
# mode) and per-message timeout for heartbeats.
|
||||
# Pin claude-agent-sdk tightly and audit on version bumps.
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage
|
||||
from claude_agent_sdk._internal.message_parser import (
|
||||
parse_message as _parse_sdk_msg,
|
||||
)
|
||||
|
||||
# NOTE: _query is a private SDK attribute; see tech-debt
|
||||
# comment on the import above.
|
||||
if client._query is None:
|
||||
raise RuntimeError(
|
||||
"SDK client query not initialized — connect() may have failed"
|
||||
)
|
||||
msg_iter = client._query.receive_messages().__aiter__()
|
||||
msg_iter = client.receive_messages().__aiter__()
|
||||
while not stream_completed:
|
||||
try:
|
||||
async with asyncio.timeout(_HEARTBEAT_INTERVAL):
|
||||
raw_data = await msg_iter.__anext__()
|
||||
sdk_msg = await msg_iter.__anext__()
|
||||
except TimeoutError:
|
||||
yield StreamHeartbeat()
|
||||
continue
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
# Capture non-ephemeral entries for transcript.
|
||||
# stream_event = streaming deltas (redundant with
|
||||
# final assistant message).
|
||||
msg_type = raw_data.get("type", "")
|
||||
if msg_type != "stream_event":
|
||||
captured_transcript.raw_entries.append(
|
||||
json.dumps(raw_data, separators=(",", ":"))
|
||||
)
|
||||
|
||||
sdk_msg = _parse_sdk_msg(raw_data)
|
||||
logger.debug(
|
||||
"[SDK] [%s] Received: %s %s",
|
||||
session_id[:12],
|
||||
@@ -745,6 +705,8 @@ async def stream_chat_completion_sdk(
|
||||
# awaits an asyncio.Event signaled by stash_pending_tool_output(),
|
||||
# completing as soon as the hook finishes (typically <1ms).
|
||||
# The sleep(0) after lets any remaining concurrent hooks complete.
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage
|
||||
|
||||
if adapter.has_unresolved_tool_calls and isinstance(
|
||||
sdk_msg, (AssistantMessage, ResultMessage)
|
||||
):
|
||||
@@ -842,24 +804,33 @@ async def stream_chat_completion_sdk(
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# The CLI does not write JSONL files in SDK mode. Instead
|
||||
# we build the transcript from the raw JSON we captured
|
||||
# during the streaming loop above.
|
||||
# After async with the SDK task group has exited, so the Stop
|
||||
# hook has already fired and the CLI has been SIGTERMed. The
|
||||
# CLI uses appendFileSync, so all writes are safely on disk.
|
||||
if config.claude_agent_use_resume and user_id:
|
||||
raw_transcript = _build_transcript(
|
||||
captured_entries=captured_transcript.raw_entries,
|
||||
user_message=current_message,
|
||||
session_id=session_id,
|
||||
previous_transcript=downloaded_transcript_content,
|
||||
)
|
||||
# With --resume the CLI appends to the resume file (most
|
||||
# complete). Otherwise use the Stop hook path.
|
||||
if use_resume and resume_file:
|
||||
raw_transcript = read_transcript_file(resume_file)
|
||||
logger.debug("[SDK] Transcript source: resume file")
|
||||
elif captured_transcript.path:
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
logger.debug(
|
||||
"[SDK] Transcript source: stop hook (%s), " "read result: %s",
|
||||
captured_transcript.path,
|
||||
f"{len(raw_transcript)}B" if raw_transcript else "None",
|
||||
)
|
||||
else:
|
||||
raw_transcript = None
|
||||
|
||||
if not raw_transcript:
|
||||
logger.debug(
|
||||
"[SDK] No usable transcript — CLI file had no "
|
||||
"conversation entries (expected for first turn "
|
||||
"without --resume)"
|
||||
)
|
||||
|
||||
if raw_transcript:
|
||||
logger.info(
|
||||
"[SDK] [%s] Uploading transcript (%dB, %d new entries)",
|
||||
session_id[:12],
|
||||
len(raw_transcript),
|
||||
len(captured_transcript.raw_entries),
|
||||
)
|
||||
# Shield the upload from generator cancellation so a
|
||||
# client disconnect / page refresh doesn't lose the
|
||||
# transcript. The upload must finish even if the SSE
|
||||
@@ -872,12 +843,6 @@ async def stream_chat_completion_sdk(
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] No transcript to upload for %s " "(%d captured entries)",
|
||||
session_id,
|
||||
len(captured_transcript.raw_entries),
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
@@ -911,67 +876,6 @@ async def stream_chat_completion_sdk(
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
|
||||
def _build_transcript(
|
||||
captured_entries: list[str],
|
||||
user_message: str,
|
||||
session_id: str,
|
||||
previous_transcript: str | None = None,
|
||||
) -> str | None:
|
||||
"""Build a JSONL transcript from captured SDK output for ``--resume``.
|
||||
|
||||
The Claude CLI does not write JSONL transcript files in SDK mode
|
||||
(``--output-format stream-json``). This function reconstructs the
|
||||
transcript from the raw JSON messages we captured from the CLI's stdout
|
||||
during the streaming loop.
|
||||
|
||||
Args:
|
||||
captured_entries: Raw JSON lines from the CLI output (non-ephemeral).
|
||||
user_message: The user's original message for this turn.
|
||||
session_id: Chat session identifier.
|
||||
previous_transcript: JSONL content of the previous transcript
|
||||
(downloaded from bucket when using ``--resume``).
|
||||
|
||||
Returns:
|
||||
Complete JSONL string ready for upload, or ``None`` if the entries
|
||||
don't constitute a valid transcript.
|
||||
"""
|
||||
if not captured_entries:
|
||||
return None
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
# 1. Include the previous transcript (old turns)
|
||||
if previous_transcript:
|
||||
parts.append(previous_transcript.rstrip("\n"))
|
||||
|
||||
# 2. Add a synthetic user entry for this turn.
|
||||
# The CLI does not echo user messages sent via stdin, so we construct
|
||||
# one. The uuid/parentUuid fields are optional for --resume.
|
||||
user_entry = {
|
||||
"type": "user",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"parentUuid": "",
|
||||
"session_id": session_id,
|
||||
"message": {"role": "user", "content": user_message},
|
||||
}
|
||||
parts.append(json.dumps(user_entry, separators=(",", ":")))
|
||||
|
||||
# 3. Append the raw CLI output entries (system init, assistant, result, …)
|
||||
parts.extend(captured_entries)
|
||||
|
||||
raw = "\n".join(parts) + "\n"
|
||||
|
||||
if not validate_transcript(raw):
|
||||
logger.warning(
|
||||
"[SDK] Built transcript not valid (%d entries, %dB)",
|
||||
len(captured_entries),
|
||||
len(raw),
|
||||
)
|
||||
return None
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
async def _try_upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
postV2CancelSessionTask,
|
||||
useDeleteV2DeleteSession,
|
||||
useGetV2ListSessions,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
@@ -109,7 +110,7 @@ export function useCopilotPage() {
|
||||
const {
|
||||
messages,
|
||||
sendMessage,
|
||||
stop,
|
||||
stop: sdkStop,
|
||||
status,
|
||||
error,
|
||||
setMessages,
|
||||
@@ -122,6 +123,52 @@ export function useCopilotPage() {
|
||||
// call resumeStream() manually after hydration + active_stream detection.
|
||||
});
|
||||
|
||||
// Wrap AI SDK's stop() to also cancel the backend executor task.
|
||||
// sdkStop() aborts the SSE fetch instantly (UI feedback), then we fire
|
||||
// the cancel API to actually stop the executor and wait for confirmation.
|
||||
async function stop() {
|
||||
sdkStop();
|
||||
|
||||
// Mark any in-progress tool parts as errored so spinners stop.
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) => ({
|
||||
...msg,
|
||||
parts: msg.parts.map((part) =>
|
||||
"state" in part &&
|
||||
(part.state === "input-streaming" || part.state === "input-available")
|
||||
? {
|
||||
...part,
|
||||
state: "output-error" as const,
|
||||
errorText: "Cancelled",
|
||||
}
|
||||
: part,
|
||||
),
|
||||
})),
|
||||
);
|
||||
|
||||
if (!sessionId) return;
|
||||
try {
|
||||
const res = await postV2CancelSessionTask(sessionId);
|
||||
if (
|
||||
res.status === 200 &&
|
||||
"reason" in res.data &&
|
||||
res.data.reason === "cancel_published_not_confirmed"
|
||||
) {
|
||||
toast({
|
||||
title: "Stop may take a moment",
|
||||
description:
|
||||
"The cancel was sent but not yet confirmed. The task should stop shortly.",
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
toast({
|
||||
title: "Could not stop the task",
|
||||
description: "The task may still be running in the background.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Abort the stream if the backend doesn't start sending data within 12s.
|
||||
const stopRef = useRef(stop);
|
||||
stopRef.current = stop;
|
||||
|
||||
@@ -1263,6 +1263,44 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/cancel": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Cancel Session Task",
|
||||
"description": "Cancel the active streaming task for a session.\n\nPublishes a cancel event to the executor via RabbitMQ FANOUT, then\npolls Redis until the task status flips from ``running`` or a timeout\n(5 s) is reached. Returns only after the cancellation is confirmed.",
|
||||
"operationId": "postV2CancelSessionTask",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CancelTaskResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -7537,6 +7575,23 @@
|
||||
"required": ["file"],
|
||||
"title": "Body_postV2Upload submission media"
|
||||
},
|
||||
"CancelTaskResponse": {
|
||||
"properties": {
|
||||
"cancelled": { "type": "boolean", "title": "Cancelled" },
|
||||
"task_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Task Id"
|
||||
},
|
||||
"reason": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Reason"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["cancelled"],
|
||||
"title": "CancelTaskResponse",
|
||||
"description": "Response model for the cancel task endpoint."
|
||||
},
|
||||
"ChangelogEntry": {
|
||||
"properties": {
|
||||
"version": { "type": "string", "title": "Version" },
|
||||
|
||||
Reference in New Issue
Block a user