Compare commits
6 Commits
test-scree
...
chore/sdk-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6dc0b6cffd | ||
|
|
a6e306d28a | ||
|
|
d6f0fcb052 | ||
|
|
feb247d56e | ||
|
|
fdb3590693 | ||
|
|
b319c26cab |
100
autogpt_platform/analytics/queries/platform_cost_log.sql
Normal file
@@ -0,0 +1,100 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.platform_cost_log
|
||||
-- Looker source alias: ds115 | Charts: 0
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per platform cost log entry (last 90 days).
|
||||
-- Tracks real API spend at the call level: provider, model,
|
||||
-- token counts (including Anthropic cache tokens), cost in
|
||||
-- microdollars, and the block/execution that incurred the cost.
|
||||
-- Joins the User table to provide email for per-user breakdowns.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.PlatformCostLog — Per-call cost records
|
||||
-- platform.User — User email
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Log entry UUID
|
||||
-- createdAt TIMESTAMPTZ When the cost was recorded
|
||||
-- userId TEXT User who incurred the cost (nullable)
|
||||
-- email TEXT User email (nullable)
|
||||
-- graphExecId TEXT Graph execution UUID (nullable)
|
||||
-- nodeExecId TEXT Node execution UUID (nullable)
|
||||
-- blockName TEXT Block that made the API call (nullable)
|
||||
-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic')
|
||||
-- model TEXT Model name (nullable)
|
||||
-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc.
|
||||
-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD)
|
||||
-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000)
|
||||
-- inputTokens INT Prompt/input tokens (nullable)
|
||||
-- outputTokens INT Completion/output tokens (nullable)
|
||||
-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable)
|
||||
-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable)
|
||||
-- totalTokens INT inputTokens + outputTokens (nullable if either is null)
|
||||
-- duration FLOAT API call duration in seconds (nullable)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend by provider (last 90 days)
|
||||
-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- GROUP BY 1 ORDER BY total_usd DESC;
|
||||
--
|
||||
-- -- Spend by model
|
||||
-- SELECT provider, model, SUM("costUsd") AS total_usd,
|
||||
-- SUM("inputTokens") AS input_tokens,
|
||||
-- SUM("outputTokens") AS output_tokens
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE model IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_usd DESC;
|
||||
--
|
||||
-- -- Top 20 users by spend
|
||||
-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE "userId" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20;
|
||||
--
|
||||
-- -- Daily spend trend
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
|
||||
-- SUM("costUsd") AS daily_usd,
|
||||
-- COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- Cache hit rate for Anthropic (cache reads vs total reads)
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
|
||||
-- SUM("cacheReadTokens")::float /
|
||||
-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE provider = 'anthropic'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
p."id" AS id,
|
||||
p."createdAt" AS createdAt,
|
||||
p."userId" AS userId,
|
||||
u."email" AS email,
|
||||
p."graphExecId" AS graphExecId,
|
||||
p."nodeExecId" AS nodeExecId,
|
||||
p."blockName" AS blockName,
|
||||
p."provider" AS provider,
|
||||
p."model" AS model,
|
||||
p."trackingType" AS trackingType,
|
||||
p."costMicrodollars" AS costMicrodollars,
|
||||
p."costMicrodollars"::float / 1000000.0 AS costUsd,
|
||||
p."inputTokens" AS inputTokens,
|
||||
p."outputTokens" AS outputTokens,
|
||||
p."cacheReadTokens" AS cacheReadTokens,
|
||||
p."cacheCreationTokens" AS cacheCreationTokens,
|
||||
CASE
|
||||
WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL
|
||||
THEN p."inputTokens" + p."outputTokens"
|
||||
ELSE NULL
|
||||
END AS totalTokens,
|
||||
p."duration" AS duration
|
||||
FROM platform."PlatformCostLog" p
|
||||
LEFT JOIN platform."User" u ON u."id" = p."userId"
|
||||
WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated, Any, cast
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
@@ -29,12 +29,6 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.pending_messages import (
|
||||
MAX_PENDING_MESSAGES,
|
||||
PendingMessage,
|
||||
PendingMessageContext,
|
||||
push_pending_message,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
@@ -90,27 +84,6 @@ _UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
# Call-frequency cap for the pending-message endpoint. The token-budget
|
||||
# check in queue_pending_message guards against overspend, but does not
|
||||
# prevent rapid-fire pushes from a client with a large budget. This cap
|
||||
# (per user, per 60-second window) limits the rate a caller can hammer the
|
||||
# endpoint independently of token consumption.
|
||||
_PENDING_CALL_LIMIT = 30 # pushes per minute per user
|
||||
_PENDING_CALL_WINDOW_SECONDS = 60
|
||||
_PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
|
||||
|
||||
# Lua script for atomic INCR + conditional EXPIRE.
|
||||
# Using a single EVAL ensures the counter never persists without a TTL —
|
||||
# a bare INCR followed by a separate EXPIRE can leave the key without
|
||||
# an expiry if the process crashes between the two commands.
|
||||
_CALL_INCR_LUA = """
|
||||
local count = redis.call('INCR', KEYS[1])
|
||||
if count == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[1]))
|
||||
end
|
||||
return count
|
||||
"""
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
@@ -123,29 +96,6 @@ async def _validate_and_get_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _resolve_workspace_files(
|
||||
user_id: str,
|
||||
file_ids: list[str],
|
||||
) -> list[UserWorkspaceFile]:
|
||||
"""Filter *file_ids* to UUID-valid entries that exist in the caller's workspace.
|
||||
|
||||
Returns the matching ``UserWorkspaceFile`` records (empty list if none pass).
|
||||
Used by both the stream and pending-message endpoints to prevent callers from
|
||||
referencing other users' files.
|
||||
"""
|
||||
valid_ids = [fid for fid in file_ids if _UUID_RE.match(fid)]
|
||||
if not valid_ids:
|
||||
return []
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
return await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
@@ -169,64 +119,6 @@ class StreamChatRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class QueuePendingMessageRequest(BaseModel):
|
||||
"""Request model for queueing a message into an in-flight turn.
|
||||
|
||||
Unlike ``StreamChatRequest`` this endpoint does **not** start a new
|
||||
turn — the message is appended to a per-session pending buffer that
|
||||
the executor currently processing the turn will drain between tool
|
||||
rounds.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
message: str = Field(min_length=1, max_length=16_000)
|
||||
context: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="Optional page context: expected keys are 'url' and 'content'.",
|
||||
)
|
||||
file_ids: list[str] | None = Field(default=None, max_length=20)
|
||||
|
||||
@field_validator("context")
|
||||
@classmethod
|
||||
def _validate_context_length(
|
||||
cls, v: dict[str, str] | None
|
||||
) -> dict[str, str] | None:
|
||||
if v is None:
|
||||
return v
|
||||
# Cap context values to prevent LLM context-window stuffing via
|
||||
# large page payloads (url: 2 KB, content: 32 KB).
|
||||
_URL_LIMIT = 2_000
|
||||
_CONTENT_LIMIT = 32_000
|
||||
url = v.get("url", "")
|
||||
if len(url) > _URL_LIMIT:
|
||||
raise ValueError(
|
||||
f"context.url exceeds maximum length of {_URL_LIMIT} characters"
|
||||
)
|
||||
content = v.get("content", "")
|
||||
if len(content) > _CONTENT_LIMIT:
|
||||
raise ValueError(
|
||||
f"context.content exceeds maximum length of {_CONTENT_LIMIT} characters"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class QueuePendingMessageResponse(BaseModel):
|
||||
"""Response for the pending-message endpoint.
|
||||
|
||||
- ``buffer_length``: how many messages are now in the session's
|
||||
pending buffer (after this push)
|
||||
- ``max_buffer_length``: the per-session cap (server-side constant)
|
||||
- ``turn_in_flight``: ``True`` if a copilot turn was running when
|
||||
we checked — purely informational for UX feedback. Even when
|
||||
``False`` the message is still queued: the next turn drains it.
|
||||
"""
|
||||
|
||||
buffer_length: int
|
||||
max_buffer_length: int
|
||||
turn_in_flight: bool
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
@@ -894,21 +786,33 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids:
|
||||
files = await _resolve_workspace_files(user_id, request.file_ids)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
request.message += files_block
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
@@ -1108,129 +1012,6 @@ async def stream_chat_post(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/messages/pending",
|
||||
response_model=QueuePendingMessageResponse,
|
||||
status_code=202,
|
||||
)
|
||||
async def queue_pending_message(
|
||||
session_id: str,
|
||||
request: QueuePendingMessageRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Queue a new user message into an in-flight copilot turn.
|
||||
|
||||
When a user sends a follow-up message while a turn is still
|
||||
streaming, we don't want to block them or start a separate turn —
|
||||
this endpoint appends the message to a per-session pending buffer.
|
||||
The executor currently running the turn (baseline path) drains the
|
||||
buffer between tool-call rounds and appends the message to the
|
||||
conversation before the next LLM call. On the SDK path the buffer
|
||||
is drained at the *start* of the next turn (the long-lived
|
||||
``ClaudeSDKClient.receive_response`` iterator returns after a
|
||||
``ResultMessage`` so there is no safe point to inject mid-stream
|
||||
into an existing connection).
|
||||
|
||||
Returns 202. Enforces the same per-user daily/weekly token rate
|
||||
limit as the regular ``/stream`` endpoint so a client can't bypass
|
||||
it by batching messages through here.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
# Pre-turn rate-limit check — mirrors stream_chat_post. Without
|
||||
# this, a client could bypass per-turn token limits by batching
|
||||
# their extra context through this endpoint while a cheap stream
|
||||
# is in flight.
|
||||
# user_id is guaranteed non-empty by Security(auth.get_user_id) — no guard needed.
|
||||
try:
|
||||
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Call-frequency cap: prevent rapid-fire pushes that would bypass the
|
||||
# token-budget check (which only fires per-turn, not per-push).
|
||||
# Uses an atomic Lua EVAL (INCR + EXPIRE) so the key can never be
|
||||
# orphaned without a TTL; fails open if Redis is down.
|
||||
try:
|
||||
_redis = await get_redis_async()
|
||||
_call_key = f"{_PENDING_CALL_KEY_PREFIX}{user_id}"
|
||||
_call_count = int(
|
||||
await cast(
|
||||
"Any",
|
||||
_redis.eval(
|
||||
_CALL_INCR_LUA,
|
||||
1,
|
||||
_call_key,
|
||||
str(_PENDING_CALL_WINDOW_SECONDS),
|
||||
),
|
||||
)
|
||||
)
|
||||
if _call_count > _PENDING_CALL_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Too many pending messages: limit is {_PENDING_CALL_LIMIT} per {_PENDING_CALL_WINDOW_SECONDS}s",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass # Redis failure is non-fatal; fail open
|
||||
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
|
||||
# Sanitise file IDs to the user's own workspace so injection doesn't
|
||||
# surface other users' files. _resolve_workspace_files handles UUID
|
||||
# filtering and the workspace-scoped DB lookup.
|
||||
sanitized_file_ids: list[str] = []
|
||||
if request.file_ids:
|
||||
valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.match(fid))
|
||||
files = await _resolve_workspace_files(user_id, request.file_ids)
|
||||
sanitized_file_ids = [wf.id for wf in files]
|
||||
if len(sanitized_file_ids) != valid_id_count:
|
||||
logger.warning(
|
||||
"queue_pending_message: dropped %d file id(s) not in "
|
||||
"caller's workspace (session=%s)",
|
||||
valid_id_count - len(sanitized_file_ids),
|
||||
session_id,
|
||||
)
|
||||
|
||||
# Redis is the single source of truth for pending messages. We do
|
||||
# NOT persist to ``session.messages`` here — the drain-at-start
|
||||
# path in the baseline/SDK executor is the sole writer for pending
|
||||
# content. Persisting both here AND in the drain would cause
|
||||
# double injection (executor sees the message in ``session.messages``
|
||||
# *and* drains it from Redis) unless we also dedupe. The dedup in
|
||||
# ``maybe_append_user_message`` only checks trailing same-role
|
||||
# repeats, so relying on it is fragile. Keeping the endpoint
|
||||
# Redis-only avoids the whole consistency-bug class.
|
||||
pending = PendingMessage(
|
||||
content=request.message,
|
||||
file_ids=sanitized_file_ids,
|
||||
context=PendingMessageContext(**request.context) if request.context else None,
|
||||
)
|
||||
buffer_length = await push_pending_message(session_id, pending)
|
||||
|
||||
# Check whether a turn is currently running for UX feedback.
|
||||
active_session = await stream_registry.get_session(session_id)
|
||||
turn_in_flight = bool(active_session and active_session.status == "running")
|
||||
|
||||
return QueuePendingMessageResponse(
|
||||
buffer_length=buffer_length,
|
||||
max_buffer_length=MAX_PENDING_MESSAGES,
|
||||
turn_in_flight=turn_in_flight,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
|
||||
@@ -579,300 +579,3 @@ class TestStreamChatRequestModeValidation:
|
||||
|
||||
req = StreamChatRequest(message="hi")
|
||||
assert req.mode is None
|
||||
|
||||
|
||||
# ─── QueuePendingMessageRequest validation ────────────────────────────
|
||||
|
||||
|
||||
class TestQueuePendingMessageRequest:
|
||||
"""Unit tests for QueuePendingMessageRequest field validation."""
|
||||
|
||||
def test_accepts_valid_message(self) -> None:
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
req = QueuePendingMessageRequest(message="hello")
|
||||
assert req.message == "hello"
|
||||
|
||||
def test_rejects_empty_message(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
QueuePendingMessageRequest(message="")
|
||||
|
||||
def test_rejects_message_over_limit(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
QueuePendingMessageRequest(message="x" * 16_001)
|
||||
|
||||
def test_accepts_valid_context(self) -> None:
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
req = QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
context={"url": "https://example.com", "content": "page text"},
|
||||
)
|
||||
assert req.context is not None
|
||||
assert req.context["url"] == "https://example.com"
|
||||
|
||||
def test_rejects_context_url_over_limit(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError, match="url"):
|
||||
QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
context={"url": "https://example.com/" + "x" * 2_000},
|
||||
)
|
||||
|
||||
def test_rejects_context_content_over_limit(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError, match="content"):
|
||||
QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
context={"content": "x" * 32_001},
|
||||
)
|
||||
|
||||
def test_rejects_extra_fields(self) -> None:
|
||||
"""extra='forbid' should reject unknown fields."""
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
QueuePendingMessageRequest(message="hi", unknown_field="bad") # type: ignore[call-arg]
|
||||
|
||||
def test_accepts_up_to_20_file_ids(self) -> None:
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
req = QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||
)
|
||||
assert req.file_ids is not None
|
||||
assert len(req.file_ids) == 20
|
||||
|
||||
def test_rejects_more_than_20_file_ids(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
)
|
||||
|
||||
|
||||
# ─── queue_pending_message endpoint ──────────────────────────────────
|
||||
|
||||
|
||||
def _mock_pending_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
session_exists: bool = True,
|
||||
call_count: int = 1,
|
||||
):
|
||||
"""Mock all async dependencies for the pending-message endpoint."""
|
||||
if session_exists:
|
||||
mock_session = mocker.MagicMock()
|
||||
mock_session.id = "sess-1"
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_session,
|
||||
)
|
||||
else:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
side_effect=fastapi.HTTPException(
|
||||
status_code=404, detail="Session not found."
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(0, 0, None),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
# Mock Redis for per-user call-frequency rate limit (atomic Lua EVAL)
|
||||
mock_redis = mocker.MagicMock()
|
||||
mock_redis.eval = mocker.AsyncMock(return_value=call_count)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.push_pending_message",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.get_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_pending_message_returns_202(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""Happy path: valid message returns 202 with buffer_length."""
|
||||
_mock_pending_internals(mocker)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "follow-up"},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["buffer_length"] == 1
|
||||
assert data["turn_in_flight"] is False
|
||||
|
||||
|
||||
def test_queue_pending_message_empty_body_returns_422() -> None:
|
||||
"""Empty message must be rejected by Pydantic before hitting any route logic."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": ""},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_missing_message_returns_422() -> None:
|
||||
"""Missing 'message' field returns 422."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_session_not_found_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""If the session doesn't exist or belong to the user, returns 404."""
|
||||
_mock_pending_internals(mocker, session_exists=False)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/bad-sess/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_queue_pending_message_rate_limited_returns_429(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When rate limit is exceeded, endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_pending_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
def test_queue_pending_message_call_frequency_limit_returns_429(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When per-user call frequency limit is exceeded, endpoint returns 429."""
|
||||
from backend.api.features.chat.routes import _PENDING_CALL_LIMIT
|
||||
|
||||
_mock_pending_internals(mocker, call_count=_PENDING_CALL_LIMIT + 1)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "Too many pending messages" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_queue_pending_message_context_url_too_long_returns_422() -> None:
|
||||
"""context.url over 2 KB is rejected."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={
|
||||
"message": "hi",
|
||||
"context": {"url": "https://example.com/" + "x" * 2_000},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_context_content_too_long_returns_422() -> None:
|
||||
"""context.content over 32 KB is rejected."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={
|
||||
"message": "hi",
|
||||
"context": {"content": "x" * 32_001},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_too_many_file_ids_returns_422() -> None:
|
||||
"""More than 20 file_ids should be rejected."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={
|
||||
"message": "hi",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_file_ids_scoped_to_workspace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""File IDs must be sanitized to the user's workspace before push."""
|
||||
_mock_pending_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
|
||||
client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi", "file_ids": [fid, "not-a-uuid"]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [fid]
|
||||
assert call_kwargs["where"]["workspaceId"] == "ws-1"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
@@ -887,6 +887,21 @@ async def llm_call(
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
# Transparent OpenRouter routing for Anthropic models: when an OpenRouter API key
|
||||
# is configured, route direct-Anthropic models through OpenRouter instead. This
|
||||
# gives us the x-total-cost header for free, so provider_cost is always populated
|
||||
# without manual token-rate arithmetic.
|
||||
or_key = settings.secrets.open_router_api_key
|
||||
or_model_id: str | None = None
|
||||
if provider == "anthropic" and or_key:
|
||||
provider = "open_router"
|
||||
credentials = APIKeyCredentials(
|
||||
provider=ProviderName.OPEN_ROUTER,
|
||||
title="OpenRouter (auto)",
|
||||
api_key=SecretStr(or_key),
|
||||
)
|
||||
or_model_id = f"anthropic/{llm_model.value}"
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
result = await compress_context(
|
||||
messages=prompt,
|
||||
@@ -1134,7 +1149,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
model=or_model_id or llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
|
||||
@@ -77,7 +77,11 @@ class TestLLMStatsTracking:
|
||||
mock_response.usage = mock_usage
|
||||
mock_response.stop_reason = "end_turn"
|
||||
|
||||
with patch("anthropic.AsyncAnthropic") as mock_anthropic:
|
||||
with (
|
||||
patch("anthropic.AsyncAnthropic") as mock_anthropic,
|
||||
patch("backend.blocks.llm.settings") as mock_settings,
|
||||
):
|
||||
mock_settings.secrets.open_router_api_key = ""
|
||||
mock_client = AsyncMock()
|
||||
mock_anthropic.return_value = mock_client
|
||||
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
||||
@@ -96,6 +100,56 @@ class TestLLMStatsTracking:
|
||||
assert response.cache_creation_tokens == 50
|
||||
assert response.response == "Test anthropic response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_routes_through_openrouter_when_key_present(self):
|
||||
"""When open_router_api_key is set, Anthropic models route via OpenRouter."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
anthropic_creds = APIKeyCredentials(
|
||||
id="test-anthropic-id",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr("mock-anthropic-key"),
|
||||
title="Mock Anthropic key",
|
||||
)
|
||||
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "routed response"
|
||||
mock_choice.message.tool_calls = None
|
||||
|
||||
mock_usage = MagicMock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 5
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = mock_usage
|
||||
|
||||
mock_create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with (
|
||||
patch("openai.AsyncOpenAI") as mock_openai,
|
||||
patch("backend.blocks.llm.settings") as mock_settings,
|
||||
):
|
||||
mock_settings.secrets.open_router_api_key = "sk-or-test-key"
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.chat.completions.create = mock_create
|
||||
|
||||
await llm.llm_call(
|
||||
credentials=anthropic_creds,
|
||||
llm_model=llm.LlmModel.CLAUDE_3_HAIKU,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify OpenAI client was used (not Anthropic SDK) and model was prefixed
|
||||
mock_openai.assert_called_once()
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "anthropic/claude-3-haiku-20240307"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_structured_response_block_tracks_stats(self):
|
||||
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
|
||||
|
||||
@@ -36,10 +36,6 @@ from backend.copilot.model import (
|
||||
maybe_append_user_message,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.pending_messages import (
|
||||
drain_pending_messages,
|
||||
format_pending_as_user_message,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement, get_graphiti_supplement
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
@@ -934,29 +930,6 @@ async def stream_chat_completion_baseline(
|
||||
message_length=len(message or ""),
|
||||
)
|
||||
|
||||
# Capture count *before* the pending drain so is_first_turn and the
|
||||
# transcript staleness check are not skewed by queued messages.
|
||||
_pre_drain_msg_count = len(session.messages)
|
||||
|
||||
# Drain any messages the user queued via POST /messages/pending
|
||||
# while this session was idle (or during a previous turn whose
|
||||
# mid-loop drains missed them). Atomic LPOP guarantees that a
|
||||
# concurrent push lands *after* the drain and stays queued for the
|
||||
# next turn instead of being lost.
|
||||
drained_at_start = await drain_pending_messages(session_id)
|
||||
if drained_at_start:
|
||||
logger.info(
|
||||
"[Baseline] Draining %d pending message(s) at turn start for session %s",
|
||||
len(drained_at_start),
|
||||
session_id,
|
||||
)
|
||||
for pm in drained_at_start:
|
||||
content = format_pending_as_user_message(pm)["content"]
|
||||
# Append directly — pending messages are atomically-popped from
|
||||
# Redis and are never stale-cache duplicates, so the
|
||||
# maybe_append_user_message dedup is wrong here.
|
||||
session.messages.append(ChatMessage(role="user", content=content))
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Select model based on the per-request mode. 'fast' downgrades to
|
||||
@@ -986,9 +959,7 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
# Use the pre-drain count so queued pending messages don't incorrectly
|
||||
# flip is_first_turn to False on an actual first turn.
|
||||
is_first_turn = _pre_drain_msg_count <= 1
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
# Gate context fetch on both first turn AND user message so that assistant-
|
||||
# role calls (e.g. tool-result submissions) on the first turn don't trigger
|
||||
# a needless DB lookup for user understanding.
|
||||
@@ -999,18 +970,14 @@ async def stream_chat_completion_baseline(
|
||||
prompt_task = _build_cacheable_system_prompt(None)
|
||||
|
||||
# Run download + prompt build concurrently — both are independent I/O
|
||||
# on the request critical path. Use the pre-drain count so pending
|
||||
# messages drained at turn start don't spuriously trigger a transcript
|
||||
# load on an actual first turn.
|
||||
if user_id and _pre_drain_msg_count > 1:
|
||||
# on the request critical path.
|
||||
if user_id and len(session.messages) > 1:
|
||||
transcript_covers_prefix, (base_system_prompt, understanding) = (
|
||||
await asyncio.gather(
|
||||
_load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
# Use pre-drain count so pending messages don't falsely
|
||||
# mark the stored transcript as stale and prevent upload.
|
||||
session_msg_count=_pre_drain_msg_count,
|
||||
session_msg_count=len(session.messages),
|
||||
transcript_builder=transcript_builder,
|
||||
),
|
||||
prompt_task,
|
||||
@@ -1022,16 +989,6 @@ async def stream_chat_completion_baseline(
|
||||
# Append user message to transcript after context injection below so the
|
||||
# transcript receives the prefixed message when user context is available.
|
||||
|
||||
# Mirror any messages drained at turn start (see above) into the
|
||||
# transcript — otherwise the loaded prior transcript would be
|
||||
# missing them and a mid-turn upload could leave a malformed
|
||||
# assistant-after-assistant structure on the next turn.
|
||||
if drained_at_start:
|
||||
for pm in drained_at_start:
|
||||
transcript_builder.append_user(
|
||||
content=format_pending_as_user_message(pm)["content"]
|
||||
)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
@@ -1052,10 +1009,8 @@ async def stream_chat_completion_baseline(
|
||||
graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else ""
|
||||
system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement
|
||||
|
||||
# Warm context: pre-load relevant facts from Graphiti on first turn.
|
||||
# Use the pre-drain count so pending messages drained at turn start
|
||||
# don't prevent warm context injection on an actual first turn.
|
||||
if graphiti_enabled and user_id and _pre_drain_msg_count <= 1:
|
||||
# Warm context: pre-load relevant facts from Graphiti on first turn
|
||||
if graphiti_enabled and user_id and len(session.messages) <= 1:
|
||||
from backend.copilot.graphiti.context import fetch_warm_context
|
||||
|
||||
warm_ctx = await fetch_warm_context(user_id, message or "")
|
||||
@@ -1248,64 +1203,6 @@ async def stream_chat_completion_baseline(
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
|
||||
# Inject any messages the user queued while the turn was
|
||||
# running. ``tool_call_loop`` mutates ``openai_messages``
|
||||
# in-place, so appending here means the model sees the new
|
||||
# messages on its next LLM call.
|
||||
#
|
||||
# IMPORTANT: skip when the loop has already finished (no
|
||||
# more LLM calls are coming). ``tool_call_loop`` yields
|
||||
# a final ``ToolCallLoopResult`` on both paths:
|
||||
# - natural finish: ``finished_naturally=True``
|
||||
# - hit max_iterations: ``finished_naturally=False``
|
||||
# and ``iterations >= max_iterations``
|
||||
# In either case the loop is about to return on the next
|
||||
# ``async for`` step, so draining here would silently
|
||||
# lose the message (the user sees 202 but the model never
|
||||
# reads the text). Those messages stay in the buffer and
|
||||
# get picked up at the start of the next turn.
|
||||
if loop_result is None:
|
||||
continue
|
||||
is_final_yield = (
|
||||
loop_result.finished_naturally
|
||||
or loop_result.iterations >= _MAX_TOOL_ROUNDS
|
||||
)
|
||||
if is_final_yield:
|
||||
continue
|
||||
pending = await drain_pending_messages(session_id)
|
||||
if pending:
|
||||
for pm in pending:
|
||||
# ``format_pending_as_user_message`` embeds file
|
||||
# attachments and context URL/page content into the
|
||||
# content string so the in-session transcript is
|
||||
# a faithful copy of what the model actually saw.
|
||||
formatted = format_pending_as_user_message(pm)
|
||||
content_for_db = formatted["content"]
|
||||
# Append directly — pending messages are atomically-popped
|
||||
# from Redis and are never stale-cache duplicates, so the
|
||||
# maybe_append_user_message dedup is wrong here and would
|
||||
# cause openai_messages/transcript to diverge from session.
|
||||
session.messages.append(
|
||||
ChatMessage(role="user", content=content_for_db)
|
||||
)
|
||||
openai_messages.append(formatted)
|
||||
transcript_builder.append_user(content=content_for_db)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as persist_err:
|
||||
logger.warning(
|
||||
"[Baseline] Failed to persist pending messages for "
|
||||
"session %s: %s",
|
||||
session_id,
|
||||
persist_err,
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] Injected %d pending message(s) into "
|
||||
"session %s mid-turn",
|
||||
len(pending),
|
||||
session_id,
|
||||
)
|
||||
|
||||
if loop_result and not loop_result.finished_naturally:
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
@@ -1346,11 +1243,6 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Pending messages are drained atomically at turn start and
|
||||
# between tool rounds, so there's nothing to clear in finally.
|
||||
# Any message pushed after the final drain window stays in the
|
||||
# buffer and gets picked up at the start of the next turn.
|
||||
|
||||
# Set cost attributes on OTEL span before closing
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
|
||||
@@ -172,6 +172,20 @@ class ChatConfig(BaseSettings):
|
||||
description="Maximum number of retries for transient API errors "
|
||||
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
|
||||
)
|
||||
claude_agent_cli_path: str | None = Field(
|
||||
default=None,
|
||||
description="Optional explicit path to a Claude Code CLI binary. "
|
||||
"When set, the SDK uses this binary instead of the version bundled "
|
||||
"with the installed `claude-agent-sdk` package — letting us pin "
|
||||
"the Python SDK and the CLI independently. Critical for keeping "
|
||||
"OpenRouter compatibility while still picking up newer SDK API "
|
||||
"features (the bundled CLI version in 0.1.46+ is broken against "
|
||||
"OpenRouter — see PR #12294 and "
|
||||
"anthropics/claude-agent-sdk-python#789). Falls back to the "
|
||||
"bundled binary when unset. Reads from `CHAT_CLAUDE_AGENT_CLI_PATH` "
|
||||
"or the unprefixed `CLAUDE_AGENT_CLI_PATH` environment variable "
|
||||
"(same pattern as `api_key` / `base_url`).",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
@@ -294,6 +308,26 @@ class ChatConfig(BaseSettings):
|
||||
v = OPENROUTER_BASE_URL
|
||||
return v
|
||||
|
||||
@field_validator("claude_agent_cli_path", mode="before")
|
||||
@classmethod
|
||||
def get_claude_agent_cli_path(cls, v):
|
||||
"""Resolve the Claude Code CLI override path from environment.
|
||||
|
||||
Accepts either the Pydantic-prefixed ``CHAT_CLAUDE_AGENT_CLI_PATH``
|
||||
or the unprefixed ``CLAUDE_AGENT_CLI_PATH`` (matching the same
|
||||
fallback pattern used by ``api_key`` / ``base_url``). Keeping the
|
||||
unprefixed form working is important because the field is
|
||||
primarily an operator escape hatch set via container/host env,
|
||||
and the unprefixed name is what the PR description, the field
|
||||
docstrings, and the reproduction test in
|
||||
``cli_openrouter_compat_test.py`` refer to.
|
||||
"""
|
||||
if not v:
|
||||
v = os.getenv("CHAT_CLAUDE_AGENT_CLI_PATH")
|
||||
if not v:
|
||||
v = os.getenv("CLAUDE_AGENT_CLI_PATH")
|
||||
return v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -174,13 +174,25 @@ class CoPilotProcessor:
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def _prewarm_cli(self) -> None:
|
||||
"""Run the bundled CLI binary once to warm OS page caches."""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
"""Run the Claude Code CLI binary once to warm OS page caches.
|
||||
|
||||
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
Honours the ``claude_agent_cli_path`` config override (which lets
|
||||
us run a pinned CLI version independent of the bundled one in the
|
||||
installed ``claude-agent-sdk`` wheel — see
|
||||
``ChatConfig.claude_agent_cli_path`` for the rationale). Falls
|
||||
back to the bundled binary when no override is set.
|
||||
"""
|
||||
try:
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
cfg = ChatConfig()
|
||||
cli_path: str | None = cfg.claude_agent_cli_path
|
||||
if not cli_path:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if cli_path:
|
||||
result = subprocess.run(
|
||||
[cli_path, "-v"],
|
||||
|
||||
@@ -1,222 +0,0 @@
|
||||
"""Pending-message buffer for in-flight copilot turns.
|
||||
|
||||
When a user sends a new message while a copilot turn is already executing,
|
||||
instead of blocking the frontend (or queueing a brand-new turn after the
|
||||
current one finishes), we want the new message to be *injected into the
|
||||
running turn* — appended between tool-call rounds so the model sees it
|
||||
before its next LLM call.
|
||||
|
||||
This module provides the cross-process buffer that makes that possible:
|
||||
|
||||
- **Producer** (chat API route): pushes a pending message to Redis and
|
||||
publishes a notification on a pub/sub channel.
|
||||
- **Consumer** (executor running the turn): on each tool-call round,
|
||||
drains the buffer and appends the pending messages to the conversation.
|
||||
|
||||
The Redis list is the durable store; the pub/sub channel is a fast
|
||||
wake-up hint for long-idle consumers (not used by default, but available
|
||||
for future blocking-wait semantics).
|
||||
|
||||
A hard cap of ``MAX_PENDING_MESSAGES`` per session prevents abuse. The
|
||||
buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-session cap. Higher values risk a runaway consumer; lower values
|
||||
# risk dropping user input under heavy typing. 10 was chosen as a
|
||||
# reasonable ceiling — a user typing faster than the copilot can drain
|
||||
# between tool rounds is already an unusual usage pattern.
|
||||
MAX_PENDING_MESSAGES = 10
|
||||
|
||||
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
|
||||
# executor dies, the pending messages should either have been drained
|
||||
# already or are safe to drop (the user can resend).
|
||||
_PENDING_KEY_PREFIX = "copilot:pending:"
|
||||
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
|
||||
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
|
||||
|
||||
# Payload sent on the pub/sub notify channel. Subscribers treat any
|
||||
# message as a wake-up hint; the value itself is not meaningful.
|
||||
_NOTIFY_PAYLOAD = "1"
|
||||
|
||||
|
||||
class PendingMessageContext(BaseModel):
|
||||
"""Structured page context attached to a pending message."""
|
||||
|
||||
url: str | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class PendingMessage(BaseModel):
|
||||
"""A user message queued for injection into an in-flight turn."""
|
||||
|
||||
content: str = Field(min_length=1, max_length=16_000)
|
||||
file_ids: list[str] = Field(default_factory=list)
|
||||
context: PendingMessageContext | None = None
|
||||
|
||||
|
||||
def _buffer_key(session_id: str) -> str:
|
||||
return f"{_PENDING_KEY_PREFIX}{session_id}"
|
||||
|
||||
|
||||
def _notify_channel(session_id: str) -> str:
|
||||
return f"{_PENDING_CHANNEL_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Lua script: push-then-trim-then-expire-then-length, atomically.
|
||||
# Redis serializes EVAL commands, so a concurrent ``LPOP`` drain
|
||||
# observes either the pre-push or post-push state of the list — never
|
||||
# a partial state where the RPUSH has landed but LTRIM hasn't run.
|
||||
_PUSH_LUA = """
|
||||
redis.call('RPUSH', KEYS[1], ARGV[1])
|
||||
redis.call('LTRIM', KEYS[1], -tonumber(ARGV[2]), -1)
|
||||
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[3]))
|
||||
return redis.call('LLEN', KEYS[1])
|
||||
"""
|
||||
|
||||
|
||||
async def push_pending_message(
|
||||
session_id: str,
|
||||
message: PendingMessage,
|
||||
) -> int:
|
||||
"""Append a pending message to the session's buffer atomically.
|
||||
|
||||
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
|
||||
trimming from the left (oldest) — the newest message always wins if
|
||||
the user has been typing faster than the copilot can drain.
|
||||
|
||||
The push + trim + expire + llen are wrapped in a single Lua EVAL so
|
||||
concurrent LPOP drains from the executor never observe a partial
|
||||
state.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
payload = message.model_dump_json()
|
||||
|
||||
new_length = int(
|
||||
await cast(
|
||||
"Any",
|
||||
redis.eval(
|
||||
_PUSH_LUA,
|
||||
1,
|
||||
key,
|
||||
payload,
|
||||
str(MAX_PENDING_MESSAGES),
|
||||
str(_PENDING_TTL_SECONDS),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
|
||||
# the buffer itself is authoritative so a lost notify is harmless.
|
||||
try:
|
||||
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
|
||||
|
||||
logger.info(
|
||||
"pending_messages: pushed message to session=%s (buffer_len=%d)",
|
||||
session_id,
|
||||
new_length,
|
||||
)
|
||||
return new_length
|
||||
|
||||
|
||||
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
|
||||
"""Atomically pop all pending messages for *session_id*.
|
||||
|
||||
Returns them in enqueue order (oldest first). Uses ``LPOP`` with a
|
||||
count so the read+delete is a single Redis round trip. If the list
|
||||
is empty or missing, returns ``[]``.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
|
||||
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
|
||||
# empty list if we somehow race an empty key, or the popped items.
|
||||
# redis-py's async lpop overload with a count collapses the return
|
||||
# type in pyright; cast the awaitable so strict type-check stays
|
||||
# clean without changing runtime behaviour.
|
||||
lpop_result = await cast(
|
||||
"Any",
|
||||
redis.lpop(key, MAX_PENDING_MESSAGES),
|
||||
)
|
||||
if not lpop_result:
|
||||
return []
|
||||
raw_popped: list[Any] = list(lpop_result)
|
||||
|
||||
# redis-py may return bytes or str depending on decode_responses.
|
||||
decoded: list[str] = [
|
||||
item.decode("utf-8") if isinstance(item, bytes) else str(item)
|
||||
for item in raw_popped
|
||||
]
|
||||
|
||||
messages: list[PendingMessage] = []
|
||||
for payload in decoded:
|
||||
try:
|
||||
messages.append(PendingMessage(**json.loads(payload)))
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"pending_messages: dropping malformed entry for %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
|
||||
if messages:
|
||||
logger.info(
|
||||
"pending_messages: drained %d messages for session=%s",
|
||||
len(messages),
|
||||
session_id,
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def peek_pending_count(session_id: str) -> int:
|
||||
"""Return the current buffer length without consuming it."""
|
||||
redis = await get_redis_async()
|
||||
length = await cast("Any", redis.llen(_buffer_key(session_id)))
|
||||
return int(length)
|
||||
|
||||
|
||||
async def clear_pending_messages(session_id: str) -> None:
|
||||
"""Drop the session's pending buffer.
|
||||
|
||||
Not called by the normal turn flow — the atomic ``LPOP`` drain at
|
||||
turn start is the primary consumer, and any push that arrives
|
||||
after the drain window belongs to the next turn by definition.
|
||||
Retained as an operator/debug escape hatch for manually clearing a
|
||||
stuck session and as a fixture in the unit tests.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_buffer_key(session_id))
|
||||
|
||||
|
||||
def format_pending_as_user_message(message: PendingMessage) -> dict[str, Any]:
|
||||
"""Shape a ``PendingMessage`` into the OpenAI-format user message dict.
|
||||
|
||||
Used by the baseline tool-call loop when injecting the buffered
|
||||
message into the conversation. Context/file metadata (if any) is
|
||||
embedded into the content so the model sees everything in one block.
|
||||
"""
|
||||
parts: list[str] = [message.content]
|
||||
if message.context:
|
||||
if message.context.url:
|
||||
parts.append(f"\n\n[Page URL: {message.context.url}]")
|
||||
if message.context.content:
|
||||
parts.append(f"\n\n[Page content]\n{message.context.content}")
|
||||
if message.file_ids:
|
||||
parts.append(
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(f"- file_id={fid}" for fid in message.file_ids)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
return {"role": "user", "content": "".join(parts)}
|
||||
@@ -1,246 +0,0 @@
|
||||
"""Tests for the copilot pending-messages buffer.
|
||||
|
||||
Uses a fake async Redis client so the tests don't require a real Redis
|
||||
instance (the backend test suite's DB/Redis fixtures are heavyweight
|
||||
and pull in the full app startup).
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import pending_messages as pm_module
|
||||
from backend.copilot.pending_messages import (
|
||||
MAX_PENDING_MESSAGES,
|
||||
PendingMessage,
|
||||
PendingMessageContext,
|
||||
clear_pending_messages,
|
||||
drain_pending_messages,
|
||||
format_pending_as_user_message,
|
||||
peek_pending_count,
|
||||
push_pending_message,
|
||||
)
|
||||
|
||||
# ── Fake Redis ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
# Values are ``str | bytes`` because real redis-py returns
|
||||
# bytes when ``decode_responses=False``; the drain path must
|
||||
# handle both and our tests exercise both.
|
||||
self.lists: dict[str, list[str | bytes]] = {}
|
||||
self.published: list[tuple[str, str]] = []
|
||||
|
||||
async def eval(self, script: str, num_keys: int, *args: Any) -> Any:
|
||||
"""Emulate the push Lua script.
|
||||
|
||||
The real Lua script runs atomically in Redis; the fake
|
||||
implementation just runs the equivalent list operations in
|
||||
order and returns the final LLEN. That's enough to exercise
|
||||
the cap + ordering invariants the tests care about.
|
||||
"""
|
||||
key = args[0]
|
||||
payload = args[1]
|
||||
max_len = int(args[2])
|
||||
# ARGV[3] is TTL — fake doesn't enforce expiry
|
||||
lst = self.lists.setdefault(key, [])
|
||||
lst.append(payload)
|
||||
if len(lst) > max_len:
|
||||
# RPUSH + LTRIM(-N, -1) = keep only last N
|
||||
self.lists[key] = lst[-max_len:]
|
||||
return len(self.lists[key])
|
||||
|
||||
async def publish(self, channel: str, payload: str) -> int:
|
||||
self.published.append((channel, payload))
|
||||
return 1
|
||||
|
||||
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
|
||||
lst = self.lists.get(key)
|
||||
if not lst:
|
||||
return None
|
||||
popped = lst[:count]
|
||||
self.lists[key] = lst[count:]
|
||||
return popped
|
||||
|
||||
async def llen(self, key: str) -> int:
|
||||
return len(self.lists.get(key, []))
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
if key in self.lists:
|
||||
del self.lists[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis:
|
||||
redis = _FakeRedis()
|
||||
|
||||
async def _get_redis_async() -> _FakeRedis:
|
||||
return redis
|
||||
|
||||
monkeypatch.setattr(pm_module, "get_redis_async", _get_redis_async)
|
||||
return redis
|
||||
|
||||
|
||||
# ── Basic push / drain ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_and_drain_single_message(fake_redis: _FakeRedis) -> None:
|
||||
length = await push_pending_message("sess1", PendingMessage(content="hello"))
|
||||
assert length == 1
|
||||
assert await peek_pending_count("sess1") == 1
|
||||
|
||||
drained = await drain_pending_messages("sess1")
|
||||
assert len(drained) == 1
|
||||
assert drained[0].content == "hello"
|
||||
assert await peek_pending_count("sess1") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_and_drain_preserves_order(fake_redis: _FakeRedis) -> None:
|
||||
for i in range(3):
|
||||
await push_pending_message("sess2", PendingMessage(content=f"msg {i}"))
|
||||
|
||||
drained = await drain_pending_messages("sess2")
|
||||
assert [m.content for m in drained] == ["msg 0", "msg 1", "msg 2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_empty_returns_empty_list(fake_redis: _FakeRedis) -> None:
|
||||
assert await drain_pending_messages("nope") == []
|
||||
|
||||
|
||||
# ── Buffer cap ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
|
||||
# Push MAX_PENDING_MESSAGES + 3 messages
|
||||
for i in range(MAX_PENDING_MESSAGES + 3):
|
||||
await push_pending_message("sess3", PendingMessage(content=f"m{i}"))
|
||||
|
||||
# Buffer should be clamped to MAX
|
||||
assert await peek_pending_count("sess3") == MAX_PENDING_MESSAGES
|
||||
|
||||
drained = await drain_pending_messages("sess3")
|
||||
assert len(drained) == MAX_PENDING_MESSAGES
|
||||
# Oldest 3 dropped — we should only see m3..m(MAX+2)
|
||||
assert drained[0].content == "m3"
|
||||
assert drained[-1].content == f"m{MAX_PENDING_MESSAGES + 2}"
|
||||
|
||||
|
||||
# ── Clear ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
|
||||
await push_pending_message("sess4", PendingMessage(content="x"))
|
||||
await push_pending_message("sess4", PendingMessage(content="y"))
|
||||
await clear_pending_messages("sess4")
|
||||
assert await peek_pending_count("sess4") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
|
||||
# Clearing an already-empty buffer should not raise
|
||||
await clear_pending_messages("sess_empty")
|
||||
await clear_pending_messages("sess_empty")
|
||||
|
||||
|
||||
# ── Publish hook ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_publishes_notification(fake_redis: _FakeRedis) -> None:
|
||||
await push_pending_message("sess5", PendingMessage(content="hi"))
|
||||
assert ("copilot:pending:notify:sess5", "1") in fake_redis.published
|
||||
|
||||
|
||||
# ── Format helper ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_format_pending_plain_text() -> None:
|
||||
msg = PendingMessage(content="just text")
|
||||
out = format_pending_as_user_message(msg)
|
||||
assert out == {"role": "user", "content": "just text"}
|
||||
|
||||
|
||||
def test_format_pending_with_context_url() -> None:
|
||||
msg = PendingMessage(
|
||||
content="see this page",
|
||||
context=PendingMessageContext(url="https://example.com"),
|
||||
)
|
||||
out = format_pending_as_user_message(msg)
|
||||
content = out["content"]
|
||||
assert out["role"] == "user"
|
||||
assert "see this page" in content
|
||||
# The URL should appear verbatim in the [Page URL: ...] block.
|
||||
assert "[Page URL: https://example.com]" in content
|
||||
|
||||
|
||||
def test_format_pending_with_file_ids() -> None:
|
||||
msg = PendingMessage(content="look here", file_ids=["a", "b"])
|
||||
out = format_pending_as_user_message(msg)
|
||||
assert "file_id=a" in out["content"]
|
||||
assert "file_id=b" in out["content"]
|
||||
|
||||
|
||||
def test_format_pending_with_all_fields() -> None:
|
||||
"""All fields (content + context url/content + file_ids) should all appear."""
|
||||
msg = PendingMessage(
|
||||
content="summarise this",
|
||||
context=PendingMessageContext(
|
||||
url="https://example.com/page",
|
||||
content="headline text",
|
||||
),
|
||||
file_ids=["f1", "f2"],
|
||||
)
|
||||
out = format_pending_as_user_message(msg)
|
||||
body = out["content"]
|
||||
assert out["role"] == "user"
|
||||
assert "summarise this" in body
|
||||
assert "[Page URL: https://example.com/page]" in body
|
||||
assert "[Page content]\nheadline text" in body
|
||||
assert "file_id=f1" in body
|
||||
assert "file_id=f2" in body
|
||||
|
||||
|
||||
# ── Malformed payload handling ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_skips_malformed_entries(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
# Seed the fake with a mix of valid and malformed payloads
|
||||
fake_redis.lists["copilot:pending:bad"] = [
|
||||
json.dumps({"content": "valid"}),
|
||||
"{not valid json",
|
||||
json.dumps({"content": "also valid", "file_ids": ["a"]}),
|
||||
]
|
||||
drained = await drain_pending_messages("bad")
|
||||
assert len(drained) == 2
|
||||
assert drained[0].content == "valid"
|
||||
assert drained[1].content == "also valid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_decodes_bytes_payloads(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Real redis-py returns ``bytes`` when ``decode_responses=False``.
|
||||
|
||||
Seed the fake with bytes values to exercise the ``decode("utf-8")``
|
||||
branch in ``drain_pending_messages`` so a regression there doesn't
|
||||
slip past CI.
|
||||
"""
|
||||
fake_redis.lists["copilot:pending:bytes_sess"] = [
|
||||
json.dumps({"content": "from bytes"}).encode("utf-8"),
|
||||
]
|
||||
drained = await drain_pending_messages("bytes_sess")
|
||||
assert len(drained) == 1
|
||||
assert drained[0].content == "from bytes"
|
||||
@@ -0,0 +1,577 @@
|
||||
"""Reproduction test for the OpenRouter incompatibility in newer
|
||||
``claude-agent-sdk`` / Claude Code CLI versions.
|
||||
|
||||
Background — there are two stacked regressions that block us from
|
||||
upgrading the ``claude-agent-sdk`` package above ``0.1.45``:
|
||||
|
||||
1. **`tool_reference` content blocks** introduced by CLI ``2.1.69`` (=
|
||||
SDK ``0.1.46``). The CLI's built-in ``ToolSearch`` tool returns
|
||||
``{"type": "tool_reference", "tool_name": "..."}`` content blocks in
|
||||
``tool_result.content``. OpenRouter's stricter Zod validation
|
||||
rejects this with::
|
||||
|
||||
messages[N].content[0].content: Invalid input: expected string, received array
|
||||
|
||||
This is the regression that originally pinned us at 0.1.45 — see
|
||||
https://github.com/Significant-Gravitas/AutoGPT/pull/12294 for the
|
||||
full forensic write-up. CLI 2.1.70 added proxy detection that
|
||||
*should* disable the offending blocks when ``ANTHROPIC_BASE_URL`` is
|
||||
set, but our subsequent attempts at 0.1.55 / 0.1.56 still failed.
|
||||
|
||||
2. **`context-management-2025-06-27` beta header** — some CLI version
|
||||
after ``2.1.91`` started injecting this header / beta flag, which
|
||||
OpenRouter rejects with::
|
||||
|
||||
400 No endpoints available that support Anthropic's context
|
||||
management features (context-management-2025-06-27). Context
|
||||
management requires a supported provider (Anthropic).
|
||||
|
||||
Tracked upstream at
|
||||
https://github.com/anthropics/claude-agent-sdk-python/issues/789.
|
||||
Still open at the time of writing, no upstream PR linked, no
|
||||
workaround documented.
|
||||
|
||||
The purpose of this test:
|
||||
* Spin up a tiny in-process HTTP server that pretends to be the
|
||||
Anthropic Messages API.
|
||||
* Capture every request body the CLI sends.
|
||||
* Inspect the captured bodies for the two forbidden patterns above.
|
||||
* Fail loudly if either is present, with a pointer to the issue
|
||||
tracker.
|
||||
|
||||
This is the reproduction we use as a CI gate when bisecting which SDK /
|
||||
CLI version is safe to upgrade to. It runs against the bundled CLI by
|
||||
default (or against ``ChatConfig.claude_agent_cli_path`` when set), so
|
||||
it doubles as a regression guard for the ``cli_path`` override
|
||||
mechanism.
|
||||
|
||||
The test does **not** need an OpenRouter API key — it reproduces the
|
||||
mechanism (forbidden content blocks / headers in the *outgoing*
|
||||
request) rather than the symptom (the 400 OpenRouter would return).
|
||||
This keeps it deterministic, free, and CI-runnable without secrets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forbidden patterns we scan for in captured request bodies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Match the `tool_reference` content block that breaks OpenRouter's stricter
|
||||
# Zod validation in tool_result.content. PR #12294 root-cause.
|
||||
#
|
||||
# We use a whitespace-tolerant regex rather than a literal substring because
|
||||
# the Claude Code CLI is Node.js and `JSON.stringify` without an indent
|
||||
# argument emits no whitespace between the key, colon, and value
|
||||
# (`{"type":"tool_reference"}`), while a Python serializer would emit
|
||||
# `{"type": "tool_reference"}`. A naive substring with one specific spacing
|
||||
# would silently miss the real regression.
|
||||
_FORBIDDEN_TOOL_REFERENCE_RE = re.compile(r'"type"\s*:\s*"tool_reference"')
|
||||
|
||||
# Beta string OpenRouter rejects in upstream issue #789. Can appear in
|
||||
# either `betas` arrays or the `anthropic-beta` header value. This is a
|
||||
# unique opaque token (no JSON punctuation around it that could vary), so
|
||||
# a plain substring match is robust.
|
||||
_FORBIDDEN_CONTEXT_MANAGEMENT_BETA = "context-management-2025-06-27"
|
||||
|
||||
|
||||
def _scan_request_for_forbidden_patterns(
|
||||
body_text: str,
|
||||
headers: dict[str, str],
|
||||
) -> list[str]:
|
||||
"""Return a list of forbidden patterns found in *body_text* / *headers*.
|
||||
|
||||
Empty list = clean request. Non-empty = the CLI is sending one of the
|
||||
OpenRouter-incompatible features.
|
||||
"""
|
||||
findings: list[str] = []
|
||||
if _FORBIDDEN_TOOL_REFERENCE_RE.search(body_text):
|
||||
findings.append(
|
||||
"`tool_reference` content block in request body — "
|
||||
"PR #12294 / CLI 2.1.69 regression"
|
||||
)
|
||||
if _FORBIDDEN_CONTEXT_MANAGEMENT_BETA in body_text:
|
||||
findings.append(
|
||||
f"{_FORBIDDEN_CONTEXT_MANAGEMENT_BETA!r} in request body — "
|
||||
"anthropics/claude-agent-sdk-python#789"
|
||||
)
|
||||
# Header values are case-insensitive in HTTP — aiohttp normalises
|
||||
# incoming names but values are stored as-is.
|
||||
for header_name, header_value in headers.items():
|
||||
if header_name.lower() == "anthropic-beta":
|
||||
if _FORBIDDEN_CONTEXT_MANAGEMENT_BETA in header_value:
|
||||
findings.append(
|
||||
f"{_FORBIDDEN_CONTEXT_MANAGEMENT_BETA!r} in "
|
||||
"`anthropic-beta` header — issue #789"
|
||||
)
|
||||
return findings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake Anthropic Messages API
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# We need to give the CLI a *successful* response so it doesn't error out
|
||||
# before we get a chance to inspect the request. The minimal thing the
|
||||
# CLI accepts is a streamed (SSE) message-start → content-block-delta →
|
||||
# message-stop sequence.
|
||||
#
|
||||
# We don't strictly *need* the CLI to accept the response — we already
|
||||
# have the request body by the time we send any reply — but giving it a
|
||||
# valid stream means the assertion failure (if any) is the *only*
|
||||
# failure mode in the test, not "CLI exited 1 because we sent garbage".
|
||||
|
||||
|
||||
def _build_streaming_message_response() -> str:
|
||||
"""Return an SSE-formatted body containing a minimal Anthropic
|
||||
Messages API streamed response.
|
||||
|
||||
This is the smallest stream that the Claude Code CLI will accept
|
||||
end-to-end without errors. Each line is one SSE event."""
|
||||
events: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": "claude-test",
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": "ok"},
|
||||
},
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
||||
"usage": {"output_tokens": 1},
|
||||
},
|
||||
{"type": "message_stop"},
|
||||
]
|
||||
return "".join(
|
||||
f"event: {evt['type']}\ndata: {json.dumps(evt)}\n\n" for evt in events
|
||||
)
|
||||
|
||||
|
||||
class _CapturedRequest:
|
||||
"""One request the fake server received."""
|
||||
|
||||
def __init__(self, path: str, headers: dict[str, str], body: str) -> None:
|
||||
self.path = path
|
||||
self.headers = headers
|
||||
self.body = body
|
||||
|
||||
|
||||
async def _start_fake_anthropic_server(
|
||||
captured: list[_CapturedRequest],
|
||||
) -> tuple[web.AppRunner, int]:
|
||||
"""Start an aiohttp server pretending to be the Anthropic API.
|
||||
|
||||
All POSTs to ``/v1/messages`` are recorded into *captured* and
|
||||
answered with a valid streaming response. Returns ``(runner, port)``
|
||||
so the caller can ``await runner.cleanup()`` when finished.
|
||||
"""
|
||||
import socket
|
||||
|
||||
async def messages_handler(request: web.Request) -> web.StreamResponse:
|
||||
body = await request.text()
|
||||
captured.append(
|
||||
_CapturedRequest(
|
||||
path=request.path,
|
||||
headers={k: v for k, v in request.headers.items()},
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
# Stream a minimal valid response so the CLI doesn't error out
|
||||
# before we can inspect what it sent.
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
await response.prepare(request)
|
||||
await response.write(_build_streaming_message_response().encode("utf-8"))
|
||||
await response.write_eof()
|
||||
return response
|
||||
|
||||
async def fallback_handler(_request: web.Request) -> web.Response:
|
||||
# OAuth/profile endpoints the CLI may probe — answer 404 so it
|
||||
# falls through quickly without retrying.
|
||||
return web.Response(status=404)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/v1/messages", messages_handler)
|
||||
app.router.add_route("*", "/{tail:.*}", fallback_handler)
|
||||
|
||||
# Bind an ephemeral port ourselves so we can read it back via the
|
||||
# public ``getsockname`` API rather than reaching into ``site._server``
|
||||
# private aiohttp internals.
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
port: int = sock.getsockname()[1]
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.SockSite(runner, sock)
|
||||
await site.start()
|
||||
|
||||
return runner, port
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI invocation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_cli_path() -> Path | None:
|
||||
"""Return the Claude Code CLI binary the SDK would use.
|
||||
|
||||
Honours the same override mechanism as ``service.py`` /
|
||||
``ChatConfig.claude_agent_cli_path``: checks either the Pydantic-
|
||||
prefixed ``CHAT_CLAUDE_AGENT_CLI_PATH`` or the unprefixed
|
||||
``CLAUDE_AGENT_CLI_PATH`` env var first, then falls back to the
|
||||
bundled binary that ships with the installed ``claude-agent-sdk``
|
||||
wheel. The two env var names are accepted at the config layer via
|
||||
``ChatConfig.get_claude_agent_cli_path`` and mirrored here so the
|
||||
reproduction test picks up the same override regardless of which
|
||||
form an operator sets.
|
||||
"""
|
||||
override = os.environ.get("CHAT_CLAUDE_AGENT_CLI_PATH") or os.environ.get(
|
||||
"CLAUDE_AGENT_CLI_PATH"
|
||||
)
|
||||
if override:
|
||||
candidate = Path(override)
|
||||
return candidate if candidate.is_file() else None
|
||||
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import ( # type: ignore[import-untyped]
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
bundled = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
return Path(bundled) if bundled else None
|
||||
except Exception as e: # pragma: no cover - import-time guard
|
||||
logger.warning("Could not locate bundled Claude CLI: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def _run_cli_against_fake_server(
|
||||
cli_path: Path,
|
||||
fake_server_port: int,
|
||||
timeout_seconds: float,
|
||||
) -> tuple[int, str, str]:
|
||||
"""Spawn the CLI pointed at the fake Anthropic server and feed it a
|
||||
single ``user`` message via stream-json on stdin.
|
||||
|
||||
Returns ``(returncode, stdout, stderr)``. The return code is not
|
||||
asserted by the test — we only care that the CLI made at least one
|
||||
POST to ``/v1/messages`` so the fake server captured the body.
|
||||
"""
|
||||
fake_url = f"http://127.0.0.1:{fake_server_port}"
|
||||
env = {
|
||||
# Inherit basic shell variables so the CLI can find its tools,
|
||||
# but force network/auth at our fake endpoint.
|
||||
**os.environ,
|
||||
"ANTHROPIC_BASE_URL": fake_url,
|
||||
"ANTHROPIC_API_KEY": "sk-test-fake-key-not-real",
|
||||
# Disable any features that would phone home to a different host
|
||||
# mid-test (telemetry, plugin marketplace fetch).
|
||||
"DISABLE_TELEMETRY": "1",
|
||||
"CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1",
|
||||
}
|
||||
|
||||
# The CLI accepts stream-json input on stdin in `query` mode. A
|
||||
# minimal user-message envelope is enough to trigger an API call.
|
||||
stdin_payload = (
|
||||
json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "hello"},
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(cli_path),
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--verbose",
|
||||
"--print",
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
try:
|
||||
assert proc.stdin is not None
|
||||
proc.stdin.write(stdin_payload.encode("utf-8"))
|
||||
await proc.stdin.drain()
|
||||
proc.stdin.close()
|
||||
|
||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout_seconds
|
||||
)
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
# Best-effort kill — we already have whatever requests the CLI
|
||||
# managed to send before stalling.
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
# Reap the process to avoid leaving a zombie + open pipe FDs.
|
||||
# Without this the asyncio transport keeps the stdout/stderr
|
||||
# pipes alive until the loop exits, and in CI loops where this
|
||||
# test runs many times the file-descriptor count creeps up.
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=5.0)
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
pass
|
||||
stdout_bytes, stderr_bytes = b"", b""
|
||||
|
||||
return (
|
||||
proc.returncode if proc.returncode is not None else -1,
|
||||
stdout_bytes.decode("utf-8", errors="replace"),
|
||||
stderr_bytes.decode("utf-8", errors="replace"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# The actual test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cli_does_not_send_openrouter_incompatible_features(caplog):
|
||||
"""End-to-end OpenRouter compatibility reproduction.
|
||||
|
||||
Spawns the bundled (or overridden) Claude Code CLI against a fake
|
||||
Anthropic API server, captures every request body it sends, and
|
||||
asserts that none of them contain the two known OpenRouter-breaking
|
||||
features (`tool_reference` content blocks or the
|
||||
`context-management-2025-06-27` beta header).
|
||||
|
||||
Why this matters: pinning the CLI version via
|
||||
``test_bundled_cli_version_is_known_good_against_openrouter`` only
|
||||
catches accidental SDK bumps — it doesn't tell us *why* the new
|
||||
version would fail. This test reproduces the exact mechanism so
|
||||
bisecting via CI commits gives an actionable signal.
|
||||
"""
|
||||
cli_path = _resolve_cli_path()
|
||||
if cli_path is None or not cli_path.is_file():
|
||||
pytest.skip(
|
||||
"No Claude Code CLI binary available (neither bundled nor "
|
||||
"overridden via CLAUDE_AGENT_CLI_PATH / "
|
||||
"CHAT_CLAUDE_AGENT_CLI_PATH); cannot reproduce."
|
||||
)
|
||||
|
||||
captured: list[_CapturedRequest] = []
|
||||
runner, port = await _start_fake_anthropic_server(captured)
|
||||
try:
|
||||
returncode, stdout, stderr = await _run_cli_against_fake_server(
|
||||
cli_path=cli_path,
|
||||
fake_server_port=port,
|
||||
timeout_seconds=30.0,
|
||||
)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
# We don't assert the CLI's exit code — depending on the CLI version
|
||||
# and what we send back, the CLI may exit non-zero after a single
|
||||
# successful round-trip. All we care about is that the captured
|
||||
# request bodies don't contain the forbidden patterns.
|
||||
logger.info(
|
||||
"CLI exited rc=%d; captured %d requests; stdout=%d bytes; stderr=%d bytes",
|
||||
returncode,
|
||||
len(captured),
|
||||
len(stdout),
|
||||
len(stderr),
|
||||
)
|
||||
|
||||
if not captured:
|
||||
pytest.skip(
|
||||
"Bundled CLI did not make any HTTP requests to the fake server "
|
||||
f"(rc={returncode}). The CLI may have failed before reaching "
|
||||
f"the network — stderr tail: {stderr[-500:]!r}. "
|
||||
"Nothing to assert; treating as inconclusive rather than "
|
||||
"either passing or failing."
|
||||
)
|
||||
|
||||
all_findings: list[str] = []
|
||||
for req in captured:
|
||||
findings = _scan_request_for_forbidden_patterns(req.body, req.headers)
|
||||
if findings:
|
||||
all_findings.extend(f"{req.path}: {finding}" for finding in findings)
|
||||
|
||||
assert not all_findings, (
|
||||
f"Bundled Claude Code CLI sent OpenRouter-incompatible features in "
|
||||
f"{len(all_findings)} request(s):\n - "
|
||||
+ "\n - ".join(all_findings)
|
||||
+ "\n\nThis is the regression that prevents us from upgrading "
|
||||
"`claude-agent-sdk` above 0.1.45. See "
|
||||
"https://github.com/Significant-Gravitas/AutoGPT/pull/12294 and "
|
||||
"https://github.com/anthropics/claude-agent-sdk-python/issues/789. "
|
||||
"If you intended to upgrade, you must use a known-good CLI binary "
|
||||
"via `claude_agent_cli_path` (env: `CLAUDE_AGENT_CLI_PATH` or "
|
||||
"`CHAT_CLAUDE_AGENT_CLI_PATH`) instead of the bundled one."
|
||||
)
|
||||
|
||||
|
||||
def test_subprocess_module_available():
|
||||
"""Sentinel test: the subprocess module must be importable so the
|
||||
main reproduction test can spawn the CLI. Catches sandboxed CI
|
||||
runners that block subprocess execution before the slow test runs."""
|
||||
assert subprocess.__name__ == "subprocess"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helper unit tests — pin the forbidden-pattern detection so any
|
||||
# future drift in the scanner is caught fast, even when the slow
|
||||
# end-to-end CLI subprocess test isn't runnable.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScanRequestForForbiddenPatterns:
|
||||
def test_clean_body_returns_empty_findings(self):
|
||||
body = '{"model": "claude-opus-4.6", "messages": [{"role": "user", "content": "hi"}]}'
|
||||
assert _scan_request_for_forbidden_patterns(body, {}) == []
|
||||
|
||||
def test_detects_tool_reference_in_body(self):
|
||||
body = (
|
||||
'{"messages": [{"role": "user", "content": ['
|
||||
'{"type": "tool_reference", "tool_name": "find"}'
|
||||
"]}]}"
|
||||
)
|
||||
findings = _scan_request_for_forbidden_patterns(body, {})
|
||||
assert len(findings) == 1
|
||||
assert "tool_reference" in findings[0]
|
||||
assert "PR #12294" in findings[0]
|
||||
|
||||
def test_detects_context_management_in_body(self):
|
||||
body = '{"betas": ["context-management-2025-06-27"]}'
|
||||
findings = _scan_request_for_forbidden_patterns(body, {})
|
||||
assert len(findings) == 1
|
||||
assert "context-management-2025-06-27" in findings[0]
|
||||
assert "#789" in findings[0]
|
||||
|
||||
def test_detects_context_management_in_anthropic_beta_header(self):
|
||||
findings = _scan_request_for_forbidden_patterns(
|
||||
body_text="{}",
|
||||
headers={"anthropic-beta": "context-management-2025-06-27"},
|
||||
)
|
||||
assert len(findings) == 1
|
||||
assert "anthropic-beta" in findings[0]
|
||||
|
||||
def test_detects_context_management_in_uppercase_header_name(self):
|
||||
# HTTP header names are case-insensitive — make sure the
|
||||
# scanner handles a server that didn't normalise names.
|
||||
findings = _scan_request_for_forbidden_patterns(
|
||||
body_text="{}",
|
||||
headers={"Anthropic-Beta": "context-management-2025-06-27, other"},
|
||||
)
|
||||
assert len(findings) == 1
|
||||
|
||||
def test_ignores_unrelated_header_values(self):
|
||||
findings = _scan_request_for_forbidden_patterns(
|
||||
body_text="{}",
|
||||
headers={
|
||||
"authorization": "Bearer secret",
|
||||
"anthropic-beta": "fine-grained-tool-streaming-2025",
|
||||
},
|
||||
)
|
||||
assert findings == []
|
||||
|
||||
def test_detects_both_patterns_simultaneously(self):
|
||||
body = (
|
||||
'{"betas": ["context-management-2025-06-27"], '
|
||||
'"messages": [{"role": "user", "content": ['
|
||||
'{"type": "tool_reference", "tool_name": "find"}'
|
||||
"]}]}"
|
||||
)
|
||||
findings = _scan_request_for_forbidden_patterns(body, {})
|
||||
# Both patterns hit, in stable order: tool_reference then betas.
|
||||
assert len(findings) == 2
|
||||
assert "tool_reference" in findings[0]
|
||||
assert "context-management-2025-06-27" in findings[1]
|
||||
|
||||
|
||||
class TestResolveCliPath:
|
||||
def test_honours_explicit_env_var_when_file_exists(self, tmp_path, monkeypatch):
|
||||
fake_cli = tmp_path / "fake-claude"
|
||||
fake_cli.write_text("#!/bin/sh\necho fake\n")
|
||||
fake_cli.chmod(0o755)
|
||||
monkeypatch.delenv("CHAT_CLAUDE_AGENT_CLI_PATH", raising=False)
|
||||
monkeypatch.setenv("CLAUDE_AGENT_CLI_PATH", str(fake_cli))
|
||||
resolved = _resolve_cli_path()
|
||||
assert resolved == fake_cli
|
||||
|
||||
def test_honours_chat_prefixed_env_var_when_file_exists(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""The Pydantic ``CHAT_`` prefix variant is also honoured.
|
||||
|
||||
Mirrors ``ChatConfig.get_claude_agent_cli_path`` which accepts
|
||||
either ``CHAT_CLAUDE_AGENT_CLI_PATH`` (prefix applied by
|
||||
``pydantic_settings``) or the unprefixed ``CLAUDE_AGENT_CLI_PATH``
|
||||
form documented in the PR and field docstring.
|
||||
"""
|
||||
fake_cli = tmp_path / "fake-claude-prefixed"
|
||||
fake_cli.write_text("#!/bin/sh\necho fake\n")
|
||||
fake_cli.chmod(0o755)
|
||||
monkeypatch.delenv("CLAUDE_AGENT_CLI_PATH", raising=False)
|
||||
monkeypatch.setenv("CHAT_CLAUDE_AGENT_CLI_PATH", str(fake_cli))
|
||||
resolved = _resolve_cli_path()
|
||||
assert resolved == fake_cli
|
||||
|
||||
def test_returns_none_when_env_var_points_to_missing_file(self, monkeypatch):
|
||||
monkeypatch.delenv("CHAT_CLAUDE_AGENT_CLI_PATH", raising=False)
|
||||
monkeypatch.setenv("CLAUDE_AGENT_CLI_PATH", "/nonexistent/path/to/claude")
|
||||
# When the override is set but the file is missing, the resolver
|
||||
# returns ``None`` outright — it does NOT silently fall through to
|
||||
# the bundled binary, because doing so would defeat the purpose of
|
||||
# the override (the operator explicitly asked for a specific path).
|
||||
# The strict ``is None`` assertion catches any future regression
|
||||
# that swaps this fail-loud behaviour for a silent fallback.
|
||||
resolved = _resolve_cli_path()
|
||||
assert resolved is None
|
||||
|
||||
def test_falls_back_to_bundled_when_env_var_unset(self, monkeypatch):
|
||||
monkeypatch.delenv("CLAUDE_AGENT_CLI_PATH", raising=False)
|
||||
monkeypatch.delenv("CHAT_CLAUDE_AGENT_CLI_PATH", raising=False)
|
||||
# Same caveat as above — returns the bundled path or None,
|
||||
# depending on what's installed in the test env.
|
||||
resolved = _resolve_cli_path()
|
||||
assert resolved is None or resolved.is_file()
|
||||
@@ -226,111 +226,6 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
assert was_compacted is False # mock returns False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_prevents_pending_duplication():
|
||||
"""session_msg_ceiling stops pending messages from leaking into the gap.
|
||||
|
||||
Scenario: transcript covers 2 messages, session has 2 historical + 1 current
|
||||
+ 2 pending drained at turn start. Without the ceiling the gap would include
|
||||
the pending messages AND current_message already has them → duplication.
|
||||
With session_msg_ceiling=3 (pre-drain count) the gap slice is empty and
|
||||
only current_message carries the pending content.
|
||||
"""
|
||||
# session.messages after drain: [hist1, hist2, current_msg, pending1, pending2]
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hist1"),
|
||||
ChatMessage(role="assistant", content="hist2"),
|
||||
ChatMessage(role="user", content="current msg with pending1 pending2"),
|
||||
ChatMessage(role="user", content="pending1"),
|
||||
ChatMessage(role="user", content="pending2"),
|
||||
]
|
||||
)
|
||||
# transcript covers hist1+hist2 (2 messages); pre-drain count was 3 (includes current_msg)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"current msg with pending1 pending2",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=3, # len(session.messages) before drain
|
||||
)
|
||||
# Gap should be empty (transcript_msg_count == ceiling - 1), so no history prepended
|
||||
assert result == "current msg with pending1 pending2"
|
||||
assert was_compacted is False
|
||||
# Pending messages must NOT appear in gap context
|
||||
assert "pending1" not in result.split("current msg")[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_preserves_real_gap():
|
||||
"""session_msg_ceiling still surfaces a genuine stale-transcript gap.
|
||||
|
||||
Scenario: transcript covers 2 messages, session has 4 historical + 1 current
|
||||
+ 2 pending. Ceiling = 5 (pre-drain). Real gap = messages 2-3 (hist3, hist4).
|
||||
"""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hist1"),
|
||||
ChatMessage(role="assistant", content="hist2"),
|
||||
ChatMessage(role="user", content="hist3"),
|
||||
ChatMessage(role="assistant", content="hist4"),
|
||||
ChatMessage(role="user", content="current"),
|
||||
ChatMessage(role="user", content="pending1"),
|
||||
ChatMessage(role="user", content="pending2"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"current",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=5, # pre-drain: [hist1..hist4, current]
|
||||
)
|
||||
# Gap = session.messages[2:4] = [hist3, hist4]
|
||||
assert "<conversation_history>" in result
|
||||
assert "hist3" in result
|
||||
assert "hist4" in result
|
||||
assert "Now, the user says:\ncurrent" in result
|
||||
# Pending messages must NOT appear in gap
|
||||
assert "pending1" not in result
|
||||
assert "pending2" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_suppresses_spurious_no_resume_fallback():
|
||||
"""session_msg_ceiling prevents the no-resume compression fallback from
|
||||
firing on the first turn of a session when pending messages inflate msg_count.
|
||||
|
||||
Scenario: fresh session (1 message) + 1 pending message drained at turn start.
|
||||
Without the ceiling: msg_count=2 > 1 → fallback triggers → pending message
|
||||
leaked into history → wrong context sent to model.
|
||||
With session_msg_ceiling=1 (pre-drain count): effective_count=1, 1 > 1 is False
|
||||
→ fallback does not trigger → current_message returned as-is.
|
||||
"""
|
||||
# session.messages after drain: [current_msg, pending_msg]
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="What is 2 plus 2?"),
|
||||
ChatMessage(role="user", content="What is 7 plus 7?"), # pending
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"What is 2 plus 2?\n\nWhat is 7 plus 7?",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=1, # pre-drain: only 1 message existed
|
||||
)
|
||||
# Should return current_message directly without wrapping in history context
|
||||
assert result == "What is 2 plus 2?\n\nWhat is 7 plus 7?"
|
||||
assert was_compacted is False
|
||||
# Pending question must NOT appear in a spurious history section
|
||||
assert "<conversation_history>" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
"""When compression actually compacts, was_compacted should be True."""
|
||||
|
||||
@@ -1031,12 +1031,6 @@ def _make_sdk_patches(
|
||||
),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
|
||||
# Stub pending-message drain so retry tests don't hit Redis.
|
||||
# Returns an empty list → no mid-turn injection happens.
|
||||
(
|
||||
f"{_SVC}.drain_pending_messages",
|
||||
dict(new_callable=AsyncMock, return_value=[]),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -196,3 +196,79 @@ def test_sdk_exports_hook_event_type(hook_event: str):
|
||||
# HookEvent is a Literal type — check that our events are valid values.
|
||||
# We can't easily inspect Literal at runtime, so just verify the type exists.
|
||||
assert HookEvent is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenRouter compatibility — bundled CLI version pin
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# We're stuck on ``claude-agent-sdk==0.1.45`` (bundled CLI ``2.1.63``)
|
||||
# because every version above introduces a 400 against OpenRouter:
|
||||
#
|
||||
# 1. CLI ``2.1.69`` (= SDK ``0.1.46``) shipped a `tool_reference` content
|
||||
# block in `tool_result.content` that OpenRouter's stricter Zod
|
||||
# validation rejects. See PR
|
||||
# https://github.com/Significant-Gravitas/AutoGPT/pull/12294 for the
|
||||
# forensic write-up that originally pinned us. CLI ``2.1.70`` added
|
||||
# proxy detection that *should* disable the offending block, but two
|
||||
# later attempts (Dependabot bumps to 0.1.55 / 0.1.56) still failed.
|
||||
#
|
||||
# 2. A second regression — the ``context-management-2025-06-27`` beta
|
||||
# header — appeared in some CLI version after ``2.1.91``. Tracked
|
||||
# upstream at
|
||||
# https://github.com/anthropics/claude-agent-sdk-python/issues/789
|
||||
# (still open at the time of writing, no upstream PR yet).
|
||||
#
|
||||
# This test is the cheapest possible regression guard: it pins the
|
||||
# bundled CLI to a known-good version. If anyone bumps
|
||||
# ``claude-agent-sdk`` in ``pyproject.toml``, the bundled CLI version in
|
||||
# ``_cli_version.py`` will change and this test will fail with a clear
|
||||
# message that points the next person at the OpenRouter compat issue
|
||||
# instead of letting them silently re-break production.
|
||||
#
|
||||
# Workaround for actually upgrading: set the
|
||||
# ``claude_agent_cli_path`` config option (or the matching env var) to
|
||||
# point at a separately-installed Claude Code CLI binary at a known-good
|
||||
# version, so the SDK Python API surface and the CLI binary version can
|
||||
# be picked independently.
|
||||
|
||||
# CLI versions verified to work against OpenRouter from production
|
||||
# traffic. When upstream lands a fix and we can confirm a newer version
|
||||
# works, add it to this set rather than blanket-removing the assertion.
|
||||
_KNOWN_GOOD_BUNDLED_CLI_VERSIONS: frozenset[str] = frozenset({"2.1.63"})
|
||||
|
||||
|
||||
def test_bundled_cli_version_is_known_good_against_openrouter():
|
||||
"""Pin the bundled CLI version so accidental SDK bumps cause a loud,
|
||||
fast failure with a pointer to the OpenRouter compatibility issue."""
|
||||
from claude_agent_sdk._cli_version import __cli_version__
|
||||
|
||||
assert __cli_version__ in _KNOWN_GOOD_BUNDLED_CLI_VERSIONS, (
|
||||
f"Bundled Claude Code CLI version is {__cli_version__!r}, which is "
|
||||
f"not in the OpenRouter-known-good set "
|
||||
f"{sorted(_KNOWN_GOOD_BUNDLED_CLI_VERSIONS)!r}. "
|
||||
"If you intentionally bumped `claude-agent-sdk`, verify the new "
|
||||
"bundled CLI works with OpenRouter against the reproduction test "
|
||||
"in `cli_openrouter_compat_test.py`, then add the new CLI version "
|
||||
"to `_KNOWN_GOOD_BUNDLED_CLI_VERSIONS`. If you cannot make the "
|
||||
"bundled CLI work, set `claude_agent_cli_path` to a known-good "
|
||||
"binary instead and skip the bundled one. See "
|
||||
"https://github.com/anthropics/claude-agent-sdk-python/issues/789 "
|
||||
"and https://github.com/Significant-Gravitas/AutoGPT/pull/12294."
|
||||
)
|
||||
|
||||
|
||||
def test_sdk_exposes_cli_path_option():
|
||||
"""Sanity-check that the SDK still exposes the `cli_path` option we use
|
||||
for the OpenRouter workaround. If upstream removes it we need to know."""
|
||||
import inspect
|
||||
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
sig = inspect.signature(ClaudeAgentOptions)
|
||||
assert "cli_path" in sig.parameters, (
|
||||
"ClaudeAgentOptions no longer accepts `cli_path` — our "
|
||||
"claude_agent_cli_path config override would be silently ignored. "
|
||||
"Either find an alternative override mechanism or pin the SDK to a "
|
||||
"version that still exposes it."
|
||||
)
|
||||
|
||||
@@ -34,10 +34,6 @@ from opentelemetry import trace as otel_trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.pending_messages import (
|
||||
drain_pending_messages,
|
||||
format_pending_as_user_message,
|
||||
)
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.copilot.rate_limit import get_user_tier
|
||||
from backend.copilot.transcript import (
|
||||
@@ -959,33 +955,17 @@ async def _build_query_message(
|
||||
use_resume: bool,
|
||||
transcript_msg_count: int,
|
||||
session_id: str,
|
||||
*,
|
||||
session_msg_ceiling: int | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Build the query message with appropriate context.
|
||||
|
||||
Args:
|
||||
session_msg_ceiling: If provided, treat ``session.messages`` as if it
|
||||
only has this many entries when computing the gap slice. Pass
|
||||
``len(session.messages)`` captured *before* appending any pending
|
||||
messages so that mid-turn drains do not skew the gap calculation
|
||||
and cause pending messages to be duplicated in both the gap context
|
||||
and ``current_message``.
|
||||
|
||||
Returns:
|
||||
Tuple of (query_message, was_compacted).
|
||||
"""
|
||||
msg_count = len(session.messages)
|
||||
# Use the ceiling if supplied (prevents pending-message duplication when
|
||||
# messages were appended to session.messages after the drain but before
|
||||
# this function is called).
|
||||
effective_count = (
|
||||
session_msg_ceiling if session_msg_ceiling is not None else msg_count
|
||||
)
|
||||
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
if transcript_msg_count < effective_count - 1:
|
||||
gap = session.messages[transcript_msg_count : effective_count - 1]
|
||||
if transcript_msg_count < msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
compressed, was_compressed = await _compress_messages(gap)
|
||||
gap_context = _format_conversation_context(compressed)
|
||||
if gap_context:
|
||||
@@ -1001,14 +981,12 @@ async def _build_query_message(
|
||||
f"{gap_context}\n\nNow, the user says:\n{current_message}",
|
||||
was_compressed,
|
||||
)
|
||||
elif not use_resume and effective_count > 1:
|
||||
elif not use_resume and msg_count > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({effective_count} messages) — no transcript for --resume"
|
||||
)
|
||||
compressed, was_compressed = await _compress_messages(
|
||||
session.messages[: effective_count - 1]
|
||||
f"{session_id} ({msg_count} messages) — no transcript for --resume"
|
||||
)
|
||||
compressed, was_compressed = await _compress_messages(session.messages[:-1])
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
return (
|
||||
@@ -2064,7 +2042,6 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
assert session is not None # narrowed at line 1898
|
||||
if not (
|
||||
config.claude_agent_use_resume and user_id and len(session.messages) > 1
|
||||
):
|
||||
@@ -2268,6 +2245,12 @@ async def stream_chat_completion_sdk(
|
||||
sdk_options_kwargs["env"] = sdk_env
|
||||
if use_resume and resume_file:
|
||||
sdk_options_kwargs["resume"] = resume_file
|
||||
# Optional explicit Claude Code CLI binary path (decouples the
|
||||
# bundled SDK version from the CLI version we run — needed because
|
||||
# the CLI bundled in 0.1.46+ is broken against OpenRouter). Falls
|
||||
# back to the bundled binary when unset.
|
||||
if config.claude_agent_cli_path:
|
||||
sdk_options_kwargs["cli_path"] = config.claude_agent_cli_path
|
||||
|
||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
@@ -2300,61 +2283,6 @@ async def stream_chat_completion_sdk(
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
# Capture the message count *before* draining so _build_query_message
|
||||
# can compute the gap slice without including the newly-drained pending
|
||||
# messages. Pending messages are both appended to session.messages AND
|
||||
# concatenated into current_message; without the ceiling the gap slice
|
||||
# would extend into the pending messages and duplicate them in the
|
||||
# model's input context (gap_context + current_message both containing
|
||||
# them).
|
||||
_pre_drain_msg_count = len(session.messages)
|
||||
|
||||
# Drain any messages the user queued via POST /messages/pending
|
||||
# while the previous turn was running (or since the session was
|
||||
# idle). Messages are drained ATOMICALLY — one LPOP with count
|
||||
# removes them all at once, so a concurrent push lands *after*
|
||||
# the drain and stays queued for the next turn instead of being
|
||||
# lost between LPOP and clear. File IDs and context are
|
||||
# preserved via format_pending_as_user_message.
|
||||
#
|
||||
# The drained content is concatenated into ``current_message``
|
||||
# so the SDK CLI sees it in the new user message, AND appended
|
||||
# directly to ``session.messages`` (no dedup — pending messages are
|
||||
# atomically-popped from Redis and are never stale-cache duplicates)
|
||||
# so the durable transcript records it too. Session is persisted
|
||||
# immediately after the drain so a crash doesn't lose the messages.
|
||||
# The endpoint deliberately does NOT persist to session.messages —
|
||||
# Redis is the single source of truth until this drain runs.
|
||||
pending_at_start = await drain_pending_messages(session_id)
|
||||
if pending_at_start:
|
||||
logger.info(
|
||||
"%s Draining %d pending message(s) at turn start",
|
||||
log_prefix,
|
||||
len(pending_at_start),
|
||||
)
|
||||
pending_texts: list[str] = [
|
||||
format_pending_as_user_message(pm)["content"] for pm in pending_at_start
|
||||
]
|
||||
for pt in pending_texts:
|
||||
# Append directly — pending messages are atomically-popped from
|
||||
# Redis and are never stale-cache duplicates, so the
|
||||
# maybe_append_user_message dedup is wrong here.
|
||||
session.messages.append(ChatMessage(role="user", content=pt))
|
||||
if current_message.strip():
|
||||
current_message = current_message + "\n\n" + "\n\n".join(pending_texts)
|
||||
else:
|
||||
current_message = "\n\n".join(pending_texts)
|
||||
# Persist immediately so a crash between here and the finally block
|
||||
# doesn't lose messages that were already drained from Redis.
|
||||
try:
|
||||
session = await upsert_chat_session(session)
|
||||
except Exception as _persist_err:
|
||||
logger.warning(
|
||||
"%s Failed to persist drained pending messages: %s",
|
||||
log_prefix,
|
||||
_persist_err,
|
||||
)
|
||||
|
||||
if not current_message.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
@@ -2368,7 +2296,6 @@ async def stream_chat_completion_sdk(
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
session_msg_ceiling=_pre_drain_msg_count,
|
||||
)
|
||||
# On the first turn inject user context into the message instead of the
|
||||
# system prompt — the system prompt is now static (same for all users)
|
||||
@@ -2506,7 +2433,6 @@ async def stream_chat_completion_sdk(
|
||||
state.use_resume,
|
||||
state.transcript_msg_count,
|
||||
session_id,
|
||||
session_msg_ceiling=_pre_drain_msg_count,
|
||||
)
|
||||
if attachments.hint:
|
||||
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
|
||||
@@ -2836,11 +2762,6 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
raise
|
||||
finally:
|
||||
# Pending messages are drained atomically at the start of each
|
||||
# turn (see drain_pending_messages call above), so there's
|
||||
# nothing to clean up here — any message pushed after that
|
||||
# point belongs to the next turn.
|
||||
|
||||
# --- Close OTEL context (with cost attributes) ---
|
||||
if _otel_ctx is not None:
|
||||
try:
|
||||
|
||||
@@ -4,10 +4,10 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import PlatformCostLog as PrismaLog
|
||||
from prisma.types import PlatformCostLogCreateInput
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MICRODOLLARS_PER_USD = 1_000_000
|
||||
|
||||
# Dashboard query limits — keep in sync with the SQL queries below
|
||||
# Dashboard query limits
|
||||
MAX_PROVIDER_ROWS = 500
|
||||
MAX_USER_ROWS = 100
|
||||
|
||||
@@ -169,53 +169,61 @@ class PlatformCostDashboard(BaseModel):
|
||||
total_users: int
|
||||
|
||||
|
||||
def _build_where(
|
||||
def _si(row: dict, field: str) -> int:
|
||||
"""Extract an integer from a Prisma group_by _sum dict.
|
||||
|
||||
Prisma Python serialises BigInt/Int aggregate sums as strings; coerce to int.
|
||||
"""
|
||||
return int((row.get("_sum") or {}).get(field) or 0)
|
||||
|
||||
|
||||
def _sf(row: dict, field: str) -> float:
|
||||
"""Extract a float from a Prisma group_by _sum dict."""
|
||||
return float((row.get("_sum") or {}).get(field) or 0.0)
|
||||
|
||||
|
||||
def _ca(row: dict) -> int:
|
||||
"""Extract _count._all from a Prisma group_by row."""
|
||||
c = row.get("_count") or {}
|
||||
return int(c.get("_all") or 0) if isinstance(c, dict) else int(c or 0)
|
||||
|
||||
|
||||
def _build_prisma_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
table_alias: str = "",
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
) -> tuple[str, list[Any]]:
|
||||
prefix = f"{table_alias}." if table_alias else ""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
) -> PlatformCostLogWhereInput:
|
||||
"""Build a Prisma WhereInput for PlatformCostLog filters."""
|
||||
where: PlatformCostLogWhereInput = {}
|
||||
|
||||
if start and end:
|
||||
where["createdAt"] = {"gte": start, "lte": end}
|
||||
elif start:
|
||||
where["createdAt"] = {"gte": start}
|
||||
elif end:
|
||||
where["createdAt"] = {"lte": end}
|
||||
|
||||
if start:
|
||||
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
if end:
|
||||
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
if provider:
|
||||
# Provider names are normalized to lowercase at write time so a plain
|
||||
# equality check is sufficient and the (provider, createdAt) index is used.
|
||||
clauses.append(f'{prefix}"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
if user_id:
|
||||
clauses.append(f'{prefix}"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
if model:
|
||||
clauses.append(f'{prefix}"model" = ${idx}')
|
||||
params.append(model)
|
||||
idx += 1
|
||||
if block_name:
|
||||
clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})')
|
||||
params.append(block_name)
|
||||
idx += 1
|
||||
if tracking_type:
|
||||
clauses.append(f'{prefix}"trackingType" = ${idx}')
|
||||
params.append(tracking_type)
|
||||
idx += 1
|
||||
where["provider"] = provider.lower()
|
||||
|
||||
return (" AND ".join(clauses) if clauses else "TRUE", params)
|
||||
if user_id:
|
||||
where["userId"] = user_id
|
||||
|
||||
if model:
|
||||
where["model"] = model
|
||||
|
||||
if block_name:
|
||||
# Case-insensitive match — mirrors the original LOWER() SQL filter.
|
||||
where["blockName"] = {"equals": block_name, "mode": "insensitive"}
|
||||
|
||||
if tracking_type:
|
||||
where["trackingType"] = tracking_type
|
||||
|
||||
return where
|
||||
|
||||
|
||||
@cached(ttl_seconds=30)
|
||||
@@ -241,110 +249,107 @@ async def get_platform_cost_dashboard(
|
||||
"""
|
||||
if start is None:
|
||||
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_p, params_p = _build_where(
|
||||
start, end, provider, user_id, "p", model, block_name, tracking_type
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
|
||||
by_provider_rows, by_user_rows, total_user_rows, total_agg_rows = (
|
||||
sum_fields = {
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
"cacheReadTokens": True,
|
||||
"cacheCreationTokens": True,
|
||||
"duration": True,
|
||||
"trackingAmount": True,
|
||||
}
|
||||
|
||||
# Run all four aggregation queries in parallel.
|
||||
by_provider_groups, by_user_groups, total_user_groups, total_agg_groups = (
|
||||
await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."model",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(p."cacheReadTokens"), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(p."cacheCreationTokens"), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
|
||||
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
GROUP BY p."provider", p."trackingType", p."model"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_PROVIDER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
|
||||
# sort by total cost descending in Python after fetch.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType", "model"],
|
||||
where=where,
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_p}
|
||||
GROUP BY p."userId", u."email"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_USER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
# userId aggregation — emails fetched separately below.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
""",
|
||||
*params_p,
|
||||
# Distinct user count: group by userId, count groups.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
count=True,
|
||||
),
|
||||
# Separate aggregate query so dashboard totals are never derived
|
||||
# from the capped by_provider_rows list. With model-level grouping,
|
||||
# MAX_PROVIDER_ROWS is hit more easily; summing the capped rows
|
||||
# would silently undercount once >500 (provider, type, model) exist.
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
""",
|
||||
*params_p,
|
||||
# Total aggregate: group by provider (no limit) to sum across all
|
||||
# matching rows. Summed in Python to get grand totals.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider"],
|
||||
where=where,
|
||||
sum={"costMicrodollars": True},
|
||||
count=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
|
||||
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
|
||||
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
|
||||
total_cost = int(total_agg_rows[0]["total_cost"]) if total_agg_rows else 0
|
||||
total_requests = int(total_agg_rows[0]["request_count"]) if total_agg_rows else 0
|
||||
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
|
||||
by_provider_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True)
|
||||
by_provider_groups = by_provider_groups[:MAX_PROVIDER_ROWS]
|
||||
|
||||
# Sort by_user by total cost descending and cap at MAX_USER_ROWS.
|
||||
by_user_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True)
|
||||
by_user_groups = by_user_groups[:MAX_USER_ROWS]
|
||||
|
||||
# Batch-fetch emails for the users in by_user.
|
||||
user_ids = [r["userId"] for r in by_user_groups if r.get("userId") is not None]
|
||||
email_by_user_id: dict[str, str | None] = {}
|
||||
if user_ids:
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
where={"id": {"in": user_ids}},
|
||||
)
|
||||
email_by_user_id = {u.id: u.email for u in users}
|
||||
|
||||
# Total distinct users — exclude the NULL-userId group (deleted users).
|
||||
total_users = len([g for g in total_user_groups if g.get("userId") is not None])
|
||||
|
||||
# Grand totals — sum across all provider groups (no LIMIT applied above).
|
||||
total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups)
|
||||
total_requests = sum(_ca(r) for r in total_agg_groups)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
tracking_type=r.get("trackingType"),
|
||||
model=r.get("model"),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
total_cache_read_tokens=r.get("total_cache_read_tokens", 0),
|
||||
total_cache_creation_tokens=r.get("total_cache_creation_tokens", 0),
|
||||
total_duration_seconds=r.get("total_duration", 0.0),
|
||||
total_tracking_amount=r.get("total_tracking_amount", 0.0),
|
||||
request_count=r["request_count"],
|
||||
total_cost_microdollars=_si(r, "costMicrodollars"),
|
||||
total_input_tokens=_si(r, "inputTokens"),
|
||||
total_output_tokens=_si(r, "outputTokens"),
|
||||
total_cache_read_tokens=_si(r, "cacheReadTokens"),
|
||||
total_cache_creation_tokens=_si(r, "cacheCreationTokens"),
|
||||
total_duration_seconds=_sf(r, "duration"),
|
||||
total_tracking_amount=_sf(r, "trackingAmount"),
|
||||
request_count=_ca(r),
|
||||
)
|
||||
for r in by_provider_rows
|
||||
for r in by_provider_groups
|
||||
],
|
||||
by_user=[
|
||||
UserCostSummary(
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
request_count=r["request_count"],
|
||||
user_id=r.get("userId"),
|
||||
email=_mask_email(email_by_user_id.get(r.get("userId") or "")),
|
||||
total_cost_microdollars=_si(r, "costMicrodollars"),
|
||||
total_input_tokens=_si(r, "inputTokens"),
|
||||
total_output_tokens=_si(r, "outputTokens"),
|
||||
request_count=_ca(r),
|
||||
)
|
||||
for r in by_user_rows
|
||||
for r in by_user_groups
|
||||
],
|
||||
total_cost_microdollars=total_cost,
|
||||
total_requests=total_requests,
|
||||
@@ -365,73 +370,41 @@ async def get_platform_cost_logs(
|
||||
) -> tuple[list[CostLogRow], int]:
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_sql, params = _build_where(
|
||||
start, end, provider, user_id, "p", model, block_name, tracking_type
|
||||
)
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
offset = (page - 1) * page_size
|
||||
limit_idx = len(params) + 1
|
||||
offset_idx = len(params) + 2
|
||||
|
||||
count_rows, rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*)::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_sql}
|
||||
""",
|
||||
*params,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."id",
|
||||
p."createdAt" AS created_at,
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
p."graphExecId" AS graph_exec_id,
|
||||
p."nodeExecId" AS node_exec_id,
|
||||
p."blockName" AS block_name,
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."costMicrodollars" AS cost_microdollars,
|
||||
p."inputTokens" AS input_tokens,
|
||||
p."outputTokens" AS output_tokens,
|
||||
p."cacheReadTokens" AS cache_read_tokens,
|
||||
p."cacheCreationTokens" AS cache_creation_tokens,
|
||||
p."duration",
|
||||
p."model"
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_sql}
|
||||
ORDER BY p."createdAt" DESC, p."id" DESC
|
||||
LIMIT ${limit_idx} OFFSET ${offset_idx}
|
||||
""",
|
||||
*params,
|
||||
page_size,
|
||||
offset,
|
||||
total, rows = await asyncio.gather(
|
||||
PrismaLog.prisma().count(where=where),
|
||||
PrismaLog.prisma().find_many(
|
||||
where=where,
|
||||
include={"User": True},
|
||||
order=[{"createdAt": "desc"}, {"id": "desc"}],
|
||||
take=page_size,
|
||||
skip=offset,
|
||||
),
|
||||
)
|
||||
total = count_rows[0]["cnt"] if count_rows else 0
|
||||
|
||||
logs = [
|
||||
CostLogRow(
|
||||
id=r["id"],
|
||||
created_at=r["created_at"],
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
graph_exec_id=r.get("graph_exec_id"),
|
||||
node_exec_id=r.get("node_exec_id"),
|
||||
block_name=r["block_name"],
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
cost_microdollars=r.get("cost_microdollars"),
|
||||
input_tokens=r.get("input_tokens"),
|
||||
output_tokens=r.get("output_tokens"),
|
||||
cache_read_tokens=r.get("cache_read_tokens"),
|
||||
cache_creation_tokens=r.get("cache_creation_tokens"),
|
||||
duration=r.get("duration"),
|
||||
model=r.get("model"),
|
||||
id=r.id,
|
||||
created_at=r.createdAt,
|
||||
user_id=r.userId,
|
||||
email=_mask_email(r.User.email if r.User else None),
|
||||
graph_exec_id=r.graphExecId,
|
||||
node_exec_id=r.nodeExecId,
|
||||
block_name=r.blockName or "",
|
||||
provider=r.provider,
|
||||
tracking_type=r.trackingType,
|
||||
cost_microdollars=r.costMicrodollars,
|
||||
input_tokens=r.inputTokens,
|
||||
output_tokens=r.outputTokens,
|
||||
cache_read_tokens=getattr(r, "cacheReadTokens", None),
|
||||
cache_creation_tokens=getattr(r, "cacheCreationTokens", None),
|
||||
duration=r.duration,
|
||||
model=r.model,
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
@@ -457,38 +430,16 @@ async def get_platform_cost_logs_for_export(
|
||||
"""
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_sql, params = _build_where(
|
||||
start, end, provider, user_id, "p", model, block_name, tracking_type
|
||||
)
|
||||
limit_idx = len(params) + 1
|
||||
|
||||
rows = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."id",
|
||||
p."createdAt" AS created_at,
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
p."graphExecId" AS graph_exec_id,
|
||||
p."nodeExecId" AS node_exec_id,
|
||||
p."blockName" AS block_name,
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."costMicrodollars" AS cost_microdollars,
|
||||
p."inputTokens" AS input_tokens,
|
||||
p."outputTokens" AS output_tokens,
|
||||
p."cacheReadTokens" AS cache_read_tokens,
|
||||
p."cacheCreationTokens" AS cache_creation_tokens,
|
||||
p."duration",
|
||||
p."model"
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_sql}
|
||||
ORDER BY p."createdAt" DESC, p."id" DESC
|
||||
LIMIT ${limit_idx}
|
||||
""",
|
||||
*params,
|
||||
EXPORT_MAX_ROWS + 1,
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
|
||||
rows = await PrismaLog.prisma().find_many(
|
||||
where=where,
|
||||
include={"User": True},
|
||||
order=[{"createdAt": "desc"}, {"id": "desc"}],
|
||||
take=EXPORT_MAX_ROWS + 1,
|
||||
)
|
||||
|
||||
truncated = len(rows) > EXPORT_MAX_ROWS
|
||||
@@ -496,22 +447,80 @@ async def get_platform_cost_logs_for_export(
|
||||
|
||||
return [
|
||||
CostLogRow(
|
||||
id=r["id"],
|
||||
created_at=r["created_at"],
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
graph_exec_id=r.get("graph_exec_id"),
|
||||
node_exec_id=r.get("node_exec_id"),
|
||||
block_name=r["block_name"],
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
cost_microdollars=r.get("cost_microdollars"),
|
||||
input_tokens=r.get("input_tokens"),
|
||||
output_tokens=r.get("output_tokens"),
|
||||
cache_read_tokens=r.get("cache_read_tokens"),
|
||||
cache_creation_tokens=r.get("cache_creation_tokens"),
|
||||
duration=r.get("duration"),
|
||||
model=r.get("model"),
|
||||
id=r.id,
|
||||
created_at=r.createdAt,
|
||||
user_id=r.userId,
|
||||
email=_mask_email(r.User.email if r.User else None),
|
||||
graph_exec_id=r.graphExecId,
|
||||
node_exec_id=r.nodeExecId,
|
||||
block_name=r.blockName or "",
|
||||
provider=r.provider,
|
||||
tracking_type=r.trackingType,
|
||||
cost_microdollars=r.costMicrodollars,
|
||||
input_tokens=r.inputTokens,
|
||||
output_tokens=r.outputTokens,
|
||||
cache_read_tokens=getattr(r, "cacheReadTokens", None),
|
||||
cache_creation_tokens=getattr(r, "cacheCreationTokens", None),
|
||||
duration=r.duration,
|
||||
model=r.model,
|
||||
)
|
||||
for r in rows
|
||||
], truncated
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers kept for backward-compatibility with existing tests.
|
||||
# New code should not use these — use _build_prisma_where instead.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
table_alias: str = "",
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
) -> tuple[str, list[Any]]:
|
||||
"""Legacy SQL WHERE builder — retained so existing unit tests still pass.
|
||||
|
||||
Only used by tests that verify the SQL-string generation logic. All
|
||||
production code uses _build_prisma_where instead.
|
||||
"""
|
||||
prefix = f"{table_alias}." if table_alias else ""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
|
||||
if start:
|
||||
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
if end:
|
||||
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
if provider:
|
||||
clauses.append(f'{prefix}"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
if user_id:
|
||||
clauses.append(f'{prefix}"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
if model:
|
||||
clauses.append(f'{prefix}"model" = ${idx}')
|
||||
params.append(model)
|
||||
idx += 1
|
||||
if block_name:
|
||||
clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})')
|
||||
params.append(block_name)
|
||||
idx += 1
|
||||
if tracking_type:
|
||||
clauses.append(f'{prefix}"trackingType" = ${idx}')
|
||||
params.append(tracking_type)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses) if clauses else "TRUE", params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for helpers and async functions in platform_cost module."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma import Json
|
||||
@@ -224,6 +224,41 @@ class TestLogPlatformCostSafe:
|
||||
mock_create.assert_awaited_once()
|
||||
|
||||
|
||||
def _make_group_by_row(
|
||||
provider: str = "openai",
|
||||
tracking_type: str | None = "tokens",
|
||||
model: str | None = None,
|
||||
cost: int = 5000,
|
||||
input_tokens: int = 1000,
|
||||
output_tokens: int = 500,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
duration: float = 10.5,
|
||||
tracking_amount: float = 0.0,
|
||||
count: int = 3,
|
||||
user_id: str | None = None,
|
||||
) -> dict:
|
||||
row: dict = {
|
||||
"_sum": {
|
||||
"costMicrodollars": cost,
|
||||
"inputTokens": input_tokens,
|
||||
"outputTokens": output_tokens,
|
||||
"cacheReadTokens": cache_read_tokens,
|
||||
"cacheCreationTokens": cache_creation_tokens,
|
||||
"duration": duration,
|
||||
"trackingAmount": tracking_amount,
|
||||
},
|
||||
"_count": {"_all": count},
|
||||
}
|
||||
if user_id is not None:
|
||||
row["userId"] = user_id
|
||||
else:
|
||||
row["provider"] = provider
|
||||
row["trackingType"] = tracking_type
|
||||
row["model"] = model
|
||||
return row
|
||||
|
||||
|
||||
class TestGetPlatformCostDashboard:
|
||||
def setup_method(self):
|
||||
# @cached stores results in-process; clear between tests to avoid bleed.
|
||||
@@ -231,35 +266,44 @@ class TestGetPlatformCostDashboard:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_dashboard_with_data(self):
|
||||
provider_rows = [
|
||||
{
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"total_duration": 10.5,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
user_rows = [
|
||||
{
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
# Dashboard runs 4 queries: by_provider, by_user, COUNT(DISTINCT userId),
|
||||
# and a separate total aggregate (total_cost + request_count with no LIMIT).
|
||||
agg_rows = [{"total_cost": 5000, "request_count": 3}]
|
||||
mock_query = AsyncMock(
|
||||
side_effect=[provider_rows, user_rows, [{"cnt": 1}], agg_rows]
|
||||
provider_row = _make_group_by_row(
|
||||
provider="openai",
|
||||
tracking_type="tokens",
|
||||
cost=5000,
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
duration=10.5,
|
||||
count=3,
|
||||
)
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
user_row = _make_group_by_row(user_id="u1", cost=5000, count=3)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "u1"
|
||||
mock_user.email = "a@b.com"
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[mock_user])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
assert dashboard.total_cost_microdollars == 5000
|
||||
assert dashboard.total_requests == 3
|
||||
assert dashboard.total_users == 1
|
||||
@@ -271,10 +315,67 @@ class TestGetPlatformCostDashboard:
|
||||
assert dashboard.by_user[0].email == "a***@b.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_query = AsyncMock(side_effect=[[], [], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
async def test_cache_tokens_aggregated_not_hardcoded(self):
|
||||
"""cache_read_tokens and cache_creation_tokens must be read from the
|
||||
DB aggregation, not hardcoded to 0 (regression guard for Sentry report)."""
|
||||
provider_row = _make_group_by_row(
|
||||
provider="anthropic",
|
||||
tracking_type="tokens",
|
||||
cost=1000,
|
||||
input_tokens=800,
|
||||
output_tokens=200,
|
||||
cache_read_tokens=400,
|
||||
cache_creation_tokens=100,
|
||||
count=1,
|
||||
)
|
||||
user_row = _make_group_by_row(user_id="u2", cost=1000, count=1)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[{"userId": "u2"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
assert len(dashboard.by_provider) == 1
|
||||
row = dashboard.by_provider[0]
|
||||
assert row.total_cache_read_tokens == 400
|
||||
assert row.total_cache_creation_tokens == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
assert dashboard.total_cost_microdollars == 0
|
||||
assert dashboard.total_requests == 0
|
||||
assert dashboard.total_users == 0
|
||||
@@ -284,160 +385,228 @@ class TestGetPlatformCostDashboard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_filters_to_queries(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(side_effect=[[], [], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
assert mock_query.await_count == 4
|
||||
first_call_sql = mock_query.call_args_list[0][0][0]
|
||||
assert "createdAt" in first_call_sql
|
||||
|
||||
# group_by called 4 times (by_provider, by_user, distinct users, totals)
|
||||
assert mock_actions.group_by.await_count == 4
|
||||
# The where dict passed to the first call should include createdAt
|
||||
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
|
||||
assert "createdAt" in first_call_kwargs.get("where", {})
|
||||
|
||||
|
||||
def _make_prisma_log_row(
|
||||
i: int = 0,
|
||||
user_email: str | None = None,
|
||||
) -> MagicMock:
|
||||
row = MagicMock()
|
||||
row.id = f"log-{i}"
|
||||
row.createdAt = datetime(2026, 3, 1, tzinfo=timezone.utc)
|
||||
row.userId = "u1"
|
||||
row.graphExecId = None
|
||||
row.nodeExecId = None
|
||||
row.blockName = "TestBlock"
|
||||
row.provider = "openai"
|
||||
row.trackingType = "tokens"
|
||||
row.costMicrodollars = 1000
|
||||
row.inputTokens = 10
|
||||
row.outputTokens = 5
|
||||
row.duration = 0.5
|
||||
row.model = "gpt-4"
|
||||
# cacheReadTokens / cacheCreationTokens may not exist on older Prisma clients
|
||||
row.configure_mock(**{"cacheReadTokens": None, "cacheCreationTokens": None})
|
||||
if user_email is not None:
|
||||
row.User = MagicMock()
|
||||
row.User.email = user_email
|
||||
else:
|
||||
row.User = None
|
||||
return row
|
||||
|
||||
|
||||
class TestGetPlatformCostLogs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_logs_and_total(self):
|
||||
count_rows = [{"cnt": 1}]
|
||||
log_rows = [
|
||||
{
|
||||
"id": "log-1",
|
||||
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"graph_exec_id": "g1",
|
||||
"node_exec_id": "n1",
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"cost_microdollars": 5000,
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"duration": 1.5,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
]
|
||||
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
row = _make_prisma_log_row(0, user_email="a@b.com")
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=1)
|
||||
mock_actions.find_many = AsyncMock(return_value=[row])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs(page=1, page_size=10)
|
||||
|
||||
assert total == 1
|
||||
assert len(logs) == 1
|
||||
assert logs[0].id == "log-1"
|
||||
assert logs[0].id == "log-0"
|
||||
assert logs[0].provider == "openai"
|
||||
assert logs[0].model == "gpt-4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_data(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=0)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
|
||||
assert total == 0
|
||||
assert logs == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_offset(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs(page=3, page_size=25)
|
||||
assert total == 100
|
||||
second_call_args = mock_query.call_args_list[1][0]
|
||||
assert 25 in second_call_args # page_size
|
||||
assert 50 in second_call_args # offset = (3-1) * 25
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=100)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_count_returns_zero(self):
|
||||
mock_query = AsyncMock(side_effect=[[], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
assert total == 0
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs(page=3, page_size=25)
|
||||
|
||||
assert total == 100
|
||||
find_many_call = mock_actions.find_many.call_args[1]
|
||||
assert find_many_call["take"] == 25
|
||||
assert find_many_call["skip"] == 50 # (3-1) * 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_start_skips_default(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=0)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs(start=start)
|
||||
|
||||
assert total == 0
|
||||
|
||||
|
||||
def _make_log_row(i: int = 0) -> dict:
|
||||
return {
|
||||
"id": f"log-{i}",
|
||||
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
"user_id": "u1",
|
||||
"email": None,
|
||||
"graph_exec_id": None,
|
||||
"node_exec_id": None,
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"cost_microdollars": 1000,
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"duration": 0.5,
|
||||
"model": "gpt-4",
|
||||
"cache_read_tokens": None,
|
||||
"cache_creation_tokens": None,
|
||||
}
|
||||
where = mock_actions.count.call_args[1]["where"]
|
||||
# start provided — should appear in the where filter
|
||||
assert "createdAt" in where
|
||||
|
||||
|
||||
class TestGetPlatformCostLogsForExport:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_logs_not_truncated(self):
|
||||
rows = [_make_log_row(0)]
|
||||
mock_query = AsyncMock(return_value=rows)
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
row = _make_prisma_log_row(0)
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[row])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export()
|
||||
|
||||
assert len(logs) == 1
|
||||
assert truncated is False
|
||||
assert logs[0].id == "log-0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_not_truncated(self):
|
||||
mock_query = AsyncMock(return_value=[])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export()
|
||||
|
||||
assert logs == []
|
||||
assert truncated is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_at_export_max_rows(self):
|
||||
rows = [_make_log_row(i) for i in range(3)]
|
||||
mock_query = AsyncMock(return_value=rows)
|
||||
with patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema", new=mock_query
|
||||
), patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2):
|
||||
rows = [_make_prisma_log_row(i) for i in range(3)]
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=rows)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2),
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export()
|
||||
|
||||
assert len(logs) == 2
|
||||
assert truncated is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_model_block_tracking_filters(self):
|
||||
mock_query = AsyncMock(return_value=[])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
await get_platform_cost_logs_for_export(
|
||||
model="gpt-4", block_name="LLMBlock", tracking_type="tokens"
|
||||
)
|
||||
call_args = mock_query.call_args[0]
|
||||
assert "gpt-4" in call_args
|
||||
assert "LLMBlock" in call_args
|
||||
assert "tokens" in call_args
|
||||
|
||||
where = mock_actions.find_many.call_args[1]["where"]
|
||||
assert where.get("model") == "gpt-4"
|
||||
assert where.get("trackingType") == "tokens"
|
||||
# blockName uses a dict filter for case-insensitive match
|
||||
assert "blockName" in where
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maps_cache_tokens(self):
|
||||
row = _make_log_row(0)
|
||||
row["cache_read_tokens"] = 50
|
||||
row["cache_creation_tokens"] = 25
|
||||
mock_query = AsyncMock(return_value=[row])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
row = _make_prisma_log_row(0)
|
||||
row.configure_mock(**{"cacheReadTokens": 50, "cacheCreationTokens": 25})
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[row])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, _ = await get_platform_cost_logs_for_export()
|
||||
|
||||
assert logs[0].cache_read_tokens == 50
|
||||
assert logs[0].cache_creation_tokens == 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_start_skips_default(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(return_value=[])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export(start=start)
|
||||
|
||||
assert logs == []
|
||||
assert truncated is False
|
||||
where = mock_actions.find_many.call_args[1]["where"]
|
||||
assert "createdAt" in where
|
||||
|
||||
@@ -1605,56 +1605,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/messages/pending": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Queue Pending Message",
|
||||
"description": "Queue a new user message into an in-flight copilot turn.\n\nWhen a user sends a follow-up message while a turn is still\nstreaming, we don't want to block them or start a separate turn —\nthis endpoint appends the message to a per-session pending buffer.\nThe executor currently running the turn (baseline path) drains the\nbuffer between tool-call rounds and appends the message to the\nconversation before the next LLM call. On the SDK path the buffer\nis drained at the *start* of the next turn (the long-lived\n``ClaudeSDKClient.receive_response`` iterator returns after a\n``ResultMessage`` so there is no safe point to inject mid-stream\ninto an existing connection).\n\nReturns 202. Enforces the same per-user daily/weekly token rate\nlimit as the regular ``/stream`` endpoint so a client can't bypass\nit by batching messages through here.",
|
||||
"operationId": "postV2QueuePendingMessage",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/QueuePendingMessageRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"202": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/QueuePendingMessageResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -12718,57 +12668,6 @@
|
||||
"required": ["providers", "pagination"],
|
||||
"title": "ProviderResponse"
|
||||
},
|
||||
"QueuePendingMessageRequest": {
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"maxLength": 16000,
|
||||
"minLength": 1,
|
||||
"title": "Message"
|
||||
},
|
||||
"context": {
|
||||
"anyOf": [
|
||||
{
|
||||
"additionalProperties": { "type": "string" },
|
||||
"type": "object"
|
||||
},
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Context",
|
||||
"description": "Optional page context: expected keys are 'url' and 'content'."
|
||||
},
|
||||
"file_ids": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"maxItems": 20
|
||||
},
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "File Ids"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"required": ["message"],
|
||||
"title": "QueuePendingMessageRequest",
|
||||
"description": "Request model for queueing a message into an in-flight turn.\n\nUnlike ``StreamChatRequest`` this endpoint does **not** start a new\nturn — the message is appended to a per-session pending buffer that\nthe executor currently processing the turn will drain between tool\nrounds."
|
||||
},
|
||||
"QueuePendingMessageResponse": {
|
||||
"properties": {
|
||||
"buffer_length": { "type": "integer", "title": "Buffer Length" },
|
||||
"max_buffer_length": {
|
||||
"type": "integer",
|
||||
"title": "Max Buffer Length"
|
||||
},
|
||||
"turn_in_flight": { "type": "boolean", "title": "Turn In Flight" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["buffer_length", "max_buffer_length", "turn_in_flight"],
|
||||
"title": "QueuePendingMessageResponse",
|
||||
"description": "Response for the pending-message endpoint.\n\n- ``buffer_length``: how many messages are now in the session's\n pending buffer (after this push)\n- ``max_buffer_length``: the per-session cap (server-side constant)\n- ``turn_in_flight``: ``True`` if a copilot turn was running when\n we checked — purely informational for UX feedback. Even when\n ``False`` the message is still queued: the next turn drains it."
|
||||
},
|
||||
"RateLimitResetResponse": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
|
||||
|
Before Width: | Height: | Size: 114 KiB |
|
Before Width: | Height: | Size: 46 KiB |
|
Before Width: | Height: | Size: 66 KiB |
|
Before Width: | Height: | Size: 82 KiB |
|
Before Width: | Height: | Size: 78 KiB |
|
Before Width: | Height: | Size: 90 KiB |
|
Before Width: | Height: | Size: 75 KiB |
|
Before Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 82 KiB |
|
Before Width: | Height: | Size: 77 KiB |
|
Before Width: | Height: | Size: 80 KiB |
|
Before Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 85 KiB |
|
Before Width: | Height: | Size: 80 KiB |
|
Before Width: | Height: | Size: 89 KiB |
|
Before Width: | Height: | Size: 88 KiB |
|
Before Width: | Height: | Size: 65 KiB |
|
Before Width: | Height: | Size: 70 KiB |
|
Before Width: | Height: | Size: 94 KiB |