diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index eda88f23fe..82b50e4166 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -97,6 +97,15 @@ class ChatConfig(BaseSettings): default=True, description="Use Claude Agent SDK for chat completions", ) + sdk_model: str | None = Field( + default=None, + description="Model for SDK path. If None, derives from the `model` field " + "by stripping the OpenRouter provider prefix.", + ) + sdk_max_budget_usd: float | None = Field( + default=None, + description="Max budget in USD per SDK session (None = unlimited)", + ) sdk_max_buffer_size: int = Field( default=10 * 1024 * 1024, # 10MB (default SDK is 1MB) description="Max buffer size in bytes for SDK JSON message parsing. " diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py b/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py index c0ef9f531d..34abea6768 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py @@ -6,7 +6,6 @@ directly when the Claude Agent SDK is not available. import json import logging -import os import uuid from collections.abc import AsyncGenerator from typing import Any, cast @@ -46,17 +45,27 @@ async def stream_with_anthropic( """ import anthropic - # Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys - api_key = os.getenv("ANTHROPIC_API_KEY") + # Use config.api_key (CHAT_API_KEY > OPEN_ROUTER_API_KEY > OPENAI_API_KEY) + # with config.base_url for OpenRouter routing — matching the non-SDK path. + api_key = config.api_key if not api_key: yield StreamError( - errorText="ANTHROPIC_API_KEY not configured for fallback", + errorText="No API key configured (set CHAT_API_KEY or OPENAI_API_KEY)", code="config_error", ) yield StreamFinish() return - client = anthropic.AsyncAnthropic(api_key=api_key) + # Build kwargs for the Anthropic client — use base_url if configured + client_kwargs: dict[str, Any] = {"api_key": api_key} + if config.base_url: + # Strip /v1 suffix — Anthropic SDK adds its own version path + base = config.base_url.rstrip("/") + if base.endswith("/v1"): + base = base[:-3] + client_kwargs["base_url"] = base + + client = anthropic.AsyncAnthropic(**client_kwargs) tool_definitions = get_tool_definitions() tool_handlers = get_tool_handlers() diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py index f15c27565e..a6b7ea9d7b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py @@ -33,6 +33,7 @@ from backend.api.features.chat.response_model import ( StreamToolInputAvailable, StreamToolInputStart, StreamToolOutputAvailable, + StreamUsage, ) from backend.api.features.chat.sdk.tool_adapter import ( MCP_TOOL_PREFIX, @@ -148,6 +149,19 @@ class SDKResponseAdapter: responses.append(StreamFinishStep()) self.step_open = False + # Emit token usage if the SDK reported it + usage = getattr(sdk_message, "usage", None) or {} + if usage: + input_tokens = usage.get("input_tokens", 0) + output_tokens = usage.get("output_tokens", 0) + responses.append( + StreamUsage( + promptTokens=input_tokens, + completionTokens=output_tokens, + totalTokens=input_tokens + output_tokens, + ) + ) + if sdk_message.subtype == "success": responses.append(StreamFinish()) elif sdk_message.subtype in ("error", "error_during_execution"): diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py index 19cb3ff99e..f65239ab47 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -15,6 +15,7 @@ from ..config import ChatConfig from ..model import ( ChatMessage, ChatSession, + Usage, get_chat_session, update_session_title, upsert_chat_session, @@ -27,6 +28,7 @@ from ..response_model import ( StreamTextDelta, StreamToolInputAvailable, StreamToolOutputAvailable, + StreamUsage, ) from ..service import _build_system_prompt, _generate_session_title from ..tracking import track_user_message @@ -64,6 +66,41 @@ interpreters (python, node) are NOT available. """ +def _resolve_sdk_model() -> str | None: + """Resolve the model name for the SDK CLI. + + Uses ``config.sdk_model`` if set, otherwise derives from ``config.model`` + by stripping the OpenRouter provider prefix (e.g., + ``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``). + """ + if config.sdk_model: + return config.sdk_model + model = config.model + if "/" in model: + return model.split("/", 1)[1] + return model + + +def _build_sdk_env() -> dict[str, str]: + """Build env vars for the SDK CLI process. + + Routes API calls through OpenRouter (or a custom base_url) using + the same ``config.api_key`` / ``config.base_url`` as the non-SDK path. + This gives per-call token and cost tracking on the OpenRouter dashboard. + """ + env: dict[str, str] = {} + if config.api_key and config.base_url: + # Strip /v1 suffix — SDK expects the base URL without a version path + base = config.base_url.rstrip("/") + if base.endswith("/v1"): + base = base[:-3] + env["ANTHROPIC_BASE_URL"] = base + env["ANTHROPIC_AUTH_TOKEN"] = config.api_key + # Must be explicitly empty to prevent the CLI from using a local key + env["ANTHROPIC_API_KEY"] = "" + return env + + def _make_sdk_cwd(session_id: str) -> str: """Create a safe, session-specific working directory path. @@ -317,8 +354,10 @@ async def stream_chat_completion_sdk( mcp_server = create_copilot_mcp_server() + sdk_model = _resolve_sdk_model() + # Initialize Langfuse tracing (no-op if not configured) - tracer = TracedSession(session_id, user_id, system_prompt) + tracer = TracedSession(session_id, user_id, system_prompt, model=sdk_model) # Merge security hooks with optional tracing hooks security_hooks = create_security_hooks(user_id, sdk_cwd=sdk_cwd) @@ -332,6 +371,10 @@ async def stream_chat_completion_sdk( hooks=combined_hooks, # type: ignore[arg-type] cwd=sdk_cwd, max_buffer_size=config.sdk_max_buffer_size, + model=sdk_model, + env=_build_sdk_env(), + user=user_id or None, + max_budget_usd=config.sdk_max_budget_usd, ) adapter = SDKResponseAdapter(message_id=message_id) @@ -438,6 +481,15 @@ async def stream_chat_completion_sdk( ) has_tool_results = True + elif isinstance(response, StreamUsage): + session.usage.append( + Usage( + prompt_tokens=response.promptTokens, + completion_tokens=response.completionTokens, + total_tokens=response.totalTokens, + ) + ) + elif isinstance(response, StreamFinish): stream_completed = True diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tracing.py b/autogpt_platform/backend/backend/api/features/chat/sdk/tracing.py index 4c453a787d..e97fc4d8d5 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/tracing.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/tracing.py @@ -77,10 +77,12 @@ class TracedSession: session_id: str, user_id: str | None = None, system_prompt: str | None = None, + model: str | None = None, ): self.session_id = session_id self.user_id = user_id self.system_prompt = system_prompt + self.model = model self.enabled = _is_langfuse_configured() # Internal state @@ -265,7 +267,7 @@ class TracedSession: if usage or result.total_cost_usd: self._trace.generation( name="claude-sdk-completion", - model="claude-sonnet-4-20250514", # SDK default model + model=self.model or "claude-sonnet-4-20250514", usage=( { "input": usage.get("input_tokens", 0), @@ -313,6 +315,7 @@ async def traced_session( session_id: str, user_id: str | None = None, system_prompt: str | None = None, + model: str | None = None, ): """Convenience async context manager for tracing SDK sessions. @@ -322,7 +325,7 @@ async def traced_session( async for msg in client.receive_messages(): tracer.log_sdk_message(msg) """ - tracer = TracedSession(session_id, user_id, system_prompt) + tracer = TracedSession(session_id, user_id, system_prompt, model=model) async with tracer: yield tracer