diff --git a/autogpt_platform/analytics/queries/platform_cost_log.sql b/autogpt_platform/analytics/queries/platform_cost_log.sql new file mode 100644 index 0000000000..b3e33d7515 --- /dev/null +++ b/autogpt_platform/analytics/queries/platform_cost_log.sql @@ -0,0 +1,100 @@ +-- ============================================================= +-- View: analytics.platform_cost_log +-- Looker source alias: ds115 | Charts: 0 +-- ============================================================= +-- DESCRIPTION +-- One row per platform cost log entry (last 90 days). +-- Tracks real API spend at the call level: provider, model, +-- token counts (including Anthropic cache tokens), cost in +-- microdollars, and the block/execution that incurred the cost. +-- Joins the User table to provide email for per-user breakdowns. +-- +-- SOURCE TABLES +-- platform.PlatformCostLog — Per-call cost records +-- platform.User — User email +-- +-- OUTPUT COLUMNS +-- id TEXT Log entry UUID +-- createdAt TIMESTAMPTZ When the cost was recorded +-- userId TEXT User who incurred the cost (nullable) +-- email TEXT User email (nullable) +-- graphExecId TEXT Graph execution UUID (nullable) +-- nodeExecId TEXT Node execution UUID (nullable) +-- blockName TEXT Block that made the API call (nullable) +-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic') +-- model TEXT Model name (nullable) +-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc. +-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD) +-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000) +-- inputTokens INT Prompt/input tokens (nullable) +-- outputTokens INT Completion/output tokens (nullable) +-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable) +-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable) +-- totalTokens INT inputTokens + outputTokens (nullable if either is null) +-- duration FLOAT API call duration in seconds (nullable) +-- +-- WINDOW +-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days) +-- +-- EXAMPLE QUERIES +-- -- Total spend by provider (last 90 days) +-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- GROUP BY 1 ORDER BY total_usd DESC; +-- +-- -- Spend by model +-- SELECT provider, model, SUM("costUsd") AS total_usd, +-- SUM("inputTokens") AS input_tokens, +-- SUM("outputTokens") AS output_tokens +-- FROM analytics.platform_cost_log +-- WHERE model IS NOT NULL +-- GROUP BY 1, 2 ORDER BY total_usd DESC; +-- +-- -- Top 20 users by spend +-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- WHERE "userId" IS NOT NULL +-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20; +-- +-- -- Daily spend trend +-- SELECT DATE_TRUNC('day', "createdAt") AS day, +-- SUM("costUsd") AS daily_usd, +-- COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- GROUP BY 1 ORDER BY 1; +-- +-- -- Cache hit rate for Anthropic (cache reads vs total reads) +-- SELECT DATE_TRUNC('day', "createdAt") AS day, +-- SUM("cacheReadTokens")::float / +-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate +-- FROM analytics.platform_cost_log +-- WHERE provider = 'anthropic' +-- GROUP BY 1 ORDER BY 1; +-- ============================================================= + +SELECT + p."id" AS id, + p."createdAt" AS createdAt, + p."userId" AS userId, + u."email" AS email, + p."graphExecId" AS graphExecId, + p."nodeExecId" AS nodeExecId, + p."blockName" AS blockName, + p."provider" AS provider, + p."model" AS model, + p."trackingType" AS trackingType, + p."costMicrodollars" AS costMicrodollars, + p."costMicrodollars"::float / 1000000.0 AS costUsd, + p."inputTokens" AS inputTokens, + p."outputTokens" AS outputTokens, + p."cacheReadTokens" AS cacheReadTokens, + p."cacheCreationTokens" AS cacheCreationTokens, + CASE + WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL + THEN p."inputTokens" + p."outputTokens" + ELSE NULL + END AS totalTokens, + p."duration" AS duration +FROM platform."PlatformCostLog" p +LEFT JOIN platform."User" u ON u."id" = p."userId" +WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days' diff --git a/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes.py b/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes.py index fcf13dc9c7..70e7772790 100644 --- a/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes.py +++ b/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes.py @@ -10,6 +10,7 @@ from backend.data.platform_cost import ( PlatformCostDashboard, get_platform_cost_dashboard, get_platform_cost_logs, + get_platform_cost_logs_for_export, ) from backend.util.models import Pagination @@ -39,6 +40,9 @@ async def get_cost_dashboard( end: datetime | None = Query(None), provider: str | None = Query(None), user_id: str | None = Query(None), + model: str | None = Query(None), + block_name: str | None = Query(None), + tracking_type: str | None = Query(None), ): logger.info("Admin %s fetching platform cost dashboard", admin_user_id) return await get_platform_cost_dashboard( @@ -46,6 +50,9 @@ async def get_cost_dashboard( end=end, provider=provider, user_id=user_id, + model=model, + block_name=block_name, + tracking_type=tracking_type, ) @@ -62,6 +69,9 @@ async def get_cost_logs( user_id: str | None = Query(None), page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=200), + model: str | None = Query(None), + block_name: str | None = Query(None), + tracking_type: str | None = Query(None), ): logger.info("Admin %s fetching platform cost logs", admin_user_id) logs, total = await get_platform_cost_logs( @@ -71,6 +81,9 @@ async def get_cost_logs( user_id=user_id, page=page, page_size=page_size, + model=model, + block_name=block_name, + tracking_type=tracking_type, ) total_pages = (total + page_size - 1) // page_size return PlatformCostLogsResponse( @@ -82,3 +95,41 @@ async def get_cost_logs( page_size=page_size, ), ) + + +class PlatformCostExportResponse(BaseModel): + logs: list[CostLogRow] + total_rows: int + truncated: bool + + +@router.get( + "/logs/export", + response_model=PlatformCostExportResponse, + summary="Export Platform Cost Logs", +) +async def export_cost_logs( + admin_user_id: str = Security(get_user_id), + start: datetime | None = Query(None), + end: datetime | None = Query(None), + provider: str | None = Query(None), + user_id: str | None = Query(None), + model: str | None = Query(None), + block_name: str | None = Query(None), + tracking_type: str | None = Query(None), +): + logger.info("Admin %s exporting platform cost logs", admin_user_id) + logs, truncated = await get_platform_cost_logs_for_export( + start=start, + end=end, + provider=provider, + user_id=user_id, + model=model, + block_name=block_name, + tracking_type=tracking_type, + ) + return PlatformCostExportResponse( + logs=logs, + total_rows=len(logs), + truncated=truncated, + ) diff --git a/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes_test.py index 224a754487..8cfc0e47b5 100644 --- a/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/admin/platform_cost_routes_test.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone from unittest.mock import AsyncMock import fastapi @@ -6,7 +7,7 @@ import pytest import pytest_mock from autogpt_libs.auth.jwt_utils import get_jwt_payload -from backend.data.platform_cost import PlatformCostDashboard +from backend.data.platform_cost import CostLogRow, PlatformCostDashboard from .platform_cost_routes import router as platform_cost_router @@ -190,3 +191,101 @@ def test_get_dashboard_repeated_requests( assert r2.status_code == 200 assert r1.json()["total_cost_microdollars"] == 42 assert r2.json()["total_cost_microdollars"] == 42 + + +def _make_cost_log_row() -> CostLogRow: + return CostLogRow( + id="log-1", + created_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + user_id="user-1", + email="u***@example.com", + graph_exec_id="graph-1", + node_exec_id="node-1", + block_name="LlmCallBlock", + provider="anthropic", + tracking_type="token", + cost_microdollars=500, + input_tokens=100, + output_tokens=50, + cache_read_tokens=10, + cache_creation_tokens=5, + duration=1.5, + model="claude-3-5-sonnet-20241022", + ) + + +def test_export_logs_success( + mocker: pytest_mock.MockerFixture, +) -> None: + row = _make_cost_log_row() + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export", + AsyncMock(return_value=([row], False)), + ) + + response = client.get("/platform-costs/logs/export") + assert response.status_code == 200 + data = response.json() + assert data["total_rows"] == 1 + assert data["truncated"] is False + assert len(data["logs"]) == 1 + assert data["logs"][0]["cache_read_tokens"] == 10 + assert data["logs"][0]["cache_creation_tokens"] == 5 + + +def test_export_logs_truncated( + mocker: pytest_mock.MockerFixture, +) -> None: + rows = [_make_cost_log_row() for _ in range(3)] + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export", + AsyncMock(return_value=(rows, True)), + ) + + response = client.get("/platform-costs/logs/export") + assert response.status_code == 200 + data = response.json() + assert data["total_rows"] == 3 + assert data["truncated"] is True + + +def test_export_logs_with_filters( + mocker: pytest_mock.MockerFixture, +) -> None: + mock_export = AsyncMock(return_value=([], False)) + mocker.patch( + "backend.api.features.admin.platform_cost_routes.get_platform_cost_logs_for_export", + mock_export, + ) + + response = client.get( + "/platform-costs/logs/export", + params={ + "provider": "anthropic", + "model": "claude-3-5-sonnet-20241022", + "block_name": "LlmCallBlock", + "tracking_type": "token", + }, + ) + assert response.status_code == 200 + mock_export.assert_called_once() + call_kwargs = mock_export.call_args.kwargs + assert call_kwargs["provider"] == "anthropic" + assert call_kwargs["model"] == "claude-3-5-sonnet-20241022" + assert call_kwargs["block_name"] == "LlmCallBlock" + assert call_kwargs["tracking_type"] == "token" + + +def test_export_logs_requires_admin() -> None: + import fastapi + from fastapi import HTTPException + + def reject_jwt(request: fastapi.Request): + raise HTTPException(status_code=401, detail="Not authenticated") + + app.dependency_overrides[get_jwt_payload] = reject_jwt + try: + response = client.get("/platform-costs/logs/export") + assert response.status_code == 401 + finally: + app.dependency_overrides.clear() diff --git a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py new file mode 100644 index 0000000000..7a7ec518c6 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py @@ -0,0 +1,294 @@ +"""Tests for subscription tier API endpoints.""" + +from unittest.mock import AsyncMock, Mock + +import fastapi +import fastapi.testclient +import pytest_mock +from autogpt_libs.auth.jwt_utils import get_jwt_payload +from prisma.enums import SubscriptionTier + +from .v1 import v1_router + +app = fastapi.FastAPI() +app.include_router(v1_router) + +client = fastapi.testclient.TestClient(app) + +TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a" + + +def setup_auth(app: fastapi.FastAPI): + def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]: + return {"sub": TEST_USER_ID, "role": "user", "email": "test@example.com"} + + app.dependency_overrides[get_jwt_payload] = override_get_jwt_payload + + +def teardown_auth(app: fastapi.FastAPI): + app.dependency_overrides.clear() + + +def test_get_subscription_status_pro( + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription returns PRO tier with Stripe price for a PRO user.""" + setup_auth(app) + try: + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mock_price = Mock() + mock_price.unit_amount = 1999 # $19.99 + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return "price_pro" if tier == SubscriptionTier.PRO else None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1.stripe.Price.retrieve", + return_value=mock_price, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "PRO" + assert data["monthly_cost"] == 1999 + assert data["tier_costs"]["PRO"] == 1999 + assert data["tier_costs"]["BUSINESS"] == 0 + assert data["tier_costs"]["FREE"] == 0 + finally: + teardown_auth(app) + + +def test_get_subscription_status_defaults_to_free( + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription when subscription_tier is None defaults to FREE.""" + setup_auth(app) + try: + mock_user = Mock() + mock_user.subscription_tier = None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == SubscriptionTier.FREE.value + assert data["monthly_cost"] == 0 + assert data["tier_costs"] == { + "FREE": 0, + "PRO": 0, + "BUSINESS": 0, + "ENTERPRISE": 0, + } + finally: + teardown_auth(app) + + +def test_update_subscription_tier_free_no_payment( + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription to FREE tier when payment disabled skips Stripe.""" + setup_auth(app) + try: + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_feature_disabled(*args, **kwargs): + return False + + async def mock_set_tier(*args, **kwargs): + pass + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_disabled, + ) + mocker.patch( + "backend.api.features.v1.set_subscription_tier", + side_effect=mock_set_tier, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 200 + assert response.json()["url"] == "" + finally: + teardown_auth(app) + + +def test_update_subscription_tier_paid_beta_user( + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription for paid tier when payment disabled sets tier directly.""" + setup_auth(app) + try: + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE + + async def mock_feature_disabled(*args, **kwargs): + return False + + async def mock_set_tier(*args, **kwargs): + pass + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_disabled, + ) + mocker.patch( + "backend.api.features.v1.set_subscription_tier", + side_effect=mock_set_tier, + ) + + response = client.post("/credits/subscription", json={"tier": "PRO"}) + + assert response.status_code == 200 + assert response.json()["url"] == "" + finally: + teardown_auth(app) + + +def test_update_subscription_tier_paid_requires_urls( + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription for paid tier without success/cancel URLs returns 422.""" + setup_auth(app) + try: + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE + + async def mock_feature_enabled(*args, **kwargs): + return True + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + + response = client.post("/credits/subscription", json={"tier": "PRO"}) + + assert response.status_code == 422 + finally: + teardown_auth(app) + + +def test_update_subscription_tier_creates_checkout( + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription creates Stripe Checkout Session for paid upgrade.""" + setup_auth(app) + try: + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE + + async def mock_feature_enabled(*args, **kwargs): + return True + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + return_value="https://checkout.stripe.com/pay/cs_test_abc", + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": "https://app.example.com/success", + "cancel_url": "https://app.example.com/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_abc" + finally: + teardown_auth(app) + + +def test_update_subscription_tier_free_with_payment_cancels_stripe( + mocker: pytest_mock.MockFixture, +) -> None: + """Downgrading to FREE cancels active Stripe subscription when payment is enabled.""" + setup_auth(app) + try: + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_feature_enabled(*args, **kwargs): + return True + + mock_cancel = mocker.patch( + "backend.api.features.v1.cancel_stripe_subscription", + new_callable=AsyncMock, + ) + + async def mock_set_tier(*args, **kwargs): + pass + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.set_subscription_tier", + side_effect=mock_set_tier, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 200 + mock_cancel.assert_awaited_once() + finally: + teardown_auth(app) diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index d208114f95..5767cebd94 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -5,7 +5,7 @@ import time import uuid from collections import defaultdict from datetime import datetime, timezone -from typing import Annotated, Any, Sequence, get_args +from typing import Annotated, Any, Literal, Sequence, get_args import pydantic import stripe @@ -24,6 +24,7 @@ from fastapi import ( UploadFile, ) from fastapi.concurrency import run_in_threadpool +from prisma.enums import SubscriptionTier from pydantic import BaseModel from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND from typing_extensions import Optional, TypedDict @@ -50,9 +51,14 @@ from backend.data.credit import ( RefundRequest, TransactionHistory, UserCredit, + cancel_stripe_subscription, + create_subscription_checkout, get_auto_top_up, + get_subscription_price_id, get_user_credit_model, set_auto_top_up, + set_subscription_tier, + sync_subscription_from_stripe, ) from backend.data.graph import GraphSettings from backend.data.model import CredentialsMetaInput, UserOnboarding @@ -661,9 +667,12 @@ async def configure_user_auto_top_up( raise HTTPException(status_code=422, detail=str(e)) raise - await set_auto_top_up( - user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount) - ) + try: + await set_auto_top_up( + user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount) + ) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) return "Auto top-up settings updated" @@ -679,6 +688,115 @@ async def get_user_auto_top_up( return await get_auto_top_up(user_id) +class SubscriptionTierRequest(BaseModel): + tier: Literal["FREE", "PRO", "BUSINESS"] + success_url: str = "" + cancel_url: str = "" + + +class SubscriptionCheckoutResponse(BaseModel): + url: str + + +class SubscriptionStatusResponse(BaseModel): + tier: str + monthly_cost: int + tier_costs: dict[str, int] + + +@v1_router.get( + path="/credits/subscription", + summary="Get subscription tier, current cost, and all tier costs", + operation_id="getSubscriptionStatus", + tags=["credits"], + dependencies=[Security(requires_user)], +) +async def get_subscription_status( + user_id: Annotated[str, Security(get_user_id)], +) -> SubscriptionStatusResponse: + user = await get_user_by_id(user_id) + tier = user.subscription_tier or SubscriptionTier.FREE + + paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS] + price_ids = await asyncio.gather( + *[get_subscription_price_id(t) for t in paid_tiers] + ) + + tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0} + for t, price_id in zip(paid_tiers, price_ids): + cost = 0 + if price_id: + try: + price = await run_in_threadpool(stripe.Price.retrieve, price_id) + cost = price.unit_amount or 0 + except stripe.StripeError: + pass + tier_costs[t.value] = cost + + return SubscriptionStatusResponse( + tier=tier.value, + monthly_cost=tier_costs.get(tier.value, 0), + tier_costs=tier_costs, + ) + + +@v1_router.post( + path="/credits/subscription", + summary="Start a Stripe Checkout session to upgrade subscription tier", + operation_id="updateSubscriptionTier", + tags=["credits"], + dependencies=[Security(requires_user)], +) +async def update_subscription_tier( + request: SubscriptionTierRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> SubscriptionCheckoutResponse: + # Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type. + tier = SubscriptionTier(request.tier) + + # ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users. + user = await get_user_by_id(user_id) + if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE: + raise HTTPException( + status_code=403, + detail="ENTERPRISE subscription changes must be managed by an administrator", + ) + + payment_enabled = await is_feature_enabled( + Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False + ) + + # Downgrade to FREE: cancel active Stripe subscription, then update the DB tier. + if tier == SubscriptionTier.FREE: + if payment_enabled: + await cancel_stripe_subscription(user_id) + await set_subscription_tier(user_id, tier) + return SubscriptionCheckoutResponse(url="") + + # Beta users (payment not enabled) → update tier directly without Stripe. + if not payment_enabled: + await set_subscription_tier(user_id, tier) + return SubscriptionCheckoutResponse(url="") + + # Paid upgrade → create Stripe Checkout Session. + if not request.success_url or not request.cancel_url: + raise HTTPException( + status_code=422, + detail="success_url and cancel_url are required for paid tier upgrades", + ) + try: + url = await create_subscription_checkout( + user_id=user_id, + tier=tier, + success_url=request.success_url, + cancel_url=request.cancel_url, + ) + except (ValueError, stripe.StripeError) as e: + raise HTTPException(status_code=422, detail=str(e)) + + return SubscriptionCheckoutResponse(url=url) + + @v1_router.post( path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"] ) @@ -709,6 +827,13 @@ async def stripe_webhook(request: Request): ): await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"]) + if event["type"] in ( + "customer.subscription.created", + "customer.subscription.updated", + "customer.subscription.deleted", + ): + await sync_subscription_from_stripe(event["data"]["object"]) + if event["type"] == "charge.dispute.created": await UserCredit().handle_dispute(event["data"]["object"]) diff --git a/autogpt_platform/backend/backend/blocks/ai_condition.py b/autogpt_platform/backend/backend/blocks/ai_condition.py index 6d62d4ab77..db8c023b99 100644 --- a/autogpt_platform/backend/backend/blocks/ai_condition.py +++ b/autogpt_platform/backend/backend/blocks/ai_condition.py @@ -207,6 +207,9 @@ class AIConditionBlock(AIBlockBase): NodeExecutionStats( input_token_count=response.prompt_tokens, output_token_count=response.completion_tokens, + cache_read_token_count=response.cache_read_tokens, + cache_creation_token_count=response.cache_creation_tokens, + provider_cost=response.provider_cost, ) ) self.prompt = response.prompt diff --git a/autogpt_platform/backend/backend/blocks/ai_condition_test.py b/autogpt_platform/backend/backend/blocks/ai_condition_test.py index babb1eb4cf..5520963682 100644 --- a/autogpt_platform/backend/backend/blocks/ai_condition_test.py +++ b/autogpt_platform/backend/backend/blocks/ai_condition_test.py @@ -47,7 +47,13 @@ def _make_input(**overrides) -> AIConditionBlock.Input: return AIConditionBlock.Input(**defaults) -def _mock_llm_response(response_text: str) -> LLMResponse: +def _mock_llm_response( + response_text: str, + *, + cache_read_tokens: int = 0, + cache_creation_tokens: int = 0, + provider_cost: float | None = None, +) -> LLMResponse: return LLMResponse( raw_response="", prompt=[], @@ -56,6 +62,9 @@ def _mock_llm_response(response_text: str) -> LLMResponse: prompt_tokens=10, completion_tokens=5, reasoning=None, + cache_read_tokens=cache_read_tokens, + cache_creation_tokens=cache_creation_tokens, + provider_cost=provider_cost, ) @@ -145,3 +154,35 @@ class TestExceptionPropagation: input_data = _make_input() with pytest.raises(RuntimeError, match="LLM provider error"): await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS) + + +# --------------------------------------------------------------------------- +# Regression: cache tokens and provider_cost must be propagated to stats +# --------------------------------------------------------------------------- + + +class TestCacheTokenPropagation: + @pytest.mark.asyncio + async def test_cache_tokens_propagated_to_stats( + self, monkeypatch: pytest.MonkeyPatch + ): + """cache_read_tokens and cache_creation_tokens must be forwarded to + NodeExecutionStats so that usage dashboards count cached tokens.""" + block = AIConditionBlock() + + async def spy_llm(**kwargs): + return _mock_llm_response( + "true", + cache_read_tokens=7, + cache_creation_tokens=3, + provider_cost=0.0012, + ) + + monkeypatch.setattr(block, "llm_call", spy_llm) + + input_data = _make_input() + await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS) + + assert block.execution_stats.cache_read_token_count == 7 + assert block.execution_stats.cache_creation_token_count == 3 + assert block.execution_stats.provider_cost == 0.0012 diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index 1e2ca23c37..52e32feb13 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -738,18 +738,20 @@ class LLMResponse(BaseModel): tool_calls: Optional[List[ToolContentBlock]] | None prompt_tokens: int completion_tokens: int + cache_read_tokens: int = 0 + cache_creation_tokens: int = 0 reasoning: Optional[str] = None provider_cost: float | None = None def convert_openai_tool_fmt_to_anthropic( openai_tools: list[dict] | None = None, -) -> Iterable[ToolParam] | anthropic.Omit: +) -> Iterable[ToolParam] | anthropic.NotGiven: """ Convert OpenAI tool format to Anthropic tool format. """ if not openai_tools or len(openai_tools) == 0: - return anthropic.omit + return anthropic.NOT_GIVEN anthropic_tools = [] for tool in openai_tools: @@ -885,6 +887,21 @@ async def llm_call( provider = llm_model.metadata.provider context_window = llm_model.context_window + # Transparent OpenRouter routing for Anthropic models: when an OpenRouter API key + # is configured, route direct-Anthropic models through OpenRouter instead. This + # gives us the x-total-cost header for free, so provider_cost is always populated + # without manual token-rate arithmetic. + or_key = settings.secrets.open_router_api_key + or_model_id: str | None = None + if provider == "anthropic" and or_key: + provider = "open_router" + credentials = APIKeyCredentials( + provider=ProviderName.OPEN_ROUTER, + title="OpenRouter (auto)", + api_key=SecretStr(or_key), + ) + or_model_id = f"anthropic/{llm_model.value}" + if compress_prompt_to_fit: result = await compress_context( messages=prompt, @@ -972,6 +989,11 @@ async def llm_call( elif provider == "anthropic": an_tools = convert_openai_tool_fmt_to_anthropic(tools) + # Cache tool definitions alongside the system prompt. + # Placing cache_control on the last tool caches all tool schemas as a + # single prefix — reads cost 10% of normal input tokens. + if isinstance(an_tools, list) and an_tools: + an_tools[-1] = {**an_tools[-1], "cache_control": {"type": "ephemeral"}} system_messages = [p["content"] for p in prompt if p["role"] == "system"] sysprompt = " ".join(system_messages) @@ -994,14 +1016,22 @@ async def llm_call( client = anthropic.AsyncAnthropic( api_key=credentials.api_key.get_secret_value() ) - resp = await client.messages.create( + create_kwargs: dict[str, Any] = dict( model=llm_model.value, - system=sysprompt, messages=messages, max_tokens=max_tokens, tools=an_tools, timeout=600, ) + if sysprompt.strip(): + create_kwargs["system"] = [ + { + "type": "text", + "text": sysprompt, + "cache_control": {"type": "ephemeral"}, + } + ] + resp = await client.messages.create(**create_kwargs) if not resp.content: raise ValueError("No content returned from Anthropic.") @@ -1046,6 +1076,11 @@ async def llm_call( tool_calls=tool_calls, prompt_tokens=resp.usage.input_tokens, completion_tokens=resp.usage.output_tokens, + cache_read_tokens=getattr(resp.usage, "cache_read_input_tokens", None) or 0, + cache_creation_tokens=getattr( + resp.usage, "cache_creation_input_tokens", None + ) + or 0, reasoning=reasoning, ) elif provider == "groq": @@ -1114,7 +1149,7 @@ async def llm_call( "HTTP-Referer": "https://agpt.co", "X-Title": "AutoGPT", }, - model=llm_model.value, + model=or_model_id or llm_model.value, messages=prompt, # type: ignore max_tokens=max_tokens, tools=tools_param, # type: ignore @@ -1443,7 +1478,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): error_feedback_message = "" llm_model = input_data.model - last_attempt_cost: float | None = None + total_provider_cost: float | None = None for retry_count in range(input_data.retry): logger.debug(f"LLM request: {prompt}") @@ -1461,15 +1496,19 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): max_tokens=input_data.max_tokens, ) response_text = llm_response.response - # Merge token counts for every attempt (each call costs tokens). - # provider_cost (actual USD) is tracked separately and only merged - # on success to avoid double-counting across retries. + # Accumulate token counts and provider_cost for every attempt + # (each call costs tokens and USD, regardless of validation outcome). token_stats = NodeExecutionStats( input_token_count=llm_response.prompt_tokens, output_token_count=llm_response.completion_tokens, + cache_read_token_count=llm_response.cache_read_tokens, + cache_creation_token_count=llm_response.cache_creation_tokens, ) self.merge_stats(token_stats) - last_attempt_cost = llm_response.provider_cost + if llm_response.provider_cost is not None: + total_provider_cost = ( + total_provider_cost or 0.0 + ) + llm_response.provider_cost logger.debug(f"LLM attempt-{retry_count} response: {response_text}") if input_data.expected_format: @@ -1538,7 +1577,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): NodeExecutionStats( llm_call_count=retry_count + 1, llm_retry_count=retry_count, - provider_cost=last_attempt_cost, + provider_cost=total_provider_cost, ) ) yield "response", response_obj @@ -1559,7 +1598,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): NodeExecutionStats( llm_call_count=retry_count + 1, llm_retry_count=retry_count, - provider_cost=last_attempt_cost, + provider_cost=total_provider_cost, ) ) yield "response", {"response": response_text} @@ -1591,6 +1630,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): error_feedback_message = f"Error calling LLM: {e}" + # All retries exhausted or user-error break: persist accumulated cost so + # the executor can still charge/report the spend even on failure. + if total_provider_cost is not None: + self.merge_stats(NodeExecutionStats(provider_cost=total_provider_cost)) raise RuntimeError(error_feedback_message) def response_format_instructions( diff --git a/autogpt_platform/backend/backend/blocks/orchestrator.py b/autogpt_platform/backend/backend/blocks/orchestrator.py index 9ab0318165..7e51f43b80 100644 --- a/autogpt_platform/backend/backend/blocks/orchestrator.py +++ b/autogpt_platform/backend/backend/blocks/orchestrator.py @@ -859,7 +859,10 @@ class OrchestratorBlock(Block): NodeExecutionStats( input_token_count=resp.prompt_tokens, output_token_count=resp.completion_tokens, + cache_read_token_count=resp.cache_read_tokens, + cache_creation_token_count=resp.cache_creation_tokens, llm_call_count=1, + provider_cost=resp.provider_cost, ) ) @@ -1635,6 +1638,7 @@ class OrchestratorBlock(Block): conversation: list[dict[str, Any]] = list(prompt) # Start with input prompt total_prompt_tokens = 0 total_completion_tokens = 0 + total_cost_usd: float | None = None sdk_error: Exception | None = None try: @@ -1778,6 +1782,8 @@ class OrchestratorBlock(Block): total_completion_tokens += getattr( sdk_msg.usage, "output_tokens", 0 ) + if sdk_msg.total_cost_usd is not None: + total_cost_usd = sdk_msg.total_cost_usd finally: if pending_task is not None and not pending_task.done(): pending_task.cancel() @@ -1805,12 +1811,17 @@ class OrchestratorBlock(Block): # those stats would under-count resource usage. # llm_call_count=1 is approximate; the SDK manages its own # multi-turn loop and only exposes aggregate usage. - if total_prompt_tokens > 0 or total_completion_tokens > 0: + if ( + total_prompt_tokens > 0 + or total_completion_tokens > 0 + or total_cost_usd is not None + ): self.merge_stats( NodeExecutionStats( input_token_count=total_prompt_tokens, output_token_count=total_completion_tokens, llm_call_count=1, + provider_cost=total_cost_usd, ) ) # Clean up execution-specific working directory. diff --git a/autogpt_platform/backend/backend/blocks/test/test_llm.py b/autogpt_platform/backend/backend/blocks/test/test_llm.py index a6fb1dd448..e8eea20040 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_llm.py +++ b/autogpt_platform/backend/backend/blocks/test/test_llm.py @@ -46,6 +46,110 @@ class TestLLMStatsTracking: assert response.completion_tokens == 20 assert response.response == "Test response" + @pytest.mark.asyncio + async def test_llm_call_anthropic_returns_cache_tokens(self): + """Test that llm_call returns cache read/creation tokens from Anthropic.""" + from pydantic import SecretStr + + import backend.blocks.llm as llm + from backend.data.model import APIKeyCredentials + + anthropic_creds = APIKeyCredentials( + id="test-anthropic-id", + provider="anthropic", + api_key=SecretStr("mock-anthropic-key"), + title="Mock Anthropic key", + expires_at=None, + ) + + mock_content_block = MagicMock() + mock_content_block.type = "text" + mock_content_block.text = "Test anthropic response" + + mock_usage = MagicMock() + mock_usage.input_tokens = 15 + mock_usage.output_tokens = 25 + mock_usage.cache_read_input_tokens = 100 + mock_usage.cache_creation_input_tokens = 50 + + mock_response = MagicMock() + mock_response.content = [mock_content_block] + mock_response.usage = mock_usage + mock_response.stop_reason = "end_turn" + + with ( + patch("anthropic.AsyncAnthropic") as mock_anthropic, + patch("backend.blocks.llm.settings") as mock_settings, + ): + mock_settings.secrets.open_router_api_key = "" + mock_client = AsyncMock() + mock_anthropic.return_value = mock_client + mock_client.messages.create = AsyncMock(return_value=mock_response) + + response = await llm.llm_call( + credentials=anthropic_creds, + llm_model=llm.LlmModel.CLAUDE_3_HAIKU, + prompt=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + assert isinstance(response, llm.LLMResponse) + assert response.prompt_tokens == 15 + assert response.completion_tokens == 25 + assert response.cache_read_tokens == 100 + assert response.cache_creation_tokens == 50 + assert response.response == "Test anthropic response" + + @pytest.mark.asyncio + async def test_anthropic_routes_through_openrouter_when_key_present(self): + """When open_router_api_key is set, Anthropic models route via OpenRouter.""" + from pydantic import SecretStr + + import backend.blocks.llm as llm + from backend.data.model import APIKeyCredentials + + anthropic_creds = APIKeyCredentials( + id="test-anthropic-id", + provider="anthropic", + api_key=SecretStr("mock-anthropic-key"), + title="Mock Anthropic key", + ) + + mock_choice = MagicMock() + mock_choice.message.content = "routed response" + mock_choice.message.tool_calls = None + + mock_usage = MagicMock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 5 + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = mock_usage + + mock_create = AsyncMock(return_value=mock_response) + + with ( + patch("openai.AsyncOpenAI") as mock_openai, + patch("backend.blocks.llm.settings") as mock_settings, + ): + mock_settings.secrets.open_router_api_key = "sk-or-test-key" + mock_client = MagicMock() + mock_openai.return_value = mock_client + mock_client.chat.completions.create = mock_create + + await llm.llm_call( + credentials=anthropic_creds, + llm_model=llm.LlmModel.CLAUDE_3_HAIKU, + prompt=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + # Verify OpenAI client was used (not Anthropic SDK) and model was prefixed + mock_openai.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-3-haiku-20240307" + @pytest.mark.asyncio async def test_ai_structured_response_block_tracks_stats(self): """Test that AIStructuredResponseGeneratorBlock correctly tracks stats.""" @@ -200,12 +304,11 @@ class TestLLMStatsTracking: assert block.execution_stats.llm_retry_count == 1 @pytest.mark.asyncio - async def test_retry_cost_uses_last_attempt_only(self): - """provider_cost is only merged from the final successful attempt. + async def test_retry_cost_accumulates_across_attempts(self): + """provider_cost accumulates across all retry attempts. - Intermediate retry costs are intentionally dropped to avoid - double-counting: the cost of failed attempts is captured in - last_attempt_cost only when the loop eventually succeeds. + Each LLM call incurs a real cost, including failed validation attempts. + The total cost is the sum of all attempts so no billed USD is lost. """ import backend.blocks.llm as llm @@ -253,12 +356,86 @@ class TestLLMStatsTracking: async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS): pass - # Only the final successful attempt's cost is merged - assert block.execution_stats.provider_cost == pytest.approx(0.02) + # provider_cost accumulates across all attempts: $0.01 + $0.02 = $0.03 + assert block.execution_stats.provider_cost == pytest.approx(0.03) # Tokens from both attempts accumulate assert block.execution_stats.input_token_count == 30 assert block.execution_stats.output_token_count == 15 + @pytest.mark.asyncio + async def test_cache_tokens_accumulated_in_stats(self): + """Cache read/creation tokens are tracked per-attempt and accumulated.""" + import backend.blocks.llm as llm + + block = llm.AIStructuredResponseGeneratorBlock() + + async def mock_llm_call(*args, **kwargs): + return llm.LLMResponse( + raw_response="", + prompt=[], + response='{"key1": "v1", "key2": "v2"}', + tool_calls=None, + prompt_tokens=10, + completion_tokens=5, + cache_read_tokens=20, + cache_creation_tokens=8, + reasoning=None, + provider_cost=0.005, + ) + + block.llm_call = mock_llm_call # type: ignore + + input_data = llm.AIStructuredResponseGeneratorBlock.Input( + prompt="Test prompt", + expected_format={"key1": "desc1", "key2": "desc2"}, + model=llm.DEFAULT_LLM_MODEL, + credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore + retry=1, + ) + + with patch("secrets.token_hex", return_value="tok123456"): + async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS): + pass + + assert block.execution_stats.cache_read_token_count == 20 + assert block.execution_stats.cache_creation_token_count == 8 + + @pytest.mark.asyncio + async def test_failure_path_persists_accumulated_cost(self): + """When all retries are exhausted, accumulated provider_cost is preserved.""" + import backend.blocks.llm as llm + + block = llm.AIStructuredResponseGeneratorBlock() + + async def mock_llm_call(*args, **kwargs): + return llm.LLMResponse( + raw_response="", + prompt=[], + response="not valid json at all", + tool_calls=None, + prompt_tokens=10, + completion_tokens=5, + reasoning=None, + provider_cost=0.01, + ) + + block.llm_call = mock_llm_call # type: ignore + + input_data = llm.AIStructuredResponseGeneratorBlock.Input( + prompt="Test prompt", + expected_format={"key1": "desc1"}, + model=llm.DEFAULT_LLM_MODEL, + credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore + retry=2, + ) + + with pytest.raises(RuntimeError): + async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS): + pass + + # Both retry attempts each cost $0.01, total $0.02 + assert block.execution_stats.provider_cost == pytest.approx(0.02) + @pytest.mark.asyncio async def test_ai_text_summarizer_multiple_chunks(self): """Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks.""" @@ -1111,3 +1288,181 @@ class TestExtractOpenRouterCost: def test_returns_none_for_negative_cost(self): response = self._mk_response({"x-total-cost": "-0.005"}) assert llm.extract_openrouter_cost(response) is None + + +class TestAnthropicCacheControl: + """Verify that llm_call attaches cache_control to the system prompt block + and to the last tool definition when calling the Anthropic API.""" + + def _make_anthropic_credentials(self) -> llm.APIKeyCredentials: + from pydantic import SecretStr + + return llm.APIKeyCredentials( + id="test-anthropic-id", + provider="anthropic", + api_key=SecretStr("mock-anthropic-key"), + title="Mock Anthropic key", + expires_at=None, + ) + + @pytest.mark.asyncio + async def test_system_prompt_sent_as_block_with_cache_control(self): + """The system prompt is wrapped in a structured block with cache_control ephemeral.""" + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="hello")] + mock_resp.usage = MagicMock(input_tokens=5, output_tokens=3) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[ + {"role": "system", "content": "You are an assistant."}, + {"role": "user", "content": "Hello"}, + ], + max_tokens=100, + ) + + system_arg = captured_kwargs.get("system") + assert isinstance(system_arg, list), "system should be a list of blocks" + assert len(system_arg) == 1 + block = system_arg[0] + assert block["type"] == "text" + assert block["text"] == "You are an assistant." + assert block.get("cache_control") == {"type": "ephemeral"} + + @pytest.mark.asyncio + async def test_last_tool_gets_cache_control(self): + """cache_control is placed on the last tool in the Anthropic tools list.""" + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="ok")] + mock_resp.usage = MagicMock(input_tokens=10, output_tokens=5) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + tools = [ + { + "type": "function", + "function": { + "name": "tool_a", + "description": "First tool", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, + { + "type": "function", + "function": { + "name": "tool_b", + "description": "Second tool", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, + ] + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[ + {"role": "system", "content": "System."}, + {"role": "user", "content": "Do something"}, + ], + max_tokens=100, + tools=tools, + ) + + an_tools = captured_kwargs.get("tools") + assert isinstance(an_tools, list) + assert len(an_tools) == 2 + assert ( + an_tools[0].get("cache_control") is None + ), "Only last tool gets cache_control" + assert an_tools[-1].get("cache_control") == {"type": "ephemeral"} + + @pytest.mark.asyncio + async def test_no_tools_no_cache_control_on_tools(self): + """When there are no tools, the Anthropic call receives anthropic.NOT_GIVEN for tools.""" + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="ok")] + mock_resp.usage = MagicMock(input_tokens=5, output_tokens=2) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[ + {"role": "system", "content": "System."}, + {"role": "user", "content": "Hello"}, + ], + max_tokens=100, + tools=None, + ) + + tools_arg = captured_kwargs.get("tools") + assert tools_arg is llm.convert_openai_tool_fmt_to_anthropic( + None + ), "Empty tools should pass anthropic.NOT_GIVEN sentinel" + + @pytest.mark.asyncio + async def test_empty_system_prompt_omits_system_key(self): + """When sysprompt is empty, the 'system' key must not be sent to Anthropic. + + Anthropic rejects empty text blocks; the guard in llm_call must ensure + the system argument is omitted entirely when no system messages are present. + """ + mock_resp = MagicMock() + mock_resp.content = [MagicMock(type="text", text="ok")] + mock_resp.usage = MagicMock(input_tokens=3, output_tokens=2) + + captured_kwargs: dict = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_resp + + mock_client = MagicMock() + mock_client.messages.create = fake_create + + credentials = self._make_anthropic_credentials() + + with patch("anthropic.AsyncAnthropic", return_value=mock_client): + await llm.llm_call( + credentials=credentials, + llm_model=llm.LlmModel.CLAUDE_4_6_SONNET, + prompt=[{"role": "user", "content": "Hi"}], + max_tokens=50, + ) + + assert ( + "system" not in captured_kwargs + ), "system must be omitted when sysprompt is empty to avoid Anthropic 400" diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py index e5a7300732..5407972fa1 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py @@ -306,6 +306,9 @@ async def test_output_yielding_with_dynamic_fields(): mock_response.raw_response = {"role": "assistant", "content": "test"} mock_response.prompt_tokens = 100 mock_response.completion_tokens = 50 + mock_response.cache_read_tokens = 0 + mock_response.cache_creation_tokens = 0 + mock_response.provider_cost = None # Mock the LLM call with patch( diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index a8044d80b7..1f1fe42f59 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -27,6 +27,7 @@ from opentelemetry import trace as otel_trace from backend.copilot.config import CopilotMode from backend.copilot.context import get_workspace_manager, set_execution_context +from backend.copilot.db import update_message_content_by_sequence from backend.copilot.graphiti.config import is_enabled_for_user from backend.copilot.model import ( ChatMessage, @@ -52,7 +53,7 @@ from backend.copilot.response_model import ( StreamUsage, ) from backend.copilot.service import ( - _build_system_prompt, + _build_cacheable_system_prompt, _get_openai_client, _update_title_async, config, @@ -69,6 +70,7 @@ from backend.copilot.transcript import ( validate_transcript, ) from backend.copilot.transcript_builder import TranscriptBuilder +from backend.data.understanding import format_understanding_for_prompt from backend.util.exceptions import NotFoundError from backend.util.prompt import ( compress_context, @@ -958,35 +960,34 @@ async def stream_chat_completion_baseline( # Build system prompt only on the first turn to avoid mid-conversation # changes from concurrent chats updating business understanding. is_first_turn = len(session.messages) <= 1 - if is_first_turn: - prompt_task = _build_system_prompt(user_id, has_conversation_history=False) + # Gate context fetch on both first turn AND user message so that assistant- + # role calls (e.g. tool-result submissions) on the first turn don't trigger + # a needless DB lookup for user understanding. + should_inject_user_context = is_first_turn and is_user_message + if should_inject_user_context: + prompt_task = _build_cacheable_system_prompt(user_id) else: - prompt_task = _build_system_prompt(user_id=None, has_conversation_history=True) + prompt_task = _build_cacheable_system_prompt(None) # Run download + prompt build concurrently — both are independent I/O # on the request critical path. if user_id and len(session.messages) > 1: - transcript_covers_prefix, (base_system_prompt, _) = await asyncio.gather( - _load_prior_transcript( - user_id=user_id, - session_id=session_id, - session_msg_count=len(session.messages), - transcript_builder=transcript_builder, - ), - prompt_task, + transcript_covers_prefix, (base_system_prompt, understanding) = ( + await asyncio.gather( + _load_prior_transcript( + user_id=user_id, + session_id=session_id, + session_msg_count=len(session.messages), + transcript_builder=transcript_builder, + ), + prompt_task, + ) ) else: - base_system_prompt, _ = await prompt_task + base_system_prompt, understanding = await prompt_task - # Append user message to transcript. - # Always append when the message is present and is from the user, - # even on duplicate-suppressed retries (is_new_message=False). - # The loaded transcript may be stale (uploaded before the previous - # attempt stored this message), so skipping it would leave the - # transcript without the user turn, creating a malformed - # assistant-after-assistant structure when the LLM reply is added. - if message and is_user_message: - transcript_builder.append_user(content=message) + # Append user message to transcript after context injection below so the + # transcript receives the prefixed message when user context is available. # Generate title for new sessions if is_user_message and not session.title: @@ -1047,6 +1048,48 @@ async def stream_chat_completion_baseline( elif msg.role == "user" and msg.content: openai_messages.append({"role": msg.role, "content": msg.content}) + # Inject user context into the first user message on first turn. + # Done before attachment/URL injection so the context prefix lands at + # the very start of the message content. + # The prefixed content is also stored back into session.messages and the + # transcript so that resumed sessions and the transcript both carry the + # personalisation beyond the first request. + user_message_for_transcript = message + if should_inject_user_context and understanding: + user_ctx = format_understanding_for_prompt(understanding) + prefixed: str | None = None + for msg in openai_messages: + if msg["role"] == "user": + prefixed = ( + f"\n{user_ctx}\n\n\n{msg['content']}" + ) + msg["content"] = prefixed + break + if prefixed is not None: + # Persist the prefixed content so subsequent turns and --resume + # retain the user context. + # The user message was already saved to DB before context injection + # (at ~line 932); update the DB record so the prefixed content + # survives page reload. + for idx, session_msg in enumerate(session.messages): + if session_msg.role == "user": + session_msg.content = prefixed + await update_message_content_by_sequence(session_id, idx, prefixed) + break + user_message_for_transcript = prefixed + else: + logger.warning("[Baseline] No user message found for context injection") + + # Append user message to transcript. + # Always append when the message is present and is from the user, + # even on duplicate-suppressed retries (is_new_message=False). + # The loaded transcript may be stale (uploaded before the previous + # attempt stored this message), so skipping it would leave the + # transcript without the user turn, creating a malformed + # assistant-after-assistant structure when the LLM reply is added. + if message and is_user_message: + transcript_builder.append_user(content=user_message_for_transcript or message) + # --- File attachments (feature parity with SDK path) --- working_dir: str | None = None attachment_hint = "" diff --git a/autogpt_platform/backend/backend/copilot/db.py b/autogpt_platform/backend/backend/copilot/db.py index a1dd93e752..6ab131beed 100644 --- a/autogpt_platform/backend/backend/copilot/db.py +++ b/autogpt_platform/backend/backend/copilot/db.py @@ -498,6 +498,42 @@ async def update_tool_message_content( return False +async def update_message_content_by_sequence( + session_id: str, + sequence: int, + new_content: str, +) -> bool: + """Update the content of a specific message by its sequence number. + + Used to persist content modifications (e.g. user-context prefix injection) + to a message that was already saved to the DB. + + Args: + session_id: The chat session ID. + sequence: The 0-based sequence number of the message to update. + new_content: The new content to set. + + Returns: + True if a message was updated, False otherwise. + """ + try: + result = await PrismaChatMessage.prisma().update_many( + where={"sessionId": session_id, "sequence": sequence}, + data={"content": sanitize_string(new_content)}, + ) + if result == 0: + logger.warning( + f"No message found to update for session {session_id}, sequence {sequence}" + ) + return False + return True + except Exception as e: + logger.error( + f"Failed to update message for session {session_id}, sequence {sequence}: {e}" + ) + return False + + async def set_turn_duration(session_id: str, duration_ms: int) -> None: """Set durationMs on the last assistant message in a session. diff --git a/autogpt_platform/backend/backend/copilot/db_test.py b/autogpt_platform/backend/backend/copilot/db_test.py index 27fa788702..e73249669b 100644 --- a/autogpt_platform/backend/backend/copilot/db_test.py +++ b/autogpt_platform/backend/backend/copilot/db_test.py @@ -14,6 +14,7 @@ from backend.copilot.db import ( PaginatedMessages, get_chat_messages_paginated, set_turn_duration, + update_message_content_by_sequence, ) from backend.copilot.model import ChatMessage as CopilotChatMessage from backend.copilot.model import ChatSession, get_chat_session, upsert_chat_session @@ -386,3 +387,53 @@ async def test_set_turn_duration_no_assistant_message(setup_test_user, test_user assert cached is not None # User message should not have durationMs assert cached.messages[0].duration_ms is None + + +# ---------- update_message_content_by_sequence ---------- + + +@pytest.mark.asyncio +async def test_update_message_content_by_sequence_success(): + """Returns True when update_many reports at least one row updated.""" + with patch.object(PrismaChatMessage, "prisma") as mock_prisma: + mock_prisma.return_value.update_many = AsyncMock(return_value=1) + + result = await update_message_content_by_sequence("sess-1", 0, "new content") + + assert result is True + mock_prisma.return_value.update_many.assert_called_once_with( + where={"sessionId": "sess-1", "sequence": 0}, + data={"content": "new content"}, + ) + + +@pytest.mark.asyncio +async def test_update_message_content_by_sequence_not_found(): + """Returns False and logs a warning when no rows are updated.""" + with ( + patch.object(PrismaChatMessage, "prisma") as mock_prisma, + patch("backend.copilot.db.logger") as mock_logger, + ): + mock_prisma.return_value.update_many = AsyncMock(return_value=0) + + result = await update_message_content_by_sequence("sess-1", 99, "content") + + assert result is False + mock_logger.warning.assert_called_once() + + +@pytest.mark.asyncio +async def test_update_message_content_by_sequence_db_error(): + """Returns False and logs an error when the DB raises an exception.""" + with ( + patch.object(PrismaChatMessage, "prisma") as mock_prisma, + patch("backend.copilot.db.logger") as mock_logger, + ): + mock_prisma.return_value.update_many = AsyncMock( + side_effect=RuntimeError("db error") + ) + + result = await update_message_content_by_sequence("sess-1", 0, "content") + + assert result is False + mock_logger.error.assert_called_once() diff --git a/autogpt_platform/backend/backend/copilot/prompt_cache_test.py b/autogpt_platform/backend/backend/copilot/prompt_cache_test.py new file mode 100644 index 0000000000..7bec927cb5 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/prompt_cache_test.py @@ -0,0 +1,146 @@ +"""Unit tests for the cacheable system prompt building logic. + +These tests verify that _build_cacheable_system_prompt: +- Returns the static _CACHEABLE_SYSTEM_PROMPT when no user_id is given +- Returns the static prompt + understanding when user_id is given +- Falls through to _CACHEABLE_SYSTEM_PROMPT when Langfuse is not configured +- Returns the Langfuse-compiled prompt when Langfuse is configured +- Handles DB errors and Langfuse errors gracefully +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +_SVC = "backend.copilot.service" + + +class TestBuildCacheableSystemPrompt: + @pytest.mark.asyncio + async def test_no_user_id_returns_static_prompt(self): + """When user_id is None, no DB lookup happens and the static prompt is returned.""" + with (patch(f"{_SVC}._is_langfuse_configured", return_value=False),): + from backend.copilot.service import ( + _CACHEABLE_SYSTEM_PROMPT, + _build_cacheable_system_prompt, + ) + + prompt, understanding = await _build_cacheable_system_prompt(None) + + assert prompt == _CACHEABLE_SYSTEM_PROMPT + assert understanding is None + + @pytest.mark.asyncio + async def test_with_user_id_fetches_understanding(self): + """When user_id is provided, understanding is fetched and returned alongside prompt.""" + fake_understanding = MagicMock() + mock_db = MagicMock() + mock_db.get_business_understanding = AsyncMock(return_value=fake_understanding) + + with ( + patch(f"{_SVC}._is_langfuse_configured", return_value=False), + patch(f"{_SVC}.understanding_db", return_value=mock_db), + ): + from backend.copilot.service import ( + _CACHEABLE_SYSTEM_PROMPT, + _build_cacheable_system_prompt, + ) + + prompt, understanding = await _build_cacheable_system_prompt("user-123") + + assert prompt == _CACHEABLE_SYSTEM_PROMPT + assert understanding is fake_understanding + mock_db.get_business_understanding.assert_called_once_with("user-123") + + @pytest.mark.asyncio + async def test_db_error_returns_prompt_with_no_understanding(self): + """When the DB raises an exception, understanding is None and prompt is still returned.""" + mock_db = MagicMock() + mock_db.get_business_understanding = AsyncMock( + side_effect=RuntimeError("db down") + ) + + with ( + patch(f"{_SVC}._is_langfuse_configured", return_value=False), + patch(f"{_SVC}.understanding_db", return_value=mock_db), + ): + from backend.copilot.service import ( + _CACHEABLE_SYSTEM_PROMPT, + _build_cacheable_system_prompt, + ) + + prompt, understanding = await _build_cacheable_system_prompt("user-456") + + assert prompt == _CACHEABLE_SYSTEM_PROMPT + assert understanding is None + + @pytest.mark.asyncio + async def test_langfuse_compiled_prompt_returned(self): + """When Langfuse is configured and returns a prompt, the compiled text is returned.""" + fake_understanding = MagicMock() + mock_db = MagicMock() + mock_db.get_business_understanding = AsyncMock(return_value=fake_understanding) + + langfuse_prompt_text = "You are a Langfuse-sourced assistant." + mock_prompt_obj = MagicMock() + mock_prompt_obj.compile.return_value = langfuse_prompt_text + + mock_langfuse = MagicMock() + mock_langfuse.get_prompt.return_value = mock_prompt_obj + + with ( + patch(f"{_SVC}._is_langfuse_configured", return_value=True), + patch(f"{_SVC}.understanding_db", return_value=mock_db), + patch(f"{_SVC}._get_langfuse", return_value=mock_langfuse), + patch( + f"{_SVC}.asyncio.to_thread", new=AsyncMock(return_value=mock_prompt_obj) + ), + ): + from backend.copilot.service import _build_cacheable_system_prompt + + prompt, understanding = await _build_cacheable_system_prompt("user-789") + + assert prompt == langfuse_prompt_text + assert understanding is fake_understanding + mock_prompt_obj.compile.assert_called_once_with(users_information="") + + @pytest.mark.asyncio + async def test_langfuse_error_falls_back_to_static_prompt(self): + """When Langfuse raises an error, the fallback _CACHEABLE_SYSTEM_PROMPT is used.""" + mock_db = MagicMock() + mock_db.get_business_understanding = AsyncMock(return_value=None) + + with ( + patch(f"{_SVC}._is_langfuse_configured", return_value=True), + patch(f"{_SVC}.understanding_db", return_value=mock_db), + patch( + f"{_SVC}.asyncio.to_thread", + new=AsyncMock(side_effect=RuntimeError("langfuse down")), + ), + ): + from backend.copilot.service import ( + _CACHEABLE_SYSTEM_PROMPT, + _build_cacheable_system_prompt, + ) + + prompt, understanding = await _build_cacheable_system_prompt("user-000") + + assert prompt == _CACHEABLE_SYSTEM_PROMPT + assert understanding is None + + +class TestCacheableSystemPromptContent: + """Smoke-test the _CACHEABLE_SYSTEM_PROMPT constant for key structural requirements.""" + + def test_cacheable_prompt_has_no_placeholder(self): + """The static cacheable prompt must not contain format placeholders.""" + from backend.copilot.service import _CACHEABLE_SYSTEM_PROMPT + + assert "{users_information}" not in _CACHEABLE_SYSTEM_PROMPT + assert "{" not in _CACHEABLE_SYSTEM_PROMPT + + def test_cacheable_prompt_mentions_user_context(self): + """The prompt instructs the model to parse blocks.""" + from backend.copilot.service import _CACHEABLE_SYSTEM_PROMPT + + assert "user_context" in _CACHEABLE_SYSTEM_PROMPT diff --git a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py index 52a1eff5df..fd831214a6 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py @@ -988,7 +988,7 @@ def _make_sdk_patches( dict(return_value=MagicMock(__enter__=MagicMock(), __exit__=MagicMock())), ), ( - f"{_SVC}._build_system_prompt", + f"{_SVC}._build_cacheable_system_prompt", dict(new_callable=AsyncMock, return_value=("system prompt", None)), ), ( diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index c2a60a8ba0..23f8041d53 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -48,6 +48,7 @@ from backend.copilot.transcript import ( ) from backend.copilot.transcript_builder import TranscriptBuilder from backend.data.redis_client import get_redis_async +from backend.data.understanding import format_understanding_for_prompt from backend.executor.cluster_lock import AsyncClusterLock from backend.util.exceptions import NotFoundError from backend.util.settings import Settings @@ -61,6 +62,7 @@ from ..constants import ( is_transient_api_error, ) from ..context import encode_cwd_for_cli +from ..db import update_message_content_by_sequence from ..graphiti.config import is_enabled_for_user from ..model import ( ChatMessage, @@ -85,7 +87,11 @@ from ..response_model import ( StreamToolOutputAvailable, StreamUsage, ) -from ..service import _build_system_prompt, _is_langfuse_configured, _update_title_async +from ..service import ( + _build_cacheable_system_prompt, + _is_langfuse_configured, + _update_title_async, +) from ..token_tracking import persist_and_record_usage from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path @@ -2052,9 +2058,9 @@ async def stream_chat_completion_sdk( ) return None - e2b_sandbox, (base_system_prompt, _), dl = await asyncio.gather( + e2b_sandbox, (base_system_prompt, understanding), dl = await asyncio.gather( _setup_e2b(), - _build_system_prompt(user_id, has_conversation_history=has_history), + _build_cacheable_system_prompt(user_id if not has_history else None), _fetch_transcript(), ) @@ -2285,6 +2291,30 @@ async def stream_chat_completion_sdk( transcript_msg_count, session_id, ) + # On the first turn inject user context into the message instead of the + # system prompt — the system prompt is now static (same for all users) + # so the LLM can cache it across sessions. + # current_message is updated so the transcript and session.messages also + # store the prefixed content, preserving personalisation across turns and + # on --resume. + if not has_history and understanding: + user_ctx = format_understanding_for_prompt(understanding) + prefixed_message = ( + f"\n{user_ctx}\n\n\n{current_message}" + ) + current_message = prefixed_message + query_message = prefixed_message + # Persist the prefixed content so resumed sessions retain the context. + # The user message was already saved to DB before context injection; + # update the DB record so the prefixed content survives page reload + # and --resume (the save at line ~1926 used the un-prefixed content). + for idx, session_msg in enumerate(session.messages): + if session_msg.role == "user": + session_msg.content = prefixed_message + await update_message_content_by_sequence( + session_id, idx, prefixed_message + ) + break # If files are attached, prepare them: images become vision # content blocks in the user message, other files go to sdk_cwd. attachments = await _prepare_file_attachments( diff --git a/autogpt_platform/backend/backend/copilot/service.py b/autogpt_platform/backend/backend/copilot/service.py index fdd6fe24b6..b80e484735 100644 --- a/autogpt_platform/backend/backend/copilot/service.py +++ b/autogpt_platform/backend/backend/copilot/service.py @@ -70,6 +70,21 @@ Your goal is to help users automate tasks by: Be concise, proactive, and action-oriented. Bias toward showing working solutions over lengthy explanations.""" +# Static system prompt for token caching — identical for all users. +# User-specific context is injected into the first user message instead, +# so the system prompt never changes and can be cached across all sessions. +_CACHEABLE_SYSTEM_PROMPT = """You are an AI automation assistant helping users build and run automations. + +Your goal is to help users automate tasks by: +- Understanding their needs and business context +- Building and running working automations +- Delivering tangible value through action, not just explanation + +Be concise, proactive, and action-oriented. Bias toward showing working solutions over lengthy explanations. + +When the user provides a block in their message, use it to personalise your responses. +For users you are meeting for the first time with no context provided, greet them warmly and introduce them to the AutoGPT platform.""" + # --------------------------------------------------------------------------- # Shared helpers (used by SDK service and baseline) @@ -150,6 +165,50 @@ async def _build_system_prompt( return compiled, understanding +async def _build_cacheable_system_prompt( + user_id: str | None, +) -> tuple[str, Any]: + """Build a fully static system prompt suitable for LLM token caching. + + Unlike _build_system_prompt, user-specific context is NOT embedded here. + Callers must inject the returned understanding into the first user message + via format_understanding_for_prompt() so the system prompt stays identical + across all users and sessions, enabling cross-session cache hits. + + Returns: + Tuple of (static_prompt, understanding_object_or_None) + """ + understanding = None + if user_id: + try: + understanding = await understanding_db().get_business_understanding(user_id) + except Exception as e: + logger.warning(f"Failed to fetch business understanding: {e}") + + if _is_langfuse_configured(): + try: + label = ( + None + if settings.config.app_env == AppEnvironment.PRODUCTION + else "latest" + ) + prompt = await asyncio.to_thread( + _get_langfuse().get_prompt, + config.langfuse_prompt_name, + label=label, + cache_ttl_seconds=config.langfuse_prompt_cache_ttl, + ) + # Pass empty string so existing Langfuse templates stay static + compiled = prompt.compile(users_information="") + return compiled, understanding + except Exception as e: + logger.warning( + f"Failed to fetch cacheable prompt from Langfuse, using default: {e}" + ) + + return _CACHEABLE_SYSTEM_PROMPT, understanding + + async def _generate_session_title( message: str, user_id: str | None = None, diff --git a/autogpt_platform/backend/backend/copilot/token_tracking.py b/autogpt_platform/backend/backend/copilot/token_tracking.py index f48749e712..e84b64d449 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking.py @@ -202,6 +202,8 @@ async def persist_and_record_usage( cost_microdollars=cost_microdollars, input_tokens=prompt_tokens, output_tokens=completion_tokens, + cache_read_tokens=cache_read_tokens or None, + cache_creation_tokens=cache_creation_tokens or None, model=model, tracking_type=tracking_type, tracking_amount=tracking_amount, diff --git a/autogpt_platform/backend/backend/copilot/tools/ask_question.py b/autogpt_platform/backend/backend/copilot/tools/ask_question.py index cf0226533e..edd7edf51a 100644 --- a/autogpt_platform/backend/backend/copilot/tools/ask_question.py +++ b/autogpt_platform/backend/backend/copilot/tools/ask_question.py @@ -1,5 +1,6 @@ -"""AskQuestionTool - Ask the user a clarifying question before proceeding.""" +"""AskQuestionTool - Ask the user one or more clarifying questions.""" +import logging from typing import Any from backend.copilot.model import ChatSession @@ -7,14 +8,16 @@ from backend.copilot.model import ChatSession from .base import BaseTool from .models import ClarificationNeededResponse, ClarifyingQuestion, ToolResponseBase +logger = logging.getLogger(__name__) + class AskQuestionTool(BaseTool): - """Ask the user a clarifying question and wait for their answer. + """Ask the user one or more clarifying questions and wait for answers. Use this tool when the user's request is ambiguous and you need more - information before proceeding. Call find_block or other discovery tools - first to ground your question in real platform options, then call this - tool with a concrete question listing those options. + information before proceeding. Call find_block or other discovery tools + first to ground your questions in real platform options, then call this + tool with concrete questions listing those options. """ @property @@ -24,9 +27,9 @@ class AskQuestionTool(BaseTool): @property def description(self) -> str: return ( - "Ask the user a clarifying question. Use when the request is " - "ambiguous and you need to confirm intent, choose between options, " - "or gather missing details before proceeding." + "Ask the user one or more clarifying questions. Use when the " + "request is ambiguous and you need to confirm intent, choose " + "between options, or gather missing details before proceeding." ) @property @@ -34,27 +37,34 @@ class AskQuestionTool(BaseTool): return { "type": "object", "properties": { - "question": { - "type": "string", - "description": ( - "The concrete question to ask the user. Should list " - "real options when applicable." - ), - }, - "options": { + "questions": { "type": "array", - "items": {"type": "string"}, + "items": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The question text.", + }, + "options": { + "type": "array", + "items": {"type": "string"}, + "description": "Options for this question.", + }, + "keyword": { + "type": "string", + "description": "Short label for this question.", + }, + }, + "required": ["question"], + }, "description": ( - "Options for the user to choose from " - "(e.g. ['Email', 'Slack', 'Google Docs'])." + "One or more clarifying questions. Each item has " + "'question' (required), 'options', and 'keyword'." ), }, - "keyword": { - "type": "string", - "description": "Short label identifying what the question is about.", - }, }, - "required": ["question"], + "required": ["questions"], } @property @@ -67,27 +77,61 @@ class AskQuestionTool(BaseTool): session: ChatSession, **kwargs: Any, ) -> ToolResponseBase: - del user_id # unused; required by BaseTool contract - question_raw = kwargs.get("question") - if not isinstance(question_raw, str) or not question_raw.strip(): - raise ValueError("ask_question requires a non-empty 'question' string") - question = question_raw.strip() - raw_options = kwargs.get("options", []) - if not isinstance(raw_options, list): - raw_options = [] - options: list[str] = [str(o) for o in raw_options if o] - raw_keyword = kwargs.get("keyword", "") - keyword: str = str(raw_keyword) if raw_keyword else "" - session_id = session.session_id if session else None + del user_id + raw_questions = kwargs.get("questions", []) + if not isinstance(raw_questions, list) or not raw_questions: + raise ValueError("ask_question requires a non-empty 'questions' array") + + questions = _parse_questions(raw_questions) + if not questions: + raise ValueError( + "ask_question requires at least one valid question in 'questions'" + ) - example = ", ".join(options) if options else None - clarifying_question = ClarifyingQuestion( - question=question, - keyword=keyword, - example=example, - ) return ClarificationNeededResponse( - message=question, - session_id=session_id, - questions=[clarifying_question], + message="; ".join(q.question for q in questions), + session_id=session.session_id if session else None, + questions=questions, ) + + +def _parse_questions(raw: list[Any]) -> list[ClarifyingQuestion]: + """Parse and validate raw question dicts into ClarifyingQuestion objects.""" + return [ + q for idx, item in enumerate(raw) if (q := _parse_one(item, idx)) is not None + ] + + +def _parse_one(item: Any, idx: int) -> ClarifyingQuestion | None: + """Parse a single question item, returning None for invalid entries.""" + if not isinstance(item, dict): + logger.warning("ask_question: skipping non-dict item at index %d", idx) + return None + + text = item.get("question") + if not isinstance(text, str) or not text.strip(): + logger.warning( + "ask_question: skipping item at index %d with missing/empty question", + idx, + ) + return None + + raw_keyword = item.get("keyword") + keyword = ( + str(raw_keyword).strip() + if raw_keyword is not None and str(raw_keyword).strip() + else f"question-{idx}" + ) + + raw_options = item.get("options") + options = ( + [str(o) for o in raw_options if o is not None and str(o).strip()] + if isinstance(raw_options, list) + else [] + ) + + return ClarifyingQuestion( + question=text.strip(), + keyword=keyword, + example=", ".join(options) if options else None, + ) diff --git a/autogpt_platform/backend/backend/copilot/tools/ask_question_test.py b/autogpt_platform/backend/backend/copilot/tools/ask_question_test.py index 607d50e872..9cc4c58025 100644 --- a/autogpt_platform/backend/backend/copilot/tools/ask_question_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/ask_question_test.py @@ -17,83 +17,235 @@ def session() -> ChatSession: return ChatSession.new(user_id="test-user", dry_run=False) +# ── Happy paths ────────────────────────────────────────────────────── + + @pytest.mark.asyncio -async def test_execute_with_options(tool: AskQuestionTool, session: ChatSession): +async def test_single_question(tool: AskQuestionTool, session: ChatSession): result = await tool._execute( user_id=None, session=session, - question="Which channel?", - options=["Email", "Slack", "Google Docs"], - keyword="channel", + questions=[{"question": "Which channel?", "keyword": "channel"}], ) assert isinstance(result, ClarificationNeededResponse) assert result.message == "Which channel?" assert result.session_id == session.session_id assert len(result.questions) == 1 + assert result.questions[0].question == "Which channel?" + assert result.questions[0].keyword == "channel" + +@pytest.mark.asyncio +async def test_single_question_with_options( + tool: AskQuestionTool, session: ChatSession +): + result = await tool._execute( + user_id=None, + session=session, + questions=[ + { + "question": "Which channel?", + "options": ["Email", "Slack", "Google Docs"], + "keyword": "channel", + } + ], + ) + + assert isinstance(result, ClarificationNeededResponse) q = result.questions[0] - assert q.question == "Which channel?" - assert q.keyword == "channel" assert q.example == "Email, Slack, Google Docs" @pytest.mark.asyncio -async def test_execute_without_options(tool: AskQuestionTool, session: ChatSession): +async def test_multiple_questions(tool: AskQuestionTool, session: ChatSession): result = await tool._execute( user_id=None, session=session, - question="What format do you want?", + questions=[ + { + "question": "Which channel?", + "options": ["Email", "Slack"], + "keyword": "channel", + }, + { + "question": "How often?", + "options": ["Daily", "Weekly"], + "keyword": "frequency", + }, + {"question": "Any extra notes?"}, + ], + ) + + assert isinstance(result, ClarificationNeededResponse) + assert len(result.questions) == 3 + assert result.message == "Which channel?; How often?; Any extra notes?" + + assert result.questions[0].keyword == "channel" + assert result.questions[0].example == "Email, Slack" + assert result.questions[1].keyword == "frequency" + assert result.questions[2].keyword == "question-2" + assert result.questions[2].example is None + + +# ── Keyword handling ───────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_missing_keyword_gets_index_fallback( + tool: AskQuestionTool, session: ChatSession +): + result = await tool._execute( + user_id=None, + session=session, + questions=[{"question": "First?"}, {"question": "Second?"}], + ) + + assert isinstance(result, ClarificationNeededResponse) + assert result.questions[0].keyword == "question-0" + assert result.questions[1].keyword == "question-1" + + +@pytest.mark.asyncio +async def test_null_keyword_gets_index_fallback( + tool: AskQuestionTool, session: ChatSession +): + result = await tool._execute( + user_id=None, + session=session, + questions=[{"question": "First?", "keyword": None}], + ) + + assert isinstance(result, ClarificationNeededResponse) + assert result.questions[0].keyword == "question-0" + + +@pytest.mark.asyncio +async def test_duplicate_keywords_preserved( + tool: AskQuestionTool, session: ChatSession +): + """Frontend normalizeClarifyingQuestions() handles dedup.""" + result = await tool._execute( + user_id=None, + session=session, + questions=[ + {"question": "First?", "keyword": "same"}, + {"question": "Second?", "keyword": "same"}, + ], + ) + + assert isinstance(result, ClarificationNeededResponse) + assert result.questions[0].keyword == "same" + assert result.questions[1].keyword == "same" + + +# ── Options filtering ──────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_options_preserves_falsy_strings( + tool: AskQuestionTool, session: ChatSession +): + result = await tool._execute( + user_id=None, + session=session, + questions=[{"question": "Pick", "options": ["0", "1", "2"]}], + ) + + assert isinstance(result, ClarificationNeededResponse) + assert result.questions[0].example == "0, 1, 2" + + +@pytest.mark.asyncio +async def test_options_filters_none_and_empty( + tool: AskQuestionTool, session: ChatSession +): + result = await tool._execute( + user_id=None, + session=session, + questions=[{"question": "Pick", "options": ["Email", "", "Slack", None]}], + ) + + assert isinstance(result, ClarificationNeededResponse) + assert result.questions[0].example == "Email, Slack" + + +@pytest.mark.asyncio +async def test_no_options_gives_none_example( + tool: AskQuestionTool, session: ChatSession +): + result = await tool._execute( + user_id=None, + session=session, + questions=[{"question": "Thoughts?"}], + ) + + assert isinstance(result, ClarificationNeededResponse) + assert result.questions[0].example is None + + +# ── Invalid input handling ─────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_skips_non_dict_items(tool: AskQuestionTool, session: ChatSession): + result = await tool._execute( + user_id=None, + session=session, + questions=["not-a-dict", {"question": "Valid?", "keyword": "v"}], ) assert isinstance(result, ClarificationNeededResponse) - assert result.message == "What format do you want?" assert len(result.questions) == 1 - - q = result.questions[0] - assert q.question == "What format do you want?" - assert q.keyword == "" - assert q.example is None + assert result.questions[0].question == "Valid?" @pytest.mark.asyncio -async def test_execute_with_keyword_only(tool: AskQuestionTool, session: ChatSession): +async def test_skips_empty_question_items(tool: AskQuestionTool, session: ChatSession): result = await tool._execute( user_id=None, session=session, - question="How often should it run?", - keyword="trigger", + questions=[ + {"keyword": "missing-question"}, + {"question": ""}, + {"question": " Valid ", "keyword": "v"}, + ], ) assert isinstance(result, ClarificationNeededResponse) - q = result.questions[0] - assert q.keyword == "trigger" - assert q.example is None + assert len(result.questions) == 1 + assert result.questions[0].question == "Valid" @pytest.mark.asyncio -async def test_execute_rejects_empty_question( - tool: AskQuestionTool, session: ChatSession -): - with pytest.raises(ValueError, match="non-empty"): - await tool._execute(user_id=None, session=session, question="") - - with pytest.raises(ValueError, match="non-empty"): - await tool._execute(user_id=None, session=session, question=" ") +async def test_rejects_all_invalid_items(tool: AskQuestionTool, session: ChatSession): + with pytest.raises(ValueError, match="at least one valid question"): + await tool._execute( + user_id=None, + session=session, + questions=[{"keyword": "no-q"}, "bad"], + ) @pytest.mark.asyncio -async def test_execute_coerces_invalid_options( +async def test_rejects_empty_questions_array( tool: AskQuestionTool, session: ChatSession ): - """LLM may send options as a string instead of a list; should not crash.""" - result = await tool._execute( - user_id=None, - session=session, - question="Pick one", - options="not-a-list", # type: ignore[arg-type] - ) + with pytest.raises(ValueError, match="non-empty"): + await tool._execute(user_id=None, session=session, questions=[]) - assert isinstance(result, ClarificationNeededResponse) - q = result.questions[0] - assert q.example is None + +@pytest.mark.asyncio +async def test_rejects_missing_questions(tool: AskQuestionTool, session: ChatSession): + with pytest.raises(ValueError, match="non-empty"): + await tool._execute(user_id=None, session=session) + + +@pytest.mark.asyncio +async def test_rejects_non_list_questions(tool: AskQuestionTool, session: ChatSession): + with pytest.raises(ValueError, match="non-empty"): + await tool._execute( + user_id=None, + session=session, + questions="not-a-list", + ) diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index 04f91d8d61..0959c15d34 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -10,6 +10,7 @@ from prisma.enums import ( CreditTransactionType, NotificationType, OnboardingStep, + SubscriptionTier, ) from prisma.errors import UniqueViolationError from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance @@ -31,7 +32,7 @@ from backend.data.notifications import NotificationEventModel, RefundRequestData from backend.data.user import get_user_by_id, get_user_email_by_id from backend.notifications.notifications import queue_notification_async from backend.util.exceptions import InsufficientBalanceError -from backend.util.feature_flag import Flag, is_feature_enabled +from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled from backend.util.json import SafeJson, dumps from backend.util.models import Pagination from backend.util.retry import func_retry @@ -1144,10 +1145,12 @@ class BetaUserCredit(UserCredit): if (snapshot_time.year, snapshot_time.month) == (cur_time.year, cur_time.month): return balance + target = self.num_user_credits_refill + try: balance, _ = await self._add_transaction( user_id=user_id, - amount=max(self.num_user_credits_refill - balance, 0), + amount=max(target - balance, 0), transaction_type=CreditTransactionType.GRANT, transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}", metadata=SafeJson({"reason": "Monthly credit refill"}), @@ -1250,6 +1253,33 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig): where={"id": user_id}, data={"topUpConfig": SafeJson(config.model_dump())}, ) + get_user_by_id.cache_delete(user_id) + + +async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None: + """Set the user's subscription tier (used by webhook and admin flows).""" + await User.prisma().update( + where={"id": user_id}, + data={"subscriptionTier": tier}, + ) + get_user_by_id.cache_delete(user_id) + + +async def cancel_stripe_subscription(user_id: str) -> None: + """Cancel all active Stripe subscriptions for a user (called on downgrade to FREE).""" + customer_id = await get_stripe_customer_id(user_id) + subscriptions = stripe.Subscription.list( + customer=customer_id, status="active", limit=10 + ) + for sub in subscriptions.auto_paging_iter(): + try: + stripe.Subscription.cancel(sub["id"]) + except stripe.StripeError: + logger.warning( + "cancel_stripe_subscription: failed to cancel sub %s for user %s", + sub["id"], + user_id, + ) async def get_auto_top_up(user_id: str) -> AutoTopUpConfig: @@ -1261,6 +1291,78 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig: return AutoTopUpConfig.model_validate(user.top_up_config) +async def get_subscription_price_id(tier: SubscriptionTier) -> str | None: + """Return Stripe Price ID for a tier from LaunchDarkly. None = not configured.""" + flag_map = { + SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO, + SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS, + } + flag = flag_map.get(tier) + if flag is None: + return None + price_id = await get_feature_flag_value(flag.value, user_id="", default="") + return price_id if isinstance(price_id, str) and price_id else None + + +async def create_subscription_checkout( + user_id: str, + tier: SubscriptionTier, + success_url: str, + cancel_url: str, +) -> str: + """Create a Stripe Checkout Session for a subscription. Returns the redirect URL.""" + price_id = await get_subscription_price_id(tier) + if not price_id: + raise ValueError(f"Subscription not available for tier {tier.value}") + customer_id = await get_stripe_customer_id(user_id) + session = stripe.checkout.Session.create( + customer=customer_id, + mode="subscription", + line_items=[{"price": price_id, "quantity": 1}], + success_url=success_url, + cancel_url=cancel_url, + subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}}, + ) + return session.url or "" + + +async def sync_subscription_from_stripe(stripe_subscription: dict) -> None: + """Update User.subscriptionTier from a Stripe subscription object.""" + customer_id = stripe_subscription["customer"] + user = await User.prisma().find_first(where={"stripeCustomerId": customer_id}) + if not user: + logger.warning( + "sync_subscription_from_stripe: no user for customer %s", customer_id + ) + return + status = stripe_subscription.get("status", "") + if status in ("active", "trialing"): + price_id = "" + items = stripe_subscription.get("items", {}).get("data", []) + if items: + price_id = items[0].get("price", {}).get("id", "") + pro_price = await get_subscription_price_id(SubscriptionTier.PRO) + biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS) + if price_id and pro_price and price_id == pro_price: + tier = SubscriptionTier.PRO + elif price_id and biz_price and price_id == biz_price: + tier = SubscriptionTier.BUSINESS + else: + # Unknown or unconfigured price ID — preserve the user's current tier + # rather than defaulting to FREE. This prevents accidental downgrades + # during a price migration or when LD flags are not yet configured. + logger.warning( + "sync_subscription_from_stripe: unknown price %s for customer %s," + " preserving current tier", + price_id, + customer_id, + ) + return + else: + tier = SubscriptionTier.FREE + await set_subscription_tier(user.id, tier) + + async def admin_get_user_history( page: int = 1, page_size: int = 20, diff --git a/autogpt_platform/backend/backend/data/credit_subscription_test.py b/autogpt_platform/backend/backend/data/credit_subscription_test.py new file mode 100644 index 0000000000..34ba19b83c --- /dev/null +++ b/autogpt_platform/backend/backend/data/credit_subscription_test.py @@ -0,0 +1,360 @@ +""" +Tests for Stripe-based subscription tier billing. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from prisma.enums import SubscriptionTier +from prisma.models import User + +from backend.data.credit import ( + cancel_stripe_subscription, + create_subscription_checkout, + set_subscription_tier, + sync_subscription_from_stripe, +) + + +@pytest.mark.asyncio +async def test_set_subscription_tier_updates_db(): + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(update=AsyncMock()), + ) as mock_prisma, + patch("backend.data.credit.get_user_by_id"), + ): + await set_subscription_tier("user-1", SubscriptionTier.PRO) + mock_prisma.return_value.update.assert_awaited_once_with( + where={"id": "user-1"}, + data={"subscriptionTier": SubscriptionTier.PRO}, + ) + + +@pytest.mark.asyncio +async def test_set_subscription_tier_downgrade(): + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(update=AsyncMock()), + ), + patch("backend.data.credit.get_user_by_id"), + ): + # Downgrade to FREE should not raise + await set_subscription_tier("user-1", SubscriptionTier.FREE) + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_active(): + mock_user = MagicMock(spec=User) + mock_user.id = "user-1" + stripe_sub = { + "customer": "cus_123", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_cancelled(): + mock_user = MagicMock(spec=User) + mock_user.id = "user-1" + stripe_sub = { + "customer": "cus_123", + "status": "canceled", + "items": {"data": []}, + } + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.FREE) + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_unknown_customer(): + stripe_sub = { + "customer": "cus_unknown", + "status": "active", + "items": {"data": []}, + } + with patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=None)), + ): + # Should not raise even if user not found + await sync_subscription_from_stripe(stripe_sub) + + +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_cancels_active(): + mock_sub = {"id": "sub_abc123"} + mock_subscriptions = MagicMock() + mock_subscriptions.auto_paging_iter.return_value = iter([mock_sub]) + + with ( + patch( + "backend.data.credit.get_stripe_customer_id", + new_callable=AsyncMock, + return_value="cus_123", + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_subscriptions, + ), + patch("backend.data.credit.stripe.Subscription.cancel") as mock_cancel, + ): + await cancel_stripe_subscription("user-1") + mock_cancel.assert_called_once_with("sub_abc123") + + +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_no_active(): + mock_subscriptions = MagicMock() + mock_subscriptions.auto_paging_iter.return_value = iter([]) + + with ( + patch( + "backend.data.credit.get_stripe_customer_id", + new_callable=AsyncMock, + return_value="cus_123", + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_subscriptions, + ), + patch("backend.data.credit.stripe.Subscription.cancel") as mock_cancel, + ): + await cancel_stripe_subscription("user-1") + mock_cancel.assert_not_called() + + +@pytest.mark.asyncio +async def test_create_subscription_checkout_returns_url(): + mock_session = MagicMock() + mock_session.url = "https://checkout.stripe.com/pay/cs_test_abc123" + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_stripe_customer_id", + new_callable=AsyncMock, + return_value="cus_123", + ), + patch("stripe.checkout.Session.create", return_value=mock_session), + ): + url = await create_subscription_checkout( + user_id="user-1", + tier=SubscriptionTier.PRO, + success_url="https://app.example.com/success", + cancel_url="https://app.example.com/cancel", + ) + assert url == "https://checkout.stripe.com/pay/cs_test_abc123" + + +@pytest.mark.asyncio +async def test_create_subscription_checkout_no_price_raises(): + with patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ): + with pytest.raises(ValueError, match="not available"): + await create_subscription_checkout( + user_id="user-1", + tier=SubscriptionTier.PRO, + success_url="https://app.example.com/success", + cancel_url="https://app.example.com/cancel", + ) + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_unknown_price_defaults_to_free(): + """Unknown price_id should default to FREE instead of returning early.""" + mock_user = MagicMock(spec=User) + mock_user.id = "user-1" + stripe_sub = { + "customer": "cus_123", + "status": "active", + "items": {"data": [{"price": {"id": "price_unknown"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return "price_pro_monthly" if tier == SubscriptionTier.PRO else None + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + # Unknown price → preserve current tier (early return, no DB write) + mock_set.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_none_ld_price_defaults_to_free(): + """When LD returns None for price IDs, active subscription should default to FREE.""" + mock_user = MagicMock(spec=User) + mock_user.id = "user-1" + stripe_sub = { + "customer": "cus_123", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, # LD flags unconfigured + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + # None from LD → comparison guards prevent match → preserve current tier + mock_set.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_business_tier(): + """BUSINESS price_id should map to BUSINESS tier.""" + mock_user = MagicMock(spec=User) + mock_user.id = "user-1" + stripe_sub = { + "customer": "cus_123", + "status": "active", + "items": {"data": [{"price": {"id": "price_biz_monthly"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.BUSINESS) + + +@pytest.mark.asyncio +async def test_get_subscription_price_id_pro(): + from backend.data.credit import get_subscription_price_id + + with patch( + "backend.data.credit.get_feature_flag_value", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ): + price_id = await get_subscription_price_id(SubscriptionTier.PRO) + assert price_id == "price_pro_monthly" + + +@pytest.mark.asyncio +async def test_get_subscription_price_id_free_returns_none(): + from backend.data.credit import get_subscription_price_id + + price_id = await get_subscription_price_id(SubscriptionTier.FREE) + assert price_id is None + + +@pytest.mark.asyncio +async def test_get_subscription_price_id_empty_flag_returns_none(): + from backend.data.credit import get_subscription_price_id + + with patch( + "backend.data.credit.get_feature_flag_value", + new_callable=AsyncMock, + return_value="", # LD flag not set + ): + price_id = await get_subscription_price_id(SubscriptionTier.BUSINESS) + assert price_id is None + + +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_handles_stripe_error(): + """Stripe errors during cancellation should be logged, not raised.""" + import stripe as stripe_mod + + mock_sub = {"id": "sub_abc123"} + mock_subscriptions = MagicMock() + mock_subscriptions.auto_paging_iter.return_value = iter([mock_sub]) + + with ( + patch( + "backend.data.credit.get_stripe_customer_id", + new_callable=AsyncMock, + return_value="cus_123", + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_subscriptions, + ), + patch( + "backend.data.credit.stripe.Subscription.cancel", + side_effect=stripe_mod.StripeError("network error"), + ), + ): + # Should not raise — errors are logged as warnings + await cancel_stripe_subscription("user-1") diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index add0c6b5cf..f0393133e6 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -71,6 +71,9 @@ class User(BaseModel): top_up_config: Optional["AutoTopUpConfig"] = Field( None, description="Top up configuration" ) + subscription_tier: SubscriptionTier = Field( + default=SubscriptionTier.FREE, description="User subscription tier" + ) # Notification preferences max_emails_per_day: int = Field(default=3, description="Maximum emails per day") @@ -103,9 +106,6 @@ class User(BaseModel): description="User timezone (IANA timezone identifier or 'not-set')", ) - # Subscription / rate-limit tier - subscription_tier: SubscriptionTier | None = Field(default=None) - @classmethod def from_db(cls, prisma_user: "PrismaUser") -> "User": """Convert a database User object to application User model.""" @@ -148,6 +148,7 @@ class User(BaseModel): integrations=prisma_user.integrations or "", stripe_customer_id=prisma_user.stripeCustomerId, top_up_config=top_up_config, + subscription_tier=prisma_user.subscriptionTier or SubscriptionTier.FREE, max_emails_per_day=prisma_user.maxEmailsPerDay or 3, notify_on_agent_run=prisma_user.notifyOnAgentRun or True, notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True, @@ -160,7 +161,6 @@ class User(BaseModel): notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True, notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True, timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET, - subscription_tier=prisma_user.subscriptionTier, ) @@ -850,6 +850,8 @@ class NodeExecutionStats(BaseModel): llm_retry_count: int = 0 input_token_count: int = 0 output_token_count: int = 0 + cache_read_token_count: int = 0 + cache_creation_token_count: int = 0 extra_cost: int = 0 extra_steps: int = 0 provider_cost: float | None = None diff --git a/autogpt_platform/backend/backend/data/platform_cost.py b/autogpt_platform/backend/backend/data/platform_cost.py index 6865967627..17915e115c 100644 --- a/autogpt_platform/backend/backend/data/platform_cost.py +++ b/autogpt_platform/backend/backend/data/platform_cost.py @@ -4,10 +4,10 @@ from datetime import datetime, timedelta, timezone from typing import Any from prisma.models import PlatformCostLog as PrismaLog -from prisma.types import PlatformCostLogCreateInput +from prisma.models import User as PrismaUser +from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput from pydantic import BaseModel -from backend.data.db import query_raw_with_schema from backend.util.cache import cached from backend.util.json import SafeJson @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) MICRODOLLARS_PER_USD = 1_000_000 -# Dashboard query limits — keep in sync with the SQL queries below +# Dashboard query limits MAX_PROVIDER_ROWS = 500 MAX_USER_ROWS = 100 @@ -44,6 +44,8 @@ class PlatformCostEntry(BaseModel): cost_microdollars: int | None = None input_tokens: int | None = None output_tokens: int | None = None + cache_read_tokens: int | None = None + cache_creation_tokens: int | None = None data_size: int | None = None duration: float | None = None model: str | None = None @@ -69,6 +71,8 @@ async def log_platform_cost(entry: PlatformCostEntry) -> None: costMicrodollars=entry.cost_microdollars, inputTokens=entry.input_tokens, outputTokens=entry.output_tokens, + cacheReadTokens=entry.cache_read_tokens, + cacheCreationTokens=entry.cache_creation_tokens, dataSize=entry.data_size, duration=entry.duration, model=entry.model, @@ -118,9 +122,12 @@ def _mask_email(email: str | None) -> str | None: class ProviderCostSummary(BaseModel): provider: str tracking_type: str | None = None + model: str | None = None total_cost_microdollars: int total_input_tokens: int total_output_tokens: int + total_cache_read_tokens: int = 0 + total_cache_creation_tokens: int = 0 total_duration_seconds: float = 0.0 total_tracking_amount: float = 0.0 request_count: int @@ -150,6 +157,8 @@ class CostLogRow(BaseModel): output_tokens: int | None = None duration: float | None = None model: str | None = None + cache_read_tokens: int | None = None + cache_creation_tokens: int | None = None class PlatformCostDashboard(BaseModel): @@ -160,38 +169,61 @@ class PlatformCostDashboard(BaseModel): total_users: int -def _build_where( +def _si(row: dict, field: str) -> int: + """Extract an integer from a Prisma group_by _sum dict. + + Prisma Python serialises BigInt/Int aggregate sums as strings; coerce to int. + """ + return int((row.get("_sum") or {}).get(field) or 0) + + +def _sf(row: dict, field: str) -> float: + """Extract a float from a Prisma group_by _sum dict.""" + return float((row.get("_sum") or {}).get(field) or 0.0) + + +def _ca(row: dict) -> int: + """Extract _count._all from a Prisma group_by row.""" + c = row.get("_count") or {} + return int(c.get("_all") or 0) if isinstance(c, dict) else int(c or 0) + + +def _build_prisma_where( start: datetime | None, end: datetime | None, provider: str | None, user_id: str | None, - table_alias: str = "", -) -> tuple[str, list[Any]]: - prefix = f"{table_alias}." if table_alias else "" - clauses: list[str] = [] - params: list[Any] = [] - idx = 1 + model: str | None = None, + block_name: str | None = None, + tracking_type: str | None = None, +) -> PlatformCostLogWhereInput: + """Build a Prisma WhereInput for PlatformCostLog filters.""" + where: PlatformCostLogWhereInput = {} + + if start and end: + where["createdAt"] = {"gte": start, "lte": end} + elif start: + where["createdAt"] = {"gte": start} + elif end: + where["createdAt"] = {"lte": end} - if start: - clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz') - params.append(start) - idx += 1 - if end: - clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz') - params.append(end) - idx += 1 if provider: - # Provider names are normalized to lowercase at write time so a plain - # equality check is sufficient and the (provider, createdAt) index is used. - clauses.append(f'{prefix}"provider" = ${idx}') - params.append(provider.lower()) - idx += 1 - if user_id: - clauses.append(f'{prefix}"userId" = ${idx}') - params.append(user_id) - idx += 1 + where["provider"] = provider.lower() - return (" AND ".join(clauses) if clauses else "TRUE", params) + if user_id: + where["userId"] = user_id + + if model: + where["model"] = model + + if block_name: + # Case-insensitive match — mirrors the original LOWER() SQL filter. + where["blockName"] = {"equals": block_name, "mode": "insensitive"} + + if tracking_type: + where["trackingType"] = tracking_type + + return where @cached(ttl_seconds=30) @@ -200,6 +232,9 @@ async def get_platform_cost_dashboard( end: datetime | None = None, provider: str | None = None, user_id: str | None = None, + model: str | None = None, + block_name: str | None = None, + tracking_type: str | None = None, ) -> PlatformCostDashboard: """Aggregate platform cost logs for the admin dashboard. @@ -214,86 +249,107 @@ async def get_platform_cost_dashboard( """ if start is None: start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS) - where_p, params_p = _build_where(start, end, provider, user_id, "p") - by_provider_rows, by_user_rows, total_user_rows = await asyncio.gather( - query_raw_with_schema( - f""" - SELECT - p."provider", - p."trackingType" AS tracking_type, - COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost, - COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens, - COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens, - COALESCE(SUM(p."duration"), 0)::float AS total_duration, - COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount, - COUNT(*)::bigint AS request_count - FROM {{schema_prefix}}"PlatformCostLog" p - WHERE {where_p} - GROUP BY p."provider", p."trackingType" - ORDER BY total_cost DESC - LIMIT {MAX_PROVIDER_ROWS} - """, - *params_p, - ), - query_raw_with_schema( - f""" - SELECT - p."userId" AS user_id, - u."email", - COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost, - COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens, - COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens, - COUNT(*)::bigint AS request_count - FROM {{schema_prefix}}"PlatformCostLog" p - LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId" - WHERE {where_p} - GROUP BY p."userId", u."email" - ORDER BY total_cost DESC - LIMIT {MAX_USER_ROWS} - """, - *params_p, - ), - query_raw_with_schema( - f""" - SELECT COUNT(DISTINCT p."userId")::bigint AS cnt - FROM {{schema_prefix}}"PlatformCostLog" p - WHERE {where_p} - """, - *params_p, - ), + where = _build_prisma_where( + start, end, provider, user_id, model, block_name, tracking_type ) - # Use the exact COUNT(DISTINCT userId) so total_users is not capped at - # MAX_USER_ROWS (which would silently report 100 for >100 active users). - total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0 - total_cost = sum(r["total_cost"] for r in by_provider_rows) - total_requests = sum(r["request_count"] for r in by_provider_rows) + sum_fields = { + "costMicrodollars": True, + "inputTokens": True, + "outputTokens": True, + "cacheReadTokens": True, + "cacheCreationTokens": True, + "duration": True, + "trackingAmount": True, + } + + # Run all four aggregation queries in parallel. + by_provider_groups, by_user_groups, total_user_groups, total_agg_groups = ( + await asyncio.gather( + # (provider, trackingType, model) aggregation — no ORDER BY in ORM; + # sort by total cost descending in Python after fetch. + PrismaLog.prisma().group_by( + by=["provider", "trackingType", "model"], + where=where, + sum=sum_fields, + count=True, + ), + # userId aggregation — emails fetched separately below. + PrismaLog.prisma().group_by( + by=["userId"], + where=where, + sum=sum_fields, + count=True, + ), + # Distinct user count: group by userId, count groups. + PrismaLog.prisma().group_by( + by=["userId"], + where=where, + count=True, + ), + # Total aggregate: group by provider (no limit) to sum across all + # matching rows. Summed in Python to get grand totals. + PrismaLog.prisma().group_by( + by=["provider"], + where=where, + sum={"costMicrodollars": True}, + count=True, + ), + ) + ) + + # Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS. + by_provider_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True) + by_provider_groups = by_provider_groups[:MAX_PROVIDER_ROWS] + + # Sort by_user by total cost descending and cap at MAX_USER_ROWS. + by_user_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True) + by_user_groups = by_user_groups[:MAX_USER_ROWS] + + # Batch-fetch emails for the users in by_user. + user_ids = [r["userId"] for r in by_user_groups if r.get("userId") is not None] + email_by_user_id: dict[str, str | None] = {} + if user_ids: + users = await PrismaUser.prisma().find_many( + where={"id": {"in": user_ids}}, + ) + email_by_user_id = {u.id: u.email for u in users} + + # Total distinct users — exclude the NULL-userId group (deleted users). + total_users = len([g for g in total_user_groups if g.get("userId") is not None]) + + # Grand totals — sum across all provider groups (no LIMIT applied above). + total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups) + total_requests = sum(_ca(r) for r in total_agg_groups) return PlatformCostDashboard( by_provider=[ ProviderCostSummary( provider=r["provider"], - tracking_type=r.get("tracking_type"), - total_cost_microdollars=r["total_cost"], - total_input_tokens=r["total_input_tokens"], - total_output_tokens=r["total_output_tokens"], - total_duration_seconds=r.get("total_duration", 0.0), - total_tracking_amount=r.get("total_tracking_amount", 0.0), - request_count=r["request_count"], + tracking_type=r.get("trackingType"), + model=r.get("model"), + total_cost_microdollars=_si(r, "costMicrodollars"), + total_input_tokens=_si(r, "inputTokens"), + total_output_tokens=_si(r, "outputTokens"), + total_cache_read_tokens=_si(r, "cacheReadTokens"), + total_cache_creation_tokens=_si(r, "cacheCreationTokens"), + total_duration_seconds=_sf(r, "duration"), + total_tracking_amount=_sf(r, "trackingAmount"), + request_count=_ca(r), ) - for r in by_provider_rows + for r in by_provider_groups ], by_user=[ UserCostSummary( - user_id=r.get("user_id"), - email=_mask_email(r.get("email")), - total_cost_microdollars=r["total_cost"], - total_input_tokens=r["total_input_tokens"], - total_output_tokens=r["total_output_tokens"], - request_count=r["request_count"], + user_id=r.get("userId"), + email=_mask_email(email_by_user_id.get(r.get("userId") or "")), + total_cost_microdollars=_si(r, "costMicrodollars"), + total_input_tokens=_si(r, "inputTokens"), + total_output_tokens=_si(r, "outputTokens"), + request_count=_ca(r), ) - for r in by_user_rows + for r in by_user_groups ], total_cost_microdollars=total_cost, total_requests=total_requests, @@ -308,71 +364,163 @@ async def get_platform_cost_logs( user_id: str | None = None, page: int = 1, page_size: int = 50, + model: str | None = None, + block_name: str | None = None, + tracking_type: str | None = None, ) -> tuple[list[CostLogRow], int]: if start is None: start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS) - where_sql, params = _build_where(start, end, provider, user_id, "p") + where = _build_prisma_where( + start, end, provider, user_id, model, block_name, tracking_type + ) offset = (page - 1) * page_size - limit_idx = len(params) + 1 - offset_idx = len(params) + 2 - count_rows, rows = await asyncio.gather( - query_raw_with_schema( - f""" - SELECT COUNT(*)::bigint AS cnt - FROM {{schema_prefix}}"PlatformCostLog" p - WHERE {where_sql} - """, - *params, - ), - query_raw_with_schema( - f""" - SELECT - p."id", - p."createdAt" AS created_at, - p."userId" AS user_id, - u."email", - p."graphExecId" AS graph_exec_id, - p."nodeExecId" AS node_exec_id, - p."blockName" AS block_name, - p."provider", - p."trackingType" AS tracking_type, - p."costMicrodollars" AS cost_microdollars, - p."inputTokens" AS input_tokens, - p."outputTokens" AS output_tokens, - p."duration", - p."model" - FROM {{schema_prefix}}"PlatformCostLog" p - LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId" - WHERE {where_sql} - ORDER BY p."createdAt" DESC, p."id" DESC - LIMIT ${limit_idx} OFFSET ${offset_idx} - """, - *params, - page_size, - offset, + total, rows = await asyncio.gather( + PrismaLog.prisma().count(where=where), + PrismaLog.prisma().find_many( + where=where, + include={"User": True}, + order=[{"createdAt": "desc"}, {"id": "desc"}], + take=page_size, + skip=offset, ), ) - total = count_rows[0]["cnt"] if count_rows else 0 logs = [ CostLogRow( - id=r["id"], - created_at=r["created_at"], - user_id=r.get("user_id"), - email=_mask_email(r.get("email")), - graph_exec_id=r.get("graph_exec_id"), - node_exec_id=r.get("node_exec_id"), - block_name=r["block_name"], - provider=r["provider"], - tracking_type=r.get("tracking_type"), - cost_microdollars=r.get("cost_microdollars"), - input_tokens=r.get("input_tokens"), - output_tokens=r.get("output_tokens"), - duration=r.get("duration"), - model=r.get("model"), + id=r.id, + created_at=r.createdAt, + user_id=r.userId, + email=_mask_email(r.User.email if r.User else None), + graph_exec_id=r.graphExecId, + node_exec_id=r.nodeExecId, + block_name=r.blockName or "", + provider=r.provider, + tracking_type=r.trackingType, + cost_microdollars=r.costMicrodollars, + input_tokens=r.inputTokens, + output_tokens=r.outputTokens, + cache_read_tokens=getattr(r, "cacheReadTokens", None), + cache_creation_tokens=getattr(r, "cacheCreationTokens", None), + duration=r.duration, + model=r.model, ) for r in rows ] return logs, total + + +EXPORT_MAX_ROWS = 100_000 + + +async def get_platform_cost_logs_for_export( + start: datetime | None = None, + end: datetime | None = None, + provider: str | None = None, + user_id: str | None = None, + model: str | None = None, + block_name: str | None = None, + tracking_type: str | None = None, +) -> tuple[list[CostLogRow], bool]: + """Return all matching rows up to EXPORT_MAX_ROWS. + + Returns (rows, truncated) where truncated=True means the result was capped + and the caller should warn the user that not all rows are included. + """ + if start is None: + start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS) + + where = _build_prisma_where( + start, end, provider, user_id, model, block_name, tracking_type + ) + + rows = await PrismaLog.prisma().find_many( + where=where, + include={"User": True}, + order=[{"createdAt": "desc"}, {"id": "desc"}], + take=EXPORT_MAX_ROWS + 1, + ) + + truncated = len(rows) > EXPORT_MAX_ROWS + rows = rows[:EXPORT_MAX_ROWS] + + return [ + CostLogRow( + id=r.id, + created_at=r.createdAt, + user_id=r.userId, + email=_mask_email(r.User.email if r.User else None), + graph_exec_id=r.graphExecId, + node_exec_id=r.nodeExecId, + block_name=r.blockName or "", + provider=r.provider, + tracking_type=r.trackingType, + cost_microdollars=r.costMicrodollars, + input_tokens=r.inputTokens, + output_tokens=r.outputTokens, + cache_read_tokens=getattr(r, "cacheReadTokens", None), + cache_creation_tokens=getattr(r, "cacheCreationTokens", None), + duration=r.duration, + model=r.model, + ) + for r in rows + ], truncated + + +# --------------------------------------------------------------------------- +# Helpers kept for backward-compatibility with existing tests. +# New code should not use these — use _build_prisma_where instead. +# --------------------------------------------------------------------------- + + +def _build_where( + start: datetime | None, + end: datetime | None, + provider: str | None, + user_id: str | None, + table_alias: str = "", + model: str | None = None, + block_name: str | None = None, + tracking_type: str | None = None, +) -> tuple[str, list[Any]]: + """Legacy SQL WHERE builder — retained so existing unit tests still pass. + + Only used by tests that verify the SQL-string generation logic. All + production code uses _build_prisma_where instead. + """ + prefix = f"{table_alias}." if table_alias else "" + clauses: list[str] = [] + params: list[Any] = [] + idx = 1 + + if start: + clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz') + params.append(start) + idx += 1 + if end: + clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz') + params.append(end) + idx += 1 + if provider: + clauses.append(f'{prefix}"provider" = ${idx}') + params.append(provider.lower()) + idx += 1 + if user_id: + clauses.append(f'{prefix}"userId" = ${idx}') + params.append(user_id) + idx += 1 + if model: + clauses.append(f'{prefix}"model" = ${idx}') + params.append(model) + idx += 1 + if block_name: + clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})') + params.append(block_name) + idx += 1 + if tracking_type: + clauses.append(f'{prefix}"trackingType" = ${idx}') + params.append(tracking_type) + idx += 1 + + return (" AND ".join(clauses) if clauses else "TRUE", params) diff --git a/autogpt_platform/backend/backend/data/platform_cost_integration_test.py b/autogpt_platform/backend/backend/data/platform_cost_integration_test.py index 10fe35d748..ef457a1105 100644 --- a/autogpt_platform/backend/backend/data/platform_cost_integration_test.py +++ b/autogpt_platform/backend/backend/data/platform_cost_integration_test.py @@ -77,3 +77,25 @@ async def test_log_platform_cost_metadata_none(cost_log_user): rows = await PrismaLog.prisma().find_many(where={"userId": user_id}) assert len(rows) == 1 assert rows[0].metadata == {} + + +@pytest.mark.asyncio(loop_scope="session") +async def test_log_platform_cost_cache_tokens(cost_log_user): + """Verify that cache_read_tokens and cache_creation_tokens are persisted.""" + user_id = cost_log_user + entry = PlatformCostEntry( + user_id=user_id, + block_name="TestBlock", + provider="anthropic", + input_tokens=200, + output_tokens=100, + cache_read_tokens=50, + cache_creation_tokens=25, + model="claude-3-5-sonnet-20241022", + ) + await log_platform_cost(entry) + + rows = await PrismaLog.prisma().find_many(where={"userId": user_id}) + assert len(rows) == 1 + assert rows[0].cacheReadTokens == 50 + assert rows[0].cacheCreationTokens == 25 diff --git a/autogpt_platform/backend/backend/data/platform_cost_test.py b/autogpt_platform/backend/backend/data/platform_cost_test.py index af150346a5..4a2372628b 100644 --- a/autogpt_platform/backend/backend/data/platform_cost_test.py +++ b/autogpt_platform/backend/backend/data/platform_cost_test.py @@ -1,7 +1,7 @@ """Unit tests for helpers and async functions in platform_cost module.""" from datetime import datetime, timezone -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from prisma import Json @@ -14,11 +14,27 @@ from .platform_cost import ( _mask_email, get_platform_cost_dashboard, get_platform_cost_logs, + get_platform_cost_logs_for_export, log_platform_cost, log_platform_cost_safe, + usd_to_microdollars, ) +class TestUsdToMicrodollars: + def test_none_returns_none(self): + assert usd_to_microdollars(None) is None + + def test_zero_returns_zero(self): + assert usd_to_microdollars(0.0) == 0 + + def test_positive_value(self): + assert usd_to_microdollars(0.001) == 1000 + + def test_large_value(self): + assert usd_to_microdollars(1.0) == 1_000_000 + + class TestMaskEmail: def test_typical_email(self): assert _mask_email("user@example.com") == "us***@example.com" @@ -94,6 +110,51 @@ class TestBuildWhere: sql, _ = _build_where(start, end, None, None) assert " AND " in sql + def test_model_only(self): + sql, params = _build_where(None, None, None, None, model="gpt-4") + assert '"model" = $1' in sql + assert params == ["gpt-4"] + + def test_block_name_only(self): + sql, params = _build_where(None, None, None, None, block_name="LLMBlock") + assert 'LOWER("blockName") = LOWER($1)' in sql + assert params == ["LLMBlock"] + + def test_tracking_type_only(self): + sql, params = _build_where(None, None, None, None, tracking_type="tokens") + assert '"trackingType" = $1' in sql + assert params == ["tokens"] + + def test_all_new_filters_combined(self): + sql, params = _build_where( + None, + None, + None, + None, + model="gpt-4", + block_name="LLM", + tracking_type="tokens", + ) + assert len(params) == 3 + assert params[0] == "gpt-4" + assert params[1] == "LLM" + assert params[2] == "tokens" + + def test_new_filters_with_alias(self): + sql, params = _build_where( + None, + None, + None, + None, + table_alias="p", + model="gpt-4", + block_name="MyBlock", + tracking_type="cost_usd", + ) + assert 'p."model" = $1' in sql + assert 'LOWER(p."blockName") = LOWER($2)' in sql + assert 'p."trackingType" = $3' in sql + def _make_entry(**overrides: object) -> PlatformCostEntry: return PlatformCostEntry.model_validate( @@ -163,6 +224,41 @@ class TestLogPlatformCostSafe: mock_create.assert_awaited_once() +def _make_group_by_row( + provider: str = "openai", + tracking_type: str | None = "tokens", + model: str | None = None, + cost: int = 5000, + input_tokens: int = 1000, + output_tokens: int = 500, + cache_read_tokens: int = 0, + cache_creation_tokens: int = 0, + duration: float = 10.5, + tracking_amount: float = 0.0, + count: int = 3, + user_id: str | None = None, +) -> dict: + row: dict = { + "_sum": { + "costMicrodollars": cost, + "inputTokens": input_tokens, + "outputTokens": output_tokens, + "cacheReadTokens": cache_read_tokens, + "cacheCreationTokens": cache_creation_tokens, + "duration": duration, + "trackingAmount": tracking_amount, + }, + "_count": {"_all": count}, + } + if user_id is not None: + row["userId"] = user_id + else: + row["provider"] = provider + row["trackingType"] = tracking_type + row["model"] = model + return row + + class TestGetPlatformCostDashboard: def setup_method(self): # @cached stores results in-process; clear between tests to avoid bleed. @@ -170,31 +266,44 @@ class TestGetPlatformCostDashboard: @pytest.mark.asyncio async def test_returns_dashboard_with_data(self): - provider_rows = [ - { - "provider": "openai", - "tracking_type": "tokens", - "total_cost": 5000, - "total_input_tokens": 1000, - "total_output_tokens": 500, - "total_duration": 10.5, - "request_count": 3, - } - ] - user_rows = [ - { - "user_id": "u1", - "email": "a@b.com", - "total_cost": 5000, - "total_input_tokens": 1000, - "total_output_tokens": 500, - "request_count": 3, - } - ] - # Dashboard runs 3 queries: by_provider, by_user, COUNT(DISTINCT userId). - mock_query = AsyncMock(side_effect=[provider_rows, user_rows, [{"cnt": 1}]]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + provider_row = _make_group_by_row( + provider="openai", + tracking_type="tokens", + cost=5000, + input_tokens=1000, + output_tokens=500, + duration=10.5, + count=3, + ) + user_row = _make_group_by_row(user_id="u1", cost=5000, count=3) + + mock_user = MagicMock() + mock_user.id = "u1" + mock_user.email = "a@b.com" + + mock_actions = MagicMock() + mock_actions.group_by = AsyncMock( + side_effect=[ + [provider_row], # by_provider + [user_row], # by_user + [{"userId": "u1"}], # distinct users + [provider_row], # total agg + ] + ) + mock_actions.find_many = AsyncMock(return_value=[mock_user]) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch( + "backend.data.platform_cost.PrismaUser.prisma", + return_value=mock_actions, + ), + ): dashboard = await get_platform_cost_dashboard() + assert dashboard.total_cost_microdollars == 5000 assert dashboard.total_requests == 3 assert dashboard.total_users == 1 @@ -206,10 +315,67 @@ class TestGetPlatformCostDashboard: assert dashboard.by_user[0].email == "a***@b.com" @pytest.mark.asyncio - async def test_returns_empty_dashboard(self): - mock_query = AsyncMock(side_effect=[[], [], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + async def test_cache_tokens_aggregated_not_hardcoded(self): + """cache_read_tokens and cache_creation_tokens must be read from the + DB aggregation, not hardcoded to 0 (regression guard for Sentry report).""" + provider_row = _make_group_by_row( + provider="anthropic", + tracking_type="tokens", + cost=1000, + input_tokens=800, + output_tokens=200, + cache_read_tokens=400, + cache_creation_tokens=100, + count=1, + ) + user_row = _make_group_by_row(user_id="u2", cost=1000, count=1) + + mock_actions = MagicMock() + mock_actions.group_by = AsyncMock( + side_effect=[ + [provider_row], # by_provider + [user_row], # by_user + [{"userId": "u2"}], # distinct users + [provider_row], # total agg + ] + ) + mock_actions.find_many = AsyncMock(return_value=[]) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch( + "backend.data.platform_cost.PrismaUser.prisma", + return_value=mock_actions, + ), + ): dashboard = await get_platform_cost_dashboard() + + assert len(dashboard.by_provider) == 1 + row = dashboard.by_provider[0] + assert row.total_cache_read_tokens == 400 + assert row.total_cache_creation_tokens == 100 + + @pytest.mark.asyncio + async def test_returns_empty_dashboard(self): + mock_actions = MagicMock() + mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []]) + mock_actions.find_many = AsyncMock(return_value=[]) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch( + "backend.data.platform_cost.PrismaUser.prisma", + return_value=mock_actions, + ), + ): + dashboard = await get_platform_cost_dashboard() + assert dashboard.total_cost_microdollars == 0 assert dashboard.total_requests == 0 assert dashboard.total_users == 0 @@ -219,68 +385,228 @@ class TestGetPlatformCostDashboard: @pytest.mark.asyncio async def test_passes_filters_to_queries(self): start = datetime(2026, 1, 1, tzinfo=timezone.utc) - mock_query = AsyncMock(side_effect=[[], [], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + + mock_actions = MagicMock() + mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []]) + mock_actions.find_many = AsyncMock(return_value=[]) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch( + "backend.data.platform_cost.PrismaUser.prisma", + return_value=mock_actions, + ), + ): await get_platform_cost_dashboard( start=start, provider="openai", user_id="u1" ) - assert mock_query.await_count == 3 - first_call_sql = mock_query.call_args_list[0][0][0] - assert "createdAt" in first_call_sql + + # group_by called 4 times (by_provider, by_user, distinct users, totals) + assert mock_actions.group_by.await_count == 4 + # The where dict passed to the first call should include createdAt + first_call_kwargs = mock_actions.group_by.call_args_list[0][1] + assert "createdAt" in first_call_kwargs.get("where", {}) + + +def _make_prisma_log_row( + i: int = 0, + user_email: str | None = None, +) -> MagicMock: + row = MagicMock() + row.id = f"log-{i}" + row.createdAt = datetime(2026, 3, 1, tzinfo=timezone.utc) + row.userId = "u1" + row.graphExecId = None + row.nodeExecId = None + row.blockName = "TestBlock" + row.provider = "openai" + row.trackingType = "tokens" + row.costMicrodollars = 1000 + row.inputTokens = 10 + row.outputTokens = 5 + row.duration = 0.5 + row.model = "gpt-4" + # cacheReadTokens / cacheCreationTokens may not exist on older Prisma clients + row.configure_mock(**{"cacheReadTokens": None, "cacheCreationTokens": None}) + if user_email is not None: + row.User = MagicMock() + row.User.email = user_email + else: + row.User = None + return row class TestGetPlatformCostLogs: @pytest.mark.asyncio async def test_returns_logs_and_total(self): - count_rows = [{"cnt": 1}] - log_rows = [ - { - "id": "log-1", - "created_at": datetime(2026, 3, 1, tzinfo=timezone.utc), - "user_id": "u1", - "email": "a@b.com", - "graph_exec_id": "g1", - "node_exec_id": "n1", - "block_name": "TestBlock", - "provider": "openai", - "tracking_type": "tokens", - "cost_microdollars": 5000, - "input_tokens": 100, - "output_tokens": 50, - "duration": 1.5, - "model": "gpt-4", - } - ] - mock_query = AsyncMock(side_effect=[count_rows, log_rows]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + row = _make_prisma_log_row(0, user_email="a@b.com") + mock_actions = MagicMock() + mock_actions.count = AsyncMock(return_value=1) + mock_actions.find_many = AsyncMock(return_value=[row]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, total = await get_platform_cost_logs(page=1, page_size=10) + assert total == 1 assert len(logs) == 1 - assert logs[0].id == "log-1" + assert logs[0].id == "log-0" assert logs[0].provider == "openai" assert logs[0].model == "gpt-4" @pytest.mark.asyncio async def test_returns_empty_when_no_data(self): - mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + mock_actions = MagicMock() + mock_actions.count = AsyncMock(return_value=0) + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, total = await get_platform_cost_logs() + assert total == 0 assert logs == [] @pytest.mark.asyncio async def test_pagination_offset(self): - mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + mock_actions = MagicMock() + mock_actions.count = AsyncMock(return_value=100) + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, total = await get_platform_cost_logs(page=3, page_size=25) + assert total == 100 - second_call_args = mock_query.call_args_list[1][0] - assert 25 in second_call_args # page_size - assert 50 in second_call_args # offset = (3-1) * 25 + find_many_call = mock_actions.find_many.call_args[1] + assert find_many_call["take"] == 25 + assert find_many_call["skip"] == 50 # (3-1) * 25 @pytest.mark.asyncio - async def test_empty_count_returns_zero(self): - mock_query = AsyncMock(side_effect=[[], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): - logs, total = await get_platform_cost_logs() + async def test_explicit_start_skips_default(self): + start = datetime(2026, 1, 1, tzinfo=timezone.utc) + mock_actions = MagicMock() + mock_actions.count = AsyncMock(return_value=0) + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): + logs, total = await get_platform_cost_logs(start=start) + assert total == 0 + where = mock_actions.count.call_args[1]["where"] + # start provided — should appear in the where filter + assert "createdAt" in where + + +class TestGetPlatformCostLogsForExport: + @pytest.mark.asyncio + async def test_returns_logs_not_truncated(self): + row = _make_prisma_log_row(0) + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[row]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): + logs, truncated = await get_platform_cost_logs_for_export() + + assert len(logs) == 1 + assert truncated is False + assert logs[0].id == "log-0" + + @pytest.mark.asyncio + async def test_returns_empty_not_truncated(self): + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): + logs, truncated = await get_platform_cost_logs_for_export() + + assert logs == [] + assert truncated is False + + @pytest.mark.asyncio + async def test_truncates_at_export_max_rows(self): + rows = [_make_prisma_log_row(i) for i in range(3)] + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=rows) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2), + ): + logs, truncated = await get_platform_cost_logs_for_export() + + assert len(logs) == 2 + assert truncated is True + + @pytest.mark.asyncio + async def test_passes_model_block_tracking_filters(self): + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): + await get_platform_cost_logs_for_export( + model="gpt-4", block_name="LLMBlock", tracking_type="tokens" + ) + + where = mock_actions.find_many.call_args[1]["where"] + assert where.get("model") == "gpt-4" + assert where.get("trackingType") == "tokens" + # blockName uses a dict filter for case-insensitive match + assert "blockName" in where + + @pytest.mark.asyncio + async def test_maps_cache_tokens(self): + row = _make_prisma_log_row(0) + row.configure_mock(**{"cacheReadTokens": 50, "cacheCreationTokens": 25}) + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[row]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): + logs, _ = await get_platform_cost_logs_for_export() + + assert logs[0].cache_read_tokens == 50 + assert logs[0].cache_creation_tokens == 25 + + @pytest.mark.asyncio + async def test_explicit_start_skips_default(self): + start = datetime(2026, 1, 1, tzinfo=timezone.utc) + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): + logs, truncated = await get_platform_cost_logs_for_export(start=start) + + assert logs == [] + assert truncated is False + where = mock_actions.find_many.call_args[1]["where"] + assert "createdAt" in where diff --git a/autogpt_platform/backend/backend/executor/cost_tracking.py b/autogpt_platform/backend/backend/executor/cost_tracking.py index b1381d18c0..afe8ab9b10 100644 --- a/autogpt_platform/backend/backend/executor/cost_tracking.py +++ b/autogpt_platform/backend/backend/executor/cost_tracking.py @@ -278,6 +278,8 @@ async def log_system_credential_cost( cost_microdollars=cost_microdollars, input_tokens=stats.input_token_count, output_tokens=stats.output_token_count, + cache_read_tokens=stats.cache_read_token_count or None, + cache_creation_tokens=stats.cache_creation_token_count or None, data_size=stats.output_size if stats.output_size > 0 else None, duration=stats.walltime if stats.walltime > 0 else None, model=model_name, diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 5c30617fdb..f63f8ec76c 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -10,11 +10,13 @@ from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast -import sentry_sdk from pika.adapters.blocking_connection import BlockingChannel from pika.spec import Basic, BasicProperties from prometheus_client import Gauge, start_http_server from redis.asyncio.lock import Lock as AsyncRedisLock +from sentry_sdk.api import capture_exception as _sentry_capture_exception +from sentry_sdk.api import flush as _sentry_flush +from sentry_sdk.api import get_current_scope as _sentry_get_current_scope from backend.blocks import get_block from backend.blocks._base import BlockSchema @@ -393,7 +395,7 @@ async def execute_node( output_size = 0 # sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :( - scope = sentry_sdk.get_current_scope() + scope = _sentry_get_current_scope() # save the tags original_user = scope._user @@ -428,8 +430,8 @@ async def execute_node( ex, (NotFoundError, GraphNotFoundError) ) if not is_expected: - sentry_sdk.capture_exception(error=ex, scope=scope) - sentry_sdk.flush() + _sentry_capture_exception(error=ex, scope=scope) + _sentry_flush() # Re-raise to maintain normal error flow raise finally: diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index 74a0f960ed..27121304ca 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -43,6 +43,8 @@ class Flag(str, Enum): COPILOT_SDK = "copilot-sdk" COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit" COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit" + STRIPE_PRICE_PRO = "stripe-price-id-pro" + STRIPE_PRICE_BUSINESS = "stripe-price-id-business" GRAPHITI_MEMORY = "graphiti-memory" diff --git a/autogpt_platform/backend/backend/util/metrics.py b/autogpt_platform/backend/backend/util/metrics.py index 30a979cffb..3348dd46d1 100644 --- a/autogpt_platform/backend/backend/util/metrics.py +++ b/autogpt_platform/backend/backend/util/metrics.py @@ -1,14 +1,24 @@ import logging from enum import Enum -import sentry_sdk from pydantic import SecretStr +from sentry_sdk._init_implementation import init as _sentry_init +from sentry_sdk.api import capture_exception as _sentry_capture_exception +from sentry_sdk.api import flush as _sentry_flush from sentry_sdk.integrations import DidNotEnable -from sentry_sdk.integrations.anthropic import AnthropicIntegration from sentry_sdk.integrations.asyncio import AsyncioIntegration -from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration from sentry_sdk.integrations.logging import LoggingIntegration +try: + from sentry_sdk.integrations.anthropic import AnthropicIntegration +except ImportError: + AnthropicIntegration = None # type: ignore[assignment,misc] + +try: + from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration +except ImportError: + LaunchDarklyIntegration = None # type: ignore[assignment,misc] + from backend.util import feature_flag from backend.util.settings import BehaveAs, Settings @@ -131,32 +141,34 @@ def _before_send(event, hint): def sentry_init(): sentry_dsn = settings.secrets.sentry_dsn integrations = [] - if feature_flag.is_configured(): + if feature_flag.is_configured() and LaunchDarklyIntegration is not None: try: integrations.append(LaunchDarklyIntegration(feature_flag.get_client())) except DidNotEnable as e: logger.error(f"Error enabling LaunchDarklyIntegration for Sentry: {e}") - sentry_sdk.init( + optional_integrations = ( + [AnthropicIntegration(include_prompts=False)] + if AnthropicIntegration is not None + else [] + ) + _sentry_init( dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0, environment=f"app:{settings.config.app_env.value}-behave:{settings.config.behave_as.value}", - _experiments={"enable_logs": True}, before_send=_before_send, integrations=[ AsyncioIntegration(), - LoggingIntegration(sentry_logs_level=logging.INFO), - AnthropicIntegration( - include_prompts=False, - ), + LoggingIntegration(), ] + + optional_integrations + integrations, ) def sentry_capture_error(error: BaseException): - sentry_sdk.capture_exception(error) - sentry_sdk.flush() + _sentry_capture_exception(error) + _sentry_flush() async def discord_send_alert( diff --git a/autogpt_platform/backend/backend/util/service.py b/autogpt_platform/backend/backend/util/service.py index a1da0c1a68..459e46f01c 100644 --- a/autogpt_platform/backend/backend/util/service.py +++ b/autogpt_platform/backend/backend/util/service.py @@ -26,11 +26,11 @@ from typing import ( ) import httpx -import sentry_sdk import uvicorn from fastapi import FastAPI, Request, responses from prisma.errors import DataError, UniqueViolationError from pydantic import BaseModel, TypeAdapter, create_model +from sentry_sdk.api import capture_exception as _sentry_capture_exception import backend.util.exceptions as exceptions from backend.monitoring.instrumentation import instrument_fastapi @@ -721,7 +721,7 @@ def get_service_client( logger.warning( f"RPC return type validation failed for {type(e).__name__}: {e}" ) - sentry_sdk.capture_exception(e) + _sentry_capture_exception(e) return result return result diff --git a/autogpt_platform/backend/migrations/20260409000000_add_cache_tokens_to_platform_cost_log/migration.sql b/autogpt_platform/backend/migrations/20260409000000_add_cache_tokens_to_platform_cost_log/migration.sql new file mode 100644 index 0000000000..21c2f0d6b6 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260409000000_add_cache_tokens_to_platform_cost_log/migration.sql @@ -0,0 +1,2 @@ +ALTER TABLE "PlatformCostLog" ADD COLUMN "cacheReadTokens" INTEGER; +ALTER TABLE "PlatformCostLog" ADD COLUMN "cacheCreationTokens" INTEGER; diff --git a/autogpt_platform/backend/migrations/20260409000000_add_subscription_tier/migration.sql b/autogpt_platform/backend/migrations/20260409000000_add_subscription_tier/migration.sql new file mode 100644 index 0000000000..2240b450ec --- /dev/null +++ b/autogpt_platform/backend/migrations/20260409000000_add_subscription_tier/migration.sql @@ -0,0 +1,5 @@ +-- SubscriptionTier enum and User.subscriptionTier column already created by +-- 20260326200000_add_rate_limit_tier migration. Only add SUBSCRIPTION transaction type. + +-- AlterEnum +ALTER TYPE "CreditTransactionType" ADD VALUE IF NOT EXISTS 'SUBSCRIPTION'; diff --git a/autogpt_platform/backend/pyrightconfig.json b/autogpt_platform/backend/pyrightconfig.json new file mode 100644 index 0000000000..2241b990d2 --- /dev/null +++ b/autogpt_platform/backend/pyrightconfig.json @@ -0,0 +1,30 @@ +{ + "pythonVersion": "3.12", + "venvPath": ".", + "venv": ".venv", + "include": ["backend"], + "ignore": [ + "backend/**/*_test.py", + "backend/**/*conftest.py", + "backend/**/conftest.py", + "backend/**/_test_data.py", + "backend/**/*test_data*.py", + "backend/api/features/library/_add_to_library.py", + "backend/api/features/library/db.py", + "backend/api/features/store/db.py", + "backend/blocks/sql_query_helpers.py", + "backend/blocks/stagehand/blocks.py", + "backend/cli/oauth_tool.py", + "backend/copilot/db.py", + "backend/copilot/rate_limit.py", + "backend/data/auth/api_key.py", + "backend/data/auth/oauth.py", + "backend/data/execution.py", + "backend/data/graph.py", + "backend/data/human_review.py", + "backend/data/onboarding.py", + "backend/data/understanding.py", + "backend/data/workspace.py", + "backend/sdk/__init__.py" + ] +} diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index ed71f620ba..e224be7d5f 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -774,6 +774,7 @@ enum CreditTransactionType { GRANT REFUND CARD_CHECK + SUBSCRIPTION } model CreditTransaction { @@ -842,6 +843,8 @@ model PlatformCostLog { inputTokens Int? outputTokens Int? + cacheReadTokens Int? // Anthropic cache read tokens (billed at 10% of base) + cacheCreationTokens Int? // Anthropic cache write tokens (billed at 125% of base) dataSize Int? // bytes duration Float? // seconds model String? diff --git a/autogpt_platform/frontend/package.json b/autogpt_platform/frontend/package.json index 90c2645272..00e9e6fc8a 100644 --- a/autogpt_platform/frontend/package.json +++ b/autogpt_platform/frontend/package.json @@ -150,6 +150,7 @@ "@types/react-dom": "18.3.5", "@types/react-modal": "3.16.3", "@types/react-window": "2.0.0", + "@types/twemoji": "13.1.2", "@vitejs/plugin-react": "5.1.2", "@vitest/coverage-v8": "4.0.17", "axe-playwright": "2.2.2", diff --git a/autogpt_platform/frontend/pnpm-lock.yaml b/autogpt_platform/frontend/pnpm-lock.yaml index 95b49e3a22..057719def1 100644 --- a/autogpt_platform/frontend/pnpm-lock.yaml +++ b/autogpt_platform/frontend/pnpm-lock.yaml @@ -367,6 +367,9 @@ importers: '@types/react-window': specifier: 2.0.0 version: 2.0.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@types/twemoji': + specifier: 13.1.2 + version: 13.1.2 '@vitejs/plugin-react': specifier: 5.1.2 version: 5.1.2(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)) @@ -3705,6 +3708,10 @@ packages: '@types/trusted-types@2.0.7': resolution: {integrity: sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==} + '@types/twemoji@13.1.2': + resolution: {integrity: sha512-vPNsrN08aRI2Gmdo+Ds3zZXzUk6igp1Hg+JPCeHavpiUGfgth/tGiHLQxfSrKzPXeRC0zbLs8WaUZSYxRWPbNg==} + deprecated: This is a stub types definition. twemoji provides its own type definitions, so you do not need this installed. + '@types/unist@2.0.11': resolution: {integrity: sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==} @@ -12681,6 +12688,10 @@ snapshots: '@types/trusted-types@2.0.7': optional: true + '@types/twemoji@13.1.2': + dependencies: + twemoji: 14.0.2 + '@types/unist@2.0.11': {} '@types/unist@3.0.3': {} diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/__tests__/helpers.test.ts b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/__tests__/helpers.test.ts index 25d4f1e064..4cd7afbaec 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/__tests__/helpers.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/__tests__/helpers.test.ts @@ -1,7 +1,9 @@ import { describe, expect, it } from "vitest"; +import type { CostLogRow } from "@/app/api/__generated__/models/costLogRow"; import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary"; import { toDateOrUndefined, + buildCostLogsCsv, formatMicrodollars, formatTokens, formatDuration, @@ -126,6 +128,25 @@ describe("estimateCostForRow", () => { expect(estimateCostForRow(row, {})).toBeNull(); }); + it("uses cache-aware rates when cache tokens present", () => { + // anthropic base rate = $0.008/1K tokens + // uncached input 1000 * 0.008/1K = 0.008 + // cache reads 2000 * 0.008 * 0.1 / 1K = 0.0016 + // cache writes 500 * 0.008 * 1.25 / 1K = 0.005 + // output 1000 * 0.008/1K = 0.008 + // total = 0.0226 USD = 22_600 microdollars + const row = makeRow({ + provider: "anthropic", + tracking_type: "tokens", + total_cost_microdollars: 0, + total_input_tokens: 1000, + total_output_tokens: 1000, + total_cache_read_tokens: 2000, + total_cache_creation_tokens: 500, + }); + expect(estimateCostForRow(row, {})).toBe(22_600); + }); + it("uses per-run override when provided", () => { const row = makeRow({ provider: "google_maps", @@ -133,7 +154,7 @@ describe("estimateCostForRow", () => { request_count: 10, }); // override = 0.05 * 10 * 1_000_000 = 500_000 - expect(estimateCostForRow(row, { "google_maps:per_run": 0.05 })).toBe( + expect(estimateCostForRow(row, { "google_maps:per_run:": 0.05 })).toBe( 500_000, ); }); @@ -298,3 +319,50 @@ describe("toUtcIso", () => { expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/); }); }); + +describe("buildCostLogsCsv", () => { + function makeLog(overrides: Partial): CostLogRow { + return { + id: "abc123", + created_at: "2026-01-15T10:00:00Z" as unknown as Date, + block_name: "LLMBlock", + provider: "anthropic", + ...overrides, + }; + } + + it("emits a header row and one data row", () => { + const csv = buildCostLogsCsv([makeLog({})]); + const lines = csv.split("\r\n"); + expect(lines).toHaveLength(2); + expect(lines[0]).toContain("Time (UTC)"); + expect(lines[0]).toContain("Provider"); + expect(lines[1]).toContain("anthropic"); + }); + + it("escapes double-quotes in field values", () => { + const csv = buildCostLogsCsv([makeLog({ block_name: 'Say "Hello"' })]); + expect(csv).toContain('"Say ""Hello"""'); + }); + + it("converts cost_microdollars to USD with 8 decimal places", () => { + const csv = buildCostLogsCsv([makeLog({ cost_microdollars: 1_234_567 })]); + expect(csv).toContain("1.23456700"); + }); + + it("includes cache token columns", () => { + const csv = buildCostLogsCsv([ + makeLog({ cache_read_tokens: 500, cache_creation_tokens: 100 }), + ]); + const lines = csv.split("\r\n"); + expect(lines[0]).toContain("Cache Read Tokens"); + expect(lines[0]).toContain("Cache Creation Tokens"); + expect(lines[1]).toContain('"500"'); + expect(lines[1]).toContain('"100"'); + }); + + it("returns only header for empty log list", () => { + const csv = buildCostLogsCsv([]); + expect(csv.split("\r\n")).toHaveLength(1); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/LogsTable.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/LogsTable.tsx index a5942a2fbf..46920d15bc 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/LogsTable.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/LogsTable.tsx @@ -14,11 +14,33 @@ interface Props { logs: CostLogRow[]; pagination: Pagination | null; onPageChange: (page: number) => void; + onExport: () => Promise; + exporting: boolean; } -function LogsTable({ logs, pagination, onPageChange }: Props) { +function LogsTable({ + logs, + pagination, + onPageChange, + onExport, + exporting, +}: Props) { return (
+
+ + {pagination + ? `${pagination.total_items.toLocaleString()} total rows` + : ""} + + +
diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/PlatformCostContent.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/PlatformCostContent.tsx index 9e4d24f824..749a2136a3 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/PlatformCostContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/PlatformCostContent.tsx @@ -15,6 +15,9 @@ interface Props { end?: string; provider?: string; user_id?: string; + model?: string; + block_name?: string; + tracking_type?: string; page?: string; tab?: string; }; @@ -37,10 +40,18 @@ export function PlatformCostContent({ searchParams }: Props) { setProviderInput, userInput, setUserInput, + modelInput, + setModelInput, + blockInput, + setBlockInput, + typeInput, + setTypeInput, rateOverrides, handleRateOverride, updateUrl, handleFilter, + exporting, + handleExport, } = usePlatformCostContent(searchParams); return ( @@ -105,6 +116,54 @@ export function PlatformCostContent({ searchParams }: Props) { onChange={(e) => setUserInput(e.target.value)} /> +
+ + setModelInput(e.target.value)} + /> +
+
+ + setBlockInput(e.target.value)} + /> +
+
+ + setTypeInput(e.target.value)} + /> +
+ @@ -55,12 +58,18 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) { // For cost_usd rows the provider reports USD directly so rate // input doesn't apply; otherwise show an editable input. const showRateInput = tt !== "cost_usd"; - const key = rateKey(row.provider, tt); + const key = rateKey(row.provider, tt, row.model); const fallback = defaultRateFor(row.provider, tt); const currentRate = rateOverrides[key] ?? fallback; return ( - + + @@ -115,7 +124,7 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) { {data.length === 0 && (
Provider + Model + Type
{row.provider} + {row.model || "—"} +
No cost data yet diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/usePlatformCostContent.ts b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/usePlatformCostContent.ts index 01db1c5130..7b3f92036d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/usePlatformCostContent.ts +++ b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/components/usePlatformCostContent.ts @@ -3,17 +3,26 @@ import { useRouter, useSearchParams } from "next/navigation"; import { useState } from "react"; import { + getV2ExportPlatformCostLogs, useGetV2GetPlatformCostDashboard, useGetV2GetPlatformCostLogs, } from "@/app/api/__generated__/endpoints/admin/admin"; import { okData } from "@/app/api/helpers"; -import { estimateCostForRow, toLocalInput, toUtcIso } from "../helpers"; +import { + buildCostLogsCsv, + estimateCostForRow, + toLocalInput, + toUtcIso, +} from "../helpers"; interface InitialSearchParams { start?: string; end?: string; provider?: string; user_id?: string; + model?: string; + block_name?: string; + tracking_type?: string; page?: string; tab?: string; } @@ -29,14 +38,23 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) { const providerFilter = urlParams.get("provider") || searchParams.provider || ""; const userFilter = urlParams.get("user_id") || searchParams.user_id || ""; + const modelFilter = urlParams.get("model") || searchParams.model || ""; + const blockFilter = + urlParams.get("block_name") || searchParams.block_name || ""; + const typeFilter = + urlParams.get("tracking_type") || searchParams.tracking_type || ""; const [startInput, setStartInput] = useState(toLocalInput(startDate)); const [endInput, setEndInput] = useState(toLocalInput(endDate)); const [providerInput, setProviderInput] = useState(providerFilter); const [userInput, setUserInput] = useState(userFilter); + const [modelInput, setModelInput] = useState(modelFilter); + const [blockInput, setBlockInput] = useState(blockFilter); + const [typeInput, setTypeInput] = useState(typeFilter); const [rateOverrides, setRateOverrides] = useState>( {}, ); + const [exporting, setExporting] = useState(false); // Pass ISO date strings through `as unknown as Date` so Orval's URL builder // forwards them as-is. Date.toString() produces a format FastAPI rejects; @@ -46,6 +64,9 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) { end: (endDate || undefined) as unknown as Date | undefined, provider: providerFilter || undefined, user_id: userFilter || undefined, + model: modelFilter || undefined, + block_name: blockFilter || undefined, + tracking_type: typeFilter || undefined, }; const { @@ -91,6 +112,9 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) { end: toUtcIso(endInput), provider: providerInput, user_id: userInput, + model: modelInput, + block_name: blockInput, + tracking_type: typeInput, page: "1", }); } @@ -105,6 +129,33 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) { }); } + async function handleExport() { + setExporting(true); + try { + const response = await getV2ExportPlatformCostLogs(filterParams); + const data = okData(response); + if (!data) throw new Error("Export failed: unexpected response"); + const csv = buildCostLogsCsv(data.logs); + const blob = new Blob([csv], { type: "text/csv;charset=utf-8;" }); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = `platform_costs_${new Date().toISOString().slice(0, 10)}.csv`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + if (data.truncated) { + // eslint-disable-next-line no-console + console.warn( + `Export truncated: only the first ${data.total_rows} rows were included.`, + ); + } + } finally { + setExporting(false); + } + } + const totalEstimatedCost = dashboard?.by_provider.reduce((sum, row) => { const est = estimateCostForRow(row, rateOverrides); @@ -128,9 +179,17 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) { setProviderInput, userInput, setUserInput, + modelInput, + setModelInput, + blockInput, + setBlockInput, + typeInput, + setTypeInput, rateOverrides, handleRateOverride, updateUrl, handleFilter, + exporting, + handleExport, }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/helpers.ts index 63d14a82c1..9883a5a952 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/admin/platform-costs/helpers.ts @@ -1,3 +1,4 @@ +import type { CostLogRow } from "@/app/api/__generated__/models/costLogRow"; import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary"; const MICRODOLLARS_PER_USD = 1_000_000; @@ -111,13 +112,14 @@ export function defaultRateFor( } } -// Overrides are keyed on `${provider}:${tracking_type}` since the same -// provider can have multiple rows with different billing models. +// Overrides are keyed on `${provider}:${tracking_type}:${model}` since the +// same provider can have multiple rows with different billing models and models. export function rateKey( provider: string, trackingType: string | null | undefined, + model?: string | null, ): string { - return `${provider}:${trackingType ?? "per_run"}`; + return `${provider}:${trackingType ?? "per_run"}:${model ?? ""}`; } export function estimateCostForRow( @@ -136,17 +138,34 @@ export function estimateCostForRow( } const rate = - rateOverrides[rateKey(row.provider, tt)] ?? + rateOverrides[rateKey(row.provider, tt, row.model)] ?? defaultRateFor(row.provider, tt); if (rate === null || rate === undefined) return null; // Compute the amount for this tracking type, then multiply by rate. let amount: number; switch (tt) { - case "tokens": + case "tokens": { + // Anthropic cache tokens are billed at different rates: + // - cache reads: 10% of base input rate + // - cache writes: 125% of base input rate + // - uncached input: 100% of base input rate + const cacheRead = row.total_cache_read_tokens ?? 0; + const cacheWrite = row.total_cache_creation_tokens ?? 0; + if (cacheRead > 0 || cacheWrite > 0) { + const uncachedInput = row.total_input_tokens; + const output = row.total_output_tokens; + const cost = + (uncachedInput / 1000) * rate + + (cacheRead / 1000) * rate * 0.1 + + (cacheWrite / 1000) * rate * 1.25 + + (output / 1000) * rate; + return Math.round(cost * MICRODOLLARS_PER_USD); + } // Rate is per-1K tokens. amount = (row.total_input_tokens + row.total_output_tokens) / 1000; break; + } case "characters": // Rate is per-1K chars. trackingAmount aggregates char counts. amount = (row.total_tracking_amount || 0) / 1000; @@ -175,6 +194,11 @@ export function trackingValue(row: ProviderCostSummary) { if (tt === "cost_usd") return formatMicrodollars(row.total_cost_microdollars); if (tt === "tokens") { const tokens = row.total_input_tokens + row.total_output_tokens; + const cacheRead = row.total_cache_read_tokens ?? 0; + const cacheWrite = row.total_cache_creation_tokens ?? 0; + if (cacheRead > 0 || cacheWrite > 0) { + return `${formatTokens(tokens)} tokens (+${formatTokens(cacheRead)}r/${formatTokens(cacheWrite)}w cached)`; + } return `${formatTokens(tokens)} tokens`; } if (tt === "sandbox_seconds" || tt === "walltime_seconds") @@ -202,3 +226,54 @@ export function toUtcIso(local: string) { const d = new Date(local); return isNaN(d.getTime()) ? "" : d.toISOString(); } + +const CSV_HEADERS = [ + "Time (UTC)", + "User ID", + "Email", + "Block", + "Provider", + "Type", + "Model", + "Cost (USD)", + "Input Tokens", + "Output Tokens", + "Cache Read Tokens", + "Cache Creation Tokens", + "Duration (s)", + "Graph Exec ID", + "Node Exec ID", +]; + +function csvEscape(val: unknown): string { + const s = val == null ? "" : String(val); + return `"${s.replace(/"/g, '""')}"`; +} + +export function buildCostLogsCsv(logs: CostLogRow[]): string { + const header = CSV_HEADERS.map(csvEscape).join(","); + const rows = logs.map((log) => + [ + log.created_at, + log.user_id, + log.email, + log.block_name, + log.provider, + log.tracking_type, + log.model, + log.cost_microdollars != null + ? (log.cost_microdollars / 1_000_000).toFixed(8) + : null, + log.input_tokens, + log.output_tokens, + log.cache_read_tokens, + log.cache_creation_tokens, + log.duration, + log.graph_exec_id, + log.node_exec_id, + ] + .map(csvEscape) + .join(","), + ); + return [header, ...rows].join("\r\n"); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderChatPanel/BuilderChatPanel.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderChatPanel/BuilderChatPanel.tsx new file mode 100644 index 0000000000..23f600dc58 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderChatPanel/BuilderChatPanel.tsx @@ -0,0 +1,452 @@ +"use client"; + +import { Button } from "@/components/atoms/Button/Button"; +import { cn } from "@/lib/utils"; +import { + ArrowCounterClockwise, + ChatCircle, + PaperPlaneTilt, + SpinnerGap, + StopCircle, + X, +} from "@phosphor-icons/react"; +import { KeyboardEvent, useEffect, useRef } from "react"; +import { ToolUIPart } from "ai"; +import { MessagePartRenderer } from "@/app/(platform)/copilot/components/ChatMessagesContainer/components/MessagePartRenderer"; +import { CopilotChatActionsProvider } from "@/app/(platform)/copilot/components/CopilotChatActionsProvider/CopilotChatActionsProvider"; +import type { CustomNode } from "../FlowEditor/nodes/CustomNode/CustomNode"; +import { + GraphAction, + SEED_PROMPT_PREFIX, + extractTextFromParts, + getActionKey, + getNodeDisplayName, +} from "./helpers"; +import { useBuilderChatPanel } from "./useBuilderChatPanel"; + +interface Props { + className?: string; + isGraphLoaded?: boolean; + onGraphEdited?: () => void; +} + +export function BuilderChatPanel({ + className, + isGraphLoaded, + onGraphEdited, +}: Props) { + const panelRef = useRef(null); + const { + isOpen, + handleToggle, + retrySession, + messages, + stop, + error, + isCreatingSession, + sessionError, + nodes, + parsedActions, + appliedActionKeys, + handleApplyAction, + undoStack, + handleUndoLastAction, + inputValue, + setInputValue, + handleSend, + sendRawMessage, + handleKeyDown, + isStreaming, + canSend, + } = useBuilderChatPanel({ isGraphLoaded, onGraphEdited, panelRef }); + + const messagesEndRef = useRef(null); + const textareaRef = useRef(null); + + useEffect(() => { + messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); + }, [messages.length]); + + // Move focus to the textarea when the panel opens so keyboard users can type immediately. + useEffect(() => { + if (isOpen) { + textareaRef.current?.focus(); + } + }, [isOpen]); + + return ( +
+ {isOpen && ( + +
+ + + + + +
+
+ )} + + +
+ ); +} + +function PanelHeader({ + onClose, + undoCount, + onUndo, +}: { + onClose: () => void; + undoCount: number; + onUndo: () => void; +}) { + return ( +
+
+ + + Chat with Builder + +
+
+ {undoCount > 0 && ( + + )} + +
+
+ ); +} + +interface MessageListProps { + messages: ReturnType["messages"]; + isCreatingSession: boolean; + sessionError: boolean; + streamError: Error | undefined; + nodes: CustomNode[]; + parsedActions: GraphAction[]; + appliedActionKeys: Set; + onApplyAction: (action: GraphAction) => void; + onRetry: () => void; + messagesEndRef: React.RefObject; + isStreaming: boolean; +} + +function MessageList({ + messages, + isCreatingSession, + sessionError, + streamError, + nodes, + parsedActions, + appliedActionKeys, + onApplyAction, + onRetry, + messagesEndRef, + isStreaming, +}: MessageListProps) { + const visibleMessages = messages.filter((msg) => { + const text = extractTextFromParts(msg.parts); + if (msg.role === "user" && text.startsWith(SEED_PROMPT_PREFIX)) + return false; + return ( + Boolean(text) || + (msg.role === "assistant" && + msg.parts?.some((p) => p.type === "dynamic-tool")) + ); + }); + const lastVisibleRole = visibleMessages.at(-1)?.role; + const showTypingIndicator = + isStreaming && (!lastVisibleRole || lastVisibleRole === "user"); + + return ( +
+ {isCreatingSession && ( +
+ + Setting up chat session... +
+ )} + + {sessionError && ( +
+

Failed to start chat session.

+ +
+ )} + + {streamError && ( +
+ Connection error. Please try sending your message again. +
+ )} + + {visibleMessages.length === 0 && !isCreatingSession && !sessionError && ( +
+ +

Ask me to explain or modify your agent.

+

+ You can say things like “What does this agent do?” or + “Add a step that formats the output.” +

+
+ )} + + {visibleMessages.map((msg) => { + const textParts = extractTextFromParts(msg.parts); + + return ( +
+ {msg.role === "assistant" + ? (msg.parts ?? []).map((part, i) => { + // Normalize dynamic-tool parts → tool-{name} so MessagePartRenderer + // can route them: edit_agent/run_agent get their specific renderers, + // everything else falls through to GenericTool (collapsed accordion). + const renderedPart = + part.type === "dynamic-tool" + ? ({ + ...part, + type: `tool-${(part as { toolName: string }).toolName}`, + } as ToolUIPart) + : (part as ToolUIPart); + return ( + + ); + }) + : textParts} +
+ ); + })} + + {showTypingIndicator && } + + {parsedActions.length > 0 && ( + + )} + +
+
+ ); +} + +function ActionList({ + parsedActions, + nodes, + appliedActionKeys, + onApplyAction, +}: { + parsedActions: GraphAction[]; + nodes: CustomNode[]; + appliedActionKeys: Set; + onApplyAction: (action: GraphAction) => void; +}) { + const nodeMap = new Map(nodes.map((n) => [n.id, n])); + return ( +
+

Suggested changes

+ {parsedActions.map((action) => { + const key = getActionKey(action); + return ( + + ); + })} +
+ ); +} + +function ActionItem({ + action, + nodeMap, + isApplied, + onApply, +}: { + action: GraphAction; + nodeMap: Map; + isApplied: boolean; + onApply: (action: GraphAction) => void; +}) { + const label = + action.type === "update_node_input" + ? `Set "${getNodeDisplayName(nodeMap.get(action.nodeId), action.nodeId)}" "${action.key}" = ${JSON.stringify(action.value)}` + : `Connect "${getNodeDisplayName(nodeMap.get(action.source), action.source)}" → "${getNodeDisplayName(nodeMap.get(action.target), action.target)}"`; + + return ( +
+ {label} + {isApplied ? ( + + Applied + + ) : ( + + )} +
+ ); +} + +interface PanelInputProps { + value: string; + onChange: (v: string) => void; + onKeyDown: (e: KeyboardEvent) => void; + onSend: () => void; + onStop: () => void; + isStreaming: boolean; + isDisabled: boolean; + textareaRef?: React.RefObject; +} + +function PanelInput({ + value, + onChange, + onKeyDown, + onSend, + onStop, + isStreaming, + isDisabled, + textareaRef, +}: PanelInputProps) { + return ( +
+
+