fix(backend/copilot): address PR review — CancelledError propagation, error status, BackgroundToolStatus in codegen union

- _execute_tool_sync now catches asyncio.CancelledError and cancels the
  unregistered child task before re-raising. Prevents orphans when the
  handler is torn down before the per-tool timeout fires (child is not
  yet in the registry so cancel_all_background_tasks can't clean it up).
- check_background_tool now maps result.success=False to status='error'
  (not 'completed'), so an agent doesn't treat a failed finish as a win.
- BackgroundToolStatus moved to tools/models.py and added to the
  ToolResponseUnion in chat routes so frontend codegen picks it up.
- Tests: replace broad `except (CancelledError, BaseException)` catches
  with contextlib.suppress(asyncio.CancelledError) in the cleanup paths.
- New tests: handler cancellation propagates to child task; success=False
  result reports status='error'.
This commit is contained in:
Zamil Majdy
2026-04-18 06:51:37 +07:00
parent bca21e84e4
commit 453e90d0f4
6 changed files with 123 additions and 45 deletions

View File

@@ -50,6 +50,7 @@ from backend.copilot.tools.models import (
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BackgroundToolStatus,
BlockDetailsResponse,
BlockListResponse,
BlockOutputResponse,
@@ -1323,6 +1324,7 @@ ToolResponseUnion = (
| MemorySearchResponse
| MemoryForgetCandidatesResponse
| MemoryForgetConfirmResponse
| BackgroundToolStatus
)

View File

@@ -276,24 +276,34 @@ async def _execute_tool_sync(
)
timeout = base_tool.timeout_seconds
if timeout is None:
result = await task
else:
# asyncio.wait (unlike wait_for) does NOT cancel on timeout — the
# task keeps running in the background.
await asyncio.wait({task}, timeout=timeout)
try:
if timeout is None:
result = await task
else:
# asyncio.wait (unlike wait_for) does NOT cancel on timeout — the
# task keeps running in the background.
await asyncio.wait({task}, timeout=timeout)
if not task.done():
bg_id = _register_background_task(task, base_tool.name)
logger.warning(
"Tool %s exceeded %ss budget — parked as "
"background_id=%s (args=%s)",
base_tool.name,
timeout,
bg_id,
_redact_args_for_log(args),
)
return _tool_background_result(base_tool.name, timeout, bg_id)
# Completed within budget — .result() re-raises any exception.
result = task.result()
except asyncio.CancelledError:
# The handler itself was cancelled (e.g. stream teardown) mid-wait.
# Cancel the child so it doesn't keep running untracked — the
# registry hasn't seen it yet, so cancel_all_background_tasks
# couldn't clean it up.
if not task.done():
bg_id = _register_background_task(task, base_tool.name)
logger.warning(
"Tool %s exceeded %ss budget — parked as background_id=%s " "(args=%s)",
base_tool.name,
timeout,
bg_id,
_redact_args_for_log(args),
)
return _tool_background_result(base_tool.name, timeout, bg_id)
# Completed within budget — .result() re-raises any exception.
result = task.result()
task.cancel()
raise
text = (
result.output if isinstance(result.output, str) else json.dumps(result.output)

View File

@@ -466,6 +466,38 @@ class TestToolTimeout:
assert result["isError"] is False
assert "completed" in result["content"][0]["text"]
@pytest.mark.asyncio
async def test_handler_cancellation_cancels_child_task(self):
"""If the handler itself is cancelled before the tool completes,
the child task is cancelled too (no leak into the background
registry, since it wasn't parked yet)."""
import contextlib
mock_tool = _make_mock_tool("slow_tool", timeout_seconds=60)
child_cancelled = asyncio.Event()
async def hang_until_cancelled(*_args, **_kwargs):
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
child_cancelled.set()
raise
mock_tool.execute = AsyncMock(side_effect=hang_until_cancelled)
from backend.copilot.sdk.tool_adapter import _execute_tool_sync
outer_task = asyncio.create_task(
_execute_tool_sync(mock_tool, "u", _make_mock_session(), {})
)
# Let the handler start waiting on the child.
await asyncio.sleep(0.05)
outer_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await outer_task
await asyncio.sleep(0)
assert child_cancelled.is_set()
@pytest.mark.asyncio
async def test_fast_tool_within_timeout_succeeds(self):
"""Tools that complete well under the timeout are unaffected."""

View File

@@ -8,9 +8,7 @@ stays in control rather than the handler making an irreversible choice.
import asyncio
import logging
from typing import Any, Literal
from pydantic import Field
from typing import Any
from backend.copilot.model import ChatSession
from backend.copilot.sdk.background_registry import (
@@ -22,26 +20,11 @@ from backend.copilot.sdk.background_registry import (
)
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
from .models import BackgroundToolStatus, ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class BackgroundToolStatus(ToolResponseBase):
"""Status of a backgrounded tool call."""
type: ResponseType = ResponseType.MCP_TOOL_OUTPUT
status: Literal["completed", "still_running", "cancelled", "error"] = Field(
description="Current state of the background task."
)
tool: str = Field(description="The name of the originally-backgrounded tool.")
background_id: str
output: Any | None = Field(
default=None, description="Tool output when status=completed."
)
waited_seconds: int | None = Field(default=None)
class CheckBackgroundToolTool(BaseTool):
"""Inspect, wait on, or cancel a backgrounded tool call."""
@@ -207,6 +190,17 @@ def _status_from_finished_task(
)
result = task.result()
# A tool can complete with success=False without raising — preserve
# that as status="error" so the agent doesn't treat it as a win.
if not result.success:
return BackgroundToolStatus(
message=f"'{tool_name}' completed with an error.",
session_id=session.session_id,
status="error",
tool=tool_name,
background_id=background_id,
output=result.output,
)
return BackgroundToolStatus(
message=f"'{tool_name}' completed.",
session_id=session.session_id,

View File

@@ -1,6 +1,7 @@
"""Tests for CheckBackgroundToolTool."""
import asyncio
import contextlib
from unittest.mock import MagicMock
import pytest
@@ -11,7 +12,8 @@ from backend.copilot.sdk.background_registry import (
register_background_task,
)
from .check_background_tool import BackgroundToolStatus, CheckBackgroundToolTool
from .check_background_tool import CheckBackgroundToolTool
from .models import BackgroundToolStatus
def _make_session() -> MagicMock:
@@ -78,10 +80,8 @@ class TestCheckBackgroundTool:
assert response.background_id == bg_id
task.cancel()
try:
with contextlib.suppress(asyncio.CancelledError):
await task
except (asyncio.CancelledError, BaseException):
pass
@pytest.mark.asyncio
async def test_wait_returns_completed_when_task_finishes(self):
@@ -124,15 +124,11 @@ class TestCheckBackgroundTool:
assert response.waited_seconds == 1
task.cancel()
try:
with contextlib.suppress(asyncio.CancelledError):
await task
except (asyncio.CancelledError, BaseException):
pass
@pytest.mark.asyncio
async def test_cancel_true_cancels_and_removes_from_registry(self):
import contextlib
observed_cancel = asyncio.Event()
async def stays_until_cancelled():
@@ -189,3 +185,31 @@ class TestCheckBackgroundTool:
assert isinstance(response, BackgroundToolStatus)
assert response.status == "error"
assert "boom" in response.message
@pytest.mark.asyncio
async def test_finished_task_with_success_false_reports_error(self):
"""A tool that completes with success=False (without raising) is
reported as status='error', not 'completed', so the agent doesn't
treat it as a win."""
async def finish_with_failure():
return StreamToolOutputAvailable(
toolCallId="tc-1",
output="partial",
toolName="broken_tool",
success=False,
)
task = asyncio.create_task(finish_with_failure())
await task
bg_id = register_background_task(task, "broken_tool")
tool = CheckBackgroundToolTool()
response = await tool._execute(
user_id="u",
session=_make_session(),
background_id=bg_id,
)
assert isinstance(response, BackgroundToolStatus)
assert response.status == "error"
assert response.output == "partial"

View File

@@ -259,6 +259,22 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None
class BackgroundToolStatus(ToolResponseBase):
"""Status of a backgrounded tool call, returned by ``check_background_tool``."""
type: ResponseType = ResponseType.MCP_TOOL_OUTPUT
status: Literal["completed", "still_running", "cancelled", "error"] = Field(
description="Current state of the background task."
)
tool: str = Field(description="The name of the originally-backgrounded tool.")
background_id: str
output: Any | None = Field(
default=None,
description="Tool output when status=completed or status=error.",
)
waited_seconds: int | None = Field(default=None)
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""