From 749a78723a2d1adb96652e31be0f876a0cc3639f Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Thu, 12 Feb 2026 19:26:29 +0400 Subject: [PATCH] refactor(chat/sdk): deduplicate code and remove anthropic fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract shared `make_session_path()` into sandbox.py (single source of truth for workspace path sanitization), replace duplicate in service.py - Delete anthropic_fallback.py (~360 lines) — redundant third code path; routes.py already falls back to non-SDK service - Remove dead `traced_session()`, `get_tool_definitions()`, `get_tool_handlers()`, `_current_tool_call_id` ContextVar - Fix hardcoded model in tracing — pass actual resolved model - Fix inconsistent model name splitting in anthropic fallback --- .../features/chat/sdk/anthropic_fallback.py | 356 ------------------ .../backend/api/features/chat/sdk/service.py | 42 +-- .../api/features/chat/sdk/tool_adapter.py | 45 +-- .../backend/api/features/chat/sdk/tracing.py | 29 +- .../api/features/chat/tools/sandbox.py | 35 +- 5 files changed, 47 insertions(+), 460 deletions(-) delete mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py b/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py deleted file mode 100644 index 3dbe22bbb7..0000000000 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py +++ /dev/null @@ -1,356 +0,0 @@ -"""Anthropic SDK fallback implementation. - -This module provides the fallback streaming implementation using the Anthropic SDK -directly when the Claude Agent SDK is not available. -""" - -import json -import logging -import uuid -from collections.abc import AsyncGenerator -from typing import Any, cast - -import anthropic - -from ..config import ChatConfig -from ..model import ChatMessage, ChatSession -from ..response_model import ( - StreamBaseResponse, - StreamError, - StreamFinish, - StreamTextDelta, - StreamTextEnd, - StreamTextStart, - StreamToolInputAvailable, - StreamToolInputStart, - StreamToolOutputAvailable, -) -from .tool_adapter import get_tool_definitions, get_tool_handlers - -logger = logging.getLogger(__name__) -config = ChatConfig() - -# Maximum tool-call iterations before stopping to prevent infinite loops -_MAX_TOOL_ITERATIONS = 10 - - -async def stream_with_anthropic( - session: ChatSession, - system_prompt: str, - text_block_id: str, -) -> AsyncGenerator[StreamBaseResponse, None]: - """Stream using Anthropic SDK directly with tool calling support. - - This function accumulates messages into the session for persistence. - The caller should NOT yield an additional StreamFinish - this function handles it. - """ - # Use config.api_key (CHAT_API_KEY > OPEN_ROUTER_API_KEY > OPENAI_API_KEY) - # with config.base_url for OpenRouter routing — matching the non-SDK path. - api_key = config.api_key - if not api_key: - yield StreamError( - errorText="No API key configured (set CHAT_API_KEY or OPENAI_API_KEY)", - code="config_error", - ) - yield StreamFinish() - return - - # Build kwargs for the Anthropic client — use base_url if configured - client_kwargs: dict[str, Any] = {"api_key": api_key} - if config.base_url: - # Strip /v1 suffix — Anthropic SDK adds its own version path - base = config.base_url.rstrip("/") - if base.endswith("/v1"): - base = base[:-3] - client_kwargs["base_url"] = base - - client = anthropic.AsyncAnthropic(**client_kwargs) - tool_definitions = get_tool_definitions() - tool_handlers = get_tool_handlers() - - anthropic_tools = [ - { - "name": t["name"], - "description": t["description"], - "input_schema": t["inputSchema"], - } - for t in tool_definitions - ] - - anthropic_messages = _convert_session_to_anthropic(session) - - if not anthropic_messages or anthropic_messages[-1]["role"] != "user": - anthropic_messages.append( - {"role": "user", "content": "Continue with the task."} - ) - - has_started_text = False - accumulated_text = "" - accumulated_tool_calls: list[dict[str, Any]] = [] - - for _ in range(_MAX_TOOL_ITERATIONS): - try: - async with client.messages.stream( - model=( - config.model.split("/")[-1] if "/" in config.model else config.model - ), - max_tokens=4096, - system=system_prompt, - messages=cast(Any, anthropic_messages), - tools=cast(Any, anthropic_tools) if anthropic_tools else [], - ) as stream: - async for event in stream: - if event.type == "content_block_start": - block = event.content_block - if hasattr(block, "type"): - if block.type == "text" and not has_started_text: - yield StreamTextStart(id=text_block_id) - has_started_text = True - elif block.type == "tool_use": - yield StreamToolInputStart( - toolCallId=block.id, toolName=block.name - ) - - elif event.type == "content_block_delta": - delta = event.delta - if hasattr(delta, "type") and delta.type == "text_delta": - accumulated_text += delta.text - yield StreamTextDelta(id=text_block_id, delta=delta.text) - - final_message = await stream.get_final_message() - - if final_message.stop_reason == "tool_use": - if has_started_text: - yield StreamTextEnd(id=text_block_id) - has_started_text = False - text_block_id = str(uuid.uuid4()) - - tool_results = [] - assistant_content: list[dict[str, Any]] = [] - - for block in final_message.content: - if block.type == "text": - assistant_content.append( - {"type": "text", "text": block.text} - ) - elif block.type == "tool_use": - assistant_content.append( - { - "type": "tool_use", - "id": block.id, - "name": block.name, - "input": block.input, - } - ) - - # Track tool call for session persistence - accumulated_tool_calls.append( - { - "id": block.id, - "type": "function", - "function": { - "name": block.name, - "arguments": json.dumps( - block.input - if isinstance(block.input, dict) - else {} - ), - }, - } - ) - - yield StreamToolInputAvailable( - toolCallId=block.id, - toolName=block.name, - input=( - block.input if isinstance(block.input, dict) else {} - ), - ) - - output, is_error = await _execute_tool( - block.name, block.input, tool_handlers - ) - - yield StreamToolOutputAvailable( - toolCallId=block.id, - toolName=block.name, - output=output, - success=not is_error, - ) - - # Save tool result to session - session.messages.append( - ChatMessage( - role="tool", - content=output, - tool_call_id=block.id, - ) - ) - - tool_results.append( - { - "type": "tool_result", - "tool_use_id": block.id, - "content": output, - "is_error": is_error, - } - ) - - # Save assistant message with tool calls to session - session.messages.append( - ChatMessage( - role="assistant", - content=accumulated_text or None, - tool_calls=( - accumulated_tool_calls - if accumulated_tool_calls - else None - ), - ) - ) - # Reset for next iteration - accumulated_text = "" - accumulated_tool_calls = [] - - anthropic_messages.append( - {"role": "assistant", "content": assistant_content} - ) - anthropic_messages.append({"role": "user", "content": tool_results}) - continue - - else: - if has_started_text: - yield StreamTextEnd(id=text_block_id) - - # Save final assistant response to session - if accumulated_text: - session.messages.append( - ChatMessage(role="assistant", content=accumulated_text) - ) - - yield StreamFinish() - return - - except Exception as e: - logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True) - yield StreamError( - errorText="An error occurred. Please try again.", - code="anthropic_error", - ) - yield StreamFinish() - return - - yield StreamError(errorText="Max tool iterations reached", code="max_iterations") - yield StreamFinish() - - -def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]: - """Convert session messages to Anthropic format. - - Handles merging consecutive same-role messages (Anthropic requires alternating roles). - """ - messages: list[dict[str, Any]] = [] - - for msg in session.messages: - if msg.role == "user": - new_msg = {"role": "user", "content": msg.content or ""} - elif msg.role == "assistant": - content: list[dict[str, Any]] = [] - if msg.content: - content.append({"type": "text", "text": msg.content}) - if msg.tool_calls: - for tc in msg.tool_calls: - func = tc.get("function", {}) - args = func.get("arguments", {}) - if isinstance(args, str): - try: - args = json.loads(args) - except json.JSONDecodeError: - args = {} - content.append( - { - "type": "tool_use", - "id": tc.get("id", str(uuid.uuid4())), - "name": func.get("name", ""), - "input": args, - } - ) - if content: - new_msg = {"role": "assistant", "content": content} - else: - continue # Skip empty assistant messages - elif msg.role == "tool": - new_msg = { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": msg.tool_call_id or "", - "content": msg.content or "", - } - ], - } - else: - continue - - messages.append(new_msg) - - # Merge consecutive same-role messages (Anthropic requires alternating roles) - return _merge_consecutive_roles(messages) - - -def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Merge consecutive messages with the same role. - - Anthropic API requires alternating user/assistant roles. - """ - if not messages: - return [] - - merged: list[dict[str, Any]] = [] - for msg in messages: - if merged and merged[-1]["role"] == msg["role"]: - # Merge with previous message - prev_content = merged[-1]["content"] - new_content = msg["content"] - - # Normalize both to list-of-blocks form - if isinstance(prev_content, str): - prev_content = [{"type": "text", "text": prev_content}] - if isinstance(new_content, str): - new_content = [{"type": "text", "text": new_content}] - - # Ensure both are lists - if not isinstance(prev_content, list): - prev_content = [prev_content] - if not isinstance(new_content, list): - new_content = [new_content] - - merged[-1]["content"] = prev_content + new_content - else: - merged.append(msg) - - return merged - - -async def _execute_tool( - tool_name: str, tool_input: Any, handlers: dict[str, Any] -) -> tuple[str, bool]: - """Execute a tool and return (output, is_error).""" - handler = handlers.get(tool_name) - if not handler: - return f"Unknown tool: {tool_name}", True - - try: - result = await handler(tool_input) - # Safely extract output - handle empty or missing content - content = result.get("content") or [] - if content and isinstance(content, list) and len(content) > 0: - first_item = content[0] - output = first_item.get("text", "") if isinstance(first_item, dict) else "" - else: - output = "" - is_error = result.get("isError", False) - return output, is_error - except Exception as e: - return f"Error: {str(e)}", True diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py index 0ab376cbef..a3580d2e64 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -4,7 +4,6 @@ import asyncio import json import logging import os -import re import uuid from collections.abc import AsyncGenerator from typing import Any @@ -29,8 +28,8 @@ from ..response_model import ( StreamToolOutputAvailable, ) from ..service import _build_system_prompt, _generate_session_title +from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path from ..tracking import track_user_message -from .anthropic_fallback import stream_with_anthropic from .response_adapter import SDKResponseAdapter from .security_hooks import create_security_hooks from .tool_adapter import ( @@ -47,7 +46,7 @@ config = ChatConfig() _background_tasks: set[asyncio.Task[Any]] = set() -_SDK_CWD_PREFIX = "/tmp/copilot-" +_SDK_CWD_PREFIX = WORKSPACE_PREFIX # Appended to the system prompt to inform the agent about Bash restrictions. # The SDK already describes each tool (Read, Write, Edit, Glob, Grep, Bash), @@ -109,24 +108,12 @@ def _build_sdk_env() -> dict[str, str]: def _make_sdk_cwd(session_id: str) -> str: """Create a safe, session-specific working directory path. - Sanitizes session_id, then validates the resulting path stays under /tmp/ - using normpath + startswith (the pattern CodeQL recognises as a sanitizer). + Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path` + (single source of truth for path sanitization) and adds a defence-in-depth + assertion. """ - # Step 1: Sanitize - only allow alphanumeric and hyphens - safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id) - if not safe_id: - raise ValueError("Session ID is empty after sanitization") - - # Step 2: Construct path with known-safe prefix - cwd = os.path.normpath(f"{_SDK_CWD_PREFIX}{safe_id}") - - # Step 3: Validate the path is still under our prefix (prevent traversal) - if not cwd.startswith(_SDK_CWD_PREFIX): - raise ValueError(f"Session path escaped prefix: {cwd}") - - # Step 4: Additional assertion for defense-in-depth + cwd = make_session_path(session_id) assert cwd.startswith("/tmp/copilot-"), f"Path validation failed: {cwd}" - return cwd @@ -340,7 +327,6 @@ async def stream_chat_completion_sdk( ) system_prompt += _SDK_TOOL_SUPPLEMENT message_id = str(uuid.uuid4()) - text_block_id = str(uuid.uuid4()) task_id = str(uuid.uuid4()) yield StreamStart(messageId=message_id, taskId=task_id) @@ -351,7 +337,7 @@ async def stream_chat_completion_sdk( sdk_cwd = _make_sdk_cwd(session_id) os.makedirs(sdk_cwd, exist_ok=True) - set_execution_context(user_id, session, None) + set_execution_context(user_id, session) try: try: @@ -371,7 +357,7 @@ async def stream_chat_completion_sdk( sdk_model = _resolve_sdk_model() # Initialize Langfuse tracing (no-op if not configured) - tracer = TracedSession(session_id, user_id, system_prompt) + tracer = TracedSession(session_id, user_id, system_prompt, model=sdk_model) # Merge security hooks with optional tracing hooks security_hooks = create_security_hooks( @@ -510,15 +496,11 @@ async def stream_chat_completion_sdk( session.messages.append(assistant_response) except ImportError: - logger.warning( - "[SDK] claude-agent-sdk not available, using Anthropic fallback" + raise RuntimeError( + "claude-agent-sdk is not installed. " + "Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) " + "to use the OpenAI-compatible fallback." ) - async for response in stream_with_anthropic( - session, system_prompt, text_block_id - ): - if isinstance(response, StreamFinish): - stream_completed = True - yield response await upsert_chat_session(session) logger.debug( diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py index 3e78d11b6f..44fa63fd07 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py @@ -29,10 +29,6 @@ _current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default _current_session: ContextVar[ChatSession | None] = ContextVar( "current_session", default=None ) -_current_tool_call_id: ContextVar[str | None] = ContextVar( - "current_tool_call_id", default=None -) - # Stash for MCP tool outputs before the SDK potentially truncates them. # Keyed by tool_name → full output string. Consumed (popped) by the # response adapter when it builds StreamToolOutputAvailable. @@ -44,7 +40,6 @@ _pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar( def set_execution_context( user_id: str | None, session: ChatSession, - tool_call_id: str | None = None, ) -> None: """Set the execution context for tool calls. @@ -53,16 +48,14 @@ def set_execution_context( """ _current_user_id.set(user_id) _current_session.set(session) - _current_tool_call_id.set(tool_call_id) _pending_tool_outputs.set({}) -def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]: +def get_execution_context() -> tuple[str | None, ChatSession | None]: """Get the current execution context.""" return ( _current_user_id.get(), _current_session.get(), - _current_tool_call_id.get(), ) @@ -91,7 +84,7 @@ def create_tool_handler(base_tool: BaseTool): async def tool_handler(args: dict[str, Any]) -> dict[str, Any]: """Execute the wrapped tool and return MCP-formatted response.""" - user_id, session, tool_call_id = get_execution_context() + user_id, session = get_execution_context() if session is None: return { @@ -112,7 +105,7 @@ def create_tool_handler(base_tool: BaseTool): try: # Call the existing tool's execute method # Generate unique tool_call_id per invocation for proper correlation - effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}" + effective_id = f"sdk-{uuid.uuid4().hex[:12]}" result = await base_tool.execute( user_id=user_id, session=session, @@ -168,38 +161,6 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]: } -def get_tool_definitions() -> list[dict[str, Any]]: - """Get all tool definitions in MCP format. - - Returns a list of tool definitions that can be used with - create_sdk_mcp_server or as raw tool definitions. - """ - tool_definitions = [] - - for tool_name, base_tool in TOOL_REGISTRY.items(): - tool_def = { - "name": tool_name, - "description": base_tool.description, - "inputSchema": _build_input_schema(base_tool), - } - tool_definitions.append(tool_def) - - return tool_definitions - - -def get_tool_handlers() -> dict[str, Any]: - """Get all tool handlers mapped by name. - - Returns a dictionary mapping tool names to their handler functions. - """ - handlers = {} - - for tool_name, base_tool in TOOL_REGISTRY.items(): - handlers[tool_name] = create_tool_handler(base_tool) - - return handlers - - async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]: """Read a file with optional offset/limit. Restricted to SDK working directory. diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tracing.py b/autogpt_platform/backend/backend/api/features/chat/sdk/tracing.py index 4c453a787d..7052ea73a3 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/tracing.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/tracing.py @@ -4,20 +4,18 @@ This module provides modular, non-invasive observability for SDK sessions. All tracing is opt-in (only active when Langfuse credentials are configured) and designed to not affect the core execution flow. -Usage: +Usage:: + async with TracedSession(session_id, user_id) as tracer: - # Your SDK code here tracer.log_user_message(message) async for sdk_msg in client.receive_messages(): tracer.log_sdk_message(sdk_msg) - tracer.log_result(result_message) """ from __future__ import annotations import logging import time -from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -77,10 +75,12 @@ class TracedSession: session_id: str, user_id: str | None = None, system_prompt: str | None = None, + model: str | None = None, ): self.session_id = session_id self.user_id = user_id self.system_prompt = system_prompt + self.model = model self.enabled = _is_langfuse_configured() # Internal state @@ -265,7 +265,7 @@ class TracedSession: if usage or result.total_cost_usd: self._trace.generation( name="claude-sdk-completion", - model="claude-sonnet-4-20250514", # SDK default model + model=self.model or "claude-sonnet-4-20250514", usage=( { "input": usage.get("input_tokens", 0), @@ -308,25 +308,6 @@ class TracedSession: return str(content) if content else "" -@asynccontextmanager -async def traced_session( - session_id: str, - user_id: str | None = None, - system_prompt: str | None = None, -): - """Convenience async context manager for tracing SDK sessions. - - Usage: - async with traced_session(session_id, user_id) as tracer: - tracer.log_user_message(message) - async for msg in client.receive_messages(): - tracer.log_sdk_message(msg) - """ - tracer = TracedSession(session_id, user_id, system_prompt) - async with tracer: - yield tracer - - def create_tracing_hooks(tracer: TracedSession) -> dict[str, Any]: """Create SDK hooks for fine-grained Langfuse tracing. diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py b/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py index 095c296f41..9ac56eda20 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py @@ -37,23 +37,42 @@ def has_network_sandbox() -> bool: return _UNSHARE_AVAILABLE -_WORKSPACE_PREFIX = "/tmp/copilot-" +WORKSPACE_PREFIX = "/tmp/copilot-" -def get_workspace_dir(session_id: str) -> str: - """Get or create the workspace directory for a session. +def make_session_path(session_id: str) -> str: + """Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`. - Uses the same path as the SDK's ``_make_sdk_cwd()`` so that - python_exec/bash_exec share the workspace with the SDK file tools. + Shared by both the SDK working-directory setup and the sandbox tools so + they always resolve to the same directory for a given session. + + Steps: + 1. Strip all characters except ``[A-Za-z0-9-]``. + 2. Construct ``/tmp/copilot-``. + 3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised + sanitizer) to prevent path traversal. + + Raises: + ValueError: If the resulting path escapes the prefix. """ import re safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id) if not safe_id: safe_id = "default" - workspace = os.path.normpath(f"{_WORKSPACE_PREFIX}{safe_id}") - if not workspace.startswith(_WORKSPACE_PREFIX): - raise ValueError(f"Session path escaped prefix: {workspace}") + path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}") + if not path.startswith(WORKSPACE_PREFIX): + raise ValueError(f"Session path escaped prefix: {path}") + return path + + +def get_workspace_dir(session_id: str) -> str: + """Get or create the workspace directory for a session. + + Uses :func:`make_session_path` — the same path the SDK uses — so that + python_exec / bash_exec share the workspace with the SDK file tools. + """ + workspace = make_session_path(session_id) os.makedirs(workspace, exist_ok=True) return workspace