Merge branch 'dev' into fix/messed-up-copilot

Resolve service.py conflicts: take dev's file-based transcript approach
(CapturedTranscript.path + read_transcript_file) and public client API,
layer our fixes on top (wait_for_stash race-condition fix, session_id
logging, approach logging).
This commit is contained in:
Zamil Majdy
2026-02-20 09:54:07 +07:00
5 changed files with 216 additions and 134 deletions

View File

@@ -18,7 +18,7 @@ from backend.copilot.completion_handler import (
process_operation_success,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_copilot_task
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -132,6 +132,14 @@ class ListSessionsResponse(BaseModel):
total: int
class CancelTaskResponse(BaseModel):
"""Response model for the cancel task endpoint."""
cancelled: bool
task_id: str | None = None
reason: str | None = None
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
@@ -314,6 +322,57 @@ async def get_session(
)
@router.post(
"/sessions/{session_id}/cancel",
status_code=200,
)
async def cancel_session_task(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CancelTaskResponse:
"""Cancel the active streaming task for a session.
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
polls Redis until the task status flips from ``running`` or a timeout
(5 s) is reached. Returns only after the cancellation is confirmed.
"""
await _validate_and_get_session(session_id, user_id)
active_task, _ = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return CancelTaskResponse(cancelled=False, reason="no_active_task")
task_id = active_task.task_id
await enqueue_cancel_task(task_id)
logger.info(
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
f"session ...{session_id[-8:]}"
)
# Poll until the executor confirms the task is no longer running.
# Keep max_wait below typical reverse-proxy read timeouts.
poll_interval = 0.5
max_wait = 5.0
waited = 0.0
while waited < max_wait:
await asyncio.sleep(poll_interval)
waited += poll_interval
task = await stream_registry.get_task(task_id)
if task is None or task.status != "running":
logger.info(
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
)
return CancelTaskResponse(cancelled=True, task_id=task_id)
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
return CancelTaskResponse(
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
)
@router.post(
"/sessions/{session_id}/stream",
)

View File

@@ -205,3 +205,20 @@ async def enqueue_copilot_task(
message=entry.model_dump_json(),
exchange=COPILOT_EXECUTION_EXCHANGE,
)
async def enqueue_cancel_task(task_id: str) -> None:
"""Publish a cancel request for a running CoPilot task.
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
pods receive the cancellation signal.
"""
from backend.util.clients import get_async_copilot_queue
event = CancelCoPilotEvent(task_id=task_id)
queue_client = await get_async_copilot_queue()
await queue_client.publish_message(
routing_key="", # FANOUT ignores routing key
message=event.model_dump_json(),
exchange=COPILOT_CANCEL_EXCHANGE,
)

View File

@@ -7,7 +7,6 @@ import os
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from typing import Any
from backend.util.exceptions import NotFoundError
@@ -52,6 +51,7 @@ from .tool_adapter import (
from .transcript import (
cleanup_cli_project_dir,
download_transcript,
read_transcript_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -66,20 +66,14 @@ _background_tasks: set[asyncio.Task[Any]] = set()
@dataclass
class CapturedTranscript:
"""Transcript built from raw SDK output for stateless --resume.
The CLI does not write JSONL files in SDK mode, so we capture the raw
JSON messages from the CLI's stdout and build the transcript ourselves.
"""
raw_entries: list[str] = dataclass_field(default_factory=list)
"""Raw JSON lines captured from the SDK output (non-ephemeral only)."""
"""Info captured by the SDK Stop hook for stateless --resume."""
path: str = ""
sdk_session_id: str = ""
@property
def available(self) -> bool:
return bool(self.raw_entries)
return bool(self.path)
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
@@ -512,7 +506,6 @@ async def stream_chat_completion_sdk(
# even if _make_sdk_cwd raises (in that case it stays as "").
sdk_cwd = ""
use_resume = False
current_message = message or ""
try:
# Use a session-specific temp dir to avoid cleanup race conditions
@@ -541,14 +534,13 @@ async def stream_chat_completion_sdk(
sdk_model = _resolve_sdk_model()
# --- Transcript capture from SDK output ---
# The CLI does not write JSONL files in SDK mode. Instead
# we capture the raw JSON from the CLI stdout and build
# the transcript for --resume ourselves.
# --- Transcript capture via Stop hook ---
captured_transcript = CapturedTranscript()
def _on_stop(_transcript_path: str, sdk_session_id: str) -> None:
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
logger.debug(f"[SDK] Stop hook: path={transcript_path!r}")
security_hooks = create_security_hooks(
user_id,
@@ -561,7 +553,6 @@ async def stream_chat_completion_sdk(
resume_file: str | None = None
use_resume = False
transcript_msg_count = 0 # watermark: session.messages length at upload
downloaded_transcript_content: str | None = None
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
dl = await download_transcript(user_id, session_id)
@@ -572,7 +563,6 @@ async def stream_chat_completion_sdk(
if resume_file:
use_resume = True
transcript_msg_count = dl.message_count
downloaded_transcript_content = dl.content
logger.debug(
f"[SDK] Using --resume ({len(dl.content)}B, "
f"msg_count={transcript_msg_count})"
@@ -691,47 +681,17 @@ async def stream_chat_completion_sdk(
# asyncio.timeout() is preferred over asyncio.wait_for()
# because wait_for wraps in a separate Task whose cancellation
# can leave the async generator in a broken state.
#
# TECH DEBT: We use two private SDK internals here:
# 1. client._query.receive_messages() — raw dict iterator
# 2. _internal.message_parser.parse_message — dict→Message
# This is necessary because the public receive_messages()
# only yields parsed Messages, but we need the raw dicts
# for transcript capture (CLI doesn't write JSONL in SDK
# mode) and per-message timeout for heartbeats.
# Pin claude-agent-sdk tightly and audit on version bumps.
from claude_agent_sdk import AssistantMessage, ResultMessage
from claude_agent_sdk._internal.message_parser import (
parse_message as _parse_sdk_msg,
)
# NOTE: _query is a private SDK attribute; see tech-debt
# comment on the import above.
if client._query is None:
raise RuntimeError(
"SDK client query not initialized — connect() may have failed"
)
msg_iter = client._query.receive_messages().__aiter__()
msg_iter = client.receive_messages().__aiter__()
while not stream_completed:
try:
async with asyncio.timeout(_HEARTBEAT_INTERVAL):
raw_data = await msg_iter.__anext__()
sdk_msg = await msg_iter.__anext__()
except TimeoutError:
yield StreamHeartbeat()
continue
except StopAsyncIteration:
break
# Capture non-ephemeral entries for transcript.
# stream_event = streaming deltas (redundant with
# final assistant message).
msg_type = raw_data.get("type", "")
if msg_type != "stream_event":
captured_transcript.raw_entries.append(
json.dumps(raw_data, separators=(",", ":"))
)
sdk_msg = _parse_sdk_msg(raw_data)
logger.debug(
"[SDK] [%s] Received: %s %s",
session_id[:12],
@@ -745,6 +705,8 @@ async def stream_chat_completion_sdk(
# awaits an asyncio.Event signaled by stash_pending_tool_output(),
# completing as soon as the hook finishes (typically <1ms).
# The sleep(0) after lets any remaining concurrent hooks complete.
from claude_agent_sdk import AssistantMessage, ResultMessage
if adapter.has_unresolved_tool_calls and isinstance(
sdk_msg, (AssistantMessage, ResultMessage)
):
@@ -842,24 +804,33 @@ async def stream_chat_completion_sdk(
session.messages.append(assistant_response)
# --- Upload transcript for next-turn --resume ---
# The CLI does not write JSONL files in SDK mode. Instead
# we build the transcript from the raw JSON we captured
# during the streaming loop above.
# After async with the SDK task group has exited, so the Stop
# hook has already fired and the CLI has been SIGTERMed. The
# CLI uses appendFileSync, so all writes are safely on disk.
if config.claude_agent_use_resume and user_id:
raw_transcript = _build_transcript(
captured_entries=captured_transcript.raw_entries,
user_message=current_message,
session_id=session_id,
previous_transcript=downloaded_transcript_content,
)
# With --resume the CLI appends to the resume file (most
# complete). Otherwise use the Stop hook path.
if use_resume and resume_file:
raw_transcript = read_transcript_file(resume_file)
logger.debug("[SDK] Transcript source: resume file")
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), " "read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)
else:
raw_transcript = None
if not raw_transcript:
logger.debug(
"[SDK] No usable transcript — CLI file had no "
"conversation entries (expected for first turn "
"without --resume)"
)
if raw_transcript:
logger.info(
"[SDK] [%s] Uploading transcript (%dB, %d new entries)",
session_id[:12],
len(raw_transcript),
len(captured_transcript.raw_entries),
)
# Shield the upload from generator cancellation so a
# client disconnect / page refresh doesn't lose the
# transcript. The upload must finish even if the SSE
@@ -872,12 +843,6 @@ async def stream_chat_completion_sdk(
message_count=len(session.messages),
)
)
else:
logger.warning(
"[SDK] No transcript to upload for %s " "(%d captured entries)",
session_id,
len(captured_transcript.raw_entries),
)
except ImportError:
raise RuntimeError(
@@ -911,67 +876,6 @@ async def stream_chat_completion_sdk(
_cleanup_sdk_tool_results(sdk_cwd)
def _build_transcript(
captured_entries: list[str],
user_message: str,
session_id: str,
previous_transcript: str | None = None,
) -> str | None:
"""Build a JSONL transcript from captured SDK output for ``--resume``.
The Claude CLI does not write JSONL transcript files in SDK mode
(``--output-format stream-json``). This function reconstructs the
transcript from the raw JSON messages we captured from the CLI's stdout
during the streaming loop.
Args:
captured_entries: Raw JSON lines from the CLI output (non-ephemeral).
user_message: The user's original message for this turn.
session_id: Chat session identifier.
previous_transcript: JSONL content of the previous transcript
(downloaded from bucket when using ``--resume``).
Returns:
Complete JSONL string ready for upload, or ``None`` if the entries
don't constitute a valid transcript.
"""
if not captured_entries:
return None
parts: list[str] = []
# 1. Include the previous transcript (old turns)
if previous_transcript:
parts.append(previous_transcript.rstrip("\n"))
# 2. Add a synthetic user entry for this turn.
# The CLI does not echo user messages sent via stdin, so we construct
# one. The uuid/parentUuid fields are optional for --resume.
user_entry = {
"type": "user",
"uuid": str(uuid.uuid4()),
"parentUuid": "",
"session_id": session_id,
"message": {"role": "user", "content": user_message},
}
parts.append(json.dumps(user_entry, separators=(",", ":")))
# 3. Append the raw CLI output entries (system init, assistant, result, …)
parts.extend(captured_entries)
raw = "\n".join(parts) + "\n"
if not validate_transcript(raw):
logger.warning(
"[SDK] Built transcript not valid (%d entries, %dB)",
len(captured_entries),
len(raw),
)
return None
return raw
async def _try_upload_transcript(
user_id: str,
session_id: str,

View File

@@ -1,5 +1,6 @@
import {
getGetV2ListSessionsQueryKey,
postV2CancelSessionTask,
useDeleteV2DeleteSession,
useGetV2ListSessions,
} from "@/app/api/__generated__/endpoints/chat/chat";
@@ -109,7 +110,7 @@ export function useCopilotPage() {
const {
messages,
sendMessage,
stop,
stop: sdkStop,
status,
error,
setMessages,
@@ -122,6 +123,52 @@ export function useCopilotPage() {
// call resumeStream() manually after hydration + active_stream detection.
});
// Wrap AI SDK's stop() to also cancel the backend executor task.
// sdkStop() aborts the SSE fetch instantly (UI feedback), then we fire
// the cancel API to actually stop the executor and wait for confirmation.
async function stop() {
sdkStop();
// Mark any in-progress tool parts as errored so spinners stop.
setMessages((prev) =>
prev.map((msg) => ({
...msg,
parts: msg.parts.map((part) =>
"state" in part &&
(part.state === "input-streaming" || part.state === "input-available")
? {
...part,
state: "output-error" as const,
errorText: "Cancelled",
}
: part,
),
})),
);
if (!sessionId) return;
try {
const res = await postV2CancelSessionTask(sessionId);
if (
res.status === 200 &&
"reason" in res.data &&
res.data.reason === "cancel_published_not_confirmed"
) {
toast({
title: "Stop may take a moment",
description:
"The cancel was sent but not yet confirmed. The task should stop shortly.",
});
}
} catch {
toast({
title: "Could not stop the task",
description: "The task may still be running in the background.",
variant: "destructive",
});
}
}
// Abort the stream if the backend doesn't start sending data within 12s.
const stopRef = useRef(stop);
stopRef.current = stop;

View File

@@ -1263,6 +1263,44 @@
}
}
},
"/api/chat/sessions/{session_id}/cancel": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Cancel Session Task",
"description": "Cancel the active streaming task for a session.\n\nPublishes a cancel event to the executor via RabbitMQ FANOUT, then\npolls Redis until the task status flips from ``running`` or a timeout\n(5 s) is reached. Returns only after the cancellation is confirmed.",
"operationId": "postV2CancelSessionTask",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Session Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/CancelTaskResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/sessions/{session_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -7537,6 +7575,23 @@
"required": ["file"],
"title": "Body_postV2Upload submission media"
},
"CancelTaskResponse": {
"properties": {
"cancelled": { "type": "boolean", "title": "Cancelled" },
"task_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Task Id"
},
"reason": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Reason"
}
},
"type": "object",
"required": ["cancelled"],
"title": "CancelTaskResponse",
"description": "Response model for the cancel task endpoint."
},
"ChangelogEntry": {
"properties": {
"version": { "type": "string", "title": "Version" },