Compare commits

..

6 Commits

Author SHA1 Message Date
majdyz
6dc0b6cffd test(copilot/sdk-compat): tighten reproduction test (regex scan, proc reap, strict assertions, public socket API)
Address self-review findings on cli_openrouter_compat_test.py:

- Switch the tool_reference detection to a whitespace-tolerant regex
  (`"type"\s*:\s*"tool_reference"`). The Claude Code CLI is Node.js
  and `JSON.stringify` without an indent emits no whitespace, producing
  `{"type":"tool_reference"}`. The previous literal substring with one
  spacing would silently miss the real regression.

- Reap the subprocess after `proc.kill()` on timeout via
  `await asyncio.wait_for(proc.wait(), timeout=5)` so we don't leak a
  zombie + open pipe FDs across CI runs.

- Tighten `test_returns_none_when_env_var_points_to_missing_file` to
  assert `resolved is None` exactly. The previous
  `is None or .is_file()` was too permissive — it would also accept
  the function silently falling through to the bundled binary, which
  would defeat the explicit-override semantics.

- Replace `site._server` private aiohttp access with the public socket
  API: bind an ephemeral port via `socket.bind` and pass it to
  `web.SockSite`. Reading the port back via `getsockname` is robust to
  aiohttp internal changes.

- Convert the catch-all 404 route handler from a bare lambda to an
  `async def fallback_handler` to silence the aiohttp deprecation
  warning ("Bare functions are deprecated, use async ones").
2026-04-11 11:43:45 +00:00
majdyz
a6e306d28a fix(copilot): accept unprefixed CLAUDE_AGENT_CLI_PATH in config
The new `claude_agent_cli_path` field inherited the `CHAT_` Pydantic
prefix from `ChatConfig`, so the documented `CLAUDE_AGENT_CLI_PATH`
env var was silently ignored — operators following the PR description
or the field docstring would set the unprefixed form and the config
would fall back to the bundled CLI.

Add a `field_validator` that reads `CHAT_CLAUDE_AGENT_CLI_PATH` first
and falls back to the unprefixed `CLAUDE_AGENT_CLI_PATH`, matching the
same pattern already used by `api_key` and `base_url`. The test helper
`_resolve_cli_path` in `cli_openrouter_compat_test.py` mirrors the
same two-name lookup so the reproduction test picks up the override
regardless of which form is set, and a new test covers the prefixed
variant explicitly.

Flagged by sentry review on #12741 (thread IDs 3067725580 and
3067768817) as two instances of the same bug.
2026-04-11 10:11:47 +00:00
majdyz
d6f0fcb052 test(copilot/sdk-compat): unit-test the forbidden-pattern scanner
Add direct unit tests for `_scan_request_for_forbidden_patterns` and
`_resolve_cli_path` so the helper logic stays exercised even on CI
runs where the slow end-to-end CLI subprocess test can't capture a
request (sandboxed runner, missing CLI binary, etc).

Brings codecov/patch coverage above the 80% gate. No production
code changes — tests only.
2026-04-11 07:57:04 +00:00
majdyz
feb247d56e chore(backend): drop stray blank line in platform_cost_test.py
Same pre-existing dev-branch lint issue from PR #12739 — black would
reformat this file (extra blank line between two test classes), which
fails the `lint` CI job for any PR branched from current dev.
2026-04-11 07:10:55 +00:00
majdyz
fdb3590693 chore(copilot): add SDK CLI override + OpenRouter compat regression tests
We've been pinned at `claude-agent-sdk==0.1.45` (bundled CLI 2.1.63)
since PR #12294 because every version above introduces a 400 against
OpenRouter. There are two stacked regressions today:

1. CLI 2.1.69 (= SDK 0.1.46) added a `tool_reference` content block in
   `tool_result.content` that OpenRouter's stricter Zod validation
   rejects. CLI 2.1.70 added a proxy-detection workaround but our
   subsequent attempts at 0.1.55 and 0.1.56 still failed.
2. A newer regression — the `context-management-2025-06-27` beta
   header — appears in some CLI version after 2.1.91. Tracked upstream
   at anthropics/claude-agent-sdk-python#789, still open with no fix.

This commit doesn't actually upgrade the SDK — it adds the
infrastructure we need to upgrade safely *when* upstream lands a fix
or when we identify a known-good newer CLI version via bisection:

* `ChatConfig.claude_agent_cli_path` (env: `CLAUDE_AGENT_CLI_PATH`)
  threads through to `ClaudeAgentOptions(cli_path=...)` so we can
  decouple the Python SDK API surface from the CLI binary version.
  `_prewarm_cli` in the CoPilotExecutor honours the same override.

* `test_bundled_cli_version_is_known_good_against_openrouter` pins
  the bundled CLI to a known-good set (`{"2.1.63"}` today). Any
  `claude-agent-sdk` bump that changes the bundled CLI will fail this
  test loudly with a pointer to PR #12294 and issue #789, instead of
  silently re-breaking production.

* `test_sdk_exposes_cli_path_option` is a forward-compat sentinel that
  fails fast if upstream removes the `cli_path` option we depend on
  for the override.

* `cli_openrouter_compat_test.py` is the actual reproduction test:
  spawns the bundled (or `CLAUDE_AGENT_CLI_PATH`-overridden) CLI
  against an in-process aiohttp server pretending to be the Anthropic
  Messages API, captures every request body the CLI sends, and
  asserts that none of them contain the two known forbidden patterns
  (`"type": "tool_reference"` content blocks or
  `"context-management-2025-06-27"` in body or `anthropic-beta`
  header). The fake server returns a minimal valid streamed response
  so the CLI doesn't error out before we can inspect what it sent.
  No OpenRouter API key required — the test reproduces the *mechanism*
  rather than the symptom, so it's deterministic and free to run in CI.

Workflow for verifying a candidate upgrade going forward: bump the
SDK in `pyproject.toml`, push the commit, and watch the CI run for
both tests in `sdk_compat_test.py` and `cli_openrouter_compat_test.py`.
A clean run on both means it's safe to add the new bundled CLI version
to `_KNOWN_GOOD_BUNDLED_CLI_VERSIONS` and merge.
2026-04-11 07:05:05 +00:00
Zamil Majdy
b319c26cab feat(platform/admin): per-model cost breakdown, cache token tracking, OrchestratorBlock cost fix (#12726)
## Why

The platform cost tracking system had several gaps that made the admin
dashboard less accurate and harder to reason about:

**Q: Do we have per-model granularity on the provider page?**
The `model` column was stored in `PlatformCostLog` but the SQL
aggregation grouped only by `(provider, tracking_type)`, so all models
for a given provider collapsed into one row. Now grouped by `(provider,
tracking_type, model)` — each model gets its own row.

**Q: Why does Anthropic show `per_run` for OrchestratorBlock?**
Bug: `OrchestratorBlock._call_llm()` was building `NodeExecutionStats`
with only `input_token_count` and `output_token_count` — it dropped
`resp.provider_cost` entirely. For OpenRouter calls this silently
discarded the `cost_usd`. For the SDK (autopilot) path,
`ResultMessage.total_cost_usd` was never read. When `provider_cost` is
None and token counts are 0 (e.g. SDK error path), `resolve_tracking`
falls through to `per_run`. Fixed by propagating all cost/cache fields.

**Q: Why can't we get `cost_usd` for Anthropic direct API calls?**
The Anthropic Messages API does not return a dollar amount — only token
counts. OpenRouter returns cost via response headers, so it uses
`cost_usd` directly. The Claude Agent SDK *does* compute
`total_cost_usd` internally, so SDK-mode OrchestratorBlock runs now get
`cost_usd` tracking. For direct Anthropic LLM blocks the estimate uses
per-token rates (see cache section below).

**Q: What about labeling by source (autopilot vs block)?**
Already tracked: `block_name` stores `copilot:SDK`, `copilot:Baseline`,
or the actual block name. Visible in the raw logs table. Not added to
the provider group-by (would explode row count); use the logs table
filter instead.

**Q: Is there double-counting between `tokens`, `per_run`, and
`cost_usd`?**
No. `resolve_tracking()` uses a strict preference hierarchy — exactly
one tracking type per execution: `cost_usd` > `tokens` > provider
heuristics > `per_run`. A single execution produces exactly one
`PlatformCostLog` row.

**Q: Should we track Anthropic prompt cache tokens (PR #12725)?**
Yes — PR #12725 adds `cache_control` markers to Anthropic API calls,
which causes the API to return `cache_read_input_tokens` and
`cache_creation_input_tokens` alongside regular `input_tokens`. These
have different billing rates:
- Cache reads: **10%** of base input rate (much cheaper)
- Cache writes: **125%** of base input rate (slightly more expensive,
one-time)
- Uncached input: **100%** of base rate

Without tracking them separately, a flat-rate estimate on
`total_input_tokens` would be wrong in both directions.

## What

- **Per-model provider table**: SQL now groups by `(provider,
tracking_type, model)`. `ProviderCostSummary` and the frontend
`ProviderTable` show a model column.
- **Cache token columns**: New `cacheReadTokens` and
`cacheCreationTokens` columns in `PlatformCostLog` with matching
migration.
- **LLM block cache tracking**: `LLMResponse` captures
`cache_read_input_tokens` / `cache_creation_input_tokens` from Anthropic
responses. `NodeExecutionStats` gains `cache_read_token_count` /
`cache_creation_token_count`. Both propagate to `PlatformCostEntry` and
the DB.
- **Copilot path**: `token_tracking.persist_and_record_usage` now writes
cache tokens as dedicated `PlatformCostEntry` fields (was
metadata-only).
- **OrchestratorBlock bug fix**: `_call_llm()` now includes
`resp.provider_cost`, `resp.cache_read_tokens`,
`resp.cache_creation_tokens` in the stats merge. SDK path captures
`ResultMessage.total_cost_usd` as `provider_cost`.
- **Accurate cost estimation**: `estimateCostForRow` uses
token-type-specific rates for `tokens` rows (uncached=100%, reads=10%,
writes=125% of configured base rate).

## How

`resolve_tracking` priority is unchanged. For Anthropic LLM blocks the
tracking type remains `tokens` (Anthropic API returns no dollar amount).
For OrchestratorBlock in SDK/autopilot mode it now correctly uses
`cost_usd` because the Claude Agent SDK computes and returns
`total_cost_usd`. For OpenRouter through OrchestratorBlock it now
correctly uses `cost_usd` (was silently dropped before).

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `ProviderCostSummary` SQL updated
- [x] Cache token fields present in `PlatformCostEntry` and
`PlatformCostLogCreateInput`
  - [x] Prisma client regenerated — all type checks pass
  - [x] Frontend `helpers.test.ts` updated for new `rateKey` format
  - [x] Pre-commit hooks pass (Black, Ruff, isort, tsc, Prisma generate)
2026-04-10 23:14:43 +07:00
37 changed files with 1446 additions and 1783 deletions

View File

@@ -0,0 +1,100 @@
-- =============================================================
-- View: analytics.platform_cost_log
-- Looker source alias: ds115 | Charts: 0
-- =============================================================
-- DESCRIPTION
-- One row per platform cost log entry (last 90 days).
-- Tracks real API spend at the call level: provider, model,
-- token counts (including Anthropic cache tokens), cost in
-- microdollars, and the block/execution that incurred the cost.
-- Joins the User table to provide email for per-user breakdowns.
--
-- SOURCE TABLES
-- platform.PlatformCostLog — Per-call cost records
-- platform.User — User email
--
-- OUTPUT COLUMNS
-- id TEXT Log entry UUID
-- createdAt TIMESTAMPTZ When the cost was recorded
-- userId TEXT User who incurred the cost (nullable)
-- email TEXT User email (nullable)
-- graphExecId TEXT Graph execution UUID (nullable)
-- nodeExecId TEXT Node execution UUID (nullable)
-- blockName TEXT Block that made the API call (nullable)
-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic')
-- model TEXT Model name (nullable)
-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc.
-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD)
-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000)
-- inputTokens INT Prompt/input tokens (nullable)
-- outputTokens INT Completion/output tokens (nullable)
-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable)
-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable)
-- totalTokens INT inputTokens + outputTokens (nullable if either is null)
-- duration FLOAT API call duration in seconds (nullable)
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Total spend by provider (last 90 days)
-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls
-- FROM analytics.platform_cost_log
-- GROUP BY 1 ORDER BY total_usd DESC;
--
-- -- Spend by model
-- SELECT provider, model, SUM("costUsd") AS total_usd,
-- SUM("inputTokens") AS input_tokens,
-- SUM("outputTokens") AS output_tokens
-- FROM analytics.platform_cost_log
-- WHERE model IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_usd DESC;
--
-- -- Top 20 users by spend
-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls
-- FROM analytics.platform_cost_log
-- WHERE "userId" IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20;
--
-- -- Daily spend trend
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
-- SUM("costUsd") AS daily_usd,
-- COUNT(*) AS calls
-- FROM analytics.platform_cost_log
-- GROUP BY 1 ORDER BY 1;
--
-- -- Cache hit rate for Anthropic (cache reads vs total reads)
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
-- SUM("cacheReadTokens")::float /
-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate
-- FROM analytics.platform_cost_log
-- WHERE provider = 'anthropic'
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
SELECT
p."id" AS id,
p."createdAt" AS createdAt,
p."userId" AS userId,
u."email" AS email,
p."graphExecId" AS graphExecId,
p."nodeExecId" AS nodeExecId,
p."blockName" AS blockName,
p."provider" AS provider,
p."model" AS model,
p."trackingType" AS trackingType,
p."costMicrodollars" AS costMicrodollars,
p."costMicrodollars"::float / 1000000.0 AS costUsd,
p."inputTokens" AS inputTokens,
p."outputTokens" AS outputTokens,
p."cacheReadTokens" AS cacheReadTokens,
p."cacheCreationTokens" AS cacheCreationTokens,
CASE
WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL
THEN p."inputTokens" + p."outputTokens"
ELSE NULL
END AS totalTokens,
p."duration" AS duration
FROM platform."PlatformCostLog" p
LEFT JOIN platform."User" u ON u."id" = p."userId"
WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -4,7 +4,7 @@ import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated, Any, cast
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
@@ -29,12 +29,6 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
push_pending_message,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
@@ -90,27 +84,6 @@ _UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
# Call-frequency cap for the pending-message endpoint. The token-budget
# check in queue_pending_message guards against overspend, but does not
# prevent rapid-fire pushes from a client with a large budget. This cap
# (per user, per 60-second window) limits the rate a caller can hammer the
# endpoint independently of token consumption.
_PENDING_CALL_LIMIT = 30 # pushes per minute per user
_PENDING_CALL_WINDOW_SECONDS = 60
_PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
# Lua script for atomic INCR + conditional EXPIRE.
# Using a single EVAL ensures the counter never persists without a TTL —
# a bare INCR followed by a separate EXPIRE can leave the key without
# an expiry if the process crashes between the two commands.
_CALL_INCR_LUA = """
local count = redis.call('INCR', KEYS[1])
if count == 1 then
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[1]))
end
return count
"""
async def _validate_and_get_session(
session_id: str,
@@ -123,29 +96,6 @@ async def _validate_and_get_session(
return session
async def _resolve_workspace_files(
user_id: str,
file_ids: list[str],
) -> list[UserWorkspaceFile]:
"""Filter *file_ids* to UUID-valid entries that exist in the caller's workspace.
Returns the matching ``UserWorkspaceFile`` records (empty list if none pass).
Used by both the stream and pending-message endpoints to prevent callers from
referencing other users' files.
"""
valid_ids = [fid for fid in file_ids if _UUID_RE.match(fid)]
if not valid_ids:
return []
workspace = await get_or_create_workspace(user_id)
return await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
router = APIRouter(
tags=["chat"],
)
@@ -169,64 +119,6 @@ class StreamChatRequest(BaseModel):
)
class QueuePendingMessageRequest(BaseModel):
"""Request model for queueing a message into an in-flight turn.
Unlike ``StreamChatRequest`` this endpoint does **not** start a new
turn — the message is appended to a per-session pending buffer that
the executor currently processing the turn will drain between tool
rounds.
"""
model_config = ConfigDict(extra="forbid")
message: str = Field(min_length=1, max_length=16_000)
context: dict[str, str] | None = Field(
default=None,
description="Optional page context: expected keys are 'url' and 'content'.",
)
file_ids: list[str] | None = Field(default=None, max_length=20)
@field_validator("context")
@classmethod
def _validate_context_length(
cls, v: dict[str, str] | None
) -> dict[str, str] | None:
if v is None:
return v
# Cap context values to prevent LLM context-window stuffing via
# large page payloads (url: 2 KB, content: 32 KB).
_URL_LIMIT = 2_000
_CONTENT_LIMIT = 32_000
url = v.get("url", "")
if len(url) > _URL_LIMIT:
raise ValueError(
f"context.url exceeds maximum length of {_URL_LIMIT} characters"
)
content = v.get("content", "")
if len(content) > _CONTENT_LIMIT:
raise ValueError(
f"context.content exceeds maximum length of {_CONTENT_LIMIT} characters"
)
return v
class QueuePendingMessageResponse(BaseModel):
"""Response for the pending-message endpoint.
- ``buffer_length``: how many messages are now in the session's
pending buffer (after this push)
- ``max_buffer_length``: the per-session cap (server-side constant)
- ``turn_in_flight``: ``True`` if a copilot turn was running when
we checked — purely informational for UX feedback. Even when
``False`` the message is still queued: the next turn drains it.
"""
buffer_length: int
max_buffer_length: int
turn_in_flight: bool
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
@@ -894,21 +786,33 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids:
files = await _resolve_workspace_files(user_id, request.file_ids)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
request.message += files_block
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
@@ -1108,129 +1012,6 @@ async def stream_chat_post(
)
@router.post(
"/sessions/{session_id}/messages/pending",
response_model=QueuePendingMessageResponse,
status_code=202,
)
async def queue_pending_message(
session_id: str,
request: QueuePendingMessageRequest,
user_id: str = Security(auth.get_user_id),
):
"""Queue a new user message into an in-flight copilot turn.
When a user sends a follow-up message while a turn is still
streaming, we don't want to block them or start a separate turn —
this endpoint appends the message to a per-session pending buffer.
The executor currently running the turn (baseline path) drains the
buffer between tool-call rounds and appends the message to the
conversation before the next LLM call. On the SDK path the buffer
is drained at the *start* of the next turn (the long-lived
``ClaudeSDKClient.receive_response`` iterator returns after a
``ResultMessage`` so there is no safe point to inject mid-stream
into an existing connection).
Returns 202. Enforces the same per-user daily/weekly token rate
limit as the regular ``/stream`` endpoint so a client can't bypass
it by batching messages through here.
"""
await _validate_and_get_session(session_id, user_id)
# Pre-turn rate-limit check — mirrors stream_chat_post. Without
# this, a client could bypass per-turn token limits by batching
# their extra context through this endpoint while a cheap stream
# is in flight.
# user_id is guaranteed non-empty by Security(auth.get_user_id) — no guard needed.
try:
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Call-frequency cap: prevent rapid-fire pushes that would bypass the
# token-budget check (which only fires per-turn, not per-push).
# Uses an atomic Lua EVAL (INCR + EXPIRE) so the key can never be
# orphaned without a TTL; fails open if Redis is down.
try:
_redis = await get_redis_async()
_call_key = f"{_PENDING_CALL_KEY_PREFIX}{user_id}"
_call_count = int(
await cast(
"Any",
_redis.eval(
_CALL_INCR_LUA,
1,
_call_key,
str(_PENDING_CALL_WINDOW_SECONDS),
),
)
)
if _call_count > _PENDING_CALL_LIMIT:
raise HTTPException(
status_code=429,
detail=f"Too many pending messages: limit is {_PENDING_CALL_LIMIT} per {_PENDING_CALL_WINDOW_SECONDS}s",
)
except HTTPException:
raise
except Exception:
pass # Redis failure is non-fatal; fail open
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
# Sanitise file IDs to the user's own workspace so injection doesn't
# surface other users' files. _resolve_workspace_files handles UUID
# filtering and the workspace-scoped DB lookup.
sanitized_file_ids: list[str] = []
if request.file_ids:
valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.match(fid))
files = await _resolve_workspace_files(user_id, request.file_ids)
sanitized_file_ids = [wf.id for wf in files]
if len(sanitized_file_ids) != valid_id_count:
logger.warning(
"queue_pending_message: dropped %d file id(s) not in "
"caller's workspace (session=%s)",
valid_id_count - len(sanitized_file_ids),
session_id,
)
# Redis is the single source of truth for pending messages. We do
# NOT persist to ``session.messages`` here — the drain-at-start
# path in the baseline/SDK executor is the sole writer for pending
# content. Persisting both here AND in the drain would cause
# double injection (executor sees the message in ``session.messages``
# *and* drains it from Redis) unless we also dedupe. The dedup in
# ``maybe_append_user_message`` only checks trailing same-role
# repeats, so relying on it is fragile. Keeping the endpoint
# Redis-only avoids the whole consistency-bug class.
pending = PendingMessage(
content=request.message,
file_ids=sanitized_file_ids,
context=PendingMessageContext(**request.context) if request.context else None,
)
buffer_length = await push_pending_message(session_id, pending)
# Check whether a turn is currently running for UX feedback.
active_session = await stream_registry.get_session(session_id)
turn_in_flight = bool(active_session and active_session.status == "running")
return QueuePendingMessageResponse(
buffer_length=buffer_length,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=turn_in_flight,
)
@router.get(
"/sessions/{session_id}/stream",
)

View File

@@ -579,300 +579,3 @@ class TestStreamChatRequestModeValidation:
req = StreamChatRequest(message="hi")
assert req.mode is None
# ─── QueuePendingMessageRequest validation ────────────────────────────
class TestQueuePendingMessageRequest:
"""Unit tests for QueuePendingMessageRequest field validation."""
def test_accepts_valid_message(self) -> None:
from backend.api.features.chat.routes import QueuePendingMessageRequest
req = QueuePendingMessageRequest(message="hello")
assert req.message == "hello"
def test_rejects_empty_message(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError):
QueuePendingMessageRequest(message="")
def test_rejects_message_over_limit(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError):
QueuePendingMessageRequest(message="x" * 16_001)
def test_accepts_valid_context(self) -> None:
from backend.api.features.chat.routes import QueuePendingMessageRequest
req = QueuePendingMessageRequest(
message="hi",
context={"url": "https://example.com", "content": "page text"},
)
assert req.context is not None
assert req.context["url"] == "https://example.com"
def test_rejects_context_url_over_limit(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError, match="url"):
QueuePendingMessageRequest(
message="hi",
context={"url": "https://example.com/" + "x" * 2_000},
)
def test_rejects_context_content_over_limit(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError, match="content"):
QueuePendingMessageRequest(
message="hi",
context={"content": "x" * 32_001},
)
def test_rejects_extra_fields(self) -> None:
"""extra='forbid' should reject unknown fields."""
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError):
QueuePendingMessageRequest(message="hi", unknown_field="bad") # type: ignore[call-arg]
def test_accepts_up_to_20_file_ids(self) -> None:
from backend.api.features.chat.routes import QueuePendingMessageRequest
req = QueuePendingMessageRequest(
message="hi",
file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
)
assert req.file_ids is not None
assert len(req.file_ids) == 20
def test_rejects_more_than_20_file_ids(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError):
QueuePendingMessageRequest(
message="hi",
file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
)
# ─── queue_pending_message endpoint ──────────────────────────────────
def _mock_pending_internals(
mocker: pytest_mock.MockerFixture,
*,
session_exists: bool = True,
call_count: int = 1,
):
"""Mock all async dependencies for the pending-message endpoint."""
if session_exists:
mock_session = mocker.MagicMock()
mock_session.id = "sess-1"
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
new_callable=AsyncMock,
return_value=mock_session,
)
else:
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
side_effect=fastapi.HTTPException(
status_code=404, detail="Session not found."
),
)
mocker.patch(
"backend.api.features.chat.routes.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(0, 0, None),
)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
new_callable=AsyncMock,
return_value=None,
)
# Mock Redis for per-user call-frequency rate limit (atomic Lua EVAL)
mock_redis = mocker.MagicMock()
mock_redis.eval = mocker.AsyncMock(return_value=call_count)
mocker.patch(
"backend.api.features.chat.routes.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.push_pending_message",
new_callable=AsyncMock,
return_value=1,
)
mock_registry = mocker.MagicMock()
mock_registry.get_session = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
def test_queue_pending_message_returns_202(mocker: pytest_mock.MockerFixture) -> None:
"""Happy path: valid message returns 202 with buffer_length."""
_mock_pending_internals(mocker)
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": "follow-up"},
)
assert response.status_code == 202
data = response.json()
assert data["buffer_length"] == 1
assert data["turn_in_flight"] is False
def test_queue_pending_message_empty_body_returns_422() -> None:
"""Empty message must be rejected by Pydantic before hitting any route logic."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": ""},
)
assert response.status_code == 422
def test_queue_pending_message_missing_message_returns_422() -> None:
"""Missing 'message' field returns 422."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={},
)
assert response.status_code == 422
def test_queue_pending_message_session_not_found_returns_404(
mocker: pytest_mock.MockerFixture,
) -> None:
"""If the session doesn't exist or belong to the user, returns 404."""
_mock_pending_internals(mocker, session_exists=False)
response = client.post(
"/sessions/bad-sess/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 404
def test_queue_pending_message_rate_limited_returns_429(
mocker: pytest_mock.MockerFixture,
) -> None:
"""When rate limit is exceeded, endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_pending_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
)
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 429
def test_queue_pending_message_call_frequency_limit_returns_429(
mocker: pytest_mock.MockerFixture,
) -> None:
"""When per-user call frequency limit is exceeded, endpoint returns 429."""
from backend.api.features.chat.routes import _PENDING_CALL_LIMIT
_mock_pending_internals(mocker, call_count=_PENDING_CALL_LIMIT + 1)
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 429
assert "Too many pending messages" in response.json()["detail"]
def test_queue_pending_message_context_url_too_long_returns_422() -> None:
"""context.url over 2 KB is rejected."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={
"message": "hi",
"context": {"url": "https://example.com/" + "x" * 2_000},
},
)
assert response.status_code == 422
def test_queue_pending_message_context_content_too_long_returns_422() -> None:
"""context.content over 32 KB is rejected."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={
"message": "hi",
"context": {"content": "x" * 32_001},
},
)
assert response.status_code == 422
def test_queue_pending_message_too_many_file_ids_returns_422() -> None:
"""More than 20 file_ids should be rejected."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={
"message": "hi",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
},
)
assert response.status_code == 422
def test_queue_pending_message_file_ids_scoped_to_workspace(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs must be sanitized to the user's workspace before push."""
_mock_pending_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
new_callable=AsyncMock,
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/messages/pending",
json={"message": "hi", "file_ids": [fid, "not-a-uuid"]},
)
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["id"]["in"] == [fid]
assert call_kwargs["where"]["workspaceId"] == "ws-1"
assert call_kwargs["where"]["isDeleted"] is False

View File

@@ -887,6 +887,21 @@ async def llm_call(
provider = llm_model.metadata.provider
context_window = llm_model.context_window
# Transparent OpenRouter routing for Anthropic models: when an OpenRouter API key
# is configured, route direct-Anthropic models through OpenRouter instead. This
# gives us the x-total-cost header for free, so provider_cost is always populated
# without manual token-rate arithmetic.
or_key = settings.secrets.open_router_api_key
or_model_id: str | None = None
if provider == "anthropic" and or_key:
provider = "open_router"
credentials = APIKeyCredentials(
provider=ProviderName.OPEN_ROUTER,
title="OpenRouter (auto)",
api_key=SecretStr(or_key),
)
or_model_id = f"anthropic/{llm_model.value}"
if compress_prompt_to_fit:
result = await compress_context(
messages=prompt,
@@ -1134,7 +1149,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
model=or_model_id or llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore

View File

@@ -77,7 +77,11 @@ class TestLLMStatsTracking:
mock_response.usage = mock_usage
mock_response.stop_reason = "end_turn"
with patch("anthropic.AsyncAnthropic") as mock_anthropic:
with (
patch("anthropic.AsyncAnthropic") as mock_anthropic,
patch("backend.blocks.llm.settings") as mock_settings,
):
mock_settings.secrets.open_router_api_key = ""
mock_client = AsyncMock()
mock_anthropic.return_value = mock_client
mock_client.messages.create = AsyncMock(return_value=mock_response)
@@ -96,6 +100,56 @@ class TestLLMStatsTracking:
assert response.cache_creation_tokens == 50
assert response.response == "Test anthropic response"
@pytest.mark.asyncio
async def test_anthropic_routes_through_openrouter_when_key_present(self):
"""When open_router_api_key is set, Anthropic models route via OpenRouter."""
from pydantic import SecretStr
import backend.blocks.llm as llm
from backend.data.model import APIKeyCredentials
anthropic_creds = APIKeyCredentials(
id="test-anthropic-id",
provider="anthropic",
api_key=SecretStr("mock-anthropic-key"),
title="Mock Anthropic key",
)
mock_choice = MagicMock()
mock_choice.message.content = "routed response"
mock_choice.message.tool_calls = None
mock_usage = MagicMock()
mock_usage.prompt_tokens = 10
mock_usage.completion_tokens = 5
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = mock_usage
mock_create = AsyncMock(return_value=mock_response)
with (
patch("openai.AsyncOpenAI") as mock_openai,
patch("backend.blocks.llm.settings") as mock_settings,
):
mock_settings.secrets.open_router_api_key = "sk-or-test-key"
mock_client = MagicMock()
mock_openai.return_value = mock_client
mock_client.chat.completions.create = mock_create
await llm.llm_call(
credentials=anthropic_creds,
llm_model=llm.LlmModel.CLAUDE_3_HAIKU,
prompt=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
# Verify OpenAI client was used (not Anthropic SDK) and model was prefixed
mock_openai.assert_called_once()
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs["model"] == "anthropic/claude-3-haiku-20240307"
@pytest.mark.asyncio
async def test_ai_structured_response_block_tracks_stats(self):
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""

View File

@@ -36,10 +36,6 @@ from backend.copilot.model import (
maybe_append_user_message,
upsert_chat_session,
)
from backend.copilot.pending_messages import (
drain_pending_messages,
format_pending_as_user_message,
)
from backend.copilot.prompting import get_baseline_supplement, get_graphiti_supplement
from backend.copilot.response_model import (
StreamBaseResponse,
@@ -934,29 +930,6 @@ async def stream_chat_completion_baseline(
message_length=len(message or ""),
)
# Capture count *before* the pending drain so is_first_turn and the
# transcript staleness check are not skewed by queued messages.
_pre_drain_msg_count = len(session.messages)
# Drain any messages the user queued via POST /messages/pending
# while this session was idle (or during a previous turn whose
# mid-loop drains missed them). Atomic LPOP guarantees that a
# concurrent push lands *after* the drain and stays queued for the
# next turn instead of being lost.
drained_at_start = await drain_pending_messages(session_id)
if drained_at_start:
logger.info(
"[Baseline] Draining %d pending message(s) at turn start for session %s",
len(drained_at_start),
session_id,
)
for pm in drained_at_start:
content = format_pending_as_user_message(pm)["content"]
# Append directly — pending messages are atomically-popped from
# Redis and are never stale-cache duplicates, so the
# maybe_append_user_message dedup is wrong here.
session.messages.append(ChatMessage(role="user", content=content))
session = await upsert_chat_session(session)
# Select model based on the per-request mode. 'fast' downgrades to
@@ -986,9 +959,7 @@ async def stream_chat_completion_baseline(
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
# Use the pre-drain count so queued pending messages don't incorrectly
# flip is_first_turn to False on an actual first turn.
is_first_turn = _pre_drain_msg_count <= 1
is_first_turn = len(session.messages) <= 1
# Gate context fetch on both first turn AND user message so that assistant-
# role calls (e.g. tool-result submissions) on the first turn don't trigger
# a needless DB lookup for user understanding.
@@ -999,18 +970,14 @@ async def stream_chat_completion_baseline(
prompt_task = _build_cacheable_system_prompt(None)
# Run download + prompt build concurrently — both are independent I/O
# on the request critical path. Use the pre-drain count so pending
# messages drained at turn start don't spuriously trigger a transcript
# load on an actual first turn.
if user_id and _pre_drain_msg_count > 1:
# on the request critical path.
if user_id and len(session.messages) > 1:
transcript_covers_prefix, (base_system_prompt, understanding) = (
await asyncio.gather(
_load_prior_transcript(
user_id=user_id,
session_id=session_id,
# Use pre-drain count so pending messages don't falsely
# mark the stored transcript as stale and prevent upload.
session_msg_count=_pre_drain_msg_count,
session_msg_count=len(session.messages),
transcript_builder=transcript_builder,
),
prompt_task,
@@ -1022,16 +989,6 @@ async def stream_chat_completion_baseline(
# Append user message to transcript after context injection below so the
# transcript receives the prefixed message when user context is available.
# Mirror any messages drained at turn start (see above) into the
# transcript — otherwise the loaded prior transcript would be
# missing them and a mid-turn upload could leave a malformed
# assistant-after-assistant structure on the next turn.
if drained_at_start:
for pm in drained_at_start:
transcript_builder.append_user(
content=format_pending_as_user_message(pm)["content"]
)
# Generate title for new sessions
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
@@ -1052,10 +1009,8 @@ async def stream_chat_completion_baseline(
graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else ""
system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement
# Warm context: pre-load relevant facts from Graphiti on first turn.
# Use the pre-drain count so pending messages drained at turn start
# don't prevent warm context injection on an actual first turn.
if graphiti_enabled and user_id and _pre_drain_msg_count <= 1:
# Warm context: pre-load relevant facts from Graphiti on first turn
if graphiti_enabled and user_id and len(session.messages) <= 1:
from backend.copilot.graphiti.context import fetch_warm_context
warm_ctx = await fetch_warm_context(user_id, message or "")
@@ -1248,64 +1203,6 @@ async def stream_chat_completion_baseline(
yield evt
state.pending_events.clear()
# Inject any messages the user queued while the turn was
# running. ``tool_call_loop`` mutates ``openai_messages``
# in-place, so appending here means the model sees the new
# messages on its next LLM call.
#
# IMPORTANT: skip when the loop has already finished (no
# more LLM calls are coming). ``tool_call_loop`` yields
# a final ``ToolCallLoopResult`` on both paths:
# - natural finish: ``finished_naturally=True``
# - hit max_iterations: ``finished_naturally=False``
# and ``iterations >= max_iterations``
# In either case the loop is about to return on the next
# ``async for`` step, so draining here would silently
# lose the message (the user sees 202 but the model never
# reads the text). Those messages stay in the buffer and
# get picked up at the start of the next turn.
if loop_result is None:
continue
is_final_yield = (
loop_result.finished_naturally
or loop_result.iterations >= _MAX_TOOL_ROUNDS
)
if is_final_yield:
continue
pending = await drain_pending_messages(session_id)
if pending:
for pm in pending:
# ``format_pending_as_user_message`` embeds file
# attachments and context URL/page content into the
# content string so the in-session transcript is
# a faithful copy of what the model actually saw.
formatted = format_pending_as_user_message(pm)
content_for_db = formatted["content"]
# Append directly — pending messages are atomically-popped
# from Redis and are never stale-cache duplicates, so the
# maybe_append_user_message dedup is wrong here and would
# cause openai_messages/transcript to diverge from session.
session.messages.append(
ChatMessage(role="user", content=content_for_db)
)
openai_messages.append(formatted)
transcript_builder.append_user(content=content_for_db)
try:
await upsert_chat_session(session)
except Exception as persist_err:
logger.warning(
"[Baseline] Failed to persist pending messages for "
"session %s: %s",
session_id,
persist_err,
)
logger.info(
"[Baseline] Injected %d pending message(s) into "
"session %s mid-turn",
len(pending),
session_id,
)
if loop_result and not loop_result.finished_naturally:
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
@@ -1346,11 +1243,6 @@ async def stream_chat_completion_baseline(
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Pending messages are drained atomically at turn start and
# between tool rounds, so there's nothing to clear in finally.
# Any message pushed after the final drain window stays in the
# buffer and gets picked up at the start of the next turn.
# Set cost attributes on OTEL span before closing
if _trace_ctx is not None:
try:

View File

@@ -172,6 +172,20 @@ class ChatConfig(BaseSettings):
description="Maximum number of retries for transient API errors "
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
)
claude_agent_cli_path: str | None = Field(
default=None,
description="Optional explicit path to a Claude Code CLI binary. "
"When set, the SDK uses this binary instead of the version bundled "
"with the installed `claude-agent-sdk` package — letting us pin "
"the Python SDK and the CLI independently. Critical for keeping "
"OpenRouter compatibility while still picking up newer SDK API "
"features (the bundled CLI version in 0.1.46+ is broken against "
"OpenRouter — see PR #12294 and "
"anthropics/claude-agent-sdk-python#789). Falls back to the "
"bundled binary when unset. Reads from `CHAT_CLAUDE_AGENT_CLI_PATH` "
"or the unprefixed `CLAUDE_AGENT_CLI_PATH` environment variable "
"(same pattern as `api_key` / `base_url`).",
)
use_openrouter: bool = Field(
default=True,
description="Enable routing API calls through the OpenRouter proxy. "
@@ -294,6 +308,26 @@ class ChatConfig(BaseSettings):
v = OPENROUTER_BASE_URL
return v
@field_validator("claude_agent_cli_path", mode="before")
@classmethod
def get_claude_agent_cli_path(cls, v):
"""Resolve the Claude Code CLI override path from environment.
Accepts either the Pydantic-prefixed ``CHAT_CLAUDE_AGENT_CLI_PATH``
or the unprefixed ``CLAUDE_AGENT_CLI_PATH`` (matching the same
fallback pattern used by ``api_key`` / ``base_url``). Keeping the
unprefixed form working is important because the field is
primarily an operator escape hatch set via container/host env,
and the unprefixed name is what the PR description, the field
docstrings, and the reproduction test in
``cli_openrouter_compat_test.py`` refer to.
"""
if not v:
v = os.getenv("CHAT_CLAUDE_AGENT_CLI_PATH")
if not v:
v = os.getenv("CLAUDE_AGENT_CLI_PATH")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -174,13 +174,25 @@ class CoPilotProcessor:
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
def _prewarm_cli(self) -> None:
"""Run the bundled CLI binary once to warm OS page caches."""
try:
from claude_agent_sdk._internal.transport.subprocess_cli import (
SubprocessCLITransport,
)
"""Run the Claude Code CLI binary once to warm OS page caches.
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
Honours the ``claude_agent_cli_path`` config override (which lets
us run a pinned CLI version independent of the bundled one in the
installed ``claude-agent-sdk`` wheel — see
``ChatConfig.claude_agent_cli_path`` for the rationale). Falls
back to the bundled binary when no override is set.
"""
try:
from backend.copilot.config import ChatConfig
cfg = ChatConfig()
cli_path: str | None = cfg.claude_agent_cli_path
if not cli_path:
from claude_agent_sdk._internal.transport.subprocess_cli import (
SubprocessCLITransport,
)
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
if cli_path:
result = subprocess.run(
[cli_path, "-v"],

View File

@@ -1,222 +0,0 @@
"""Pending-message buffer for in-flight copilot turns.
When a user sends a new message while a copilot turn is already executing,
instead of blocking the frontend (or queueing a brand-new turn after the
current one finishes), we want the new message to be *injected into the
running turn* — appended between tool-call rounds so the model sees it
before its next LLM call.
This module provides the cross-process buffer that makes that possible:
- **Producer** (chat API route): pushes a pending message to Redis and
publishes a notification on a pub/sub channel.
- **Consumer** (executor running the turn): on each tool-call round,
drains the buffer and appends the pending messages to the conversation.
The Redis list is the durable store; the pub/sub channel is a fast
wake-up hint for long-idle consumers (not used by default, but available
for future blocking-wait semantics).
A hard cap of ``MAX_PENDING_MESSAGES`` per session prevents abuse. The
buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push.
"""
import json
import logging
from typing import Any, cast
from pydantic import BaseModel, Field, ValidationError
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Per-session cap. Higher values risk a runaway consumer; lower values
# risk dropping user input under heavy typing. 10 was chosen as a
# reasonable ceiling — a user typing faster than the copilot can drain
# between tool rounds is already an unusual usage pattern.
MAX_PENDING_MESSAGES = 10
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
# executor dies, the pending messages should either have been drained
# already or are safe to drop (the user can resend).
_PENDING_KEY_PREFIX = "copilot:pending:"
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
# Payload sent on the pub/sub notify channel. Subscribers treat any
# message as a wake-up hint; the value itself is not meaningful.
_NOTIFY_PAYLOAD = "1"
class PendingMessageContext(BaseModel):
"""Structured page context attached to a pending message."""
url: str | None = None
content: str | None = None
class PendingMessage(BaseModel):
"""A user message queued for injection into an in-flight turn."""
content: str = Field(min_length=1, max_length=16_000)
file_ids: list[str] = Field(default_factory=list)
context: PendingMessageContext | None = None
def _buffer_key(session_id: str) -> str:
return f"{_PENDING_KEY_PREFIX}{session_id}"
def _notify_channel(session_id: str) -> str:
return f"{_PENDING_CHANNEL_PREFIX}{session_id}"
# Lua script: push-then-trim-then-expire-then-length, atomically.
# Redis serializes EVAL commands, so a concurrent ``LPOP`` drain
# observes either the pre-push or post-push state of the list — never
# a partial state where the RPUSH has landed but LTRIM hasn't run.
_PUSH_LUA = """
redis.call('RPUSH', KEYS[1], ARGV[1])
redis.call('LTRIM', KEYS[1], -tonumber(ARGV[2]), -1)
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[3]))
return redis.call('LLEN', KEYS[1])
"""
async def push_pending_message(
session_id: str,
message: PendingMessage,
) -> int:
"""Append a pending message to the session's buffer atomically.
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
trimming from the left (oldest) — the newest message always wins if
the user has been typing faster than the copilot can drain.
The push + trim + expire + llen are wrapped in a single Lua EVAL so
concurrent LPOP drains from the executor never observe a partial
state.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
payload = message.model_dump_json()
new_length = int(
await cast(
"Any",
redis.eval(
_PUSH_LUA,
1,
key,
payload,
str(MAX_PENDING_MESSAGES),
str(_PENDING_TTL_SECONDS),
),
)
)
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
# the buffer itself is authoritative so a lost notify is harmless.
try:
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
except Exception as e: # pragma: no cover
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
logger.info(
"pending_messages: pushed message to session=%s (buffer_len=%d)",
session_id,
new_length,
)
return new_length
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
"""Atomically pop all pending messages for *session_id*.
Returns them in enqueue order (oldest first). Uses ``LPOP`` with a
count so the read+delete is a single Redis round trip. If the list
is empty or missing, returns ``[]``.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
# empty list if we somehow race an empty key, or the popped items.
# redis-py's async lpop overload with a count collapses the return
# type in pyright; cast the awaitable so strict type-check stays
# clean without changing runtime behaviour.
lpop_result = await cast(
"Any",
redis.lpop(key, MAX_PENDING_MESSAGES),
)
if not lpop_result:
return []
raw_popped: list[Any] = list(lpop_result)
# redis-py may return bytes or str depending on decode_responses.
decoded: list[str] = [
item.decode("utf-8") if isinstance(item, bytes) else str(item)
for item in raw_popped
]
messages: list[PendingMessage] = []
for payload in decoded:
try:
messages.append(PendingMessage(**json.loads(payload)))
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed entry for %s: %s",
session_id,
e,
)
if messages:
logger.info(
"pending_messages: drained %d messages for session=%s",
len(messages),
session_id,
)
return messages
async def peek_pending_count(session_id: str) -> int:
"""Return the current buffer length without consuming it."""
redis = await get_redis_async()
length = await cast("Any", redis.llen(_buffer_key(session_id)))
return int(length)
async def clear_pending_messages(session_id: str) -> None:
"""Drop the session's pending buffer.
Not called by the normal turn flow — the atomic ``LPOP`` drain at
turn start is the primary consumer, and any push that arrives
after the drain window belongs to the next turn by definition.
Retained as an operator/debug escape hatch for manually clearing a
stuck session and as a fixture in the unit tests.
"""
redis = await get_redis_async()
await redis.delete(_buffer_key(session_id))
def format_pending_as_user_message(message: PendingMessage) -> dict[str, Any]:
"""Shape a ``PendingMessage`` into the OpenAI-format user message dict.
Used by the baseline tool-call loop when injecting the buffered
message into the conversation. Context/file metadata (if any) is
embedded into the content so the model sees everything in one block.
"""
parts: list[str] = [message.content]
if message.context:
if message.context.url:
parts.append(f"\n\n[Page URL: {message.context.url}]")
if message.context.content:
parts.append(f"\n\n[Page content]\n{message.context.content}")
if message.file_ids:
parts.append(
"\n\n[Attached files]\n"
+ "\n".join(f"- file_id={fid}" for fid in message.file_ids)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
return {"role": "user", "content": "".join(parts)}

View File

@@ -1,246 +0,0 @@
"""Tests for the copilot pending-messages buffer.
Uses a fake async Redis client so the tests don't require a real Redis
instance (the backend test suite's DB/Redis fixtures are heavyweight
and pull in the full app startup).
"""
import json
from typing import Any
import pytest
from backend.copilot import pending_messages as pm_module
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
clear_pending_messages,
drain_pending_messages,
format_pending_as_user_message,
peek_pending_count,
push_pending_message,
)
# ── Fake Redis ──────────────────────────────────────────────────────
class _FakeRedis:
def __init__(self) -> None:
# Values are ``str | bytes`` because real redis-py returns
# bytes when ``decode_responses=False``; the drain path must
# handle both and our tests exercise both.
self.lists: dict[str, list[str | bytes]] = {}
self.published: list[tuple[str, str]] = []
async def eval(self, script: str, num_keys: int, *args: Any) -> Any:
"""Emulate the push Lua script.
The real Lua script runs atomically in Redis; the fake
implementation just runs the equivalent list operations in
order and returns the final LLEN. That's enough to exercise
the cap + ordering invariants the tests care about.
"""
key = args[0]
payload = args[1]
max_len = int(args[2])
# ARGV[3] is TTL — fake doesn't enforce expiry
lst = self.lists.setdefault(key, [])
lst.append(payload)
if len(lst) > max_len:
# RPUSH + LTRIM(-N, -1) = keep only last N
self.lists[key] = lst[-max_len:]
return len(self.lists[key])
async def publish(self, channel: str, payload: str) -> int:
self.published.append((channel, payload))
return 1
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
lst = self.lists.get(key)
if not lst:
return None
popped = lst[:count]
self.lists[key] = lst[count:]
return popped
async def llen(self, key: str) -> int:
return len(self.lists.get(key, []))
async def delete(self, key: str) -> int:
if key in self.lists:
del self.lists[key]
return 1
return 0
@pytest.fixture()
def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis:
redis = _FakeRedis()
async def _get_redis_async() -> _FakeRedis:
return redis
monkeypatch.setattr(pm_module, "get_redis_async", _get_redis_async)
return redis
# ── Basic push / drain ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_and_drain_single_message(fake_redis: _FakeRedis) -> None:
length = await push_pending_message("sess1", PendingMessage(content="hello"))
assert length == 1
assert await peek_pending_count("sess1") == 1
drained = await drain_pending_messages("sess1")
assert len(drained) == 1
assert drained[0].content == "hello"
assert await peek_pending_count("sess1") == 0
@pytest.mark.asyncio
async def test_push_and_drain_preserves_order(fake_redis: _FakeRedis) -> None:
for i in range(3):
await push_pending_message("sess2", PendingMessage(content=f"msg {i}"))
drained = await drain_pending_messages("sess2")
assert [m.content for m in drained] == ["msg 0", "msg 1", "msg 2"]
@pytest.mark.asyncio
async def test_drain_empty_returns_empty_list(fake_redis: _FakeRedis) -> None:
assert await drain_pending_messages("nope") == []
# ── Buffer cap ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
# Push MAX_PENDING_MESSAGES + 3 messages
for i in range(MAX_PENDING_MESSAGES + 3):
await push_pending_message("sess3", PendingMessage(content=f"m{i}"))
# Buffer should be clamped to MAX
assert await peek_pending_count("sess3") == MAX_PENDING_MESSAGES
drained = await drain_pending_messages("sess3")
assert len(drained) == MAX_PENDING_MESSAGES
# Oldest 3 dropped — we should only see m3..m(MAX+2)
assert drained[0].content == "m3"
assert drained[-1].content == f"m{MAX_PENDING_MESSAGES + 2}"
# ── Clear ───────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess4", PendingMessage(content="x"))
await push_pending_message("sess4", PendingMessage(content="y"))
await clear_pending_messages("sess4")
assert await peek_pending_count("sess4") == 0
@pytest.mark.asyncio
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
# Clearing an already-empty buffer should not raise
await clear_pending_messages("sess_empty")
await clear_pending_messages("sess_empty")
# ── Publish hook ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_publishes_notification(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess5", PendingMessage(content="hi"))
assert ("copilot:pending:notify:sess5", "1") in fake_redis.published
# ── Format helper ───────────────────────────────────────────────────
def test_format_pending_plain_text() -> None:
msg = PendingMessage(content="just text")
out = format_pending_as_user_message(msg)
assert out == {"role": "user", "content": "just text"}
def test_format_pending_with_context_url() -> None:
msg = PendingMessage(
content="see this page",
context=PendingMessageContext(url="https://example.com"),
)
out = format_pending_as_user_message(msg)
content = out["content"]
assert out["role"] == "user"
assert "see this page" in content
# The URL should appear verbatim in the [Page URL: ...] block.
assert "[Page URL: https://example.com]" in content
def test_format_pending_with_file_ids() -> None:
msg = PendingMessage(content="look here", file_ids=["a", "b"])
out = format_pending_as_user_message(msg)
assert "file_id=a" in out["content"]
assert "file_id=b" in out["content"]
def test_format_pending_with_all_fields() -> None:
"""All fields (content + context url/content + file_ids) should all appear."""
msg = PendingMessage(
content="summarise this",
context=PendingMessageContext(
url="https://example.com/page",
content="headline text",
),
file_ids=["f1", "f2"],
)
out = format_pending_as_user_message(msg)
body = out["content"]
assert out["role"] == "user"
assert "summarise this" in body
assert "[Page URL: https://example.com/page]" in body
assert "[Page content]\nheadline text" in body
assert "file_id=f1" in body
assert "file_id=f2" in body
# ── Malformed payload handling ──────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_skips_malformed_entries(
fake_redis: _FakeRedis,
) -> None:
# Seed the fake with a mix of valid and malformed payloads
fake_redis.lists["copilot:pending:bad"] = [
json.dumps({"content": "valid"}),
"{not valid json",
json.dumps({"content": "also valid", "file_ids": ["a"]}),
]
drained = await drain_pending_messages("bad")
assert len(drained) == 2
assert drained[0].content == "valid"
assert drained[1].content == "also valid"
@pytest.mark.asyncio
async def test_drain_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""Real redis-py returns ``bytes`` when ``decode_responses=False``.
Seed the fake with bytes values to exercise the ``decode("utf-8")``
branch in ``drain_pending_messages`` so a regression there doesn't
slip past CI.
"""
fake_redis.lists["copilot:pending:bytes_sess"] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
drained = await drain_pending_messages("bytes_sess")
assert len(drained) == 1
assert drained[0].content == "from bytes"

View File

@@ -0,0 +1,577 @@
"""Reproduction test for the OpenRouter incompatibility in newer
``claude-agent-sdk`` / Claude Code CLI versions.
Background — there are two stacked regressions that block us from
upgrading the ``claude-agent-sdk`` package above ``0.1.45``:
1. **`tool_reference` content blocks** introduced by CLI ``2.1.69`` (=
SDK ``0.1.46``). The CLI's built-in ``ToolSearch`` tool returns
``{"type": "tool_reference", "tool_name": "..."}`` content blocks in
``tool_result.content``. OpenRouter's stricter Zod validation
rejects this with::
messages[N].content[0].content: Invalid input: expected string, received array
This is the regression that originally pinned us at 0.1.45 — see
https://github.com/Significant-Gravitas/AutoGPT/pull/12294 for the
full forensic write-up. CLI 2.1.70 added proxy detection that
*should* disable the offending blocks when ``ANTHROPIC_BASE_URL`` is
set, but our subsequent attempts at 0.1.55 / 0.1.56 still failed.
2. **`context-management-2025-06-27` beta header** — some CLI version
after ``2.1.91`` started injecting this header / beta flag, which
OpenRouter rejects with::
400 No endpoints available that support Anthropic's context
management features (context-management-2025-06-27). Context
management requires a supported provider (Anthropic).
Tracked upstream at
https://github.com/anthropics/claude-agent-sdk-python/issues/789.
Still open at the time of writing, no upstream PR linked, no
workaround documented.
The purpose of this test:
* Spin up a tiny in-process HTTP server that pretends to be the
Anthropic Messages API.
* Capture every request body the CLI sends.
* Inspect the captured bodies for the two forbidden patterns above.
* Fail loudly if either is present, with a pointer to the issue
tracker.
This is the reproduction we use as a CI gate when bisecting which SDK /
CLI version is safe to upgrade to. It runs against the bundled CLI by
default (or against ``ChatConfig.claude_agent_cli_path`` when set), so
it doubles as a regression guard for the ``cli_path`` override
mechanism.
The test does **not** need an OpenRouter API key — it reproduces the
mechanism (forbidden content blocks / headers in the *outgoing*
request) rather than the symptom (the 400 OpenRouter would return).
This keeps it deterministic, free, and CI-runnable without secrets.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import subprocess
from pathlib import Path
from typing import Any
import pytest
from aiohttp import web
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Forbidden patterns we scan for in captured request bodies
# ---------------------------------------------------------------------------
# Match the `tool_reference` content block that breaks OpenRouter's stricter
# Zod validation in tool_result.content. PR #12294 root-cause.
#
# We use a whitespace-tolerant regex rather than a literal substring because
# the Claude Code CLI is Node.js and `JSON.stringify` without an indent
# argument emits no whitespace between the key, colon, and value
# (`{"type":"tool_reference"}`), while a Python serializer would emit
# `{"type": "tool_reference"}`. A naive substring with one specific spacing
# would silently miss the real regression.
_FORBIDDEN_TOOL_REFERENCE_RE = re.compile(r'"type"\s*:\s*"tool_reference"')
# Beta string OpenRouter rejects in upstream issue #789. Can appear in
# either `betas` arrays or the `anthropic-beta` header value. This is a
# unique opaque token (no JSON punctuation around it that could vary), so
# a plain substring match is robust.
_FORBIDDEN_CONTEXT_MANAGEMENT_BETA = "context-management-2025-06-27"
def _scan_request_for_forbidden_patterns(
body_text: str,
headers: dict[str, str],
) -> list[str]:
"""Return a list of forbidden patterns found in *body_text* / *headers*.
Empty list = clean request. Non-empty = the CLI is sending one of the
OpenRouter-incompatible features.
"""
findings: list[str] = []
if _FORBIDDEN_TOOL_REFERENCE_RE.search(body_text):
findings.append(
"`tool_reference` content block in request body — "
"PR #12294 / CLI 2.1.69 regression"
)
if _FORBIDDEN_CONTEXT_MANAGEMENT_BETA in body_text:
findings.append(
f"{_FORBIDDEN_CONTEXT_MANAGEMENT_BETA!r} in request body — "
"anthropics/claude-agent-sdk-python#789"
)
# Header values are case-insensitive in HTTP — aiohttp normalises
# incoming names but values are stored as-is.
for header_name, header_value in headers.items():
if header_name.lower() == "anthropic-beta":
if _FORBIDDEN_CONTEXT_MANAGEMENT_BETA in header_value:
findings.append(
f"{_FORBIDDEN_CONTEXT_MANAGEMENT_BETA!r} in "
"`anthropic-beta` header — issue #789"
)
return findings
# ---------------------------------------------------------------------------
# Fake Anthropic Messages API
# ---------------------------------------------------------------------------
#
# We need to give the CLI a *successful* response so it doesn't error out
# before we get a chance to inspect the request. The minimal thing the
# CLI accepts is a streamed (SSE) message-start → content-block-delta →
# message-stop sequence.
#
# We don't strictly *need* the CLI to accept the response — we already
# have the request body by the time we send any reply — but giving it a
# valid stream means the assertion failure (if any) is the *only*
# failure mode in the test, not "CLI exited 1 because we sent garbage".
def _build_streaming_message_response() -> str:
"""Return an SSE-formatted body containing a minimal Anthropic
Messages API streamed response.
This is the smallest stream that the Claude Code CLI will accept
end-to-end without errors. Each line is one SSE event."""
events: list[dict[str, Any]] = [
{
"type": "message_start",
"message": {
"id": "msg_test",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-test",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 1, "output_tokens": 1},
},
},
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": "ok"},
},
{"type": "content_block_stop", "index": 0},
{
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {"output_tokens": 1},
},
{"type": "message_stop"},
]
return "".join(
f"event: {evt['type']}\ndata: {json.dumps(evt)}\n\n" for evt in events
)
class _CapturedRequest:
"""One request the fake server received."""
def __init__(self, path: str, headers: dict[str, str], body: str) -> None:
self.path = path
self.headers = headers
self.body = body
async def _start_fake_anthropic_server(
captured: list[_CapturedRequest],
) -> tuple[web.AppRunner, int]:
"""Start an aiohttp server pretending to be the Anthropic API.
All POSTs to ``/v1/messages`` are recorded into *captured* and
answered with a valid streaming response. Returns ``(runner, port)``
so the caller can ``await runner.cleanup()`` when finished.
"""
import socket
async def messages_handler(request: web.Request) -> web.StreamResponse:
body = await request.text()
captured.append(
_CapturedRequest(
path=request.path,
headers={k: v for k, v in request.headers.items()},
body=body,
)
)
# Stream a minimal valid response so the CLI doesn't error out
# before we can inspect what it sent.
response = web.StreamResponse(
status=200,
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
await response.prepare(request)
await response.write(_build_streaming_message_response().encode("utf-8"))
await response.write_eof()
return response
async def fallback_handler(_request: web.Request) -> web.Response:
# OAuth/profile endpoints the CLI may probe — answer 404 so it
# falls through quickly without retrying.
return web.Response(status=404)
app = web.Application()
app.router.add_post("/v1/messages", messages_handler)
app.router.add_route("*", "/{tail:.*}", fallback_handler)
# Bind an ephemeral port ourselves so we can read it back via the
# public ``getsockname`` API rather than reaching into ``site._server``
# private aiohttp internals.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("127.0.0.1", 0))
port: int = sock.getsockname()[1]
runner = web.AppRunner(app)
await runner.setup()
site = web.SockSite(runner, sock)
await site.start()
return runner, port
# ---------------------------------------------------------------------------
# CLI invocation
# ---------------------------------------------------------------------------
def _resolve_cli_path() -> Path | None:
"""Return the Claude Code CLI binary the SDK would use.
Honours the same override mechanism as ``service.py`` /
``ChatConfig.claude_agent_cli_path``: checks either the Pydantic-
prefixed ``CHAT_CLAUDE_AGENT_CLI_PATH`` or the unprefixed
``CLAUDE_AGENT_CLI_PATH`` env var first, then falls back to the
bundled binary that ships with the installed ``claude-agent-sdk``
wheel. The two env var names are accepted at the config layer via
``ChatConfig.get_claude_agent_cli_path`` and mirrored here so the
reproduction test picks up the same override regardless of which
form an operator sets.
"""
override = os.environ.get("CHAT_CLAUDE_AGENT_CLI_PATH") or os.environ.get(
"CLAUDE_AGENT_CLI_PATH"
)
if override:
candidate = Path(override)
return candidate if candidate.is_file() else None
try:
from claude_agent_sdk._internal.transport.subprocess_cli import ( # type: ignore[import-untyped]
SubprocessCLITransport,
)
bundled = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
return Path(bundled) if bundled else None
except Exception as e: # pragma: no cover - import-time guard
logger.warning("Could not locate bundled Claude CLI: %s", e)
return None
async def _run_cli_against_fake_server(
cli_path: Path,
fake_server_port: int,
timeout_seconds: float,
) -> tuple[int, str, str]:
"""Spawn the CLI pointed at the fake Anthropic server and feed it a
single ``user`` message via stream-json on stdin.
Returns ``(returncode, stdout, stderr)``. The return code is not
asserted by the test — we only care that the CLI made at least one
POST to ``/v1/messages`` so the fake server captured the body.
"""
fake_url = f"http://127.0.0.1:{fake_server_port}"
env = {
# Inherit basic shell variables so the CLI can find its tools,
# but force network/auth at our fake endpoint.
**os.environ,
"ANTHROPIC_BASE_URL": fake_url,
"ANTHROPIC_API_KEY": "sk-test-fake-key-not-real",
# Disable any features that would phone home to a different host
# mid-test (telemetry, plugin marketplace fetch).
"DISABLE_TELEMETRY": "1",
"CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1",
}
# The CLI accepts stream-json input on stdin in `query` mode. A
# minimal user-message envelope is enough to trigger an API call.
stdin_payload = (
json.dumps(
{
"type": "user",
"message": {"role": "user", "content": "hello"},
}
)
+ "\n"
)
proc = await asyncio.create_subprocess_exec(
str(cli_path),
"--output-format",
"stream-json",
"--input-format",
"stream-json",
"--verbose",
"--print",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
try:
assert proc.stdin is not None
proc.stdin.write(stdin_payload.encode("utf-8"))
await proc.stdin.drain()
proc.stdin.close()
stdout_bytes, stderr_bytes = await asyncio.wait_for(
proc.communicate(), timeout=timeout_seconds
)
except (asyncio.TimeoutError, TimeoutError):
# Best-effort kill — we already have whatever requests the CLI
# managed to send before stalling.
try:
proc.kill()
except ProcessLookupError:
pass
# Reap the process to avoid leaving a zombie + open pipe FDs.
# Without this the asyncio transport keeps the stdout/stderr
# pipes alive until the loop exits, and in CI loops where this
# test runs many times the file-descriptor count creeps up.
try:
await asyncio.wait_for(proc.wait(), timeout=5.0)
except (asyncio.TimeoutError, TimeoutError):
pass
stdout_bytes, stderr_bytes = b"", b""
return (
proc.returncode if proc.returncode is not None else -1,
stdout_bytes.decode("utf-8", errors="replace"),
stderr_bytes.decode("utf-8", errors="replace"),
)
# ---------------------------------------------------------------------------
# The actual test
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cli_does_not_send_openrouter_incompatible_features(caplog):
"""End-to-end OpenRouter compatibility reproduction.
Spawns the bundled (or overridden) Claude Code CLI against a fake
Anthropic API server, captures every request body it sends, and
asserts that none of them contain the two known OpenRouter-breaking
features (`tool_reference` content blocks or the
`context-management-2025-06-27` beta header).
Why this matters: pinning the CLI version via
``test_bundled_cli_version_is_known_good_against_openrouter`` only
catches accidental SDK bumps — it doesn't tell us *why* the new
version would fail. This test reproduces the exact mechanism so
bisecting via CI commits gives an actionable signal.
"""
cli_path = _resolve_cli_path()
if cli_path is None or not cli_path.is_file():
pytest.skip(
"No Claude Code CLI binary available (neither bundled nor "
"overridden via CLAUDE_AGENT_CLI_PATH / "
"CHAT_CLAUDE_AGENT_CLI_PATH); cannot reproduce."
)
captured: list[_CapturedRequest] = []
runner, port = await _start_fake_anthropic_server(captured)
try:
returncode, stdout, stderr = await _run_cli_against_fake_server(
cli_path=cli_path,
fake_server_port=port,
timeout_seconds=30.0,
)
finally:
await runner.cleanup()
# We don't assert the CLI's exit code — depending on the CLI version
# and what we send back, the CLI may exit non-zero after a single
# successful round-trip. All we care about is that the captured
# request bodies don't contain the forbidden patterns.
logger.info(
"CLI exited rc=%d; captured %d requests; stdout=%d bytes; stderr=%d bytes",
returncode,
len(captured),
len(stdout),
len(stderr),
)
if not captured:
pytest.skip(
"Bundled CLI did not make any HTTP requests to the fake server "
f"(rc={returncode}). The CLI may have failed before reaching "
f"the network — stderr tail: {stderr[-500:]!r}. "
"Nothing to assert; treating as inconclusive rather than "
"either passing or failing."
)
all_findings: list[str] = []
for req in captured:
findings = _scan_request_for_forbidden_patterns(req.body, req.headers)
if findings:
all_findings.extend(f"{req.path}: {finding}" for finding in findings)
assert not all_findings, (
f"Bundled Claude Code CLI sent OpenRouter-incompatible features in "
f"{len(all_findings)} request(s):\n - "
+ "\n - ".join(all_findings)
+ "\n\nThis is the regression that prevents us from upgrading "
"`claude-agent-sdk` above 0.1.45. See "
"https://github.com/Significant-Gravitas/AutoGPT/pull/12294 and "
"https://github.com/anthropics/claude-agent-sdk-python/issues/789. "
"If you intended to upgrade, you must use a known-good CLI binary "
"via `claude_agent_cli_path` (env: `CLAUDE_AGENT_CLI_PATH` or "
"`CHAT_CLAUDE_AGENT_CLI_PATH`) instead of the bundled one."
)
def test_subprocess_module_available():
"""Sentinel test: the subprocess module must be importable so the
main reproduction test can spawn the CLI. Catches sandboxed CI
runners that block subprocess execution before the slow test runs."""
assert subprocess.__name__ == "subprocess"
# ---------------------------------------------------------------------------
# Pure helper unit tests — pin the forbidden-pattern detection so any
# future drift in the scanner is caught fast, even when the slow
# end-to-end CLI subprocess test isn't runnable.
# ---------------------------------------------------------------------------
class TestScanRequestForForbiddenPatterns:
def test_clean_body_returns_empty_findings(self):
body = '{"model": "claude-opus-4.6", "messages": [{"role": "user", "content": "hi"}]}'
assert _scan_request_for_forbidden_patterns(body, {}) == []
def test_detects_tool_reference_in_body(self):
body = (
'{"messages": [{"role": "user", "content": ['
'{"type": "tool_reference", "tool_name": "find"}'
"]}]}"
)
findings = _scan_request_for_forbidden_patterns(body, {})
assert len(findings) == 1
assert "tool_reference" in findings[0]
assert "PR #12294" in findings[0]
def test_detects_context_management_in_body(self):
body = '{"betas": ["context-management-2025-06-27"]}'
findings = _scan_request_for_forbidden_patterns(body, {})
assert len(findings) == 1
assert "context-management-2025-06-27" in findings[0]
assert "#789" in findings[0]
def test_detects_context_management_in_anthropic_beta_header(self):
findings = _scan_request_for_forbidden_patterns(
body_text="{}",
headers={"anthropic-beta": "context-management-2025-06-27"},
)
assert len(findings) == 1
assert "anthropic-beta" in findings[0]
def test_detects_context_management_in_uppercase_header_name(self):
# HTTP header names are case-insensitive — make sure the
# scanner handles a server that didn't normalise names.
findings = _scan_request_for_forbidden_patterns(
body_text="{}",
headers={"Anthropic-Beta": "context-management-2025-06-27, other"},
)
assert len(findings) == 1
def test_ignores_unrelated_header_values(self):
findings = _scan_request_for_forbidden_patterns(
body_text="{}",
headers={
"authorization": "Bearer secret",
"anthropic-beta": "fine-grained-tool-streaming-2025",
},
)
assert findings == []
def test_detects_both_patterns_simultaneously(self):
body = (
'{"betas": ["context-management-2025-06-27"], '
'"messages": [{"role": "user", "content": ['
'{"type": "tool_reference", "tool_name": "find"}'
"]}]}"
)
findings = _scan_request_for_forbidden_patterns(body, {})
# Both patterns hit, in stable order: tool_reference then betas.
assert len(findings) == 2
assert "tool_reference" in findings[0]
assert "context-management-2025-06-27" in findings[1]
class TestResolveCliPath:
def test_honours_explicit_env_var_when_file_exists(self, tmp_path, monkeypatch):
fake_cli = tmp_path / "fake-claude"
fake_cli.write_text("#!/bin/sh\necho fake\n")
fake_cli.chmod(0o755)
monkeypatch.delenv("CHAT_CLAUDE_AGENT_CLI_PATH", raising=False)
monkeypatch.setenv("CLAUDE_AGENT_CLI_PATH", str(fake_cli))
resolved = _resolve_cli_path()
assert resolved == fake_cli
def test_honours_chat_prefixed_env_var_when_file_exists(
self, tmp_path, monkeypatch
):
"""The Pydantic ``CHAT_`` prefix variant is also honoured.
Mirrors ``ChatConfig.get_claude_agent_cli_path`` which accepts
either ``CHAT_CLAUDE_AGENT_CLI_PATH`` (prefix applied by
``pydantic_settings``) or the unprefixed ``CLAUDE_AGENT_CLI_PATH``
form documented in the PR and field docstring.
"""
fake_cli = tmp_path / "fake-claude-prefixed"
fake_cli.write_text("#!/bin/sh\necho fake\n")
fake_cli.chmod(0o755)
monkeypatch.delenv("CLAUDE_AGENT_CLI_PATH", raising=False)
monkeypatch.setenv("CHAT_CLAUDE_AGENT_CLI_PATH", str(fake_cli))
resolved = _resolve_cli_path()
assert resolved == fake_cli
def test_returns_none_when_env_var_points_to_missing_file(self, monkeypatch):
monkeypatch.delenv("CHAT_CLAUDE_AGENT_CLI_PATH", raising=False)
monkeypatch.setenv("CLAUDE_AGENT_CLI_PATH", "/nonexistent/path/to/claude")
# When the override is set but the file is missing, the resolver
# returns ``None`` outright — it does NOT silently fall through to
# the bundled binary, because doing so would defeat the purpose of
# the override (the operator explicitly asked for a specific path).
# The strict ``is None`` assertion catches any future regression
# that swaps this fail-loud behaviour for a silent fallback.
resolved = _resolve_cli_path()
assert resolved is None
def test_falls_back_to_bundled_when_env_var_unset(self, monkeypatch):
monkeypatch.delenv("CLAUDE_AGENT_CLI_PATH", raising=False)
monkeypatch.delenv("CHAT_CLAUDE_AGENT_CLI_PATH", raising=False)
# Same caveat as above — returns the bundled path or None,
# depending on what's installed in the test env.
resolved = _resolve_cli_path()
assert resolved is None or resolved.is_file()

View File

@@ -226,111 +226,6 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
assert was_compacted is False # mock returns False
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_prevents_pending_duplication():
"""session_msg_ceiling stops pending messages from leaking into the gap.
Scenario: transcript covers 2 messages, session has 2 historical + 1 current
+ 2 pending drained at turn start. Without the ceiling the gap would include
the pending messages AND current_message already has them → duplication.
With session_msg_ceiling=3 (pre-drain count) the gap slice is empty and
only current_message carries the pending content.
"""
# session.messages after drain: [hist1, hist2, current_msg, pending1, pending2]
session = _make_session(
[
ChatMessage(role="user", content="hist1"),
ChatMessage(role="assistant", content="hist2"),
ChatMessage(role="user", content="current msg with pending1 pending2"),
ChatMessage(role="user", content="pending1"),
ChatMessage(role="user", content="pending2"),
]
)
# transcript covers hist1+hist2 (2 messages); pre-drain count was 3 (includes current_msg)
result, was_compacted = await _build_query_message(
"current msg with pending1 pending2",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
session_msg_ceiling=3, # len(session.messages) before drain
)
# Gap should be empty (transcript_msg_count == ceiling - 1), so no history prepended
assert result == "current msg with pending1 pending2"
assert was_compacted is False
# Pending messages must NOT appear in gap context
assert "pending1" not in result.split("current msg")[0]
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_preserves_real_gap():
"""session_msg_ceiling still surfaces a genuine stale-transcript gap.
Scenario: transcript covers 2 messages, session has 4 historical + 1 current
+ 2 pending. Ceiling = 5 (pre-drain). Real gap = messages 2-3 (hist3, hist4).
"""
session = _make_session(
[
ChatMessage(role="user", content="hist1"),
ChatMessage(role="assistant", content="hist2"),
ChatMessage(role="user", content="hist3"),
ChatMessage(role="assistant", content="hist4"),
ChatMessage(role="user", content="current"),
ChatMessage(role="user", content="pending1"),
ChatMessage(role="user", content="pending2"),
]
)
result, was_compacted = await _build_query_message(
"current",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
session_msg_ceiling=5, # pre-drain: [hist1..hist4, current]
)
# Gap = session.messages[2:4] = [hist3, hist4]
assert "<conversation_history>" in result
assert "hist3" in result
assert "hist4" in result
assert "Now, the user says:\ncurrent" in result
# Pending messages must NOT appear in gap
assert "pending1" not in result
assert "pending2" not in result
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_suppresses_spurious_no_resume_fallback():
"""session_msg_ceiling prevents the no-resume compression fallback from
firing on the first turn of a session when pending messages inflate msg_count.
Scenario: fresh session (1 message) + 1 pending message drained at turn start.
Without the ceiling: msg_count=2 > 1 → fallback triggers → pending message
leaked into history → wrong context sent to model.
With session_msg_ceiling=1 (pre-drain count): effective_count=1, 1 > 1 is False
→ fallback does not trigger → current_message returned as-is.
"""
# session.messages after drain: [current_msg, pending_msg]
session = _make_session(
[
ChatMessage(role="user", content="What is 2 plus 2?"),
ChatMessage(role="user", content="What is 7 plus 7?"), # pending
]
)
result, was_compacted = await _build_query_message(
"What is 2 plus 2?\n\nWhat is 7 plus 7?",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
session_msg_ceiling=1, # pre-drain: only 1 message existed
)
# Should return current_message directly without wrapping in history context
assert result == "What is 2 plus 2?\n\nWhat is 7 plus 7?"
assert was_compacted is False
# Pending question must NOT appear in a spurious history section
assert "<conversation_history>" not in result
@pytest.mark.asyncio
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
"""When compression actually compacts, was_compacted should be True."""

View File

@@ -1031,12 +1031,6 @@ def _make_sdk_patches(
),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
# Stub pending-message drain so retry tests don't hit Redis.
# Returns an empty list → no mid-turn injection happens.
(
f"{_SVC}.drain_pending_messages",
dict(new_callable=AsyncMock, return_value=[]),
),
]

View File

@@ -196,3 +196,79 @@ def test_sdk_exports_hook_event_type(hook_event: str):
# HookEvent is a Literal type — check that our events are valid values.
# We can't easily inspect Literal at runtime, so just verify the type exists.
assert HookEvent is not None
# ---------------------------------------------------------------------------
# OpenRouter compatibility — bundled CLI version pin
# ---------------------------------------------------------------------------
#
# We're stuck on ``claude-agent-sdk==0.1.45`` (bundled CLI ``2.1.63``)
# because every version above introduces a 400 against OpenRouter:
#
# 1. CLI ``2.1.69`` (= SDK ``0.1.46``) shipped a `tool_reference` content
# block in `tool_result.content` that OpenRouter's stricter Zod
# validation rejects. See PR
# https://github.com/Significant-Gravitas/AutoGPT/pull/12294 for the
# forensic write-up that originally pinned us. CLI ``2.1.70`` added
# proxy detection that *should* disable the offending block, but two
# later attempts (Dependabot bumps to 0.1.55 / 0.1.56) still failed.
#
# 2. A second regression — the ``context-management-2025-06-27`` beta
# header — appeared in some CLI version after ``2.1.91``. Tracked
# upstream at
# https://github.com/anthropics/claude-agent-sdk-python/issues/789
# (still open at the time of writing, no upstream PR yet).
#
# This test is the cheapest possible regression guard: it pins the
# bundled CLI to a known-good version. If anyone bumps
# ``claude-agent-sdk`` in ``pyproject.toml``, the bundled CLI version in
# ``_cli_version.py`` will change and this test will fail with a clear
# message that points the next person at the OpenRouter compat issue
# instead of letting them silently re-break production.
#
# Workaround for actually upgrading: set the
# ``claude_agent_cli_path`` config option (or the matching env var) to
# point at a separately-installed Claude Code CLI binary at a known-good
# version, so the SDK Python API surface and the CLI binary version can
# be picked independently.
# CLI versions verified to work against OpenRouter from production
# traffic. When upstream lands a fix and we can confirm a newer version
# works, add it to this set rather than blanket-removing the assertion.
_KNOWN_GOOD_BUNDLED_CLI_VERSIONS: frozenset[str] = frozenset({"2.1.63"})
def test_bundled_cli_version_is_known_good_against_openrouter():
"""Pin the bundled CLI version so accidental SDK bumps cause a loud,
fast failure with a pointer to the OpenRouter compatibility issue."""
from claude_agent_sdk._cli_version import __cli_version__
assert __cli_version__ in _KNOWN_GOOD_BUNDLED_CLI_VERSIONS, (
f"Bundled Claude Code CLI version is {__cli_version__!r}, which is "
f"not in the OpenRouter-known-good set "
f"{sorted(_KNOWN_GOOD_BUNDLED_CLI_VERSIONS)!r}. "
"If you intentionally bumped `claude-agent-sdk`, verify the new "
"bundled CLI works with OpenRouter against the reproduction test "
"in `cli_openrouter_compat_test.py`, then add the new CLI version "
"to `_KNOWN_GOOD_BUNDLED_CLI_VERSIONS`. If you cannot make the "
"bundled CLI work, set `claude_agent_cli_path` to a known-good "
"binary instead and skip the bundled one. See "
"https://github.com/anthropics/claude-agent-sdk-python/issues/789 "
"and https://github.com/Significant-Gravitas/AutoGPT/pull/12294."
)
def test_sdk_exposes_cli_path_option():
"""Sanity-check that the SDK still exposes the `cli_path` option we use
for the OpenRouter workaround. If upstream removes it we need to know."""
import inspect
from claude_agent_sdk import ClaudeAgentOptions
sig = inspect.signature(ClaudeAgentOptions)
assert "cli_path" in sig.parameters, (
"ClaudeAgentOptions no longer accepts `cli_path` — our "
"claude_agent_cli_path config override would be silently ignored. "
"Either find an alternative override mechanism or pin the SDK to a "
"version that still exposes it."
)

View File

@@ -34,10 +34,6 @@ from opentelemetry import trace as otel_trace
from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.copilot.pending_messages import (
drain_pending_messages,
format_pending_as_user_message,
)
from backend.copilot.permissions import apply_tool_permissions
from backend.copilot.rate_limit import get_user_tier
from backend.copilot.transcript import (
@@ -959,33 +955,17 @@ async def _build_query_message(
use_resume: bool,
transcript_msg_count: int,
session_id: str,
*,
session_msg_ceiling: int | None = None,
) -> tuple[str, bool]:
"""Build the query message with appropriate context.
Args:
session_msg_ceiling: If provided, treat ``session.messages`` as if it
only has this many entries when computing the gap slice. Pass
``len(session.messages)`` captured *before* appending any pending
messages so that mid-turn drains do not skew the gap calculation
and cause pending messages to be duplicated in both the gap context
and ``current_message``.
Returns:
Tuple of (query_message, was_compacted).
"""
msg_count = len(session.messages)
# Use the ceiling if supplied (prevents pending-message duplication when
# messages were appended to session.messages after the drain but before
# this function is called).
effective_count = (
session_msg_ceiling if session_msg_ceiling is not None else msg_count
)
if use_resume and transcript_msg_count > 0:
if transcript_msg_count < effective_count - 1:
gap = session.messages[transcript_msg_count : effective_count - 1]
if transcript_msg_count < msg_count - 1:
gap = session.messages[transcript_msg_count:-1]
compressed, was_compressed = await _compress_messages(gap)
gap_context = _format_conversation_context(compressed)
if gap_context:
@@ -1001,14 +981,12 @@ async def _build_query_message(
f"{gap_context}\n\nNow, the user says:\n{current_message}",
was_compressed,
)
elif not use_resume and effective_count > 1:
elif not use_resume and msg_count > 1:
logger.warning(
f"[SDK] Using compression fallback for session "
f"{session_id} ({effective_count} messages) — no transcript for --resume"
)
compressed, was_compressed = await _compress_messages(
session.messages[: effective_count - 1]
f"{session_id} ({msg_count} messages) — no transcript for --resume"
)
compressed, was_compressed = await _compress_messages(session.messages[:-1])
history_context = _format_conversation_context(compressed)
if history_context:
return (
@@ -2064,7 +2042,6 @@ async def stream_chat_completion_sdk(
async def _fetch_transcript():
"""Download transcript for --resume if applicable."""
assert session is not None # narrowed at line 1898
if not (
config.claude_agent_use_resume and user_id and len(session.messages) > 1
):
@@ -2268,6 +2245,12 @@ async def stream_chat_completion_sdk(
sdk_options_kwargs["env"] = sdk_env
if use_resume and resume_file:
sdk_options_kwargs["resume"] = resume_file
# Optional explicit Claude Code CLI binary path (decouples the
# bundled SDK version from the CLI version we run — needed because
# the CLI bundled in 0.1.46+ is broken against OpenRouter). Falls
# back to the bundled binary when unset.
if config.claude_agent_cli_path:
sdk_options_kwargs["cli_path"] = config.claude_agent_cli_path
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] # dynamic kwargs
@@ -2300,61 +2283,6 @@ async def stream_chat_completion_sdk(
if last_user:
current_message = last_user[-1].content or ""
# Capture the message count *before* draining so _build_query_message
# can compute the gap slice without including the newly-drained pending
# messages. Pending messages are both appended to session.messages AND
# concatenated into current_message; without the ceiling the gap slice
# would extend into the pending messages and duplicate them in the
# model's input context (gap_context + current_message both containing
# them).
_pre_drain_msg_count = len(session.messages)
# Drain any messages the user queued via POST /messages/pending
# while the previous turn was running (or since the session was
# idle). Messages are drained ATOMICALLY — one LPOP with count
# removes them all at once, so a concurrent push lands *after*
# the drain and stays queued for the next turn instead of being
# lost between LPOP and clear. File IDs and context are
# preserved via format_pending_as_user_message.
#
# The drained content is concatenated into ``current_message``
# so the SDK CLI sees it in the new user message, AND appended
# directly to ``session.messages`` (no dedup — pending messages are
# atomically-popped from Redis and are never stale-cache duplicates)
# so the durable transcript records it too. Session is persisted
# immediately after the drain so a crash doesn't lose the messages.
# The endpoint deliberately does NOT persist to session.messages —
# Redis is the single source of truth until this drain runs.
pending_at_start = await drain_pending_messages(session_id)
if pending_at_start:
logger.info(
"%s Draining %d pending message(s) at turn start",
log_prefix,
len(pending_at_start),
)
pending_texts: list[str] = [
format_pending_as_user_message(pm)["content"] for pm in pending_at_start
]
for pt in pending_texts:
# Append directly — pending messages are atomically-popped from
# Redis and are never stale-cache duplicates, so the
# maybe_append_user_message dedup is wrong here.
session.messages.append(ChatMessage(role="user", content=pt))
if current_message.strip():
current_message = current_message + "\n\n" + "\n\n".join(pending_texts)
else:
current_message = "\n\n".join(pending_texts)
# Persist immediately so a crash between here and the finally block
# doesn't lose messages that were already drained from Redis.
try:
session = await upsert_chat_session(session)
except Exception as _persist_err:
logger.warning(
"%s Failed to persist drained pending messages: %s",
log_prefix,
_persist_err,
)
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
@@ -2368,7 +2296,6 @@ async def stream_chat_completion_sdk(
use_resume,
transcript_msg_count,
session_id,
session_msg_ceiling=_pre_drain_msg_count,
)
# On the first turn inject user context into the message instead of the
# system prompt — the system prompt is now static (same for all users)
@@ -2506,7 +2433,6 @@ async def stream_chat_completion_sdk(
state.use_resume,
state.transcript_msg_count,
session_id,
session_msg_ceiling=_pre_drain_msg_count,
)
if attachments.hint:
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
@@ -2836,11 +2762,6 @@ async def stream_chat_completion_sdk(
raise
finally:
# Pending messages are drained atomically at the start of each
# turn (see drain_pending_messages call above), so there's
# nothing to clean up here — any message pushed after that
# point belongs to the next turn.
# --- Close OTEL context (with cost attributes) ---
if _otel_ctx is not None:
try:

View File

@@ -4,10 +4,10 @@ from datetime import datetime, timedelta, timezone
from typing import Any
from prisma.models import PlatformCostLog as PrismaLog
from prisma.types import PlatformCostLogCreateInput
from prisma.models import User as PrismaUser
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
from pydantic import BaseModel
from backend.data.db import query_raw_with_schema
from backend.util.cache import cached
from backend.util.json import SafeJson
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
MICRODOLLARS_PER_USD = 1_000_000
# Dashboard query limits — keep in sync with the SQL queries below
# Dashboard query limits
MAX_PROVIDER_ROWS = 500
MAX_USER_ROWS = 100
@@ -169,53 +169,61 @@ class PlatformCostDashboard(BaseModel):
total_users: int
def _build_where(
def _si(row: dict, field: str) -> int:
"""Extract an integer from a Prisma group_by _sum dict.
Prisma Python serialises BigInt/Int aggregate sums as strings; coerce to int.
"""
return int((row.get("_sum") or {}).get(field) or 0)
def _sf(row: dict, field: str) -> float:
"""Extract a float from a Prisma group_by _sum dict."""
return float((row.get("_sum") or {}).get(field) or 0.0)
def _ca(row: dict) -> int:
"""Extract _count._all from a Prisma group_by row."""
c = row.get("_count") or {}
return int(c.get("_all") or 0) if isinstance(c, dict) else int(c or 0)
def _build_prisma_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
table_alias: str = "",
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
) -> tuple[str, list[Any]]:
prefix = f"{table_alias}." if table_alias else ""
clauses: list[str] = []
params: list[Any] = []
idx = 1
) -> PlatformCostLogWhereInput:
"""Build a Prisma WhereInput for PlatformCostLog filters."""
where: PlatformCostLogWhereInput = {}
if start and end:
where["createdAt"] = {"gte": start, "lte": end}
elif start:
where["createdAt"] = {"gte": start}
elif end:
where["createdAt"] = {"lte": end}
if start:
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
params.append(start)
idx += 1
if end:
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
params.append(end)
idx += 1
if provider:
# Provider names are normalized to lowercase at write time so a plain
# equality check is sufficient and the (provider, createdAt) index is used.
clauses.append(f'{prefix}"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id:
clauses.append(f'{prefix}"userId" = ${idx}')
params.append(user_id)
idx += 1
if model:
clauses.append(f'{prefix}"model" = ${idx}')
params.append(model)
idx += 1
if block_name:
clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})')
params.append(block_name)
idx += 1
if tracking_type:
clauses.append(f'{prefix}"trackingType" = ${idx}')
params.append(tracking_type)
idx += 1
where["provider"] = provider.lower()
return (" AND ".join(clauses) if clauses else "TRUE", params)
if user_id:
where["userId"] = user_id
if model:
where["model"] = model
if block_name:
# Case-insensitive match — mirrors the original LOWER() SQL filter.
where["blockName"] = {"equals": block_name, "mode": "insensitive"}
if tracking_type:
where["trackingType"] = tracking_type
return where
@cached(ttl_seconds=30)
@@ -241,110 +249,107 @@ async def get_platform_cost_dashboard(
"""
if start is None:
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_p, params_p = _build_where(
start, end, provider, user_id, "p", model, block_name, tracking_type
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
)
by_provider_rows, by_user_rows, total_user_rows, total_agg_rows = (
sum_fields = {
"costMicrodollars": True,
"inputTokens": True,
"outputTokens": True,
"cacheReadTokens": True,
"cacheCreationTokens": True,
"duration": True,
"trackingAmount": True,
}
# Run all four aggregation queries in parallel.
by_provider_groups, by_user_groups, total_user_groups, total_agg_groups = (
await asyncio.gather(
query_raw_with_schema(
f"""
SELECT
p."provider",
p."trackingType" AS tracking_type,
p."model",
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
COALESCE(SUM(p."cacheReadTokens"), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(p."cacheCreationTokens"), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
GROUP BY p."provider", p."trackingType", p."model"
ORDER BY total_cost DESC
LIMIT {MAX_PROVIDER_ROWS}
""",
*params_p,
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
# sort by total cost descending in Python after fetch.
PrismaLog.prisma().group_by(
by=["provider", "trackingType", "model"],
where=where,
sum=sum_fields,
count=True,
),
query_raw_with_schema(
f"""
SELECT
p."userId" AS user_id,
u."email",
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_p}
GROUP BY p."userId", u."email"
ORDER BY total_cost DESC
LIMIT {MAX_USER_ROWS}
""",
*params_p,
# userId aggregation — emails fetched separately below.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
sum=sum_fields,
count=True,
),
query_raw_with_schema(
f"""
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
""",
*params_p,
# Distinct user count: group by userId, count groups.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
count=True,
),
# Separate aggregate query so dashboard totals are never derived
# from the capped by_provider_rows list. With model-level grouping,
# MAX_PROVIDER_ROWS is hit more easily; summing the capped rows
# would silently undercount once >500 (provider, type, model) exist.
query_raw_with_schema(
f"""
SELECT
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
""",
*params_p,
# Total aggregate: group by provider (no limit) to sum across all
# matching rows. Summed in Python to get grand totals.
PrismaLog.prisma().group_by(
by=["provider"],
where=where,
sum={"costMicrodollars": True},
count=True,
),
)
)
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
total_cost = int(total_agg_rows[0]["total_cost"]) if total_agg_rows else 0
total_requests = int(total_agg_rows[0]["request_count"]) if total_agg_rows else 0
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
by_provider_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True)
by_provider_groups = by_provider_groups[:MAX_PROVIDER_ROWS]
# Sort by_user by total cost descending and cap at MAX_USER_ROWS.
by_user_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True)
by_user_groups = by_user_groups[:MAX_USER_ROWS]
# Batch-fetch emails for the users in by_user.
user_ids = [r["userId"] for r in by_user_groups if r.get("userId") is not None]
email_by_user_id: dict[str, str | None] = {}
if user_ids:
users = await PrismaUser.prisma().find_many(
where={"id": {"in": user_ids}},
)
email_by_user_id = {u.id: u.email for u in users}
# Total distinct users — exclude the NULL-userId group (deleted users).
total_users = len([g for g in total_user_groups if g.get("userId") is not None])
# Grand totals — sum across all provider groups (no LIMIT applied above).
total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups)
total_requests = sum(_ca(r) for r in total_agg_groups)
return PlatformCostDashboard(
by_provider=[
ProviderCostSummary(
provider=r["provider"],
tracking_type=r.get("tracking_type"),
tracking_type=r.get("trackingType"),
model=r.get("model"),
total_cost_microdollars=r["total_cost"],
total_input_tokens=r["total_input_tokens"],
total_output_tokens=r["total_output_tokens"],
total_cache_read_tokens=r.get("total_cache_read_tokens", 0),
total_cache_creation_tokens=r.get("total_cache_creation_tokens", 0),
total_duration_seconds=r.get("total_duration", 0.0),
total_tracking_amount=r.get("total_tracking_amount", 0.0),
request_count=r["request_count"],
total_cost_microdollars=_si(r, "costMicrodollars"),
total_input_tokens=_si(r, "inputTokens"),
total_output_tokens=_si(r, "outputTokens"),
total_cache_read_tokens=_si(r, "cacheReadTokens"),
total_cache_creation_tokens=_si(r, "cacheCreationTokens"),
total_duration_seconds=_sf(r, "duration"),
total_tracking_amount=_sf(r, "trackingAmount"),
request_count=_ca(r),
)
for r in by_provider_rows
for r in by_provider_groups
],
by_user=[
UserCostSummary(
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
total_cost_microdollars=r["total_cost"],
total_input_tokens=r["total_input_tokens"],
total_output_tokens=r["total_output_tokens"],
request_count=r["request_count"],
user_id=r.get("userId"),
email=_mask_email(email_by_user_id.get(r.get("userId") or "")),
total_cost_microdollars=_si(r, "costMicrodollars"),
total_input_tokens=_si(r, "inputTokens"),
total_output_tokens=_si(r, "outputTokens"),
request_count=_ca(r),
)
for r in by_user_rows
for r in by_user_groups
],
total_cost_microdollars=total_cost,
total_requests=total_requests,
@@ -365,73 +370,41 @@ async def get_platform_cost_logs(
) -> tuple[list[CostLogRow], int]:
if start is None:
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_sql, params = _build_where(
start, end, provider, user_id, "p", model, block_name, tracking_type
)
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
)
offset = (page - 1) * page_size
limit_idx = len(params) + 1
offset_idx = len(params) + 2
count_rows, rows = await asyncio.gather(
query_raw_with_schema(
f"""
SELECT COUNT(*)::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_sql}
""",
*params,
),
query_raw_with_schema(
f"""
SELECT
p."id",
p."createdAt" AS created_at,
p."userId" AS user_id,
u."email",
p."graphExecId" AS graph_exec_id,
p."nodeExecId" AS node_exec_id,
p."blockName" AS block_name,
p."provider",
p."trackingType" AS tracking_type,
p."costMicrodollars" AS cost_microdollars,
p."inputTokens" AS input_tokens,
p."outputTokens" AS output_tokens,
p."cacheReadTokens" AS cache_read_tokens,
p."cacheCreationTokens" AS cache_creation_tokens,
p."duration",
p."model"
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_sql}
ORDER BY p."createdAt" DESC, p."id" DESC
LIMIT ${limit_idx} OFFSET ${offset_idx}
""",
*params,
page_size,
offset,
total, rows = await asyncio.gather(
PrismaLog.prisma().count(where=where),
PrismaLog.prisma().find_many(
where=where,
include={"User": True},
order=[{"createdAt": "desc"}, {"id": "desc"}],
take=page_size,
skip=offset,
),
)
total = count_rows[0]["cnt"] if count_rows else 0
logs = [
CostLogRow(
id=r["id"],
created_at=r["created_at"],
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
graph_exec_id=r.get("graph_exec_id"),
node_exec_id=r.get("node_exec_id"),
block_name=r["block_name"],
provider=r["provider"],
tracking_type=r.get("tracking_type"),
cost_microdollars=r.get("cost_microdollars"),
input_tokens=r.get("input_tokens"),
output_tokens=r.get("output_tokens"),
cache_read_tokens=r.get("cache_read_tokens"),
cache_creation_tokens=r.get("cache_creation_tokens"),
duration=r.get("duration"),
model=r.get("model"),
id=r.id,
created_at=r.createdAt,
user_id=r.userId,
email=_mask_email(r.User.email if r.User else None),
graph_exec_id=r.graphExecId,
node_exec_id=r.nodeExecId,
block_name=r.blockName or "",
provider=r.provider,
tracking_type=r.trackingType,
cost_microdollars=r.costMicrodollars,
input_tokens=r.inputTokens,
output_tokens=r.outputTokens,
cache_read_tokens=getattr(r, "cacheReadTokens", None),
cache_creation_tokens=getattr(r, "cacheCreationTokens", None),
duration=r.duration,
model=r.model,
)
for r in rows
]
@@ -457,38 +430,16 @@ async def get_platform_cost_logs_for_export(
"""
if start is None:
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_sql, params = _build_where(
start, end, provider, user_id, "p", model, block_name, tracking_type
)
limit_idx = len(params) + 1
rows = await query_raw_with_schema(
f"""
SELECT
p."id",
p."createdAt" AS created_at,
p."userId" AS user_id,
u."email",
p."graphExecId" AS graph_exec_id,
p."nodeExecId" AS node_exec_id,
p."blockName" AS block_name,
p."provider",
p."trackingType" AS tracking_type,
p."costMicrodollars" AS cost_microdollars,
p."inputTokens" AS input_tokens,
p."outputTokens" AS output_tokens,
p."cacheReadTokens" AS cache_read_tokens,
p."cacheCreationTokens" AS cache_creation_tokens,
p."duration",
p."model"
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_sql}
ORDER BY p."createdAt" DESC, p."id" DESC
LIMIT ${limit_idx}
""",
*params,
EXPORT_MAX_ROWS + 1,
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
)
rows = await PrismaLog.prisma().find_many(
where=where,
include={"User": True},
order=[{"createdAt": "desc"}, {"id": "desc"}],
take=EXPORT_MAX_ROWS + 1,
)
truncated = len(rows) > EXPORT_MAX_ROWS
@@ -496,22 +447,80 @@ async def get_platform_cost_logs_for_export(
return [
CostLogRow(
id=r["id"],
created_at=r["created_at"],
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
graph_exec_id=r.get("graph_exec_id"),
node_exec_id=r.get("node_exec_id"),
block_name=r["block_name"],
provider=r["provider"],
tracking_type=r.get("tracking_type"),
cost_microdollars=r.get("cost_microdollars"),
input_tokens=r.get("input_tokens"),
output_tokens=r.get("output_tokens"),
cache_read_tokens=r.get("cache_read_tokens"),
cache_creation_tokens=r.get("cache_creation_tokens"),
duration=r.get("duration"),
model=r.get("model"),
id=r.id,
created_at=r.createdAt,
user_id=r.userId,
email=_mask_email(r.User.email if r.User else None),
graph_exec_id=r.graphExecId,
node_exec_id=r.nodeExecId,
block_name=r.blockName or "",
provider=r.provider,
tracking_type=r.trackingType,
cost_microdollars=r.costMicrodollars,
input_tokens=r.inputTokens,
output_tokens=r.outputTokens,
cache_read_tokens=getattr(r, "cacheReadTokens", None),
cache_creation_tokens=getattr(r, "cacheCreationTokens", None),
duration=r.duration,
model=r.model,
)
for r in rows
], truncated
# ---------------------------------------------------------------------------
# Helpers kept for backward-compatibility with existing tests.
# New code should not use these — use _build_prisma_where instead.
# ---------------------------------------------------------------------------
def _build_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
table_alias: str = "",
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
) -> tuple[str, list[Any]]:
"""Legacy SQL WHERE builder — retained so existing unit tests still pass.
Only used by tests that verify the SQL-string generation logic. All
production code uses _build_prisma_where instead.
"""
prefix = f"{table_alias}." if table_alias else ""
clauses: list[str] = []
params: list[Any] = []
idx = 1
if start:
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
params.append(start)
idx += 1
if end:
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
params.append(end)
idx += 1
if provider:
clauses.append(f'{prefix}"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id:
clauses.append(f'{prefix}"userId" = ${idx}')
params.append(user_id)
idx += 1
if model:
clauses.append(f'{prefix}"model" = ${idx}')
params.append(model)
idx += 1
if block_name:
clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})')
params.append(block_name)
idx += 1
if tracking_type:
clauses.append(f'{prefix}"trackingType" = ${idx}')
params.append(tracking_type)
idx += 1
return (" AND ".join(clauses) if clauses else "TRUE", params)

View File

@@ -1,7 +1,7 @@
"""Unit tests for helpers and async functions in platform_cost module."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma import Json
@@ -224,6 +224,41 @@ class TestLogPlatformCostSafe:
mock_create.assert_awaited_once()
def _make_group_by_row(
provider: str = "openai",
tracking_type: str | None = "tokens",
model: str | None = None,
cost: int = 5000,
input_tokens: int = 1000,
output_tokens: int = 500,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
duration: float = 10.5,
tracking_amount: float = 0.0,
count: int = 3,
user_id: str | None = None,
) -> dict:
row: dict = {
"_sum": {
"costMicrodollars": cost,
"inputTokens": input_tokens,
"outputTokens": output_tokens,
"cacheReadTokens": cache_read_tokens,
"cacheCreationTokens": cache_creation_tokens,
"duration": duration,
"trackingAmount": tracking_amount,
},
"_count": {"_all": count},
}
if user_id is not None:
row["userId"] = user_id
else:
row["provider"] = provider
row["trackingType"] = tracking_type
row["model"] = model
return row
class TestGetPlatformCostDashboard:
def setup_method(self):
# @cached stores results in-process; clear between tests to avoid bleed.
@@ -231,35 +266,44 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_returns_dashboard_with_data(self):
provider_rows = [
{
"provider": "openai",
"tracking_type": "tokens",
"total_cost": 5000,
"total_input_tokens": 1000,
"total_output_tokens": 500,
"total_duration": 10.5,
"request_count": 3,
}
]
user_rows = [
{
"user_id": "u1",
"email": "a@b.com",
"total_cost": 5000,
"total_input_tokens": 1000,
"total_output_tokens": 500,
"request_count": 3,
}
]
# Dashboard runs 4 queries: by_provider, by_user, COUNT(DISTINCT userId),
# and a separate total aggregate (total_cost + request_count with no LIMIT).
agg_rows = [{"total_cost": 5000, "request_count": 3}]
mock_query = AsyncMock(
side_effect=[provider_rows, user_rows, [{"cnt": 1}], agg_rows]
provider_row = _make_group_by_row(
provider="openai",
tracking_type="tokens",
cost=5000,
input_tokens=1000,
output_tokens=500,
duration=10.5,
count=3,
)
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
user_row = _make_group_by_row(user_id="u1", cost=5000, count=3)
mock_user = MagicMock()
mock_user.id = "u1"
mock_user.email = "a@b.com"
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[{"userId": "u1"}], # distinct users
[provider_row], # total agg
]
)
mock_actions.find_many = AsyncMock(return_value=[mock_user])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 5000
assert dashboard.total_requests == 3
assert dashboard.total_users == 1
@@ -271,10 +315,67 @@ class TestGetPlatformCostDashboard:
assert dashboard.by_user[0].email == "a***@b.com"
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_query = AsyncMock(side_effect=[[], [], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
async def test_cache_tokens_aggregated_not_hardcoded(self):
"""cache_read_tokens and cache_creation_tokens must be read from the
DB aggregation, not hardcoded to 0 (regression guard for Sentry report)."""
provider_row = _make_group_by_row(
provider="anthropic",
tracking_type="tokens",
cost=1000,
input_tokens=800,
output_tokens=200,
cache_read_tokens=400,
cache_creation_tokens=100,
count=1,
)
user_row = _make_group_by_row(user_id="u2", cost=1000, count=1)
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[{"userId": "u2"}], # distinct users
[provider_row], # total agg
]
)
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
):
dashboard = await get_platform_cost_dashboard()
assert len(dashboard.by_provider) == 1
row = dashboard.by_provider[0]
assert row.total_cache_read_tokens == 400
assert row.total_cache_creation_tokens == 100
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 0
assert dashboard.total_requests == 0
assert dashboard.total_users == 0
@@ -284,160 +385,228 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_passes_filters_to_queries(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(side_effect=[[], [], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
assert mock_query.await_count == 4
first_call_sql = mock_query.call_args_list[0][0][0]
assert "createdAt" in first_call_sql
# group_by called 4 times (by_provider, by_user, distinct users, totals)
assert mock_actions.group_by.await_count == 4
# The where dict passed to the first call should include createdAt
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
assert "createdAt" in first_call_kwargs.get("where", {})
def _make_prisma_log_row(
i: int = 0,
user_email: str | None = None,
) -> MagicMock:
row = MagicMock()
row.id = f"log-{i}"
row.createdAt = datetime(2026, 3, 1, tzinfo=timezone.utc)
row.userId = "u1"
row.graphExecId = None
row.nodeExecId = None
row.blockName = "TestBlock"
row.provider = "openai"
row.trackingType = "tokens"
row.costMicrodollars = 1000
row.inputTokens = 10
row.outputTokens = 5
row.duration = 0.5
row.model = "gpt-4"
# cacheReadTokens / cacheCreationTokens may not exist on older Prisma clients
row.configure_mock(**{"cacheReadTokens": None, "cacheCreationTokens": None})
if user_email is not None:
row.User = MagicMock()
row.User.email = user_email
else:
row.User = None
return row
class TestGetPlatformCostLogs:
@pytest.mark.asyncio
async def test_returns_logs_and_total(self):
count_rows = [{"cnt": 1}]
log_rows = [
{
"id": "log-1",
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
"user_id": "u1",
"email": "a@b.com",
"graph_exec_id": "g1",
"node_exec_id": "n1",
"block_name": "TestBlock",
"provider": "openai",
"tracking_type": "tokens",
"cost_microdollars": 5000,
"input_tokens": 100,
"output_tokens": 50,
"duration": 1.5,
"model": "gpt-4",
}
]
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
row = _make_prisma_log_row(0, user_email="a@b.com")
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=1)
mock_actions.find_many = AsyncMock(return_value=[row])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs(page=1, page_size=10)
assert total == 1
assert len(logs) == 1
assert logs[0].id == "log-1"
assert logs[0].id == "log-0"
assert logs[0].provider == "openai"
assert logs[0].model == "gpt-4"
@pytest.mark.asyncio
async def test_returns_empty_when_no_data(self):
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=0)
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs()
assert total == 0
assert logs == []
@pytest.mark.asyncio
async def test_pagination_offset(self):
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs(page=3, page_size=25)
assert total == 100
second_call_args = mock_query.call_args_list[1][0]
assert 25 in second_call_args # page_size
assert 50 in second_call_args # offset = (3-1) * 25
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=100)
mock_actions.find_many = AsyncMock(return_value=[])
@pytest.mark.asyncio
async def test_empty_count_returns_zero(self):
mock_query = AsyncMock(side_effect=[[], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs()
assert total == 0
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs(page=3, page_size=25)
assert total == 100
find_many_call = mock_actions.find_many.call_args[1]
assert find_many_call["take"] == 25
assert find_many_call["skip"] == 50 # (3-1) * 25
@pytest.mark.asyncio
async def test_explicit_start_skips_default(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=0)
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs(start=start)
assert total == 0
def _make_log_row(i: int = 0) -> dict:
return {
"id": f"log-{i}",
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
"user_id": "u1",
"email": None,
"graph_exec_id": None,
"node_exec_id": None,
"block_name": "TestBlock",
"provider": "openai",
"tracking_type": "tokens",
"cost_microdollars": 1000,
"input_tokens": 10,
"output_tokens": 5,
"duration": 0.5,
"model": "gpt-4",
"cache_read_tokens": None,
"cache_creation_tokens": None,
}
where = mock_actions.count.call_args[1]["where"]
# start provided — should appear in the where filter
assert "createdAt" in where
class TestGetPlatformCostLogsForExport:
@pytest.mark.asyncio
async def test_returns_logs_not_truncated(self):
rows = [_make_log_row(0)]
mock_query = AsyncMock(return_value=rows)
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
row = _make_prisma_log_row(0)
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[row])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, truncated = await get_platform_cost_logs_for_export()
assert len(logs) == 1
assert truncated is False
assert logs[0].id == "log-0"
@pytest.mark.asyncio
async def test_returns_empty_not_truncated(self):
mock_query = AsyncMock(return_value=[])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, truncated = await get_platform_cost_logs_for_export()
assert logs == []
assert truncated is False
@pytest.mark.asyncio
async def test_truncates_at_export_max_rows(self):
rows = [_make_log_row(i) for i in range(3)]
mock_query = AsyncMock(return_value=rows)
with patch(
"backend.data.platform_cost.query_raw_with_schema", new=mock_query
), patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2):
rows = [_make_prisma_log_row(i) for i in range(3)]
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=rows)
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2),
):
logs, truncated = await get_platform_cost_logs_for_export()
assert len(logs) == 2
assert truncated is True
@pytest.mark.asyncio
async def test_passes_model_block_tracking_filters(self):
mock_query = AsyncMock(return_value=[])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
await get_platform_cost_logs_for_export(
model="gpt-4", block_name="LLMBlock", tracking_type="tokens"
)
call_args = mock_query.call_args[0]
assert "gpt-4" in call_args
assert "LLMBlock" in call_args
assert "tokens" in call_args
where = mock_actions.find_many.call_args[1]["where"]
assert where.get("model") == "gpt-4"
assert where.get("trackingType") == "tokens"
# blockName uses a dict filter for case-insensitive match
assert "blockName" in where
@pytest.mark.asyncio
async def test_maps_cache_tokens(self):
row = _make_log_row(0)
row["cache_read_tokens"] = 50
row["cache_creation_tokens"] = 25
mock_query = AsyncMock(return_value=[row])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
row = _make_prisma_log_row(0)
row.configure_mock(**{"cacheReadTokens": 50, "cacheCreationTokens": 25})
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[row])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, _ = await get_platform_cost_logs_for_export()
assert logs[0].cache_read_tokens == 50
assert logs[0].cache_creation_tokens == 25
@pytest.mark.asyncio
async def test_explicit_start_skips_default(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(return_value=[])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, truncated = await get_platform_cost_logs_for_export(start=start)
assert logs == []
assert truncated is False
where = mock_actions.find_many.call_args[1]["where"]
assert "createdAt" in where

View File

@@ -1605,56 +1605,6 @@
}
}
},
"/api/chat/sessions/{session_id}/messages/pending": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Queue Pending Message",
"description": "Queue a new user message into an in-flight copilot turn.\n\nWhen a user sends a follow-up message while a turn is still\nstreaming, we don't want to block them or start a separate turn —\nthis endpoint appends the message to a per-session pending buffer.\nThe executor currently running the turn (baseline path) drains the\nbuffer between tool-call rounds and appends the message to the\nconversation before the next LLM call. On the SDK path the buffer\nis drained at the *start* of the next turn (the long-lived\n``ClaudeSDKClient.receive_response`` iterator returns after a\n``ResultMessage`` so there is no safe point to inject mid-stream\ninto an existing connection).\n\nReturns 202. Enforces the same per-user daily/weekly token rate\nlimit as the regular ``/stream`` endpoint so a client can't bypass\nit by batching messages through here.",
"operationId": "postV2QueuePendingMessage",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Session Id" }
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueuePendingMessageRequest"
}
}
}
},
"responses": {
"202": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueuePendingMessageResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/sessions/{session_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -12718,57 +12668,6 @@
"required": ["providers", "pagination"],
"title": "ProviderResponse"
},
"QueuePendingMessageRequest": {
"properties": {
"message": {
"type": "string",
"maxLength": 16000,
"minLength": 1,
"title": "Message"
},
"context": {
"anyOf": [
{
"additionalProperties": { "type": "string" },
"type": "object"
},
{ "type": "null" }
],
"title": "Context",
"description": "Optional page context: expected keys are 'url' and 'content'."
},
"file_ids": {
"anyOf": [
{
"items": { "type": "string" },
"type": "array",
"maxItems": 20
},
{ "type": "null" }
],
"title": "File Ids"
}
},
"additionalProperties": false,
"type": "object",
"required": ["message"],
"title": "QueuePendingMessageRequest",
"description": "Request model for queueing a message into an in-flight turn.\n\nUnlike ``StreamChatRequest`` this endpoint does **not** start a new\nturn — the message is appended to a per-session pending buffer that\nthe executor currently processing the turn will drain between tool\nrounds."
},
"QueuePendingMessageResponse": {
"properties": {
"buffer_length": { "type": "integer", "title": "Buffer Length" },
"max_buffer_length": {
"type": "integer",
"title": "Max Buffer Length"
},
"turn_in_flight": { "type": "boolean", "title": "Turn In Flight" }
},
"type": "object",
"required": ["buffer_length", "max_buffer_length", "turn_in_flight"],
"title": "QueuePendingMessageResponse",
"description": "Response for the pending-message endpoint.\n\n- ``buffer_length``: how many messages are now in the session's\n pending buffer (after this push)\n- ``max_buffer_length``: the per-session cap (server-side constant)\n- ``turn_in_flight``: ``True`` if a copilot turn was running when\n we checked — purely informational for UX feedback. Even when\n ``False`` the message is still queued: the next turn drains it."
},
"RateLimitResetResponse": {
"properties": {
"success": { "type": "boolean", "title": "Success" },

Binary file not shown.

Before

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 85 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 89 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB