From 343222ace1568fdb25ef2bc6a3106baea1e3d7a5 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 14:01:09 +0700 Subject: [PATCH 1/4] feat(platform): defer paid-to-paid subscription downgrades + cancel-pending flow (#12865) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why / What / How **Why:** Only downgrades to FREE were scheduled at period end; paid→paid downgrades (e.g. BUSINESS→PRO) applied immediately via Stripe proration. The asymmetry meant users lost their higher tier mid-cycle in exchange for a Stripe credit voucher only redeemable on a future subscription — a confusing pattern that produces negative-value paths for users actually cancelling. There was also no way to cancel a pending downgrade or paid→FREE cancellation once scheduled. **What:** Standardize on "upgrade = immediate, downgrade = next cycle" and let users cancel a pending change by clicking their current tier. Harden the new code against conflicting subscription state, concurrent tab races, flaky Stripe calls, and hot-path latency regressions. **How:** Subscription state machine: - **Upgrade** (PRO→BUSINESS) — `stripe.Subscription.modify` with immediate proration (unchanged). If a downgrade schedule is already attached, release it first so the upgrade wins. - **Paid→paid downgrade** (BUSINESS→PRO) — creates a `stripe.SubscriptionSchedule` with two phases (current tier until `current_period_end`, target tier after). No mid-cycle tier demotion. Defensive pre-clear: existing schedule → release; `cancel_at_period_end=True` → set to False. - **Paid→FREE** — unchanged: `cancel_at_period_end=True`. - **Same-tier update** — reuses the existing `POST /credits/subscription` route. When `target_tier == current_tier`, backend calls `release_pending_subscription_schedule` (idempotent) and returns status. No dedicated cancel-pending endpoint — "Keep my current tier" IS the cancel operation. - `release_pending_subscription_schedule` is idempotent on terminal-state schedules and clears both `schedule` and `cancel_at_period_end` atomically per call. API surface: - New fields on `SubscriptionStatusResponse`: `pending_tier` + `pending_tier_effective_at` (pulled from the schedule's next-phase `start_date` so dashboard-authored schedules report the correct timestamp). - `POST /credits/subscription` now returns `SubscriptionStatusResponse` (previously `SubscriptionCheckoutResponse`); the response still carries `url` for checkout flows and adds the status fields inline. - `get_pending_subscription_change` is cached with a 30s TTL — avoids hammering Stripe on every home-page load. - Webhook dispatches `subscription_schedule.{released,completed,updated}` through the main `sync_subscription_from_stripe` flow so both event sources converge to the same DB state. Implementation notes: - New Stripe calls use native async (`stripe.Subscription.list_async` etc.) and typed attribute access — no `run_in_threadpool` wrapping in the new helpers. - Shared `_get_active_subscription` helper collapses the "list active/trialing subs, take first" pattern used by 4 callers. Frontend: - `PendingChangeBanner` sub-component above the tier grid with formatted effective date + "Keep [CurrentTier]" button. `aria-live="polite"` for screen readers; locale pinned to `en-US` to avoid SSR/CSR hydration mismatch. - "Keep [CurrentTier]" also available as a button on the current tier card. - Other tier buttons disabled while a change is pending — user must resolve pending first to prevent stacked schedules. - `cancelPendingChange` reuses `useUpdateSubscriptionTier` with `tier: current_tier`; awaits `refetch()` on both success and error paths so the UI reconciles even if the server succeeded but the client didn't receive the response. ### Changes **Backend (`credit.py`, `v1.py`)** - Tier-ordering helpers (`is_tier_upgrade`/`is_tier_downgrade`). - `modify_stripe_subscription_for_tier` routes downgrades through `_schedule_downgrade_at_period_end`; upgrade path releases any pending schedule first. - `_schedule_downgrade_at_period_end` defensively releases pre-existing schedules and clears `cancel_at_period_end` before creating the new schedule. - `release_pending_subscription_schedule` idempotent on terminal-state schedules; logs partial-failure outcomes. - `_next_phase_tier_and_start` returns both tier and phase-start timestamp; warns on unknown prices. - `get_pending_subscription_change` cached (30s TTL), narrow exception handling. - `sync_subscription_schedule_from_stripe` delegates to `sync_subscription_from_stripe` for convergence with the main webhook path. - Shared `_get_active_subscription` + `_release_schedule_ignoring_terminal` helpers. - `POST /credits/subscription` absorbs the same-tier "cancel pending change" branch. **Frontend (`SubscriptionTierSection/*`)** - `PendingChangeBanner` new sub-component (a11y, locale-pinned date, paid→FREE vs paid→paid copy split, non-null effective-date assertion, no `dark:` utilities). - "Keep [CurrentTier]" button on current tier card. - `useSubscriptionTierSection` — `cancelPendingChange` reuses the update-tier mutation. - Copy: downgrade dialog + status hint updated. - `helpers.ts` extracted from the main component. **Tests** - Backend: +24 tests (95/95 passing): upgrade-releases-pending-schedule, schedule-releases-existing-schedule, cancel-at-period-end collision, terminal-state release idempotency, unknown-price logging, status response population, same-tier-POST-with-pending, webhook delegation. - Frontend: +5 integration tests (21/21 passing): banner render/hide, Keep-button click from banner + current card, paid→paid dialog copy. ### Checklist - [x] Backend unit tests: 95 pass - [x] Frontend integration tests: 21 pass - [x] `poetry run format` / `poetry run lint` clean - [x] `pnpm format` / `pnpm lint` / `pnpm types` clean - [ ] Manual E2E on live Stripe (dev env) — pending deploy: BUSINESS→PRO creates schedule, DB tier unchanged until period end - [ ] Manual E2E: "Keep BUSINESS" in banner releases schedule - [ ] Manual E2E: cancel pending paid→FREE flips `cancel_at_period_end` back to false - [ ] Manual E2E: BUSINESS→PRO (scheduled) then attempt BUSINESS→FREE clears the PRO schedule, sets cancel_at_period_end - [ ] Manual E2E: BUSINESS→PRO (scheduled) then upgrade back to BUSINESS releases the schedule --- .../api/features/subscription_routes_test.py | 339 ++++- .../backend/backend/api/features/v1.py | 107 +- .../backend/backend/copilot/rate_limit.py | 124 +- .../backend/copilot/rate_limit_test.py | 74 + .../backend/backend/data/credit.py | 558 +++++++- .../backend/data/credit_subscription_test.py | 1274 ++++++++++++++++- .../SubscriptionTierSection.tsx | 154 +- .../SubscriptionTierSection.test.tsx | 235 ++- .../PendingChangeBanner.tsx | 60 + .../SubscriptionTierSection/helpers.ts | 54 + .../useSubscriptionTierSection.ts | 42 + .../frontend/src/app/api/openapi.json | 30 +- 12 files changed, 2907 insertions(+), 144 deletions(-) create mode 100644 autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/components/PendingChangeBanner/PendingChangeBanner.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/helpers.ts diff --git a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py index c20e0d0ceb..96fd8763eb 100644 --- a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py @@ -47,6 +47,40 @@ def _configure_frontend_origin(mocker: pytest_mock.MockFixture) -> None: ) +@pytest.fixture(autouse=True) +def _stub_pending_subscription_change(mocker: pytest_mock.MockFixture) -> None: + """Default pending-change lookup to None so tests don't hit Stripe/DB. + + Individual tests can override via their own mocker.patch call. + """ + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=None, + ) + + +@pytest.fixture(autouse=True) +def _stub_subscription_status_lookups(mocker: pytest_mock.MockFixture) -> None: + """Stub Stripe price + proration lookups used by get_subscription_status. + + The POST /credits/subscription handler now returns the full subscription + status payload from every branch (same-tier, FREE downgrade, paid→paid + modify, checkout creation), so every POST test implicitly hits these + helpers. Individual tests can override via their own mocker.patch call. + """ + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + + @pytest.mark.parametrize( "url,expected", [ @@ -407,30 +441,77 @@ def test_update_subscription_tier_enterprise_blocked( set_tier_mock.assert_not_awaited() -def test_update_subscription_tier_same_tier_is_noop( +def test_update_subscription_tier_same_tier_releases_pending_change( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """POST /credits/subscription for the user's current paid tier returns 200 with empty URL. + """POST /credits/subscription for the user's current tier releases any pending change. - Without this guard a duplicate POST (double-click, browser retry, stale page) would - create a second Stripe Checkout Session for the same price, potentially billing the - user twice until the webhook reconciliation fires. + "Stay on my current tier" — the collapsed replacement for the old + /credits/subscription/cancel-pending route. Always calls + release_pending_subscription_schedule (idempotent when nothing is pending) + and returns the refreshed status with url="". Never creates a Checkout + Session — that would double-charge a user who double-clicks their own tier. """ mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.PRO - - async def mock_feature_enabled(*args, **kwargs): - return True + mock_user.subscription_tier = SubscriptionTier.BUSINESS mocker.patch( "backend.api.features.v1.get_user_by_id", new_callable=AsyncMock, return_value=mock_user, ) - mocker.patch( + release_mock = mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + feature_mock = mocker.patch( "backend.api.features.v1.is_feature_enabled", - side_effect=mock_feature_enabled, + new_callable=AsyncMock, + return_value=True, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "BUSINESS" + assert data["url"] == "" + release_mock.assert_awaited_once_with(TEST_USER_ID) + checkout_mock.assert_not_awaited() + # Same-tier branch short-circuits before the payment-flag check. + feature_mock.assert_not_awaited() + + +def test_update_subscription_tier_same_tier_no_pending_change_returns_status( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Same-tier request when nothing is pending still returns status with url="".""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + release_mock = mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + new_callable=AsyncMock, + return_value=False, ) checkout_mock = mocker.patch( "backend.api.features.v1.create_subscription_checkout", @@ -447,10 +528,50 @@ def test_update_subscription_tier_same_tier_is_noop( ) assert response.status_code == 200 - assert response.json()["url"] == "" + data = response.json() + assert data["tier"] == "PRO" + assert data["url"] == "" + assert data["pending_tier"] is None + release_mock.assert_awaited_once_with(TEST_USER_ID) checkout_mock.assert_not_awaited() +def test_update_subscription_tier_same_tier_stripe_error_returns_502( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Same-tier request surfaces a 502 when Stripe release fails. + + Carries forward the error contract from the removed + /credits/subscription/cancel-pending route so clients keep seeing 502 for + transient Stripe failures. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + side_effect=stripe.StripeError("network"), + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 502 + assert "contact support" in response.json()["detail"].lower() + + def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_not_update_db( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, @@ -803,3 +924,197 @@ def test_update_subscription_tier_free_no_stripe_subscription( cancel_mock.assert_awaited_once_with(TEST_USER_ID) # DB tier must be updated immediately — no webhook will fire for a missing sub set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.FREE) + + +def test_get_subscription_status_includes_pending_tier( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription exposes pending_tier and pending_tier_effective_at.""" + import datetime as dt + + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + effective_at = dt.datetime(2030, 1, 1, tzinfo=dt.timezone.utc) + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=(SubscriptionTier.PRO, effective_at), + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["pending_tier"] == "PRO" + assert data["pending_tier_effective_at"] is not None + + +def test_get_subscription_status_no_pending_tier( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """When no pending change exists the response omits pending_tier.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["pending_tier"] is None + assert data["pending_tier_effective_at"] is None + + +def test_update_subscription_tier_downgrade_paid_to_paid_schedules( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """A BUSINESS→PRO downgrade request dispatches to modify_stripe_subscription_for_tier.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO) + checkout_mock.assert_not_awaited() + + +def test_stripe_webhook_dispatches_subscription_schedule_released( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """subscription_schedule.released routes to sync_subscription_schedule_from_stripe.""" + schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"} + event = { + "type": "subscription_schedule.released", + "data": {"object": schedule_obj}, + } + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_schedule_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_awaited_once_with(schedule_obj) + + +def test_stripe_webhook_ignores_subscription_schedule_updated( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """subscription_schedule.updated must NOT dispatch: our own + SubscriptionSchedule.create/.modify calls fire this event and would + otherwise loop redundant traffic through the sync handler. State + transitions we care about surface via .released/.completed, and phase + advance to a new price is already covered by customer.subscription.updated. + """ + schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"} + event = { + "type": "subscription_schedule.updated", + "data": {"object": schedule_obj}, + } + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_schedule_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_not_awaited() diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index ab0b69071d..3559071043 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -26,7 +26,7 @@ from fastapi import ( ) from fastapi.concurrency import run_in_threadpool from prisma.enums import SubscriptionTier -from pydantic import BaseModel +from pydantic import BaseModel, Field from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND from typing_extensions import Optional, TypedDict @@ -49,20 +49,24 @@ from backend.data.auth import api_key as api_key_db from backend.data.block import BlockInput, CompletedBlockOutput from backend.data.credit import ( AutoTopUpConfig, + PendingChangeUnknown, RefundRequest, TransactionHistory, UserCredit, cancel_stripe_subscription, create_subscription_checkout, get_auto_top_up, + get_pending_subscription_change, get_proration_credit_cents, get_subscription_price_id, get_user_credit_model, handle_subscription_payment_failure, modify_stripe_subscription_for_tier, + release_pending_subscription_schedule, set_auto_top_up, set_subscription_tier, sync_subscription_from_stripe, + sync_subscription_schedule_from_stripe, ) from backend.data.graph import GraphSettings from backend.data.model import CredentialsMetaInput, UserOnboarding @@ -698,15 +702,21 @@ class SubscriptionTierRequest(BaseModel): cancel_url: str = "" -class SubscriptionCheckoutResponse(BaseModel): - url: str - - class SubscriptionStatusResponse(BaseModel): tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"] monthly_cost: int # amount in cents (Stripe convention) tier_costs: dict[str, int] # tier name -> amount in cents proration_credit_cents: int # unused portion of current sub to convert on upgrade + pending_tier: Optional[Literal["FREE", "PRO", "BUSINESS"]] = None + pending_tier_effective_at: Optional[datetime] = None + url: str = Field( + default="", + description=( + "Populated only when POST /credits/subscription starts a Stripe Checkout" + " Session (FREE → paid upgrade). Empty string in all other branches —" + " the client redirects to this URL when non-empty." + ), + ) def _validate_checkout_redirect_url(url: str) -> bool: @@ -804,17 +814,42 @@ async def get_subscription_status( current_monthly_cost = tier_costs.get(tier.value, 0) proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost) - return SubscriptionStatusResponse( + try: + pending = await get_pending_subscription_change(user_id) + except (stripe.StripeError, PendingChangeUnknown): + # Swallow Stripe-side failures (rate limits, transient network) AND + # PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both + # propagate past the cache so the next request retries fresh instead + # of serving a stale None for the TTL window. Let real bugs (KeyError, + # AttributeError, etc.) propagate so they surface in Sentry. + logger.exception( + "get_subscription_status: failed to resolve pending change for user %s", + user_id, + ) + pending = None + + response = SubscriptionStatusResponse( tier=tier.value, monthly_cost=current_monthly_cost, tier_costs=tier_costs, proration_credit_cents=proration_credit, ) + if pending is not None: + pending_tier_enum, pending_effective_at = pending + if pending_tier_enum == SubscriptionTier.FREE: + response.pending_tier = "FREE" + elif pending_tier_enum == SubscriptionTier.PRO: + response.pending_tier = "PRO" + elif pending_tier_enum == SubscriptionTier.BUSINESS: + response.pending_tier = "BUSINESS" + if response.pending_tier is not None: + response.pending_tier_effective_at = pending_effective_at + return response @v1_router.post( path="/credits/subscription", - summary="Start a Stripe Checkout session to upgrade subscription tier", + summary="Update subscription tier or start a Stripe Checkout session", operation_id="updateSubscriptionTier", tags=["credits"], dependencies=[Security(requires_user)], @@ -822,7 +857,7 @@ async def get_subscription_status( async def update_subscription_tier( request: SubscriptionTierRequest, user_id: Annotated[str, Security(get_user_id)], -) -> SubscriptionCheckoutResponse: +) -> SubscriptionStatusResponse: # Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type. tier = SubscriptionTier(request.tier) @@ -834,6 +869,29 @@ async def update_subscription_tier( detail="ENTERPRISE subscription changes must be managed by an administrator", ) + # Same-tier request = "stay on my current tier" = cancel any pending + # scheduled change (paid→paid downgrade or paid→FREE cancel). This is the + # collapsed behaviour that replaces the old /credits/subscription/cancel-pending + # route. Safe when no pending change exists: release_pending_subscription_schedule + # returns False and we simply return the current status. + if (user.subscription_tier or SubscriptionTier.FREE) == tier: + try: + await release_pending_subscription_schedule(user_id) + except stripe.StripeError as e: + logger.exception( + "Stripe error releasing pending subscription change for user %s: %s", + user_id, + e, + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to cancel the pending subscription change right now. " + "Please try again or contact support." + ), + ) + return await get_subscription_status(user_id) + payment_enabled = await is_feature_enabled( Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False ) @@ -871,9 +929,9 @@ async def update_subscription_tier( # admin-granted tier. Update DB immediately since the # subscription.deleted webhook will never fire. await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) # Paid tier changes require payment to be enabled — block self-service upgrades # when the flag is off. Admins use the /api/admin/ routes to set tiers directly. @@ -883,15 +941,6 @@ async def update_subscription_tier( detail=f"Subscription not available for tier {tier}", ) - # No-op short-circuit: if the user is already on the requested paid tier, - # do NOT create a new Checkout Session. Without this guard, a duplicate - # request (double-click, retried POST, stale page) creates a second - # subscription for the same price; the user would be charged for both - # until `_cleanup_stale_subscriptions` runs from the resulting webhook — - # which only fires after the second charge has cleared. - if (user.subscription_tier or SubscriptionTier.FREE) == tier: - return SubscriptionCheckoutResponse(url="") - # Paid→paid tier change: if the user already has a Stripe subscription, # modify it in-place with proration instead of creating a new Checkout # Session. This preserves remaining paid time and avoids double-charging. @@ -901,14 +950,14 @@ async def update_subscription_tier( try: modified = await modify_stripe_subscription_for_tier(user_id, tier) if modified: - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) # modify_stripe_subscription_for_tier returns False when no active # Stripe subscription exists — i.e. the user has an admin-granted # paid tier with no Stripe record. In that case, update the DB # tier directly (same as the FREE-downgrade path for admin-granted # users) rather than sending them through a new Checkout Session. await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) except stripe.StripeError as e: @@ -978,7 +1027,9 @@ async def update_subscription_tier( ), ) - return SubscriptionCheckoutResponse(url=url) + status = await get_subscription_status(user_id) + status.url = url + return status @v1_router.post( @@ -1043,6 +1094,18 @@ async def stripe_webhook(request: Request): ): await sync_subscription_from_stripe(data_object) + # `subscription_schedule.updated` is deliberately omitted: our own + # `SubscriptionSchedule.create` + `.modify` calls in + # `_schedule_downgrade_at_period_end` would fire that event right back at us + # and loop redundant traffic through this handler. We only care about state + # transitions (released / completed); phase advance to the new price is + # already covered by `customer.subscription.updated`. + if event_type in ( + "subscription_schedule.released", + "subscription_schedule.completed", + ): + await sync_subscription_schedule_from_stripe(data_object) + if event_type == "invoice.payment_failed": await handle_subscription_payment_failure(data_object) diff --git a/autogpt_platform/backend/backend/copilot/rate_limit.py b/autogpt_platform/backend/backend/copilot/rate_limit.py index 3124c28992..c08cb1b3a8 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit.py @@ -17,6 +17,7 @@ from redis.exceptions import RedisError from backend.data.db_accessors import user_db from backend.data.redis_client import get_redis_async +from backend.data.user import get_user_by_id from backend.util.cache import cached logger = logging.getLogger(__name__) @@ -459,8 +460,20 @@ get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr- async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None: """Persist the user's rate-limit tier to the database. - Also invalidates the ``get_user_tier`` cache for this user so that - subsequent rate-limit checks immediately see the new tier. + Invalidates every cache that keys off the user's subscription tier so the + change is visible immediately: this function's own ``get_user_tier``, the + shared ``get_user_by_id`` (which exposes ``user.subscription_tier``), and + ``get_pending_subscription_change`` (since an admin override can invalidate + a cached ``cancel_at_period_end`` or schedule-based pending state). + + If the user has an active Stripe subscription whose current price does not + match ``tier``, Stripe will keep billing the old price and the next + ``customer.subscription.updated`` webhook will overwrite the DB tier back + to whatever Stripe has. Proper reconciliation (cancelling or modifying the + Stripe subscription when an admin overrides the tier) is out of scope for + this PR — it changes the admin contract and needs its own test coverage. + For now we emit a ``WARNING`` so drift surfaces via Sentry until that + follow-up lands. Raises: prisma.errors.RecordNotFoundError: If the user does not exist. @@ -469,8 +482,113 @@ async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None: where={"id": user_id}, data={"subscriptionTier": tier.value}, ) - # Invalidate cached tier so rate-limit checks pick up the change immediately. get_user_tier.cache_delete(user_id) # type: ignore[attr-defined] + # Local import required: backend.data.credit imports backend.copilot.rate_limit + # (via get_user_tier in credit.py's _invalidate_user_tier_caches), so a + # top-level ``from backend.data.credit import ...`` here would create a + # circular import at module-load time. + from backend.data.credit import get_pending_subscription_change + + get_user_by_id.cache_delete(user_id) # type: ignore[attr-defined] + get_pending_subscription_change.cache_delete(user_id) # type: ignore[attr-defined] + + # The DB write above is already committed; the drift check is best-effort + # diagnostic logging. Fire-and-forget so admin bulk ops don't wait on a + # Stripe roundtrip. The inner helper wraps its body in a timeout + broad + # except so background task errors still surface via logs rather than as + # "task exception never retrieved" warnings. Cancellation on request + # shutdown is acceptable — the drift warning is non-load-bearing. + asyncio.ensure_future(_drift_check_background(user_id, tier)) + + +async def _drift_check_background(user_id: str, tier: SubscriptionTier) -> None: + """Run the Stripe drift check in the background, logging rather than raising.""" + try: + await asyncio.wait_for( + _warn_if_stripe_subscription_drifts(user_id, tier), + timeout=5.0, + ) + logger.debug( + "set_user_tier: drift check completed for user=%s admin_tier=%s", + user_id, + tier.value, + ) + except asyncio.TimeoutError: + logger.warning( + "set_user_tier: drift check timed out for user=%s admin_tier=%s", + user_id, + tier.value, + ) + except asyncio.CancelledError: + # Request may have completed and the event loop is cancelling tasks — + # the drift log is non-critical, so accept cancellation silently. + raise + except Exception: + logger.exception( + "set_user_tier: drift check background task failed for" + " user=%s admin_tier=%s", + user_id, + tier.value, + ) + + +async def _warn_if_stripe_subscription_drifts( + user_id: str, new_tier: SubscriptionTier +) -> None: + """Emit a WARNING when an admin tier override leaves an active Stripe sub on a + mismatched price. + + The warning is diagnostic only: Stripe remains the billing source of truth, + so the next ``customer.subscription.updated`` webhook will reset the DB + tier. Surfacing the drift here lets ops catch admin overrides that bypass + the intended Checkout / Portal cancel flows before users notice surprise + charges. + """ + # Local imports: see note in ``set_user_tier`` about the credit <-> rate_limit + # circular. These helpers (``_get_active_subscription``, + # ``get_subscription_price_id``) live in credit.py alongside the rest of + # the Stripe billing code. + from backend.data.credit import _get_active_subscription, get_subscription_price_id + + try: + user = await get_user_by_id(user_id) + if not getattr(user, "stripe_customer_id", None): + return + sub = await _get_active_subscription(user.stripe_customer_id) + if sub is None: + return + items = sub["items"].data + if not items: + return + price = items[0].price + current_price_id = price if isinstance(price, str) else price.id + # The LaunchDarkly-backed price lookup must live inside this try/except: + # an LD SDK failure (network, token revoked) here would otherwise + # propagate past set_user_tier's already-committed DB write and turn a + # best-effort diagnostic into a 500 on admin tier writes. + expected_price_id = await get_subscription_price_id(new_tier) + except Exception: + logger.debug( + "_warn_if_stripe_subscription_drifts: drift lookup failed for" + " user=%s; skipping drift warning", + user_id, + exc_info=True, + ) + return + if expected_price_id is not None and expected_price_id == current_price_id: + return + logger.warning( + "Admin tier override will drift from Stripe: user=%s admin_tier=%s" + " stripe_sub=%s stripe_price=%s expected_price=%s — the next" + " customer.subscription.updated webhook will reconcile the DB tier" + " back to whatever Stripe has; cancel or modify the Stripe subscription" + " if you intended the admin override to stick.", + user_id, + new_tier.value, + sub.id, + current_price_id, + expected_price_id, + ) async def get_global_rate_limits( diff --git a/autogpt_platform/backend/backend/copilot/rate_limit_test.py b/autogpt_platform/backend/backend/copilot/rate_limit_test.py index ea87658710..577093c752 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit_test.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit_test.py @@ -581,6 +581,80 @@ class TestSetUserTier: assert tier_after == SubscriptionTier.ENTERPRISE + @pytest.mark.asyncio + async def test_drift_check_swallows_launchdarkly_failure(self): + """LaunchDarkly price-id lookup failures inside the drift check must + never bubble up and 500 the admin tier write — the DB update is + already committed by the time we check drift.""" + mock_prisma = AsyncMock() + mock_prisma.update = AsyncMock(return_value=None) + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + mock_sub = MagicMock() + mock_sub.id = "sub_abc" + mock_sub["items"].data = [MagicMock(price=MagicMock(id="price_mismatch"))] + + with ( + patch( + "backend.copilot.rate_limit.PrismaUser.prisma", + return_value=mock_prisma, + ), + patch( + "backend.copilot.rate_limit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit._get_active_subscription", + new_callable=AsyncMock, + return_value=mock_sub, + ), + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + side_effect=RuntimeError("LD SDK not initialized"), + ), + ): + # Must NOT raise — drift check is best-effort diagnostic only. + await set_user_tier(_USER, SubscriptionTier.PRO) + + mock_prisma.update.assert_awaited_once() + + @pytest.mark.asyncio + async def test_drift_check_timeout_is_bounded(self): + """A Stripe call that stalls on the 80s SDK default must not block the + admin tier write — set_user_tier wraps the drift check in a 5s timeout + and logs + returns on TimeoutError.""" + import asyncio as _asyncio + + mock_prisma = AsyncMock() + mock_prisma.update = AsyncMock(return_value=None) + + async def _never_returns(_user_id: str, _tier): + await _asyncio.sleep(60) + + with ( + patch( + "backend.copilot.rate_limit.PrismaUser.prisma", + return_value=mock_prisma, + ), + patch( + "backend.copilot.rate_limit._warn_if_stripe_subscription_drifts", + side_effect=_never_returns, + ), + patch( + "backend.copilot.rate_limit.asyncio.wait_for", + new_callable=AsyncMock, + side_effect=_asyncio.TimeoutError, + ), + ): + await set_user_tier(_USER, SubscriptionTier.PRO) + + # Set_user_tier still completed — the drift timeout did not propagate. + mock_prisma.update.assert_awaited_once() + # --------------------------------------------------------------------------- # get_global_rate_limits with tiers diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index e97578d5cc..a42ba91be8 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -15,7 +15,7 @@ from prisma.enums import ( OnboardingStep, SubscriptionTier, ) -from prisma.errors import UniqueViolationError +from prisma.errors import PrismaError, UniqueViolationError from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput from pydantic import BaseModel @@ -1280,6 +1280,12 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None: from backend.copilot.rate_limit import get_user_tier # local import avoids circular get_user_tier.cache_delete(user_id) # type: ignore[attr-defined] + # Invalidate the pending-change cache too — an admin tier override or the + # webhook-driven phase transition means any cached pending-change state + # (schedule, cancel_at_period_end) is likely stale. Without this the + # billing page can show a pending change for up to 30s after the tier + # has already flipped. + get_pending_subscription_change.cache_delete(user_id) async def _cancel_customer_subscriptions( @@ -1330,6 +1336,21 @@ async def _cancel_customer_subscriptions( continue seen_ids.add(sub_id) if at_period_end: + # Stripe rejects modify(cancel_at_period_end=True) with 400 when a + # Subscription Schedule is attached (e.g. the user previously + # queued a paid→paid downgrade and is now clicking "Cancel"). + # Release the schedule first so the cancel flag can be set; the + # schedule's pending phase change is superseded by the cancel. + existing_schedule = sub.schedule + if existing_schedule: + schedule_id = ( + existing_schedule + if isinstance(existing_schedule, str) + else existing_schedule.id + ) + await _release_schedule_ignoring_terminal( + schedule_id, "_cancel_customer_subscriptions" + ) await run_in_threadpool( stripe.Subscription.modify, sub_id, cancel_at_period_end=True ) @@ -1366,6 +1387,8 @@ async def cancel_stripe_subscription(user_id: str) -> bool: cancelled_count = await _cancel_customer_subscriptions( customer_id, at_period_end=True ) + if cancelled_count > 0: + get_pending_subscription_change.cache_delete(user_id) return cancelled_count > 0 except stripe.StripeError: logger.warning( @@ -1415,18 +1438,224 @@ async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> i return 0 +# Ordered from least- to most-privileged. Used to distinguish upgrades +# (move right) from downgrades (move left); ENTERPRISE is admin-managed and +# never reached via self-service flows. +_TIER_ORDER: tuple[SubscriptionTier, ...] = ( + SubscriptionTier.FREE, + SubscriptionTier.PRO, + SubscriptionTier.BUSINESS, + SubscriptionTier.ENTERPRISE, +) + + +def _tier_rank(tier: SubscriptionTier) -> int: + return _TIER_ORDER.index(tier) + + +def is_tier_upgrade(current: SubscriptionTier, target: SubscriptionTier) -> bool: + return _tier_rank(target) > _tier_rank(current) + + +def is_tier_downgrade(current: SubscriptionTier, target: SubscriptionTier) -> bool: + return _tier_rank(target) < _tier_rank(current) + + +class PendingChangeUnknown(Exception): + """Raised when pending-change state cannot be determined (e.g. LaunchDarkly + price-id lookup failed). Propagates past the @cached wrapper so the next + request retries instead of serving a stale `None` for the TTL window.""" + + +async def _get_active_subscription(customer_id: str) -> stripe.Subscription | None: + """Return the customer's active or trialing subscription, or None.""" + for status in ("active", "trialing"): + subs = await stripe.Subscription.list_async( + customer=customer_id, status=status, limit=1 + ) + if subs.data: + return subs.data[0] + return None + + +# Substrings Stripe uses in InvalidRequestError messages when the schedule is +# already in a terminal state (released / completed / canceled) and therefore +# cannot be released again. We only swallow the error when one of these appears; +# anything else (typo'd schedule id, wrong subscription, 404, etc.) must +# propagate so bugs aren't masked as silent no-ops. +_TERMINAL_SCHEDULE_ERROR_SUBSTRINGS = ( + "already been released", + "already released", + "already been completed", + "already completed", + "already been canceled", + "already been cancelled", + "already canceled", + "already cancelled", + "is not active", + "is not in a state", +) + + +async def _release_schedule_ignoring_terminal( + schedule_id: str, log_context: str +) -> bool: + """Release a Stripe schedule; swallow InvalidRequestError on terminal state. + + Returns True if the release call succeeded, False if the schedule was + already in a terminal (released / completed / canceled) state. Any other + Stripe error — including non-terminal InvalidRequestErrors such as typo'd + ids or 404s — propagates so the caller can surface the failure instead of + silently masking a bug. + """ + try: + await stripe.SubscriptionSchedule.release_async(schedule_id) + return True + except stripe.InvalidRequestError as e: + message = getattr(e, "user_message", None) or str(e) + if not any( + marker in message.lower() for marker in _TERMINAL_SCHEDULE_ERROR_SUBSTRINGS + ): + logger.warning( + "%s: schedule %s release failed with non-terminal" + " InvalidRequestError (%s); re-raising", + log_context, + schedule_id, + message, + ) + raise + logger.warning( + "%s: schedule %s not releasable (%s); treating as already released", + log_context, + schedule_id, + message, + ) + return False + + +async def _schedule_downgrade_at_period_end( + sub: stripe.Subscription, + new_price_id: str, + user_id: str, + tier: SubscriptionTier, +) -> None: + """Create a Subscription Schedule that defers a tier change to period end. + + Stripe's Subscription Schedule drives an existing subscription through a + series of phases. By keeping the current price for the remainder of the + billing period and switching to ``new_price_id`` afterwards, the user does + NOT receive an immediate proration charge and keeps their current tier + until period end. + + Stripe allows at most one active schedule per subscription and rejects + ``SubscriptionSchedule.create`` if either (a) a schedule is already + attached to the subscription or (b) ``cancel_at_period_end=True`` is set. + Both conditions mean the user is overwriting a pending change they made + earlier (e.g. BUSINESS→FREE cancel, now switching to BUSINESS→PRO + downgrade). We clear the conflicting state first so the new schedule can + be created. These defensive reads serialize through Stripe's own atomic + operations — by the time modify/release returns, the subscription is in a + known-clean state for the subsequent create. + """ + sub_id = sub.id + # ``sub["items"]`` (dict-item) rather than ``sub.items`` because the latter + # is shadowed by Python's dict.items() method on StripeObject. + items = sub["items"].data + if not items: + raise ValueError(f"Subscription {sub_id} has no items; cannot schedule") + price = items[0].price + current_price_id = price if isinstance(price, str) else price.id + period_start: int = sub["current_period_start"] + period_end: int = sub["current_period_end"] + + if sub.cancel_at_period_end: + await stripe.Subscription.modify_async(sub_id, cancel_at_period_end=False) + logger.info( + "_schedule_downgrade_at_period_end: cleared cancel_at_period_end" + " on sub %s for user %s before scheduling downgrade", + sub_id, + user_id, + ) + if sub.schedule: + existing_schedule_id = ( + sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id + ) + await _release_schedule_ignoring_terminal( + existing_schedule_id, "_schedule_downgrade_at_period_end" + ) + + # Create + modify as a two-step transaction. If modify fails (network, + # Stripe 500) the created schedule is orphaned AND attached to the + # subscription, which blocks any future Stripe-side change until manually + # released. Roll back by releasing the orphan, then re-raise so the caller + # sees the original failure. + schedule = await stripe.SubscriptionSchedule.create_async(from_subscription=sub_id) + try: + await stripe.SubscriptionSchedule.modify_async( + schedule.id, + phases=[ + { + "items": [{"price": current_price_id, "quantity": 1}], + "start_date": period_start, + "end_date": period_end, + "proration_behavior": "none", + }, + { + "items": [{"price": new_price_id, "quantity": 1}], + "proration_behavior": "none", + }, + ], + metadata={"user_id": user_id, "pending_tier": tier.value}, + ) + except stripe.StripeError: + logger.exception( + "_schedule_downgrade_at_period_end: modify failed for schedule %s" + " on sub %s user %s; attempting rollback release", + schedule.id, + sub_id, + user_id, + ) + try: + await _release_schedule_ignoring_terminal( + schedule.id, "_schedule_downgrade_at_period_end_rollback" + ) + except stripe.StripeError: + logger.exception( + "_schedule_downgrade_at_period_end: rollback release also failed" + " for orphaned schedule %s on sub %s user %s; manual cleanup" + " required", + schedule.id, + sub_id, + user_id, + ) + raise + logger.info( + "modify_stripe_subscription_for_tier: scheduled sub %s downgrade for user %s → %s at %d", + sub_id, + user_id, + tier, + period_end, + ) + + async def modify_stripe_subscription_for_tier( user_id: str, tier: SubscriptionTier ) -> bool: - """Modify an existing Stripe subscription to a new paid tier using proration. + """Change a Stripe subscription to a new paid tier. - For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing - subscription is preferable to cancelling + creating a new one via Checkout: - Stripe handles proration automatically, crediting unused time on the old plan - and charging the pro-rated amount for the new plan in the same billing cycle. + Upgrades (e.g. PRO→BUSINESS) apply immediately via ``stripe.Subscription.modify`` + with ``proration_behavior="create_prorations"``: Stripe credits unused time on + the old plan and charges the pro-rated amount for the new plan in the same + billing cycle. + + Downgrades (e.g. BUSINESS→PRO) are deferred to the end of the current billing + period via a Stripe Subscription Schedule: the user keeps their current tier + for the time they already paid for, and the new tier takes effect when the + next invoice is generated. The DB tier flip happens via the webhook fired + when the schedule advances to its next phase. Returns: - True — a subscription was found and modified successfully. + True — a subscription was found and modified/scheduled successfully. False — no active/trialing subscription exists (e.g. admin-granted tier or first-time paid signup); caller should fall back to Checkout. @@ -1437,41 +1666,262 @@ async def modify_stripe_subscription_for_tier( if not price_id: raise ValueError(f"No Stripe price ID configured for tier {tier}") - # Guard: only proceed if the user already has a Stripe customer ID. Calling - # get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier) - # would create an orphaned customer object if the subsequent Subscription.list call - # fails. Return False early so the API layer falls back to Checkout instead. user = await get_user_by_id(user_id) if not user.stripe_customer_id: return False + current_tier = user.subscription_tier or SubscriptionTier.FREE - customer_id = user.stripe_customer_id - for status in ("active", "trialing"): - subscriptions = await run_in_threadpool( - stripe.Subscription.list, customer=customer_id, status=status, limit=1 - ) - if not subscriptions.data: - continue - sub = subscriptions.data[0] - sub_id = sub["id"] - items = sub.get("items", {}).get("data", []) - if not items: - continue - item_id = items[0]["id"] - await run_in_threadpool( - stripe.Subscription.modify, - sub_id, - items=[{"id": item_id, "price": price_id}], - proration_behavior="create_prorations", - ) + sub = await _get_active_subscription(user.stripe_customer_id) + if sub is None: + return False + items = sub["items"].data + if not items: + return False + sub_id = sub.id + + # Invalidate the cache unconditionally on exit (success OR failure): any + # Stripe mutation below — clearing cancel_at_period_end, releasing an old + # schedule, creating a new one — may have landed partially before an error + # was raised, and the cached pending-change state would otherwise go stale + # for up to 30s until the TTL expires. + try: + if is_tier_downgrade(current_tier, tier): + await _schedule_downgrade_at_period_end(sub, price_id, user_id, tier) + return True + + # Upgrade path. If a schedule is attached from a previous pending + # downgrade, release it first — an upgrade expresses the user's + # intent to be on this tier immediately, which overrides any pending + # deferred change. Ignore terminal-state errors from release. + if sub.schedule: + existing_schedule_id = ( + sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id + ) + await _release_schedule_ignoring_terminal( + existing_schedule_id, "modify_stripe_subscription_for_tier" + ) + + # If a paid→FREE cancel is pending (cancel_at_period_end=True), clear it + # as part of the upgrade — the user is explicitly choosing to stay on a + # paid tier. Without this, the sub would be upgraded AND still cancelled + # at period end, leaving a confusing dual state. + modify_kwargs: dict = { + "items": [{"id": items[0].id, "price": price_id}], + "proration_behavior": "create_prorations", + } + if sub.cancel_at_period_end: + modify_kwargs["cancel_at_period_end"] = False + + await stripe.Subscription.modify_async(sub_id, **modify_kwargs) + # Flip the DB tier immediately. The customer.subscription.updated webhook + # will also fire and set it again — idempotent. Without this synchronous + # update, the UI refetches before the webhook lands and shows the old + # tier, making the upgrade look like a no-op to the user. + # + # Swallow DB-write exceptions here: Stripe is authoritative and the + # modify above already succeeded (the user has been charged). If the + # DB write fails and we re-raised, the API would return 5xx and the UI + # would surface a failed upgrade to a user who was already charged. + # The customer.subscription.updated webhook will reconcile the DB shortly. + # + # Only catch actual DB/connection failures — letting KeyError, + # AttributeError etc. propagate so programming errors surface in Sentry + # instead of being silently masked as benign DB-write-swallow events. + try: + await set_subscription_tier(user_id, tier) + except (PrismaError, ConnectionError, asyncio.TimeoutError): + logger.exception( + "modify_stripe_subscription_for_tier: Stripe modify on sub %s" + " succeeded for user %s → %s but DB tier flip failed; webhook" + " will reconcile", + sub_id, + user_id, + tier, + ) logger.info( - "modify_stripe_subscription_for_tier: modified sub %s for user %s → %s", + "modify_stripe_subscription_for_tier: upgraded sub %s for user %s → %s", sub_id, user_id, tier, ) return True - return False + finally: + get_pending_subscription_change.cache_delete(user_id) + + +async def release_pending_subscription_schedule(user_id: str) -> bool: + """Cancel any pending subscription change (scheduled downgrade or cancellation). + + Two pending-change mechanisms can be attached to a Stripe subscription: + + - **Subscription Schedule** (paid→paid downgrade): ``stripe.SubscriptionSchedule.release`` + detaches the schedule and lets the subscription continue on its current + phase's price. + - **cancel_at_period_end=True** (paid→FREE cancel): clearing that flag via + ``stripe.Subscription.modify`` keeps the subscription active indefinitely. + + Returns True if a pending change was found and reverted, False otherwise. + """ + user = await get_user_by_id(user_id) + if not user.stripe_customer_id: + return False + + sub = await _get_active_subscription(user.stripe_customer_id) + if sub is None: + return False + + sub_id = sub.id + did_anything = False + schedule_released = False + schedule_id: str | None = None + try: + if sub.schedule: + schedule_id = ( + sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id + ) + schedule_released = await _release_schedule_ignoring_terminal( + schedule_id, "release_pending_subscription_schedule" + ) + if schedule_released: + logger.info( + "release_pending_subscription_schedule: released schedule %s for user %s", + schedule_id, + user_id, + ) + did_anything = True + if sub.cancel_at_period_end: + try: + await stripe.Subscription.modify_async( + sub_id, cancel_at_period_end=False + ) + except stripe.StripeError: + if schedule_released: + logger.exception( + "release_pending_subscription_schedule: partial release" + " — schedule %s released but cancel_at_period_end clear" + " failed on sub %s for user %s; manual reconciliation" + " may be needed", + schedule_id, + sub_id, + user_id, + ) + raise + did_anything = True + logger.info( + "release_pending_subscription_schedule: cleared cancel_at_period_end" + " on sub %s for user %s", + sub_id, + user_id, + ) + finally: + if did_anything: + get_pending_subscription_change.cache_delete(user_id) + return did_anything + + +@cached(ttl_seconds=30, maxsize=512, cache_none=True, shared_cache=True) +async def get_pending_subscription_change( + user_id: str, +) -> tuple[SubscriptionTier, datetime] | None: + """Return ``(pending_tier, effective_at)`` when a change is queued, else ``None``. + + Reflects both Subscription Schedule phase transitions (paid→paid downgrade) + and ``cancel_at_period_end=True`` (paid→FREE cancel). + + Cached for 30 seconds per user_id. *Why the cache exists:* this function + runs on every dashboard/home fetch and would otherwise fire + 2× Subscription.list + 1× Schedule.retrieve per page load. A busy user + polling the billing page would quickly brush up against Stripe's per-API + rate limits; the 30s TTL absorbs dashboard polling while being short + enough that the UI reconciles quickly after a downgrade / cancel action. + + *Invalidation contract.* Every call-site that mutates Stripe state which + could change the pending-change answer MUST call + ``get_pending_subscription_change.cache_delete(user_id)`` so the UI never + shows a stale pending badge after a user-visible action. Current + invalidators (keep this list in sync when adding new mutators): + + - ``set_subscription_tier`` — admin or webhook-driven tier flip. + - ``modify_stripe_subscription_for_tier`` — ``finally`` block (covers + upgrade path clear + downgrade-schedule create + any partial failure). + - ``release_pending_subscription_schedule`` — ``finally`` block when a + schedule release OR ``cancel_at_period_end`` clear succeeded. + - ``cancel_stripe_subscription`` — after scheduling period-end cancel. + - ``sync_subscription_from_stripe`` — webhook entry point. + - ``set_user_tier`` (``backend.copilot.rate_limit``) — admin tier override + invalidates any cached pending state keyed off the old tier. + """ + user = await get_user_by_id(user_id) + if not user.stripe_customer_id: + # Short-circuit for users with no Stripe customer (admin-granted tiers, + # FREE-only users): skip the Stripe API calls entirely. + return None + + pro_price, biz_price = await asyncio.gather( + get_subscription_price_id(SubscriptionTier.PRO), + get_subscription_price_id(SubscriptionTier.BUSINESS), + ) + price_to_tier: dict[str, SubscriptionTier] = {} + if pro_price: + price_to_tier[pro_price] = SubscriptionTier.PRO + if biz_price: + price_to_tier[biz_price] = SubscriptionTier.BUSINESS + if not price_to_tier: + logger.warning( + "get_pending_subscription_change: no Stripe price IDs resolvable for" + " PRO/BUSINESS (LaunchDarkly fetch failed?); raising to bypass the" + " None cache so the next request retries fresh" + ) + raise PendingChangeUnknown( + "Stripe price lookup failed; pending-change state cannot be determined" + ) + + sub = await _get_active_subscription(user.stripe_customer_id) + if sub is None: + return None + period_end = sub.current_period_end + if not isinstance(period_end, int): + return None + effective_at = datetime.fromtimestamp(period_end, tz=timezone.utc) + if sub.cancel_at_period_end: + return SubscriptionTier.FREE, effective_at + if not sub.schedule: + return None + schedule_id = sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id + schedule = await stripe.SubscriptionSchedule.retrieve_async(schedule_id) + return _next_phase_tier_and_start(schedule, price_to_tier) + + +def _next_phase_tier_and_start( + schedule: stripe.SubscriptionSchedule, + price_to_tier: dict[str, SubscriptionTier], +) -> tuple[SubscriptionTier, datetime] | None: + """Return (tier, start_datetime) of the phase that follows the active one. + + Using the phase's own ``start_date`` (not the subscription's current_period_end) + is correct even for schedules created outside this flow — a dashboard-authored + schedule can have phase transitions at arbitrary timestamps. + """ + now = int(time.time()) + for phase in schedule.phases or []: + if not isinstance(phase.start_date, int) or phase.start_date <= now: + continue + # ``phase["items"]`` because ``phase.items`` is shadowed by dict.items(). + items = phase["items"] or [] + if not items: + continue + price = items[0].price + price_id = price if isinstance(price, str) else price.id + if price_id in price_to_tier: + return price_to_tier[price_id], datetime.fromtimestamp( + phase.start_date, tz=timezone.utc + ) + logger.warning( + "next_phase_tier_and_start: unknown price %s on schedule %s", + price_id, + schedule.id, + ) + return None async def get_auto_top_up(user_id: str) -> AutoTopUpConfig: @@ -1732,6 +2182,50 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None: # cancel the old sub. await _cleanup_stale_subscriptions(customer_id, new_sub_id) await set_subscription_tier(user.id, tier) + # Tier changed — bust any cached pending-change view so the next + # dashboard fetch reflects the new state immediately. + get_pending_subscription_change.cache_delete(user.id) + + +async def sync_subscription_schedule_from_stripe(stripe_schedule: dict) -> None: + """Sync the DB tier from a ``subscription_schedule.*`` webhook event. + + Stripe fires ``subscription_schedule.released`` / ``.completed`` / + ``.updated`` when a schedule advances phases or is detached. The regular + ``customer.subscription.updated`` webhook with the new price covers the + phase transition in most cases, but listening to schedule events is a + safety net that also catches releases done via the Stripe dashboard. + + The schedule payload doesn't carry the active price directly — it carries + a ``subscription`` id that we look up to get the current item. + + Webhook-ordering safety: we deliberately funnel both event sources through + ``sync_subscription_from_stripe`` so they share one code path and one DB + write. That function is idempotent — it no-ops when ``current_tier == + tier`` — so concurrent or out-of-order deliveries of + ``subscription_schedule.*`` and ``customer.subscription.updated`` converge + to the same DB state regardless of which arrives first. + """ + # When a schedule is released, Stripe clears `subscription` and moves the id + # to `released_subscription`. Fall back to that so `.released` events — the + # main reason we listen to schedule webhooks as a safety net — are processed. + sub_id = stripe_schedule.get("subscription") or stripe_schedule.get( + "released_subscription" + ) + if not isinstance(sub_id, str) or not sub_id: + logger.warning( + "sync_subscription_schedule_from_stripe: no 'subscription' id; skipping" + ) + return + try: + sub = await stripe.Subscription.retrieve_async(sub_id) + except stripe.StripeError: + logger.warning( + "sync_subscription_schedule_from_stripe: failed to retrieve sub %s", + sub_id, + ) + return + await sync_subscription_from_stripe(dict(sub)) async def handle_subscription_payment_failure(invoice: dict) -> None: diff --git a/autogpt_platform/backend/backend/data/credit_subscription_test.py b/autogpt_platform/backend/backend/data/credit_subscription_test.py index a9634afcb4..d38f71d09e 100644 --- a/autogpt_platform/backend/backend/data/credit_subscription_test.py +++ b/autogpt_platform/backend/backend/data/credit_subscription_test.py @@ -12,11 +12,16 @@ from prisma.models import User from backend.data.credit import ( cancel_stripe_subscription, create_subscription_checkout, + get_pending_subscription_change, get_proration_credit_cents, handle_subscription_payment_failure, + is_tier_downgrade, + is_tier_upgrade, modify_stripe_subscription_for_tier, + release_pending_subscription_schedule, set_subscription_tier, sync_subscription_from_stripe, + sync_subscription_schedule_from_stripe, ) @@ -310,7 +315,11 @@ def _make_user_with_stripe(stripe_customer_id: str | None = "cus_123") -> MagicM @pytest.mark.asyncio async def test_cancel_stripe_subscription_cancels_active(): mock_subscriptions = MagicMock() - mock_subscriptions.data = [{"id": "sub_abc123"}] + mock_subscriptions.data = [ + stripe.Subscription.construct_from( + {"id": "sub_abc123", "schedule": None}, "sk_test" + ) + ] mock_subscriptions.has_more = False with ( @@ -346,7 +355,14 @@ async def test_cancel_stripe_subscription_no_customer_id_returns_false(): async def test_cancel_stripe_subscription_multi_partial_failure(): """First modify raises → error propagates and subsequent subs are not scheduled.""" mock_subscriptions = MagicMock() - mock_subscriptions.data = [{"id": "sub_first"}, {"id": "sub_second"}] + mock_subscriptions.data = [ + stripe.Subscription.construct_from( + {"id": "sub_first", "schedule": None}, "sk_test" + ), + stripe.Subscription.construct_from( + {"id": "sub_second", "schedule": None}, "sk_test" + ), + ] mock_subscriptions.has_more = False with ( @@ -428,7 +444,11 @@ async def test_cancel_stripe_subscription_cancels_trialing(): active_subs.data = [] active_subs.has_more = False trialing_subs = MagicMock() - trialing_subs.data = [{"id": "sub_trial_123"}] + trialing_subs.data = [ + stripe.Subscription.construct_from( + {"id": "sub_trial_123", "schedule": None}, "sk_test" + ) + ] trialing_subs.has_more = False def list_side_effect(*args, **kwargs): @@ -454,10 +474,18 @@ async def test_cancel_stripe_subscription_cancels_trialing(): async def test_cancel_stripe_subscription_cancels_active_and_trialing(): """Both active AND trialing subs present → both get scheduled for cancellation, no duplicates.""" active_subs = MagicMock() - active_subs.data = [{"id": "sub_active_1"}] + active_subs.data = [ + stripe.Subscription.construct_from( + {"id": "sub_active_1", "schedule": None}, "sk_test" + ) + ] active_subs.has_more = False trialing_subs = MagicMock() - trialing_subs.data = [{"id": "sub_trial_2"}] + trialing_subs.data = [ + stripe.Subscription.construct_from( + {"id": "sub_trial_2", "schedule": None}, "sk_test" + ) + ] trialing_subs.has_more = False def list_side_effect(*args, **kwargs): @@ -480,6 +508,62 @@ async def test_cancel_stripe_subscription_cancels_active_and_trialing(): assert modified_ids == {"sub_active_1", "sub_trial_2"} +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_releases_attached_schedule_first(): + """Pre-existing Subscription Schedule must be released before cancel_at_period_end. + + Stripe rejects ``modify(cancel_at_period_end=True)`` with HTTP 400 when the + subscription has an attached schedule (e.g. user queued a BUSINESS→PRO + downgrade and now clicks "Downgrade to FREE"). Without the pre-release, + the API handler would surface a 502 to the user. + """ + mock_subscriptions = MagicMock() + mock_subscriptions.data = [ + stripe.Subscription.construct_from( + {"id": "sub_abc123", "schedule": "sub_sched_abc"}, "sk_test" + ) + ] + mock_subscriptions.has_more = False + + call_order: list[str] = [] + + async def record_release(schedule_id): + call_order.append(f"release:{schedule_id}") + + def record_modify(sub_id, **kwargs): + call_order.append(f"modify:{sub_id}:{kwargs}") + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe("cus_123"), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_subscriptions, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + side_effect=record_release, + ) as mock_release, + patch( + "backend.data.credit.stripe.Subscription.modify", + side_effect=record_modify, + ) as mock_modify, + ): + await cancel_stripe_subscription("user-1") + + mock_release.assert_awaited_once_with("sub_sched_abc") + mock_modify.assert_called_once_with("sub_abc123", cancel_at_period_end=True) + # Release must happen before modify, else Stripe returns 400. + assert call_order == [ + "release:sub_sched_abc", + "modify:sub_abc123:{'cancel_at_period_end': True}", + ] + + @pytest.mark.asyncio async def test_get_proration_credit_cents_no_stripe_customer_returns_zero(): """Admin-granted tier users without stripe_customer_id get 0 without creating a customer.""" @@ -878,7 +962,11 @@ async def test_cancel_stripe_subscription_raises_on_cancel_error(): import stripe as stripe_mod mock_subscriptions = MagicMock() - mock_subscriptions.data = [{"id": "sub_abc123"}] + mock_subscriptions.data = [ + stripe.Subscription.construct_from( + {"id": "sub_abc123", "schedule": None}, "sk_test" + ) + ] mock_subscriptions.has_more = False with ( @@ -1099,15 +1187,21 @@ async def test_handle_subscription_payment_failure_passes_invoice_id_as_transact @pytest.mark.asyncio async def test_modify_stripe_subscription_for_tier_modifies_existing_sub(): """modify_stripe_subscription_for_tier calls Subscription.modify and returns True.""" - mock_sub = { - "id": "sub_abc", - "items": {"data": [{"id": "si_abc"}]}, - } + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_abc", + "items": {"data": [{"id": "si_abc"}]}, + "schedule": None, + "cancel_at_period_end": False, + }, + "k", + ) mock_list = MagicMock() mock_list.data = [mock_sub] mock_user = MagicMock(spec=User) mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.FREE with ( patch( @@ -1121,12 +1215,18 @@ async def test_modify_stripe_subscription_for_tier_modifies_existing_sub(): return_value=mock_user, ), patch( - "backend.data.credit.stripe.Subscription.list", + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, return_value=mock_list, ), patch( - "backend.data.credit.stripe.Subscription.modify", + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, ) as mock_modify, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ) as mock_set_tier, ): result = await modify_stripe_subscription_for_tier( "user-1", SubscriptionTier.PRO @@ -1138,6 +1238,66 @@ async def test_modify_stripe_subscription_for_tier_modifies_existing_sub(): items=[{"id": "si_abc", "price": "price_pro_monthly"}], proration_behavior="create_prorations", ) + mock_set_tier.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_clears_cancel_at_period_end_on_upgrade(): + """Upgrading from a sub with cancel_at_period_end=True clears the flag so the + upgrade isn't silently cancelled at period end and the DB tier flips immediately.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_upgrading", + "items": {"data": [{"id": "si_abc"}]}, + "schedule": None, + "cancel_at_period_end": True, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.PRO + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_biz_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ) as mock_set_tier, + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.BUSINESS + ) + + assert result is True + mock_modify.assert_called_once_with( + "sub_upgrading", + items=[{"id": "si_abc", "price": "price_biz_monthly"}], + proration_behavior="create_prorations", + cancel_at_period_end=False, + ) + mock_set_tier.assert_awaited_once_with("user-1", SubscriptionTier.BUSINESS) @pytest.mark.asyncio @@ -1178,6 +1338,7 @@ async def test_modify_stripe_subscription_for_tier_returns_false_when_no_sub(): mock_user = MagicMock(spec=User) mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.FREE with ( patch( @@ -1191,7 +1352,8 @@ async def test_modify_stripe_subscription_for_tier_returns_false_when_no_sub(): return_value=mock_user, ), patch( - "backend.data.credit.stripe.Subscription.list", + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, return_value=mock_list, ), ): @@ -1212,3 +1374,1089 @@ async def test_modify_stripe_subscription_for_tier_raises_on_missing_price_id(): ): with pytest.raises(ValueError, match="No Stripe price ID configured"): await modify_stripe_subscription_for_tier("user-1", SubscriptionTier.PRO) + + +def test_tier_order_helpers(): + assert is_tier_upgrade(SubscriptionTier.FREE, SubscriptionTier.PRO) is True + assert is_tier_upgrade(SubscriptionTier.PRO, SubscriptionTier.BUSINESS) is True + assert is_tier_upgrade(SubscriptionTier.BUSINESS, SubscriptionTier.PRO) is False + assert is_tier_downgrade(SubscriptionTier.BUSINESS, SubscriptionTier.PRO) is True + assert is_tier_downgrade(SubscriptionTier.PRO, SubscriptionTier.FREE) is True + assert is_tier_downgrade(SubscriptionTier.PRO, SubscriptionTier.BUSINESS) is False + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_downgrade_creates_schedule(): + """Paid→paid downgrade (BUSINESS→PRO) creates a Subscription Schedule rather than proration.""" + import time as time_mod + + now = int(time_mod.time()) + period_end = now + 27 * 24 * 3600 + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "items": {"data": [{"id": "si_biz", "price": {"id": "price_biz_monthly"}}]}, + "current_period_start": now - 3 * 24 * 3600, + "current_period_end": period_end, + "schedule": None, + "cancel_at_period_end": False, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mock_schedule = stripe.SubscriptionSchedule.construct_from( + {"id": "sub_sched_1"}, "k" + ) + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + return_value=mock_schedule, + ) as mock_schedule_create, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.modify_async", + new_callable=AsyncMock, + ) as mock_schedule_modify, + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is True + # Did NOT call Subscription.modify with proration (no immediate tier change). + mock_modify.assert_not_called() + mock_schedule_create.assert_called_once_with(from_subscription="sub_biz") + assert mock_schedule_modify.call_count == 1 + _, kwargs = mock_schedule_modify.call_args + phases = kwargs["phases"] + assert phases[0]["items"][0]["price"] == "price_biz_monthly" + assert phases[0]["end_date"] == period_end + assert phases[1]["items"][0]["price"] == "price_pro_monthly" + assert phases[0]["proration_behavior"] == "none" + assert phases[1]["proration_behavior"] == "none" + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_upgrade_immediate_proration(): + """PRO→BUSINESS upgrade still uses Subscription.modify with proration (no schedule).""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "items": {"data": [{"id": "si_pro", "price": {"id": "price_pro_monthly"}}]}, + "schedule": None, + "cancel_at_period_end": False, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.PRO + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_biz_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + ) as mock_schedule_create, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.BUSINESS + ) + + assert result is True + mock_modify.assert_called_once_with( + "sub_pro", + items=[{"id": "si_pro", "price": "price_biz_monthly"}], + proration_behavior="create_prorations", + ) + mock_schedule_create.assert_not_called() + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_releases_downgrade_schedule(): + """release_pending_subscription_schedule releases the Stripe schedule if one is attached.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": "sub_sched_1", + "cancel_at_period_end": False, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + ): + result = await release_pending_subscription_schedule("user-1") + + assert result is True + mock_release.assert_called_once_with("sub_sched_1") + mock_modify.assert_not_called() + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_clears_cancel_at_period_end(): + """release_pending_subscription_schedule reverts a pending paid→FREE cancel.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "schedule": None, + "cancel_at_period_end": True, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + ): + result = await release_pending_subscription_schedule("user-1") + + assert result is True + mock_modify.assert_called_once_with("sub_pro", cancel_at_period_end=False) + mock_release.assert_not_called() + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_no_pending_change_returns_false(): + """release_pending_subscription_schedule returns False when no schedule/cancel is set.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "schedule": None, + "cancel_at_period_end": False, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + ): + result = await release_pending_subscription_schedule("user-1") + + assert result is False + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_no_stripe_customer_returns_false(): + mock_user = MagicMock() + mock_user.stripe_customer_id = None + + with patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ): + result = await release_pending_subscription_schedule("user-1") + + assert result is False + + +@pytest.mark.asyncio +async def test_get_pending_subscription_change_cancel_at_period_end(): + """cancel_at_period_end=True maps to pending FREE at current_period_end.""" + import time as time_mod + + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + now = int(time_mod.time()) + period_end = now + 10 * 24 * 3600 + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "current_period_end": period_end, + "cancel_at_period_end": True, + "schedule": None, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + ): + result = await get_pending_subscription_change("user-1") + + assert result is not None + pending_tier, effective_at = result + assert pending_tier == SubscriptionTier.FREE + assert int(effective_at.timestamp()) == period_end + + +@pytest.mark.asyncio +async def test_get_pending_subscription_change_from_schedule(): + """A schedule whose next phase uses the PRO price maps to pending_tier=PRO.""" + import time as time_mod + + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + now = int(time_mod.time()) + period_end = now + 10 * 24 * 3600 + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "current_period_end": period_end, + "cancel_at_period_end": False, + "schedule": "sub_sched_1", + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_schedule = stripe.SubscriptionSchedule.construct_from( + { + "id": "sub_sched_1", + "phases": [ + { + "start_date": now - 3 * 24 * 3600, + "end_date": period_end, + "items": [{"price": "price_biz_monthly"}], + }, + { + "start_date": period_end, + "items": [{"price": "price_pro_monthly"}], + }, + ], + }, + "k", + ) + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.retrieve_async", + new_callable=AsyncMock, + return_value=mock_schedule, + ), + ): + result = await get_pending_subscription_change("user-1") + + assert result is not None + pending_tier, effective_at = result + assert pending_tier == SubscriptionTier.PRO + assert int(effective_at.timestamp()) == period_end + + +@pytest.mark.asyncio +async def test_get_pending_subscription_change_none_when_no_schedule_or_cancel(): + """Returns None when neither a schedule nor cancel_at_period_end is set.""" + import time as time_mod + + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + now = int(time_mod.time()) + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "current_period_end": now + 10 * 24 * 3600, + "cancel_at_period_end": False, + "schedule": None, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return { + SubscriptionTier.PRO: "price_pro", + SubscriptionTier.BUSINESS: "price_biz", + }.get(tier) + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + ): + result = await get_pending_subscription_change("user-1") + + assert result is None + + +@pytest.mark.asyncio +async def test_sync_subscription_schedule_from_stripe_retrieves_and_delegates(): + """subscription_schedule.released triggers a sync via the active subscription object.""" + stripe_schedule = {"id": "sub_sched_1", "subscription": "sub_pro"} + retrieved_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "customer": "cus_abc", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + }, + "k", + ) + + with ( + patch( + "backend.data.credit.stripe.Subscription.retrieve_async", + new_callable=AsyncMock, + return_value=retrieved_sub, + ) as mock_retrieve, + patch( + "backend.data.credit.sync_subscription_from_stripe", + new_callable=AsyncMock, + ) as mock_sync, + ): + await sync_subscription_schedule_from_stripe(stripe_schedule) + + mock_retrieve.assert_called_once_with("sub_pro") + mock_sync.assert_awaited_once() + forwarded = mock_sync.call_args.args[0] + assert forwarded["id"] == "sub_pro" + assert forwarded["customer"] == "cus_abc" + + +@pytest.mark.asyncio +async def test_sync_subscription_schedule_from_stripe_uses_released_subscription_fallback(): + """subscription_schedule.released events clear `subscription` and set + `released_subscription`; the sync handler must fall back to that id.""" + stripe_schedule = { + "id": "sub_sched_1", + "subscription": None, + "released_subscription": "sub_pro_released", + } + retrieved_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro_released", + "customer": "cus_abc", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + }, + "k", + ) + + with ( + patch( + "backend.data.credit.stripe.Subscription.retrieve_async", + new_callable=AsyncMock, + return_value=retrieved_sub, + ) as mock_retrieve, + patch( + "backend.data.credit.sync_subscription_from_stripe", + new_callable=AsyncMock, + ) as mock_sync, + ): + await sync_subscription_schedule_from_stripe(stripe_schedule) + + mock_retrieve.assert_called_once_with("sub_pro_released") + mock_sync.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_sync_subscription_schedule_from_stripe_missing_sub_id_returns(): + """A schedule event with no 'subscription' field is logged and ignored.""" + with patch( + "backend.data.credit.stripe.Subscription.retrieve_async", + new_callable=AsyncMock, + ) as mock_retrieve: + await sync_subscription_schedule_from_stripe({"id": "sub_sched_1"}) + mock_retrieve.assert_not_called() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_phase_transition_updates_tier(): + """When a schedule advances phases, Stripe fires customer.subscription.updated with + the new price — the existing sync handler must update the DB tier accordingly.""" + mock_user = _make_user(tier=SubscriptionTier.BUSINESS) + stripe_sub = { + "id": "sub_pro", + "customer": "cus_abc", + "status": "active", + # Phase advanced: price is now PRO (was BUSINESS before). + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + +@pytest.mark.asyncio +async def test_release_schedule_idempotent_on_terminal_state(): + """SubscriptionSchedule.release raising InvalidRequestError on a terminal-state + schedule is treated as success; we still continue to the cancel_at_period_end clear. + """ + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": "sub_sched_terminal", + "cancel_at_period_end": True, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + side_effect=stripe.InvalidRequestError( + "Schedule has already been released", + param="schedule", + ), + ) as mock_release, + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + ): + result = await release_pending_subscription_schedule("user-1") + + # Terminal-state release is treated as idempotent success; modify still runs. + assert result is True + mock_release.assert_called_once_with("sub_sched_terminal") + mock_modify.assert_called_once_with("sub_biz", cancel_at_period_end=False) + + +@pytest.mark.asyncio +async def test_schedule_downgrade_releases_existing_schedule(): + """_schedule_downgrade_at_period_end releases any pre-existing schedule first.""" + import time as time_mod + + now = int(time_mod.time()) + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": "sub_sched_old", + "cancel_at_period_end": False, + "items": {"data": [{"id": "si_biz", "price": {"id": "price_biz_monthly"}}]}, + "current_period_start": now - 3 * 24 * 3600, + "current_period_end": now + 27 * 24 * 3600, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mock_new_schedule = stripe.SubscriptionSchedule.construct_from( + {"id": "sub_sched_new"}, "k" + ) + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + return_value=mock_new_schedule, + ) as mock_create, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.modify_async", + new_callable=AsyncMock, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is True + # Existing schedule released before creating the new one. + mock_release.assert_called_once_with("sub_sched_old") + mock_create.assert_called_once_with(from_subscription="sub_biz") + # cancel_at_period_end was False, so Subscription.modify should not be called. + mock_modify.assert_not_called() + + +@pytest.mark.asyncio +async def test_schedule_downgrade_clears_cancel_at_period_end(): + """_schedule_downgrade_at_period_end clears cancel_at_period_end before scheduling.""" + import time as time_mod + + now = int(time_mod.time()) + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": None, + "cancel_at_period_end": True, + "items": {"data": [{"id": "si_biz", "price": {"id": "price_biz_monthly"}}]}, + "current_period_start": now - 3 * 24 * 3600, + "current_period_end": now + 27 * 24 * 3600, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mock_new_schedule = stripe.SubscriptionSchedule.construct_from( + {"id": "sub_sched_new"}, "k" + ) + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + return_value=mock_new_schedule, + ) as mock_create, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.modify_async", + new_callable=AsyncMock, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is True + # cancel_at_period_end cleared before new schedule is created. + mock_modify.assert_called_once_with("sub_biz", cancel_at_period_end=False) + mock_create.assert_called_once_with(from_subscription="sub_biz") + + +@pytest.mark.asyncio +async def test_schedule_downgrade_rolls_back_orphan_on_modify_failure(): + """If SubscriptionSchedule.modify fails after a successful create, the + orphaned schedule must be released so it doesn't stay attached and block + future changes. The original StripeError re-raises to the caller. + """ + import time as time_mod + + now = int(time_mod.time()) + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": None, + "cancel_at_period_end": False, + "items": {"data": [{"id": "si_biz", "price": {"id": "price_biz_monthly"}}]}, + "current_period_start": now - 3 * 24 * 3600, + "current_period_end": now + 27 * 24 * 3600, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mock_new_schedule = stripe.SubscriptionSchedule.construct_from( + {"id": "sub_sched_new"}, "k" + ) + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + return_value=mock_new_schedule, + ) as mock_create, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.modify_async", + new_callable=AsyncMock, + side_effect=stripe.APIConnectionError("network down"), + ) as mock_schedule_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + ): + with pytest.raises(stripe.APIConnectionError): + await modify_stripe_subscription_for_tier("user-1", SubscriptionTier.PRO) + + mock_create.assert_called_once_with(from_subscription="sub_biz") + mock_schedule_modify.assert_called_once() + # Rollback must release the freshly-created (and now orphaned) schedule + # id, not the pre-existing one (there was none here). + mock_release.assert_called_once_with("sub_sched_new") + + +@pytest.mark.asyncio +async def test_release_ignoring_terminal_reraises_non_terminal_error(): + """_release_schedule_ignoring_terminal only swallows terminal-state errors. + Typos / wrong ids / 404s surface so bugs aren't silently masked. + """ + from backend.data.credit import _release_schedule_ignoring_terminal + + with patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + side_effect=stripe.InvalidRequestError( + "No such subscription_schedule: 'sub_sched_typo'", + param="schedule", + ), + ): + with pytest.raises(stripe.InvalidRequestError): + await _release_schedule_ignoring_terminal("sub_sched_typo", "test_context") + + +@pytest.mark.asyncio +async def test_release_ignoring_terminal_swallows_terminal_error(): + """Terminal-state messages are treated as idempotent success and return False.""" + from backend.data.credit import _release_schedule_ignoring_terminal + + with patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + side_effect=stripe.InvalidRequestError( + "Schedule has already been released", + param="schedule", + ), + ): + result = await _release_schedule_ignoring_terminal( + "sub_sched_done", "test_context" + ) + + assert result is False + + +@pytest.mark.asyncio +async def test_upgrade_releases_pending_schedule(): + """modify_stripe_subscription_for_tier upgrade path releases attached schedule first.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "schedule": "sub_sched_pending_downgrade", + "cancel_at_period_end": False, + "items": {"data": [{"id": "si_pro", "price": {"id": "price_pro_monthly"}}]}, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.PRO + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_biz_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.BUSINESS + ) + + assert result is True + # Pending schedule released before the upgrade modify call. + mock_release.assert_called_once_with("sub_sched_pending_downgrade") + mock_modify.assert_called_once_with( + "sub_pro", + items=[{"id": "si_pro", "price": "price_biz_monthly"}], + proration_behavior="create_prorations", + ) + + +@pytest.mark.asyncio +async def test_next_phase_tier_and_start_logs_unknown_price(caplog): + """_next_phase_tier_and_start emits a warning when the next-phase price is unmapped.""" + import logging + import time as time_mod + + from backend.data.credit import _next_phase_tier_and_start + + now = int(time_mod.time()) + schedule = stripe.SubscriptionSchedule.construct_from( + { + "id": "sub_sched_unknown", + "phases": [ + { + "start_date": now - 3 * 24 * 3600, + "end_date": now + 27 * 24 * 3600, + "items": [{"price": "price_current"}], + }, + { + "start_date": now + 27 * 24 * 3600, + "items": [{"price": "price_unknown"}], + }, + ], + }, + "k", + ) + price_to_tier = {"price_pro_monthly": SubscriptionTier.PRO} + + with caplog.at_level(logging.WARNING, logger="backend.data.credit"): + result = _next_phase_tier_and_start(schedule, price_to_tier) + + assert result is None + assert any( + "next_phase_tier_and_start: unknown price price_unknown" in record.message + and "sub_sched_unknown" in record.message + for record in caplog.records + ) + + +@pytest.mark.asyncio +async def test_get_pending_subscription_change_raises_when_price_lookups_fail(): + """When both LD price lookups return None, raise PendingChangeUnknown so the + @cached wrapper doesn't store None and hide pending changes for 30s.""" + from backend.data.credit import PendingChangeUnknown + + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return None + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + pytest.raises(PendingChangeUnknown), + ): + await get_pending_subscription_change("user-price-fail") + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_invalidates_cache_on_partial_failure(): + """If schedule.release succeeds but cancel_at_period_end clear fails, the + cache must still be invalidated — otherwise the UI shows a stale pending + banner for up to 30s even though the schedule was actually released.""" + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + import time as time_mod + + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_mixed", + "schedule": "sub_sched_to_release", + "cancel_at_period_end": True, + "current_period_end": int(time_mod.time()) + 10 * 24 * 3600, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + return_value=MagicMock(), + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + side_effect=stripe.APIConnectionError("transient Stripe error"), + ), + patch.object( + get_pending_subscription_change, "cache_delete" + ) as mock_cache_delete, + ): + with pytest.raises(stripe.APIConnectionError): + await release_pending_subscription_schedule("user-partial") + + mock_cache_delete.assert_called_once_with("user-partial") diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx index 58a4b9d58b..d8aab67b22 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx @@ -4,42 +4,14 @@ import { Button } from "@/components/ui/button"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { Skeleton } from "@/components/atoms/Skeleton/Skeleton"; import { useSubscriptionTierSection } from "./useSubscriptionTierSection"; - -type TierInfo = { - key: string; - label: string; - multiplier: string; - description: string; -}; - -const TIERS: TierInfo[] = [ - { - key: "FREE", - label: "Free", - multiplier: "1x", - description: "Base AutoPilot capacity with standard rate limits", - }, - { - key: "PRO", - label: "Pro", - multiplier: "5x", - description: "5x AutoPilot capacity — run 5× more tasks per day/week", - }, - { - key: "BUSINESS", - label: "Business", - multiplier: "20x", - description: "20x AutoPilot capacity — ideal for teams and heavy workloads", - }, -]; - -const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; - -function formatCost(cents: number, tierKey: string): string { - if (tierKey === "FREE") return "Free"; - if (cents === 0) return "Pricing available soon"; - return `$${(cents / 100).toFixed(2)}/mo`; -} +import { PendingChangeBanner } from "./components/PendingChangeBanner/PendingChangeBanner"; +import { + TIERS, + TIER_ORDER, + formatCost, + formatPendingDate, + getTierLabel, +} from "./helpers"; export function SubscriptionTierSection() { const { @@ -55,10 +27,14 @@ export function SubscriptionTierSection() { isPaymentEnabled, changeTier, handleTierChange, + cancelPendingChange, } = useSubscriptionTierSection(); const [confirmDowngradeTo, setConfirmDowngradeTo] = useState( null, ); + const [confirmReplacePendingTo, setConfirmReplacePendingTo] = useState< + string | null + >(null); if (isLoading) { return ( @@ -115,6 +91,34 @@ export function SubscriptionTierSection() { await changeTier(tier); } + async function confirmReplacePending() { + if (!confirmReplacePendingTo) return; + const tier = confirmReplacePendingTo; + setConfirmReplacePendingTo(null); + handleTierChange(tier, currentTier, setConfirmDowngradeTo); + } + + const pendingTierFromSubscription = subscription.pending_tier ?? null; + const hasPendingChange = + pendingTierFromSubscription !== null && + pendingTierFromSubscription !== currentTier; + + function onTierButtonClick(targetTierKey: string) { + // If a pending change is queued and the user clicks a DIFFERENT non-current, + // non-pending tier, surface a confirmation so they don't silently overwrite + // their own scheduled change. The on-card button for the pending tier itself + // is already disabled; the primary cancel path is the banner. + if ( + hasPendingChange && + targetTierKey !== pendingTierFromSubscription && + targetTierKey !== currentTier + ) { + setConfirmReplacePendingTo(targetTierKey); + return; + } + handleTierChange(targetTierKey, currentTier, setConfirmDowngradeTo); + } + return (

Subscription Plan

@@ -128,6 +132,16 @@ export function SubscriptionTierSection() {

)} + {hasPendingChange && pendingTierFromSubscription ? ( + void cancelPendingChange()} + isBusy={isPending} + /> + ) : null} +
{TIERS.map((tier) => { const isCurrent = currentTier === tier.key; @@ -137,6 +151,8 @@ export function SubscriptionTierSection() { const isUpgrade = targetIdx > currentIdx; const isDowngrade = targetIdx < currentIdx; const isThisPending = pendingTier === tier.key; + const isScheduledTier = + hasPendingChange && pendingTierFromSubscription === tier.key; return (
- handleTierChange( - tier.key, - currentTier, - setConfirmDowngradeTo, - ) - } + disabled={isPending || isScheduledTier} + onClick={() => onTierButtonClick(tier.key)} > {isThisPending ? "Updating..." - : isUpgrade - ? `Upgrade to ${tier.label}` - : isDowngrade - ? `Downgrade to ${tier.label}` - : `Switch to ${tier.label}`} + : isScheduledTier + ? "Scheduled" + : isUpgrade + ? `Upgrade to ${tier.label}` + : isDowngrade + ? `Downgrade to ${tier.label}` + : `Switch to ${tier.label}`} )}
@@ -196,9 +208,9 @@ export function SubscriptionTierSection() { {currentTier !== "FREE" && isPaymentEnabled && (

- Your subscription is managed through Stripe. Upgrades and paid-tier - changes take effect immediately; downgrades to Free are scheduled for - the end of the current billing period. + Your subscription is managed through Stripe. Upgrades take effect + immediately. Downgrades take effect at the end of your current billing + period.

)} @@ -215,7 +227,7 @@ export function SubscriptionTierSection() {

{confirmDowngradeTo === "FREE" ? "Downgrading to Free will schedule your subscription to cancel at the end of your current billing period. You keep your current plan until then." - : `Switching to ${TIERS.find((t) => t.key === confirmDowngradeTo)?.label ?? confirmDowngradeTo} will take effect immediately.`}{" "} + : `Switching to ${TIERS.find((t) => t.key === confirmDowngradeTo)?.label ?? confirmDowngradeTo} will take effect at the end of your current billing period. You keep your current plan until then.`}{" "} Are you sure?

@@ -235,6 +247,42 @@ export function SubscriptionTierSection() { + { + if (!open) setConfirmReplacePendingTo(null); + }, + }} + > + +

+ You have a pending change to{" "} + {getTierLabel(pendingTierFromSubscription ?? "")} + {subscription.pending_tier_effective_at + ? ` scheduled for ${formatPendingDate(subscription.pending_tier_effective_at)}` + : ""} + . Switching to {getTierLabel(confirmReplacePendingTo ?? "")} will + replace it. Continue? +

+ + + + +
+
+ ; prorationCreditCents?: number; + pendingTier?: string | null; + pendingTierEffectiveAt?: Date | string | null; } = {}) { return { tier, monthly_cost: monthlyCost, tier_costs: tierCosts, proration_credit_cents: prorationCreditCents, + pending_tier: pendingTier, + pending_tier_effective_at: pendingTierEffectiveAt, }; } @@ -92,6 +98,7 @@ function setupMocks({ mutateFn = vi.fn().mockResolvedValue({ status: 200, data: { url: "" } }), isPending = false, variables = undefined as { data?: { tier?: string } } | undefined, + refetchFn = vi.fn(), } = {}) { // The hook uses select: (data) => (data.status === 200 ? data.data : null) // so the data value returned by the hook is already the transformed subscription object. @@ -100,13 +107,14 @@ function setupMocks({ data: subscription, isLoading, error: queryError, - refetch: vi.fn(), + refetch: refetchFn, }); mockUseUpdateSubscriptionTier.mockReturnValue({ mutateAsync: mutateFn, isPending, variables, }); + return { refetchFn, mutateFn }; } afterEach(() => { @@ -355,4 +363,229 @@ describe("SubscriptionTierSection", () => { // No toast should fire — the user simply abandoned checkout expect(mockToast).not.toHaveBeenCalled(); }); + + it("renders pending-change banner when pending_tier is set", () => { + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + }); + render(); + expect(screen.getByText(/scheduled to downgrade to/i)).toBeDefined(); + // Banner "Keep Business" button — the only Keep button, since the on-card + // duplicate was removed in favour of the banner. + expect( + screen.getAllByRole("button", { name: /keep business/i }), + ).toHaveLength(1); + }); + + it("does not render pending-change banner when pending_tier is null", () => { + setupMocks({ + subscription: makeSubscription({ tier: "BUSINESS", pendingTier: null }), + }); + render(); + expect(screen.queryByText(/scheduled to downgrade/i)).toBeNull(); + expect(screen.queryByRole("button", { name: /keep business/i })).toBeNull(); + }); + + it("clicking Keep [CurrentTier] in banner submits a same-tier update and refetches", async () => { + // The cancel-pending route was collapsed into POST /credits/subscription as + // a same-tier request. Clicking "Keep BUSINESS" calls useUpdateSubscriptionTier + // with tier === current tier so the backend releases any pending schedule. + const mutateFn = vi + .fn() + .mockResolvedValue({ status: 200, data: { url: "", tier: "BUSINESS" } }); + const refetchFn = vi.fn(); + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + mutateFn, + refetchFn, + }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /keep business/i })); + + await waitFor(() => { + expect(mutateFn).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ tier: "BUSINESS" }), + }), + ); + expect(refetchFn).toHaveBeenCalled(); + }); + expect(mockToast).toHaveBeenCalledWith( + expect.objectContaining({ + title: "Pending subscription change cancelled.", + }), + ); + }); + + it("uses end-of-period copy for paid→paid downgrade confirmation", () => { + setupMocks({ subscription: makeSubscription({ tier: "BUSINESS" }) }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /downgrade to pro/i })); + + const dialog = screen.getByRole("dialog"); + expect(dialog.textContent).toMatch( + /switching to pro will take effect at the end of your current billing period/i, + ); + expect(dialog.textContent).toMatch( + /you keep your current plan until then/i, + ); + expect(dialog.textContent).not.toMatch(/take effect immediately/i); + }); + + it("shows destructive toast, tierError and still refetches when cancel-pending fails", async () => { + // The catch branch inside cancelPendingChange is load-bearing: it surfaces + // the error to the user AND re-issues a refetch so the UI reconciles if + // the server actually succeeded (webhook delivered after our client-side + // error). + const mutateFn = vi + .fn() + .mockRejectedValue(new Error("Stripe webhook failed")); + const refetchFn = vi.fn(); + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + mutateFn, + refetchFn, + }); + render(); + + const keepButtons = screen.getAllByRole("button", { + name: /keep business/i, + }); + fireEvent.click(keepButtons[0]); + + await waitFor(() => { + expect(screen.getByRole("alert")).toBeDefined(); + expect(screen.getByText(/stripe webhook failed/i)).toBeDefined(); + }); + expect(mockToast).toHaveBeenCalledWith( + expect.objectContaining({ + title: "Failed to cancel pending change", + variant: "destructive", + }), + ); + expect(refetchFn).toHaveBeenCalled(); + }); + + it("disables the tier button that matches the pending tier so users can't overwrite their own scheduled change by mis-click", () => { + // User is on BUSINESS and has a pending downgrade to PRO. The "Downgrade + // to Pro" button must be disabled + labelled "Scheduled" so the primary + // cancel path stays the banner. Other tier buttons (FREE here) remain + // clickable — the user can still overwrite their pending change by + // picking a different target; backend handles that. + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + }); + render(); + + const scheduledBtn = screen.getByRole("button", { name: /scheduled/i }); + expect(scheduledBtn).toBeDefined(); + expect((scheduledBtn as HTMLButtonElement).disabled).toBe(true); + + // The non-pending tier (FREE) button is still clickable. + const freeBtn = screen.getByRole("button", { name: /downgrade to free/i }); + expect((freeBtn as HTMLButtonElement).disabled).toBe(false); + }); + + it("shows replace-pending dialog when clicking a non-pending tier while a pending change exists, and fires the mutation after confirm", async () => { + // User is on BUSINESS with a pending downgrade to PRO. Clicking FREE (a + // tier that is neither current nor the pending target) must NOT silently + // overwrite the pending schedule — it must open a confirmation dialog. + // Only after the user explicitly confirms should changeTier (→ its own + // downgrade confirm for paid→FREE) fire. + const mutateFn = vi + .fn() + .mockResolvedValue({ status: 200, data: { url: "" } }); + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + mutateFn, + }); + render(); + + // Clicking FREE while PRO is pending surfaces the replace-pending dialog + // before anything mutates. + fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i })); + expect(screen.getByRole("dialog")).toBeDefined(); + expect(screen.getByText(/replace pending change/i)).toBeDefined(); + expect(mutateFn).not.toHaveBeenCalled(); + + // Confirm the replace: the replace-pending dialog closes and the + // downgrade-to-FREE dialog takes over (because FREE is a downgrade). + fireEvent.click( + screen.getByRole("button", { name: /replace pending change/i }), + ); + + // Now the "Confirm Downgrade" dialog should be open — confirm it to fire + // the mutation. + fireEvent.click(screen.getByRole("button", { name: /confirm downgrade/i })); + + await waitFor(() => { + expect(mutateFn).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ tier: "FREE" }), + }), + ); + }); + }); + + it("dismisses replace-pending dialog on Cancel without mutating", () => { + const mutateFn = vi + .fn() + .mockResolvedValue({ status: 200, data: { url: "" } }); + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + mutateFn, + }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i })); + expect(screen.getByRole("dialog")).toBeDefined(); + + fireEvent.click(screen.getByRole("button", { name: /^cancel$/i })); + expect(screen.queryByRole("dialog")).toBeNull(); + expect(mutateFn).not.toHaveBeenCalled(); + }); + + it("renders FREE cancellation copy in banner when pending_tier is FREE", () => { + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "FREE", + pendingTierEffectiveAt: new Date("2026-05-15T00:00:00Z"), + }), + }); + render(); + // Cancellation copy — distinct from the generic downgrade phrasing. + expect( + screen.getByText(/scheduled to cancel your subscription on/i), + ).toBeDefined(); + expect(screen.getByText(/May 15, 2026/)).toBeDefined(); + // Must NOT render the "downgrade to" phrasing on FREE cancellation. + expect(screen.queryByText(/scheduled to downgrade to/i)).toBeNull(); + }); }); diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/components/PendingChangeBanner/PendingChangeBanner.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/components/PendingChangeBanner/PendingChangeBanner.tsx new file mode 100644 index 0000000000..0088ad7666 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/components/PendingChangeBanner/PendingChangeBanner.tsx @@ -0,0 +1,60 @@ +import { Button } from "@/components/ui/button"; +import { formatPendingDate, getTierLabel } from "../../helpers"; + +interface Props { + currentTier: string; + pendingTier: string; + pendingEffectiveAt: Date | string | null | undefined; + onKeepCurrent: () => void; + isBusy: boolean; +} + +export function PendingChangeBanner({ + currentTier, + pendingTier, + pendingEffectiveAt, + onKeepCurrent, + isBusy, +}: Props) { + // Backend invariant: pending_tier_effective_at is always populated when + // pending_tier is set. Bail early if the date is missing so the sentence + // always reads with a date instead of a null-fallback branch. + if (!pendingEffectiveAt) return null; + + const pendingLabel = getTierLabel(pendingTier); + const currentLabel = getTierLabel(currentTier); + const dateText = formatPendingDate(pendingEffectiveAt); + + const isCancellation = pendingTier === "FREE"; + + return ( +
+

+ {isCancellation ? ( + <> + Scheduled to cancel your subscription on{" "} + {dateText}. + + ) : ( + <> + Scheduled to downgrade to{" "} + {pendingLabel} on{" "} + {dateText}. + + )} +

+ +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/helpers.ts new file mode 100644 index 0000000000..fde4674a8b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/helpers.ts @@ -0,0 +1,54 @@ +export interface TierInfo { + key: string; + label: string; + multiplier: string; + description: string; +} + +export const TIERS: TierInfo[] = [ + { + key: "FREE", + label: "Free", + multiplier: "1x", + description: "Base AutoPilot capacity with standard rate limits", + }, + { + key: "PRO", + label: "Pro", + multiplier: "5x", + description: "5x AutoPilot capacity — run 5× more tasks per day/week", + }, + { + key: "BUSINESS", + label: "Business", + multiplier: "20x", + description: "20x AutoPilot capacity — ideal for teams and heavy workloads", + }, +]; + +export const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; + +export function formatCost(cents: number, tierKey: string): string { + if (tierKey === "FREE") return "Free"; + if (cents === 0) return "Pricing available soon"; + return `$${(cents / 100).toFixed(2)}/mo`; +} + +export function getTierLabel(tierKey: string): string { + return ( + TIERS.find((t) => t.key === tierKey)?.label ?? + tierKey.charAt(0) + tierKey.slice(1).toLowerCase() + ); +} + +export function formatPendingDate(value: Date | string): string { + const date = value instanceof Date ? value : new Date(value); + // Pin to en-US so SSR and CSR produce the same string — passing `undefined` + // picks up the server's locale during prerender and the browser's locale on + // hydration, which triggers a React hydration mismatch warning. + return date.toLocaleDateString("en-US", { + year: "numeric", + month: "short", + day: "numeric", + }); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts index 862551c7e3..d51a2a6051 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts @@ -117,6 +117,47 @@ export function useSubscriptionTierSection() { await changeTier(tier); } + async function cancelPendingChange() { + if (!subscription) return; + setTierError(null); + try { + // "Stay on my current tier" is a same-tier POST: the backend collapses + // cancel-pending into update-tier and releases any pending schedule. + // success_url/cancel_url are unused in this branch (no Stripe Checkout + // is created) but are sent to satisfy the request schema. + await doUpdateTier({ + data: { + tier: subscription.tier as SubscriptionTierRequestTier, + success_url: `${window.location.origin}${window.location.pathname}`, + cancel_url: `${window.location.origin}${window.location.pathname}`, + }, + }); + await refetch(); + toast({ + title: "Pending subscription change cancelled.", + }); + } catch (e: unknown) { + const msg = + e instanceof Error + ? e.message + : "Failed to cancel pending subscription change"; + setTierError(msg); + toast({ + title: "Failed to cancel pending change", + description: msg, + variant: "destructive", + }); + // Refetch on error so the UI reconciles if the server actually + // succeeded (e.g. webhook delivered after our client-side error). + // Swallow refetch errors — we already have the primary error for display. + try { + await refetch(); + } catch { + // intentional + } + } + } + const pendingTier = isPending && variables?.data?.tier ? variables.data.tier : null; @@ -133,5 +174,6 @@ export function useSubscriptionTierSection() { isPaymentEnabled, changeTier, handleTierChange, + cancelPendingChange, }; } diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 920348db25..f20f34a805 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -2470,7 +2470,7 @@ }, "post": { "tags": ["v1", "credits"], - "summary": "Start a Stripe Checkout session to upgrade subscription tier", + "summary": "Update subscription tier or start a Stripe Checkout session", "operationId": "updateSubscriptionTier", "requestBody": { "content": { @@ -2488,7 +2488,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/SubscriptionCheckoutResponse" + "$ref": "#/components/schemas/SubscriptionStatusResponse" } } } @@ -14208,12 +14208,6 @@ "enum": ["DRAFT", "PENDING", "APPROVED", "REJECTED"], "title": "SubmissionStatus" }, - "SubscriptionCheckoutResponse": { - "properties": { "url": { "type": "string", "title": "Url" } }, - "type": "object", - "required": ["url"], - "title": "SubscriptionCheckoutResponse" - }, "SubscriptionStatusResponse": { "properties": { "tier": { @@ -14230,6 +14224,26 @@ "proration_credit_cents": { "type": "integer", "title": "Proration Credit Cents" + }, + "pending_tier": { + "anyOf": [ + { "type": "string", "enum": ["FREE", "PRO", "BUSINESS"] }, + { "type": "null" } + ], + "title": "Pending Tier" + }, + "pending_tier_effective_at": { + "anyOf": [ + { "type": "string", "format": "date-time" }, + { "type": "null" } + ], + "title": "Pending Tier Effective At" + }, + "url": { + "type": "string", + "title": "Url", + "description": "Populated only when POST /credits/subscription starts a Stripe Checkout Session (FREE → paid upgrade). Empty string in all other branches — the client redirects to this URL when non-empty.", + "default": "" } }, "type": "object", From 01f1289aac2e8408adbf2aa50d5fa5b2344ec488 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 14:34:43 +0700 Subject: [PATCH 2/4] feat(copilot): real OpenRouter cost + cost-based rate limits (percent-only public API) (#12864) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why After d7653acd0 removed cost estimation, most baseline turns log with `tracking_type="tokens"` and no authoritative USD figure (see: dashboard flipped from `cost_usd` to `tokens` after 4/14/2026). Rate-limit counters were also token-weighted with hand-rolled cache discounts (cache_read @ 10%, cache_create @ 25%) and a 5× Opus multiplier — a proxy for cost that drifts from real OpenRouter billing. This PR wires real generation cost from OpenRouter into both the cost-tracking log and the rate limiter, and hides raw spend figures from the user-facing API so clients can't reverse-engineer per-turn cost or platform margins. ## What 1. **Real cost from OpenRouter** — baseline passes `extra_body={"usage": {"include": True}}` and reads `chunk.usage.cost` from the final streaming chunk. `x-total-cost` header path removed. Missing cost logs an error and skips the counter update (vs the old estimator that silently under-counted). 2. **Cost-based rate limiting** — `record_token_usage(...)` → `record_cost_usage(cost_microdollars)`. The weighted-token math, cache discount factors, and `_OPUS_COST_MULTIPLIER` are gone; real USD already reflects model + cache pricing. 3. **Redis key migration** — `copilot:usage:*` → `copilot:cost:*` so stale token counters can't be misinterpreted as microdollars. 4. **LD flags + config** — renamed to `copilot-daily-cost-limit-microdollars` / `copilot-weekly-cost-limit-microdollars` (unit in the LD key so values can't accidentally be set in dollars or cents). 5. **Public `/usage` hides raw $$** — new `CoPilotUsagePublic` / `UsageWindowPublic` schemas expose only `percent_used` (0-100) + `resets_at` + `tier` + `reset_cost`. Admin endpoint keeps raw microdollars for debugging. 6. **Admin API contract** — `UserRateLimitResponse` fields renamed `daily/weekly_token_limit` → `daily/weekly_cost_limit_microdollars`, `daily/weekly_tokens_used` → `daily/weekly_cost_used_microdollars`. Admin UI displays `$X.XX`. ## How - `baseline/service.py` — pass `extra_body`, extract cost from `chunk.usage.cost`, drop the `x-total-cost` header fallback entirely. - `rate_limit.py` — rewritten around `record_cost_usage`, `check_rate_limit(daily_cost_limit, weekly_cost_limit)`, new Redis key prefix. Adds `CoPilotUsagePublic.from_status()` projector for the public API. - `token_tracking.py` — converts `cost_usd` → microdollars via `usd_to_microdollars` and calls `record_cost_usage` only when cost is present. - `sdk/service.py` — deletes `_OPUS_COST_MULTIPLIER` and simplifies `_resolve_model_and_multiplier` to `_resolve_sdk_model_for_request`. - Chat routes: `/usage` and `/usage/reset` return `CoPilotUsagePublic`. Internal server-side limit checks still use the raw microdollar `CoPilotUsageStatus`. - Admin routes: unchanged response shape (renamed fields only). - Frontend: `UsagePanelContent`, `UsageLimits`, `CopilotPage`, `BriefingTabContent`, `credits/page.tsx` consume the new public schema and render "N% used" + progress bar. Admin `RateLimitDisplay` / `UsageBar` keep `$X.XX`. Helper `formatMicrodollarsAsUsd` retained for admin use. - Tests + snapshots rewritten; new assertions explicitly check that raw `used`/`limit` keys are absent from the public payload. ## Deploy notes 1. **Before rolling this out, create the new LD flags:** `copilot-daily-cost-limit-microdollars` (default `500000`) and `copilot-weekly-cost-limit-microdollars` (default `2500000`). Old `copilot-*-token-limit` flags can stay in LD for rollback. 2. **One-time Redis cleanup (optional):** token-based counters under `copilot:usage:*` are orphaned and will TTL out within 7 days. Safe to ignore or delete manually. ## Test plan - [x] `poetry run test` — all impacted backend tests pass (182/182 in targeted scope) - [x] `pnpm test:unit` — all 1628 integration tests pass - [x] `poetry run format` / `pnpm format` / `pnpm types` clean - [x] Manual sanity against dev env — Baseline turn logged $0.1221 for 40K/139 tokens on Sonnet 4 (matches expected pricing) - [ ] `/pr-test --fix` end-to-end against local native stack --- .../features/admin/rate_limit_admin_routes.py | 32 +- .../admin/rate_limit_admin_routes_test.py | 18 +- .../backend/api/features/chat/routes.py | 58 ++- .../backend/api/features/chat/routes_test.py | 40 +- .../backend/copilot/baseline/service.py | 106 +++- .../copilot/baseline/service_unit_test.py | 476 +++++++++--------- .../backend/backend/copilot/config.py | 32 +- .../backend/backend/copilot/rate_limit.py | 270 +++++----- .../backend/copilot/rate_limit_test.py | 100 ++-- .../backend/copilot/reset_usage_test.py | 12 +- .../backend/backend/copilot/sdk/service.py | 37 +- .../backend/backend/copilot/token_tracking.py | 83 +-- .../backend/copilot/token_tracking_test.py | 100 ++-- .../backend/backend/util/feature_flag.py | 4 +- .../backend/snapshots/get_rate_limit | 8 +- .../reset_user_usage_daily_and_weekly | 8 +- .../snapshots/reset_user_usage_daily_only | 8 +- .../(platform)/admin/components/UsageBar.tsx | 10 +- .../components/__tests__/UsageBar.test.tsx | 31 ++ .../components/RateLimitDisplay.tsx | 17 +- .../__tests__/RateLimitDisplay.test.tsx | 18 +- .../__tests__/RateLimitManager.test.tsx | 16 +- .../__tests__/useRateLimitManager.test.ts | 20 +- .../app/(platform)/copilot/CopilotPage.tsx | 8 +- .../copilot/__tests__/CopilotPage.test.tsx | 22 +- .../components/UsageLimits/UsageLimits.tsx | 10 +- .../UsageLimits/UsagePanelContent.tsx | 50 +- .../__tests__/UsageLimits.test.tsx | 75 +-- .../UsagePanelContentRender.test.tsx | 68 ++- .../components/__tests__/usageHelpers.test.ts | 76 +++ .../copilot/components/usageHelpers.ts | 6 + .../AgentBriefingPanel/BriefingTabContent.tsx | 58 +-- .../__tests__/BriefingTabContent.test.tsx | 212 ++++++++ .../profile/(user)/credits/page.tsx | 10 +- .../frontend/src/app/api/openapi.json | 80 +-- 35 files changed, 1330 insertions(+), 849 deletions(-) create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/components/__tests__/UsageBar.test.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/components/__tests__/usageHelpers.test.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/__tests__/BriefingTabContent.test.tsx diff --git a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py index 379b9e9257..3b9c762f21 100644 --- a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py +++ b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py @@ -32,10 +32,10 @@ router = APIRouter( class UserRateLimitResponse(BaseModel): user_id: str user_email: Optional[str] = None - daily_token_limit: int - weekly_token_limit: int - daily_tokens_used: int - weekly_tokens_used: int + daily_cost_limit_microdollars: int + weekly_cost_limit_microdollars: int + daily_cost_used_microdollars: int + weekly_cost_used_microdollars: int tier: SubscriptionTier @@ -101,17 +101,19 @@ async def get_user_rate_limit( logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id) daily_limit, weekly_limit, tier = await get_global_rate_limits( - resolved_id, config.daily_token_limit, config.weekly_token_limit + resolved_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier) return UserRateLimitResponse( user_id=resolved_id, user_email=resolved_email, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, - daily_tokens_used=usage.daily.used, - weekly_tokens_used=usage.weekly.used, + daily_cost_limit_microdollars=daily_limit, + weekly_cost_limit_microdollars=weekly_limit, + daily_cost_used_microdollars=usage.daily.used, + weekly_cost_used_microdollars=usage.weekly.used, tier=tier, ) @@ -141,7 +143,9 @@ async def reset_user_rate_limit( raise HTTPException(status_code=500, detail="Failed to reset usage") from e daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier) @@ -154,10 +158,10 @@ async def reset_user_rate_limit( return UserRateLimitResponse( user_id=user_id, user_email=resolved_email, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, - daily_tokens_used=usage.daily.used, - weekly_tokens_used=usage.weekly.used, + daily_cost_limit_microdollars=daily_limit, + weekly_cost_limit_microdollars=weekly_limit, + daily_cost_used_microdollars=usage.daily.used, + weekly_cost_used_microdollars=usage.weekly.used, tier=tier, ) diff --git a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py index 77e4a656fb..c6c920829d 100644 --- a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py @@ -85,10 +85,10 @@ def test_get_rate_limit( data = response.json() assert data["user_id"] == target_user_id assert data["user_email"] == _TARGET_EMAIL - assert data["daily_token_limit"] == 2_500_000 - assert data["weekly_token_limit"] == 12_500_000 - assert data["daily_tokens_used"] == 500_000 - assert data["weekly_tokens_used"] == 3_000_000 + assert data["daily_cost_limit_microdollars"] == 2_500_000 + assert data["weekly_cost_limit_microdollars"] == 12_500_000 + assert data["daily_cost_used_microdollars"] == 500_000 + assert data["weekly_cost_used_microdollars"] == 3_000_000 assert data["tier"] == "FREE" configured_snapshot.assert_match( @@ -117,7 +117,7 @@ def test_get_rate_limit_by_email( data = response.json() assert data["user_id"] == target_user_id assert data["user_email"] == _TARGET_EMAIL - assert data["daily_token_limit"] == 2_500_000 + assert data["daily_cost_limit_microdollars"] == 2_500_000 def test_get_rate_limit_by_email_not_found( @@ -160,9 +160,9 @@ def test_reset_user_usage_daily_only( assert response.status_code == 200 data = response.json() - assert data["daily_tokens_used"] == 0 + assert data["daily_cost_used_microdollars"] == 0 # Weekly is untouched - assert data["weekly_tokens_used"] == 3_000_000 + assert data["weekly_cost_used_microdollars"] == 3_000_000 assert data["tier"] == "FREE" mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False) @@ -192,8 +192,8 @@ def test_reset_user_usage_daily_and_weekly( assert response.status_code == 200 data = response.json() - assert data["daily_tokens_used"] == 0 - assert data["weekly_tokens_used"] == 0 + assert data["daily_cost_used_microdollars"] == 0 + assert data["weekly_cost_used_microdollars"] == 0 assert data["tier"] == "FREE" mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index eceedb828c..6ef15f0999 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -34,7 +34,7 @@ from backend.copilot.pending_message_helpers import ( ) from backend.copilot.pending_messages import peek_pending_messages from backend.copilot.rate_limit import ( - CoPilotUsageStatus, + CoPilotUsagePublic, RateLimitExceeded, acquire_reset_lock, check_rate_limit, @@ -536,23 +536,27 @@ async def get_session( ) async def get_copilot_usage( user_id: Annotated[str, Security(auth.get_user_id)], -) -> CoPilotUsageStatus: +) -> CoPilotUsagePublic: """Get CoPilot usage status for the authenticated user. - Returns current token usage vs limits for daily and weekly windows. - Global defaults sourced from LaunchDarkly (falling back to config). - Includes the user's rate-limit tier. + Returns the percentage of the daily/weekly allowance used — not the + raw spend or cap — so clients cannot derive per-turn cost or platform + margins. Global defaults sourced from LaunchDarkly (falling back to + config). Includes the user's rate-limit tier. """ daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) - return await get_usage_status( + status = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, rate_limit_reset_cost=config.rate_limit_reset_cost, tier=tier, ) + return CoPilotUsagePublic.from_status(status) class RateLimitResetResponse(BaseModel): @@ -561,7 +565,9 @@ class RateLimitResetResponse(BaseModel): success: bool credits_charged: int = Field(description="Credits charged (in cents)") remaining_balance: int = Field(description="Credit balance after charge (in cents)") - usage: CoPilotUsageStatus = Field(description="Updated usage status after reset") + usage: CoPilotUsagePublic = Field( + description="Updated usage status after reset (percentages only)" + ) @router.post( @@ -585,7 +591,7 @@ async def reset_copilot_usage( ) -> RateLimitResetResponse: """Reset the daily CoPilot rate limit by spending credits. - Allows users who have hit their daily token limit to spend credits + Allows users who have hit their daily cost limit to spend credits to reset their daily usage counter and continue working. Returns 400 if the feature is disabled or the user is not over the limit. Returns 402 if the user has insufficient credits. @@ -604,7 +610,9 @@ async def reset_copilot_usage( ) daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) if daily_limit <= 0: @@ -641,8 +649,8 @@ async def reset_copilot_usage( # used for limit checks, not returned to the client.) usage_status = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, tier=tier, ) if daily_limit > 0 and usage_status.daily.used < daily_limit: @@ -677,7 +685,7 @@ async def reset_copilot_usage( # Reset daily usage in Redis. If this fails, refund the credits # so the user is not charged for a service they did not receive. - if not await reset_daily_usage(user_id, daily_token_limit=daily_limit): + if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit): # Compensate: refund the charged credits. refunded = False try: @@ -713,11 +721,11 @@ async def reset_copilot_usage( finally: await release_reset_lock(user_id) - # Return updated usage status. + # Return updated usage status (public schema — percentages only). updated_usage = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, rate_limit_reset_cost=config.rate_limit_reset_cost, tier=tier, ) @@ -726,7 +734,7 @@ async def reset_copilot_usage( success=True, credits_charged=cost, remaining_balance=remaining, - usage=updated_usage, + usage=CoPilotUsagePublic.from_status(updated_usage), ) @@ -787,7 +795,7 @@ async def cancel_session_task( ), }, 404: {"description": "Session not found or access denied"}, - 429: {"description": "Token rate-limit or call-frequency cap exceeded"}, + 429: {"description": "Cost rate-limit or call-frequency cap exceeded"}, }, ) async def stream_chat_post( @@ -861,18 +869,20 @@ async def stream_chat_post( }, ) - # Pre-turn rate limit check (token-based). + # Pre-turn rate limit check (cost-based, microdollars). # check_rate_limit short-circuits internally when both limits are 0. # Global defaults sourced from LaunchDarkly, falling back to config. if user_id: try: daily_limit, weekly_limit, _ = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) await check_rate_limit( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, ) except RateLimitExceeded as e: raise HTTPException(status_code=429, detail=str(e)) from e diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index 4dc6547515..88c4ef5f14 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -296,8 +296,8 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerF _mock_stream_internals(mocker) # Ensure the rate-limit branch is entered by setting a non-zero limit. - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)), @@ -318,8 +318,8 @@ def test_stream_chat_returns_429_on_weekly_rate_limit( from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) resets_at = datetime.now(UTC) + timedelta(days=3) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", @@ -341,8 +341,8 @@ def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture): from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", side_effect=RateLimitExceeded( @@ -402,23 +402,33 @@ def test_usage_returns_daily_and_weekly( mocker: pytest_mock.MockerFixture, test_user_id: str, ) -> None: - """GET /usage returns daily and weekly usage.""" + """GET /usage returns percentages for daily and weekly windows only. + + The raw used/limit microdollar values MUST NOT leak — clients should not + be able to derive per-turn cost or platform margins from the public API. + """ mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) response = client.get("/usage") assert response.status_code == 200 data = response.json() - assert data["daily"]["used"] == 500 - assert data["weekly"]["used"] == 2000 + # 500 / 10000 = 5%, 2000 / 50000 = 4% + assert data["daily"]["percent_used"] == 5.0 + assert data["weekly"]["percent_used"] == 4.0 + # Raw spend/limit must not be exposed. + assert "used" not in data["daily"] + assert "limit" not in data["daily"] + assert "used" not in data["weekly"] + assert "limit" not in data["weekly"] mock_get.assert_called_once_with( user_id=test_user_id, - daily_token_limit=10000, - weekly_token_limit=50000, + daily_cost_limit=10000, + weekly_cost_limit=50000, rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost, tier=SubscriptionTier.FREE, ) @@ -438,8 +448,8 @@ def test_usage_uses_config_limits( assert response.status_code == 200 mock_get.assert_called_once_with( user_id=test_user_id, - daily_token_limit=99999, - weekly_token_limit=77777, + daily_cost_limit=99999, + weekly_cost_limit=77777, rate_limit_reset_cost=500, tier=SubscriptionTier.FREE, ) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 7d27beac8b..8a26002e25 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -22,7 +22,9 @@ from typing import TYPE_CHECKING, Any, cast import orjson from langfuse import propagate_attributes +from openai.types import CompletionUsage from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam +from openai.types.completion_usage import PromptTokensDetails from opentelemetry import trace as otel_trace from backend.copilot.config import CopilotLlmModel, CopilotMode @@ -126,6 +128,53 @@ _MAX_INLINE_IMAGE_BYTES = 20 * 1024 * 1024 # Matches characters unsafe for filenames. _UNSAFE_FILENAME = re.compile(r"[^\w.\-]") +# OpenRouter-specific extra_body flag that embeds the real generation cost +# into the final usage chunk. Module-level constant so we don't reallocate +# an identical dict on every streaming call. +_OPENROUTER_INCLUDE_USAGE_COST = {"usage": {"include": True}} + + +def _extract_usage_cost(usage: CompletionUsage) -> float | None: + """Return the provider-reported USD cost on a streaming usage chunk. + + OpenRouter piggybacks a ``cost`` field on the OpenAI-compatible usage + object when the request body includes ``usage: {"include": True}``. + The OpenAI SDK's typed ``CompletionUsage`` does not declare it, so we + read it off ``model_extra`` (the pydantic v2 container for extras) to + keep the access fully typed — no ``getattr``. + + Returns ``None`` when the field is absent, explicitly null, + non-numeric, non-finite, or negative. Invalid values (including + present-but-null) are logged here — they indicate a provider bug + worth chasing; plain absences are silent so the caller can dedupe + the "missing cost" warning per stream. + """ + extras = usage.model_extra or {} + if "cost" not in extras: + return None + raw = extras["cost"] + if raw is None: + logger.error("[Baseline] usage.cost is present but null") + return None + try: + val = float(raw) + except (TypeError, ValueError): + logger.error("[Baseline] usage.cost is not numeric: %r", raw) + return None + if not math.isfinite(val) or val < 0: + logger.error("[Baseline] usage.cost is non-finite or negative: %r", val) + return None + return val + + +def _extract_cache_creation_tokens(ptd: PromptTokensDetails) -> int: + """Read Anthropic's ``cache_creation_input_tokens`` off an OpenAI + ``PromptTokensDetails`` — it's a provider-specific extra, not in the + typed model, so we read it via ``model_extra`` rather than + ``getattr``. + """ + return int((ptd.model_extra or {}).get("cache_creation_input_tokens") or 0) + async def _prepare_baseline_attachments( file_ids: list[str], @@ -267,6 +316,10 @@ class _BaselineStreamState: turn_cache_read_tokens: int = 0 turn_cache_creation_tokens: int = 0 cost_usd: float | None = None + # Tracks whether we've already warned about a missing `cost` field in + # the usage chunk this stream, so non-OpenRouter providers don't + # generate one warning per streaming call. + cost_missing_logged: bool = False thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper) session_messages: list[ChatMessage] = field(default_factory=list) # Tracks how much of ``assistant_text`` has already been flushed to @@ -292,10 +345,12 @@ async def _baseline_llm_caller( state.thinking_stripper = _ThinkingStripper() round_text = "" - response = None # initialized before try so finally block can access it try: client = _get_openai_client() typed_messages = cast(list[ChatCompletionMessageParam], messages) + # extra_body `usage.include=true` asks OpenRouter to embed the real + # generation cost into the final usage chunk. Without this we only get + # token counts and have no authoritative cost for rate limiting. if tools: typed_tools = cast(list[ChatCompletionToolParam], tools) response = await client.chat.completions.create( @@ -304,6 +359,7 @@ async def _baseline_llm_caller( tools=typed_tools, stream=True, stream_options={"include_usage": True}, + extra_body=_OPENROUTER_INCLUDE_USAGE_COST, ) else: response = await client.chat.completions.create( @@ -311,6 +367,7 @@ async def _baseline_llm_caller( messages=typed_messages, stream=True, stream_options={"include_usage": True}, + extra_body=_OPENROUTER_INCLUDE_USAGE_COST, ) tool_calls_by_index: dict[int, dict[str, str]] = {} @@ -323,18 +380,33 @@ async def _baseline_llm_caller( if chunk.usage: state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 state.turn_completion_tokens += chunk.usage.completion_tokens or 0 - # Extract cache token details when available (OpenAI / - # OpenRouter include these in prompt_tokens_details). - ptd = getattr(chunk.usage, "prompt_tokens_details", None) + ptd = chunk.usage.prompt_tokens_details if ptd: - state.turn_cache_read_tokens += ( - getattr(ptd, "cached_tokens", 0) or 0 - ) - # cache_creation_input_tokens is reported by some providers - # (e.g. Anthropic native) but not standard OpenAI streaming. + state.turn_cache_read_tokens += ptd.cached_tokens or 0 state.turn_cache_creation_tokens += ( - getattr(ptd, "cache_creation_input_tokens", 0) or 0 + _extract_cache_creation_tokens(ptd) ) + cost = _extract_usage_cost(chunk.usage) + if cost is not None: + state.cost_usd = (state.cost_usd or 0.0) + cost + elif ( + "cost" not in (chunk.usage.model_extra or {}) + and not state.cost_missing_logged + ): + # Field absent (non-OpenRouter route, or OpenRouter + # misconfigured) — warn once per stream so error + # monitoring picks up persistent misses without + # flooding. Invalid values already logged inside + # _extract_usage_cost, so no duplicate warning here. + logger.warning( + "[Baseline] usage chunk missing cost (model=%s, " + "prompt=%s, completion=%s) — rate-limit will " + "skip this call", + state.model, + chunk.usage.prompt_tokens, + chunk.usage.completion_tokens, + ) + state.cost_missing_logged = True delta = chunk.choices[0].delta if chunk.choices else None if not delta: @@ -394,20 +466,6 @@ async def _baseline_llm_caller( state.text_started = False state.text_block_id = str(uuid.uuid4()) finally: - # Extract OpenRouter cost from response headers (in finally so we - # capture cost even when the stream errors mid-way — we already paid). - # Accumulate across multi-round tool-calling turns. - try: - # Access undocumented _response attribute — same pattern as - # extract_openrouter_cost() in blocks/llm.py. - cost_header = response._response.headers.get("x-total-cost") # type: ignore[attr-defined] - if cost_header: - cost = float(cost_header) - if math.isfinite(cost) and cost >= 0: - state.cost_usd = (state.cost_usd or 0.0) + cost - except (AttributeError, ValueError): - pass - # Always persist partial text so the session history stays consistent, # even when the stream is interrupted by an exception. state.assistant_text += round_text diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index a0e55d843f..e21618c367 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionToolParam from backend.copilot.baseline.service import ( _baseline_conversation_updater, + _baseline_llm_caller, _BaselineStreamState, _compress_session_messages, ) @@ -574,37 +575,80 @@ class TestPrepareBaselineAttachments: assert blocks == [] +_COST_MISSING = object() + + +def _make_usage_chunk( + *, + prompt_tokens: int = 0, + completion_tokens: int = 0, + cost: float | str | None | object = _COST_MISSING, + cached_tokens: int | None = None, + cache_creation_input_tokens: int | None = None, +): + """Build a mock streaming chunk carrying usage (and optionally cost). + + Provider-specific fields (``cost`` on usage, ``cache_creation_input_tokens`` + on prompt_tokens_details) are set on ``model_extra`` because that's where + the baseline helper reads them from (typed ``CompletionUsage.model_extra`` + rather than ``getattr``). Pass ``cost=None`` to emit an explicit-null cost + key; omit ``cost`` entirely to leave the key absent. + """ + chunk = MagicMock() + chunk.choices = [] + chunk.usage = MagicMock() + chunk.usage.prompt_tokens = prompt_tokens + chunk.usage.completion_tokens = completion_tokens + usage_extras: dict[str, float | str | None] = {} + if cost is not _COST_MISSING: + usage_extras["cost"] = cost # type: ignore[assignment] + chunk.usage.model_extra = usage_extras + + if cached_tokens is not None or cache_creation_input_tokens is not None: + ptd = MagicMock() + ptd.cached_tokens = cached_tokens or 0 + ptd.model_extra = { + "cache_creation_input_tokens": cache_creation_input_tokens or 0 + } + chunk.usage.prompt_tokens_details = ptd + else: + chunk.usage.prompt_tokens_details = None + + return chunk + + +def _make_stream_mock(*chunks): + """Build an async streaming response mock that yields *chunks* in order.""" + stream = MagicMock() + stream.close = AsyncMock() + + async def aiter(): + for c in chunks: + yield c + + stream.__aiter__ = lambda self: aiter() + return stream + + class TestBaselineCostExtraction: - """Tests for x-total-cost header extraction in _baseline_llm_caller.""" + """Tests for ``usage.cost`` extraction in ``_baseline_llm_caller``. + + Cost is read from the OpenRouter ``usage.cost`` field on the final + streaming chunk when the request body includes ``usage: {include: true}`` + (handled by the baseline service via ``extra_body``). + """ @pytest.mark.asyncio - async def test_cost_usd_extracted_from_response_header(self): - """state.cost_usd is set from x-total-cost header when present.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + async def test_cost_usd_extracted_from_usage_chunk(self): + """state.cost_usd is set from chunk.usage.cost when present.""" state = _BaselineStreamState(model="gpt-4o-mini") - - # Build a mock raw httpx response with the cost header - mock_raw_response = MagicMock() - mock_raw_response.headers = {"x-total-cost": "0.0123"} - - # Build a mock async streaming response that yields no chunks but has - # a _response attribute pointing to the mock httpx response - mock_stream_response = MagicMock() - mock_stream_response._response = mock_raw_response - - async def empty_aiter(): - return - yield # make it an async generator - - mock_stream_response.__aiter__ = lambda self: empty_aiter() + chunk = _make_usage_chunk( + prompt_tokens=1000, completion_tokens=200, cost=0.0123 + ) mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( - return_value=mock_stream_response + return_value=_make_stream_mock(chunk) ) with patch( @@ -622,29 +666,14 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_cost_usd_accumulates_across_calls(self): """cost_usd accumulates when _baseline_llm_caller is called multiple times.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - state = _BaselineStreamState(model="gpt-4o-mini") - def make_stream_mock(cost: str) -> MagicMock: - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": cost} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - async def empty_aiter(): - return - yield - - mock_stream.__aiter__ = lambda self: empty_aiter() - return mock_stream - mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( - side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")] + side_effect=[ + _make_stream_mock(_make_usage_chunk(prompt_tokens=500, cost=0.01)), + _make_stream_mock(_make_usage_chunk(prompt_tokens=600, cost=0.02)), + ] ) with patch( @@ -665,28 +694,64 @@ class TestBaselineCostExtraction: assert state.cost_usd == pytest.approx(0.03) @pytest.mark.asyncio - async def test_no_cost_when_header_absent(self): - """state.cost_usd remains None when response has no x-total-cost header.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + async def test_cost_usd_accepts_string_value(self): + """OpenRouter may emit cost as a string — it should still parse.""" state = _BaselineStreamState(model="gpt-4o-mini") - - mock_raw = MagicMock() - mock_raw.headers = {} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - async def empty_aiter(): - return - yield - - mock_stream.__aiter__ = lambda self: empty_aiter() + chunk = _make_usage_chunk(prompt_tokens=10, cost="0.005") mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd == pytest.approx(0.005) + + @pytest.mark.asyncio + async def test_cost_usd_none_when_usage_cost_missing(self): + """state.cost_usd stays None when the usage chunk lacks a cost field.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + chunk = _make_usage_chunk(prompt_tokens=1000, completion_tokens=500) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + # Token accumulators are still populated so the caller can log them. + assert state.turn_prompt_tokens == 1000 + assert state.turn_completion_tokens == 500 + + @pytest.mark.asyncio + async def test_invalid_cost_string_leaves_cost_none(self): + """A non-numeric cost value is rejected without raising.""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost="not-a-number") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -701,28 +766,73 @@ class TestBaselineCostExtraction: assert state.cost_usd is None @pytest.mark.asyncio - async def test_cost_extracted_even_when_stream_raises(self): - """cost_usd is captured in the finally block even when streaming fails.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + async def test_negative_cost_is_ignored(self): + """Guard against negative cost values (shouldn't happen but be safe).""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost=-0.01) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) ) + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + + @pytest.mark.asyncio + async def test_explicit_null_cost_is_logged_and_ignored(self, caplog): + """`{"cost": null}` is rejected and logged (not silently dropped).""" + state = _BaselineStreamState(model="openrouter/auto") + chunk = _make_usage_chunk(prompt_tokens=10, cost=None) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + caplog.at_level("ERROR", logger="backend.copilot.baseline.service"), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + assert any( + "usage.cost is present but null" in rec.message for rec in caplog.records + ) + + @pytest.mark.asyncio + async def test_cost_not_captured_when_stream_raises_mid_chunk(self): + """If the stream aborts before emitting the usage chunk there is no cost.""" state = _BaselineStreamState(model="gpt-4o-mini") - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.005"} - mock_stream = MagicMock() - mock_stream._response = mock_raw + stream = MagicMock() + stream.close = AsyncMock() async def failing_aiter(): raise RuntimeError("stream error") yield # make it an async generator - mock_stream.__aiter__ = lambda self: failing_aiter() + stream.__aiter__ = lambda self: failing_aiter() mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock(return_value=stream) with ( patch( @@ -737,16 +847,12 @@ class TestBaselineCostExtraction: state=state, ) - assert state.cost_usd == pytest.approx(0.005) + # Stream aborted before yielding the usage chunk — cost stays None. + assert state.cost_usd is None @pytest.mark.asyncio async def test_no_cost_when_api_call_raises_before_stream(self): - """finally block is safe when response is None (API call failed before yielding).""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + """The helper is safe when the create() call itself raises.""" state = _BaselineStreamState(model="gpt-4o-mini") mock_client = MagicMock() @@ -767,84 +873,23 @@ class TestBaselineCostExtraction: state=state, ) - # response was never assigned so cost extraction must not raise - assert state.cost_usd is None - - @pytest.mark.asyncio - async def test_no_cost_when_header_missing(self): - """cost_usd remains None when x-total-cost is absent.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 500 - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) - - with patch( - "backend.copilot.baseline.service._get_openai_client", - return_value=mock_client, - ): - await _baseline_llm_caller( - messages=[{"role": "user", "content": "hi"}], - tools=[], - state=state, - ) - assert state.cost_usd is None @pytest.mark.asyncio async def test_cache_tokens_extracted_from_usage_details(self): """cache tokens are extracted from prompt_tokens_details.cached_tokens.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk( + prompt_tokens=1000, + completion_tokens=200, + cost=0.01, + cached_tokens=800, ) - state = _BaselineStreamState(model="openai/gpt-4o") - - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.01"} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - # Create a chunk with prompt_tokens_details - mock_ptd = MagicMock() - mock_ptd.cached_tokens = 800 - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 200 - mock_chunk.usage.prompt_tokens_details = mock_ptd - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -861,37 +906,20 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_cache_creation_tokens_extracted_from_usage_details(self): - """cache_creation_tokens are extracted from prompt_tokens_details.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + """cache_creation_input_tokens is extracted from prompt_tokens_details.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk( + prompt_tokens=1000, + completion_tokens=200, + cost=0.01, + cached_tokens=0, + cache_creation_input_tokens=500, ) - state = _BaselineStreamState(model="openai/gpt-4o") - - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.01"} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_ptd = MagicMock() - mock_ptd.cached_tokens = 0 - mock_ptd.cache_creation_input_tokens = 500 - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 200 - mock_chunk.usage.prompt_tokens_details = mock_ptd - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -908,37 +936,17 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_token_accumulators_track_across_multiple_calls(self): """Token accumulators grow correctly across multiple _baseline_llm_caller calls.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - def make_stream(prompt_tokens: int, completion_tokens: int): - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = prompt_tokens - mock_chunk.usage.completion_tokens = completion_tokens - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - return mock_stream - mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( side_effect=[ - make_stream(1000, 200), - make_stream(1100, 300), + _make_stream_mock( + _make_usage_chunk(prompt_tokens=1000, completion_tokens=200) + ), + _make_stream_mock( + _make_usage_chunk(prompt_tokens=1100, completion_tokens=300) + ), ] ) @@ -957,45 +965,33 @@ class TestBaselineCostExtraction: state=state, ) - # No x-total-cost header and empty pricing table -- cost_usd remains None + # No usage.cost on either chunk → cost stays None, tokens still accumulate. assert state.cost_usd is None - # Accumulators hold all tokens across both turns assert state.turn_prompt_tokens == 2100 assert state.turn_completion_tokens == 500 + @pytest.mark.parametrize( + "tools", + [ + pytest.param([], id="no_tools"), + pytest.param([_make_tool("search")], id="with_tools"), + ], + ) @pytest.mark.asyncio - async def test_cost_usd_remains_none_when_header_missing(self): - """cost_usd stays None when x-total-cost header is absent. + async def test_baseline_requests_usage_include_extra_body( + self, tools: list[ChatCompletionToolParam] + ): + """The baseline call must pass extra_body={'usage': {'include': True}}. - Token counts are still tracked; persist_and_record_usage handles - the None cost by falling back to tracking_type='tokens'. + This guards the contract with OpenRouter that triggers inclusion of + the authoritative cost on the final usage chunk. Without it the + rate-limit counter stays at zero. Exercise both the no-tools and + tool-calling branches so a regression in either path trips the test. """ - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 500 - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - + state = _BaselineStreamState(model="gpt-4o-mini") + create_mock = AsyncMock(return_value=_make_stream_mock()) mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = create_mock with patch( "backend.copilot.baseline.service._get_openai_client", @@ -1003,13 +999,15 @@ class TestBaselineCostExtraction: ): await _baseline_llm_caller( messages=[{"role": "user", "content": "hi"}], - tools=[], + tools=tools, state=state, ) - assert state.cost_usd is None - assert state.turn_prompt_tokens == 1000 - assert state.turn_completion_tokens == 500 + create_mock.assert_awaited_once() + await_args = create_mock.await_args + assert await_args is not None + assert await_args.kwargs["extra_body"] == {"usage": {"include": True}} + assert await_args.kwargs["stream_options"] == {"include_usage": True} class TestMidLoopPendingFlushOrdering: diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index ee4c717dbe..3277854172 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -101,25 +101,31 @@ class ChatConfig(BaseSettings): description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)", ) - # Rate limiting — token-based limits per day and per week. - # Per-turn token cost varies with context size: ~10-15K for early turns, - # ~30-50K mid-session, up to ~100K pre-compaction. Average across a - # session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily - # allows ~70-100 turns/day. + # Rate limiting — cost-based limits per day and per week, stored in + # microdollars (1 USD = 1_000_000). The counter tracks the real + # generation cost reported by the provider (OpenRouter ``usage.cost`` + # or Claude Agent SDK ``total_cost_usd``), so cache discounts and + # cross-model price differences are already reflected — no token + # weighting or model multiplier is applied on top. # Checked at the HTTP layer (routes.py) before each turn. # - # These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS, + # These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS, # ENTERPRISE) multiply these by their tier multiplier (see - # rate_limit.TIER_MULTIPLIERS). User tier is stored in the + # rate_limit.TIER_MULTIPLIERS). User tier is stored in the # User.subscriptionTier DB column and resolved inside # get_global_rate_limits(). - daily_token_limit: int = Field( - default=2_500_000, - description="Max tokens per day, resets at midnight UTC (0 = unlimited)", + # + # These defaults act as the ceiling when LaunchDarkly is unreachable; + # the live per-tier values come from the COPILOT_*_COST_LIMIT flags. + daily_cost_limit_microdollars: int = Field( + default=1_000_000, + description="Max cost per day in microdollars, resets at midnight UTC " + "(0 = unlimited).", ) - weekly_token_limit: int = Field( - default=12_500_000, - description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)", + weekly_cost_limit_microdollars: int = Field( + default=5_000_000, + description="Max cost per week in microdollars, resets Monday 00:00 UTC " + "(0 = unlimited).", ) # Cost (in credits / cents) to reset the daily rate limit using credits. diff --git a/autogpt_platform/backend/backend/copilot/rate_limit.py b/autogpt_platform/backend/backend/copilot/rate_limit.py index c08cb1b3a8..472ddf79b0 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit.py @@ -1,9 +1,16 @@ -"""CoPilot rate limiting based on token usage. +"""CoPilot rate limiting based on generation cost. -Uses Redis fixed-window counters to track per-user token consumption -with configurable daily and weekly limits. Daily windows reset at -midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00 -UTC). Fails open when Redis is unavailable to avoid blocking users. +Uses Redis fixed-window counters to track per-user USD spend (stored as +microdollars, matching ``PlatformCostLog.cost_microdollars``) with +configurable daily and weekly limits. Daily windows reset at midnight UTC; +weekly windows reset at ISO week boundary (Monday 00:00 UTC). Fails open +when Redis is unavailable to avoid blocking users. + +Storing microdollars rather than tokens means the counter already reflects +real model pricing (including cache discounts and provider surcharges), so +this module carries no pricing table — the cost comes from OpenRouter's +``usage.cost`` field (baseline) or the Claude Agent SDK's reported total +cost (SDK path). """ import asyncio @@ -22,8 +29,10 @@ from backend.util.cache import cached logger = logging.getLogger(__name__) -# Redis key prefixes -_USAGE_KEY_PREFIX = "copilot:usage" +# Redis key prefixes. Bumped from "copilot:usage" (token-based) to +# "copilot:cost" on the token→cost migration so stale counters do not +# get misinterpreted as microdollars (which would dramatically under-count). +_USAGE_KEY_PREFIX = "copilot:cost" # --------------------------------------------------------------------------- @@ -32,7 +41,7 @@ _USAGE_KEY_PREFIX = "copilot:usage" class SubscriptionTier(str, Enum): - """Subscription tiers with increasing token allowances. + """Subscription tiers with increasing cost allowances. Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``. Once ``prisma generate`` is run, this can be replaced with:: @@ -46,9 +55,9 @@ class SubscriptionTier(str, Enum): ENTERPRISE = "ENTERPRISE" -# Multiplier applied to the base limits (from LD / config) for each tier. -# Intentionally int (not float): keeps limits as whole token counts and avoids -# floating-point rounding. If fractional multipliers are ever needed, change +# Multiplier applied to the base cost limits (from LD / config) for each tier. +# Intentionally int (not float): keeps limits as whole microdollars and avoids +# floating-point rounding. If fractional multipliers are ever needed, change # the type and round the result in get_global_rate_limits(). TIER_MULTIPLIERS: dict[SubscriptionTier, int] = { SubscriptionTier.FREE: 1, @@ -61,17 +70,27 @@ DEFAULT_TIER = SubscriptionTier.FREE class UsageWindow(BaseModel): - """Usage within a single time window.""" + """Usage within a single time window. + + ``used`` and ``limit`` are in microdollars (1 USD = 1_000_000). + """ used: int limit: int = Field( - description="Maximum tokens allowed in this window. 0 means unlimited." + description="Maximum microdollars of spend allowed in this window. " + "0 means unlimited." ) resets_at: datetime class CoPilotUsageStatus(BaseModel): - """Current usage status for a user across all windows.""" + """Current usage status for a user across all windows. + + Internal representation used by server-side code that needs to compare + usage against limits (e.g. the reset-credits endpoint). The public API + returns ``CoPilotUsagePublic`` instead so that raw spend and limit + figures never leak to clients. + """ daily: UsageWindow weekly: UsageWindow @@ -82,6 +101,68 @@ class CoPilotUsageStatus(BaseModel): ) +class UsageWindowPublic(BaseModel): + """Public view of a usage window — only the percentage and reset time. + + Hides the raw spend and the cap so clients cannot derive per-turn cost + or reverse-engineer platform margins. ``percent_used`` is capped at 100. + """ + + percent_used: float = Field( + ge=0.0, + le=100.0, + description="Percentage of the window's allowance used (0-100). " + "Clamped at 100 when over the cap.", + ) + resets_at: datetime + + +class CoPilotUsagePublic(BaseModel): + """Current usage status for a user — public (client-safe) shape.""" + + daily: UsageWindowPublic | None = Field( + default=None, + description="Null when no daily cap is configured (unlimited).", + ) + weekly: UsageWindowPublic | None = Field( + default=None, + description="Null when no weekly cap is configured (unlimited).", + ) + tier: SubscriptionTier = DEFAULT_TIER + reset_cost: int = Field( + default=0, + description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.", + ) + + @classmethod + def from_status(cls, status: CoPilotUsageStatus) -> "CoPilotUsagePublic": + """Project the internal status onto the client-safe schema.""" + + def window(w: UsageWindow) -> UsageWindowPublic | None: + if w.limit <= 0: + return None + # When at/over the cap, snap to exactly 100.0 so the UI's + # rounded display and its exhaustion check (`percent_used >= 100`) + # agree. Without this, e.g. 99.95% would render as "100% used" + # via Math.round but fail the exhaustion check, leaving the + # reset button hidden while the bar appears full. + if w.used >= w.limit: + pct = 100.0 + else: + pct = round(100.0 * w.used / w.limit, 1) + return UsageWindowPublic( + percent_used=pct, + resets_at=w.resets_at, + ) + + return cls( + daily=window(status.daily), + weekly=window(status.weekly), + tier=status.tier, + reset_cost=status.reset_cost, + ) + + class RateLimitExceeded(Exception): """Raised when a user exceeds their CoPilot usage limit.""" @@ -103,8 +184,8 @@ class RateLimitExceeded(Exception): async def get_usage_status( user_id: str, - daily_token_limit: int, - weekly_token_limit: int, + daily_cost_limit: int, + weekly_cost_limit: int, rate_limit_reset_cost: int = 0, tier: SubscriptionTier = DEFAULT_TIER, ) -> CoPilotUsageStatus: @@ -112,13 +193,13 @@ async def get_usage_status( Args: user_id: The user's ID. - daily_token_limit: Max tokens per day (0 = unlimited). - weekly_token_limit: Max tokens per week (0 = unlimited). + daily_cost_limit: Max microdollars of spend per day (0 = unlimited). + weekly_cost_limit: Max microdollars of spend per week (0 = unlimited). rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled). tier: The user's rate-limit tier (included in the response). Returns: - CoPilotUsageStatus with current usage and limits. + CoPilotUsageStatus with current usage and limits in microdollars. """ now = datetime.now(UTC) daily_used = 0 @@ -137,12 +218,12 @@ async def get_usage_status( return CoPilotUsageStatus( daily=UsageWindow( used=daily_used, - limit=daily_token_limit, + limit=daily_cost_limit, resets_at=_daily_reset_time(now=now), ), weekly=UsageWindow( used=weekly_used, - limit=weekly_token_limit, + limit=weekly_cost_limit, resets_at=_weekly_reset_time(now=now), ), tier=tier, @@ -152,22 +233,22 @@ async def get_usage_status( async def check_rate_limit( user_id: str, - daily_token_limit: int, - weekly_token_limit: int, + daily_cost_limit: int, + weekly_cost_limit: int, ) -> None: """Check if user is within rate limits. Raises RateLimitExceeded if not. This is a pre-turn soft check. The authoritative usage counter is updated - by ``record_token_usage()`` after the turn completes. Under concurrency, + by ``record_cost_usage()`` after the turn completes. Under concurrency, two parallel turns may both pass this check against the same snapshot. - This is acceptable because token-based limits are approximate by nature - (the exact token count is unknown until after generation). + This is acceptable because cost-based limits are approximate by nature + (the exact cost is unknown until after generation). Fails open: if Redis is unavailable, allows the request. """ # Short-circuit: when both limits are 0 (unlimited) skip the Redis # round-trip entirely. - if daily_token_limit <= 0 and weekly_token_limit <= 0: + if daily_cost_limit <= 0 and weekly_cost_limit <= 0: return now = datetime.now(UTC) @@ -183,26 +264,25 @@ async def check_rate_limit( logger.warning("Redis unavailable for rate limit check, allowing request") return - # Worst-case overshoot: N concurrent requests × ~15K tokens each. - if daily_token_limit > 0 and daily_used >= daily_token_limit: + if daily_cost_limit > 0 and daily_used >= daily_cost_limit: raise RateLimitExceeded("daily", _daily_reset_time(now=now)) - if weekly_token_limit > 0 and weekly_used >= weekly_token_limit: + if weekly_cost_limit > 0 and weekly_used >= weekly_cost_limit: raise RateLimitExceeded("weekly", _weekly_reset_time(now=now)) -async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool: - """Reset a user's daily token usage counter in Redis. +async def reset_daily_usage(user_id: str, daily_cost_limit: int = 0) -> bool: + """Reset a user's daily cost usage counter in Redis. Called after a user pays credits to extend their daily limit. - Also reduces the weekly usage counter by ``daily_token_limit`` tokens + Also reduces the weekly usage counter by ``daily_cost_limit`` microdollars (clamped to 0) so the user effectively gets one extra day's worth of weekly capacity. Args: user_id: The user's ID. - daily_token_limit: The configured daily token limit. When positive, - the weekly counter is reduced by this amount. + daily_cost_limit: The configured daily cost limit in microdollars. + When positive, the weekly counter is reduced by this amount. Returns False if Redis is unavailable so the caller can handle compensation (fail-closed for billed operations, unlike the read-only @@ -218,12 +298,12 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool: # counter is not decremented — which would let the caller refund # credits even though the daily limit was already reset. d_key = _daily_key(user_id, now=now) - w_key = _weekly_key(user_id, now=now) if daily_token_limit > 0 else None + w_key = _weekly_key(user_id, now=now) if daily_cost_limit > 0 else None pipe = redis.pipeline(transaction=True) pipe.delete(d_key) if w_key is not None: - pipe.decrby(w_key, daily_token_limit) + pipe.decrby(w_key, daily_cost_limit) results = await pipe.execute() # Clamp negative weekly counter to 0 (best-effort; not critical). @@ -296,84 +376,40 @@ async def increment_daily_reset_count(user_id: str) -> None: logger.warning("Redis unavailable for tracking reset count") -async def record_token_usage( +async def record_cost_usage( user_id: str, - prompt_tokens: int, - completion_tokens: int, - *, - cache_read_tokens: int = 0, - cache_creation_tokens: int = 0, - model_cost_multiplier: float = 1.0, + cost_microdollars: int, ) -> None: - """Record token usage for a user across all windows. + """Record a user's generation spend against daily and weekly counters. - Uses cost-weighted counting so cached tokens don't unfairly penalise - multi-turn conversations. Anthropic's pricing: - - uncached input: 100% - - cache creation: 25% - - cache read: 10% - - output: 100% - - ``prompt_tokens`` should be the *uncached* input count (``input_tokens`` - from the API response). Cache counts are passed separately. - - ``model_cost_multiplier`` scales the final weighted total to reflect - relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet) - so that Opus turns deplete the rate limit faster, proportional to cost. + ``cost_microdollars`` is the real generation cost reported by the + provider (OpenRouter's ``usage.cost`` or the Claude Agent SDK's + ``total_cost_usd`` converted to microdollars). Because the provider + cost already reflects model pricing and cache discounts, this function + carries no pricing table or weighting — it just increments counters. Args: user_id: The user's ID. - prompt_tokens: Uncached input tokens. - completion_tokens: Output tokens. - cache_read_tokens: Tokens served from prompt cache (10% cost). - cache_creation_tokens: Tokens written to prompt cache (25% cost). - model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus). + cost_microdollars: Spend to record in microdollars (1 USD = 1_000_000). + Non-positive values are ignored. """ - prompt_tokens = max(0, prompt_tokens) - completion_tokens = max(0, completion_tokens) - cache_read_tokens = max(0, cache_read_tokens) - cache_creation_tokens = max(0, cache_creation_tokens) - - weighted_input = ( - prompt_tokens - + round(cache_creation_tokens * 0.25) - + round(cache_read_tokens * 0.1) - ) - total = round( - (weighted_input + completion_tokens) * max(1.0, model_cost_multiplier) - ) - if total <= 0: + cost_microdollars = max(0, cost_microdollars) + if cost_microdollars <= 0: return - raw_total = ( - prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens - ) - logger.info( - "Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx " - "(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)", - user_id[:8], - raw_total, - total, - model_cost_multiplier, - prompt_tokens, - cache_read_tokens, - cache_creation_tokens, - completion_tokens, - ) + logger.info("Recording copilot spend: %d microdollars", cost_microdollars) now = datetime.now(UTC) try: redis = await get_redis_async() - # transaction=False: these are independent INCRBY+EXPIRE pairs on - # separate keys — no cross-key atomicity needed. Skipping - # MULTI/EXEC avoids the overhead. If the connection drops between - # INCRBY and EXPIRE the key survives until the next date-based key - # rotation (daily/weekly), so the memory-leak risk is negligible. - pipe = redis.pipeline(transaction=False) + # Use MULTI/EXEC so each INCRBY/EXPIRE pair is atomic — guarantees + # the TTL is set even if the connection drops mid-pipeline, so + # counters can never survive past their date-based rotation window. + pipe = redis.pipeline(transaction=True) # Daily counter (expires at next midnight UTC) d_key = _daily_key(user_id, now=now) - pipe.incrby(d_key, total) + pipe.incrby(d_key, cost_microdollars) seconds_until_daily_reset = int( (_daily_reset_time(now=now) - now).total_seconds() ) @@ -381,7 +417,7 @@ async def record_token_usage( # Weekly counter (expires end of week) w_key = _weekly_key(user_id, now=now) - pipe.incrby(w_key, total) + pipe.incrby(w_key, cost_microdollars) seconds_until_weekly_reset = int( (_weekly_reset_time(now=now) - now).total_seconds() ) @@ -390,8 +426,8 @@ async def record_token_usage( await pipe.execute() except (RedisError, ConnectionError, OSError): logger.warning( - "Redis unavailable for recording token usage (tokens=%d)", - total, + "Redis unavailable for recording cost usage (microdollars=%d)", + cost_microdollars, ) @@ -598,37 +634,41 @@ async def get_global_rate_limits( ) -> tuple[int, int, SubscriptionTier]: """Resolve global rate limits from LaunchDarkly, falling back to config. - The base limits (from LD or config) are multiplied by the user's - tier multiplier so that higher tiers receive proportionally larger - allowances. + Values are microdollars. The base limits (from LD or config) are + multiplied by the user's tier multiplier so that higher tiers receive + proportionally larger allowances. Args: user_id: User ID for LD flag evaluation context. - config_daily: Fallback daily limit from ChatConfig. - config_weekly: Fallback weekly limit from ChatConfig. + config_daily: Fallback daily cost limit (microdollars) from ChatConfig. + config_weekly: Fallback weekly cost limit (microdollars) from ChatConfig. Returns: - (daily_token_limit, weekly_token_limit, tier) 3-tuple. + (daily_cost_limit, weekly_cost_limit, tier) — limits in microdollars. """ # Lazy import to avoid circular dependency: # rate_limit -> feature_flag -> settings -> ... -> rate_limit from backend.util.feature_flag import Flag, get_feature_flag_value - daily_raw = await get_feature_flag_value( - Flag.COPILOT_DAILY_TOKEN_LIMIT.value, user_id, config_daily - ) - weekly_raw = await get_feature_flag_value( - Flag.COPILOT_WEEKLY_TOKEN_LIMIT.value, user_id, config_weekly + # Fetch daily + weekly flags in parallel — each LD evaluation is an + # independent network round-trip, so gather cuts latency roughly in half. + daily_raw, weekly_raw = await asyncio.gather( + get_feature_flag_value( + Flag.COPILOT_DAILY_COST_LIMIT.value, user_id, config_daily + ), + get_feature_flag_value( + Flag.COPILOT_WEEKLY_COST_LIMIT.value, user_id, config_weekly + ), ) try: daily = max(0, int(daily_raw)) except (TypeError, ValueError): - logger.warning("Invalid LD value for daily token limit: %r", daily_raw) + logger.warning("Invalid LD value for daily cost limit: %r", daily_raw) daily = config_daily try: weekly = max(0, int(weekly_raw)) except (TypeError, ValueError): - logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw) + logger.warning("Invalid LD value for weekly cost limit: %r", weekly_raw) weekly = config_weekly # Apply tier multiplier diff --git a/autogpt_platform/backend/backend/copilot/rate_limit_test.py b/autogpt_platform/backend/backend/copilot/rate_limit_test.py index 577093c752..3787796c17 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit_test.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit_test.py @@ -24,7 +24,7 @@ from .rate_limit import ( get_usage_status, get_user_tier, increment_daily_reset_count, - record_token_usage, + record_cost_usage, release_reset_lock, reset_daily_usage, reset_user_usage, @@ -82,7 +82,7 @@ class TestGetUsageStatus: return_value=mock_redis, ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert isinstance(status, CoPilotUsageStatus) @@ -98,7 +98,7 @@ class TestGetUsageStatus: side_effect=ConnectionError("Redis down"), ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert status.daily.used == 0 @@ -115,7 +115,7 @@ class TestGetUsageStatus: return_value=mock_redis, ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert status.daily.used == 0 @@ -132,7 +132,7 @@ class TestGetUsageStatus: return_value=mock_redis, ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert status.daily.used == 500 @@ -148,7 +148,7 @@ class TestGetUsageStatus: return_value=mock_redis, ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) now = datetime.now(UTC) @@ -174,7 +174,7 @@ class TestCheckRateLimit: ): # Should not raise await check_rate_limit( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) @pytest.mark.asyncio @@ -188,7 +188,7 @@ class TestCheckRateLimit: ): with pytest.raises(RateLimitExceeded) as exc_info: await check_rate_limit( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert exc_info.value.window == "daily" @@ -203,7 +203,7 @@ class TestCheckRateLimit: ): with pytest.raises(RateLimitExceeded) as exc_info: await check_rate_limit( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert exc_info.value.window == "weekly" @@ -216,7 +216,7 @@ class TestCheckRateLimit: ): # Should not raise await check_rate_limit( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) @pytest.mark.asyncio @@ -229,15 +229,15 @@ class TestCheckRateLimit: return_value=mock_redis, ): # Should not raise — limits of 0 mean unlimited - await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0) + await check_rate_limit(_USER, daily_cost_limit=0, weekly_cost_limit=0) # --------------------------------------------------------------------------- -# record_token_usage +# record_cost_usage # --------------------------------------------------------------------------- -class TestRecordTokenUsage: +class TestRecordCostUsage: @staticmethod def _make_pipeline_mock() -> MagicMock: """Create a pipeline mock with sync methods and async execute.""" @@ -255,27 +255,40 @@ class TestRecordTokenUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50) + await record_cost_usage(_USER, cost_microdollars=123_456) - # Should call incrby twice (daily + weekly) with total=150 + # Should call incrby twice (daily + weekly) with the same cost incrby_calls = mock_pipe.incrby.call_args_list assert len(incrby_calls) == 2 - assert incrby_calls[0].args[1] == 150 # daily - assert incrby_calls[1].args[1] == 150 # weekly + assert incrby_calls[0].args[1] == 123_456 # daily + assert incrby_calls[1].args[1] == 123_456 # weekly @pytest.mark.asyncio - async def test_skips_when_zero_tokens(self): + async def test_skips_when_cost_is_zero(self): mock_redis = AsyncMock() with patch( "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0) + await record_cost_usage(_USER, cost_microdollars=0) # Should not call pipeline at all mock_redis.pipeline.assert_not_called() + @pytest.mark.asyncio + async def test_skips_when_cost_is_negative(self): + """Negative costs are clamped to zero and skip the pipeline.""" + mock_redis = AsyncMock() + + with patch( + "backend.copilot.rate_limit.get_redis_async", + return_value=mock_redis, + ): + await record_cost_usage(_USER, cost_microdollars=-10) + + mock_redis.pipeline.assert_not_called() + @pytest.mark.asyncio async def test_sets_expire_on_both_keys(self): """Pipeline should call expire for both daily and weekly keys.""" @@ -287,7 +300,7 @@ class TestRecordTokenUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50) + await record_cost_usage(_USER, cost_microdollars=5_000) expire_calls = mock_pipe.expire.call_args_list assert len(expire_calls) == 2 @@ -308,32 +321,7 @@ class TestRecordTokenUsage: side_effect=ConnectionError("Redis down"), ): # Should not raise - await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50) - - @pytest.mark.asyncio - async def test_cost_weighted_counting(self): - """Cached tokens should be weighted: cache_read=10%, cache_create=25%.""" - mock_pipe = self._make_pipeline_mock() - mock_redis = AsyncMock() - mock_redis.pipeline = lambda **_kw: mock_pipe - - with patch( - "backend.copilot.rate_limit.get_redis_async", - return_value=mock_redis, - ): - await record_token_usage( - _USER, - prompt_tokens=100, # uncached → 100 - completion_tokens=50, # output → 50 - cache_read_tokens=10000, # 10% → 1000 - cache_creation_tokens=400, # 25% → 100 - ) - - # Expected weighted total: 100 + 1000 + 100 + 50 = 1250 - incrby_calls = mock_pipe.incrby.call_args_list - assert len(incrby_calls) == 2 - assert incrby_calls[0].args[1] == 1250 # daily - assert incrby_calls[1].args[1] == 1250 # weekly + await record_cost_usage(_USER, cost_microdollars=5_000) @pytest.mark.asyncio async def test_handles_redis_error_during_pipeline_execute(self): @@ -348,7 +336,7 @@ class TestRecordTokenUsage: return_value=mock_redis, ): # Should not raise — fail-open - await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50) + await record_cost_usage(_USER, cost_microdollars=5_000) # --------------------------------------------------------------------------- @@ -819,7 +807,7 @@ class TestTierLimitsRespected: assert tier == SubscriptionTier.PRO # Should NOT raise — 3M < 12.5M await check_rate_limit( - _USER, daily_token_limit=daily, weekly_token_limit=weekly + _USER, daily_cost_limit=daily, weekly_cost_limit=weekly ) @pytest.mark.asyncio @@ -853,7 +841,7 @@ class TestTierLimitsRespected: # Should raise — 2.5M >= 2.5M with pytest.raises(RateLimitExceeded): await check_rate_limit( - _USER, daily_token_limit=daily, weekly_token_limit=weekly + _USER, daily_cost_limit=daily, weekly_cost_limit=weekly ) @pytest.mark.asyncio @@ -885,7 +873,7 @@ class TestTierLimitsRespected: assert tier == SubscriptionTier.ENTERPRISE # Should NOT raise — 100M < 150M await check_rate_limit( - _USER, daily_token_limit=daily, weekly_token_limit=weekly + _USER, daily_cost_limit=daily, weekly_cost_limit=weekly ) @@ -912,7 +900,7 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - result = await reset_daily_usage(_USER, daily_token_limit=10000) + result = await reset_daily_usage(_USER, daily_cost_limit=10000) assert result is True mock_pipe.delete.assert_called_once() @@ -928,7 +916,7 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await reset_daily_usage(_USER, daily_token_limit=10000) + await reset_daily_usage(_USER, daily_cost_limit=10000) mock_pipe.decrby.assert_called_once() mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed @@ -944,14 +932,14 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await reset_daily_usage(_USER, daily_token_limit=10000) + await reset_daily_usage(_USER, daily_cost_limit=10000) mock_pipe.decrby.assert_called_once() mock_redis.set.assert_called_once() @pytest.mark.asyncio async def test_no_weekly_reduction_when_daily_limit_zero(self): - """When daily_token_limit is 0, weekly counter should not be touched.""" + """When daily_cost_limit is 0, weekly counter should not be touched.""" mock_pipe = self._make_pipeline_mock() mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result mock_redis = AsyncMock() @@ -961,7 +949,7 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await reset_daily_usage(_USER, daily_token_limit=0) + await reset_daily_usage(_USER, daily_cost_limit=0) mock_pipe.delete.assert_called_once() mock_pipe.decrby.assert_not_called() @@ -972,7 +960,7 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", side_effect=ConnectionError("Redis down"), ): - result = await reset_daily_usage(_USER, daily_token_limit=10000) + result = await reset_daily_usage(_USER, daily_cost_limit=10000) assert result is False diff --git a/autogpt_platform/backend/backend/copilot/reset_usage_test.py b/autogpt_platform/backend/backend/copilot/reset_usage_test.py index cbbf714df0..d5b4ee140e 100644 --- a/autogpt_platform/backend/backend/copilot/reset_usage_test.py +++ b/autogpt_platform/backend/backend/copilot/reset_usage_test.py @@ -16,14 +16,14 @@ from backend.util.exceptions import InsufficientBalanceError # Minimal config mock matching ChatConfig fields used by the endpoint. def _make_config( rate_limit_reset_cost: int = 500, - daily_token_limit: int = 2_500_000, - weekly_token_limit: int = 12_500_000, + daily_cost_limit_microdollars: int = 10_000_000, + weekly_cost_limit_microdollars: int = 50_000_000, max_daily_resets: int = 5, ): cfg = MagicMock() cfg.rate_limit_reset_cost = rate_limit_reset_cost - cfg.daily_token_limit = daily_token_limit - cfg.weekly_token_limit = weekly_token_limit + cfg.daily_cost_limit_microdollars = daily_cost_limit_microdollars + cfg.weekly_cost_limit_microdollars = weekly_cost_limit_microdollars cfg.max_daily_resets = max_daily_resets return cfg @@ -77,10 +77,10 @@ class TestResetCopilotUsage: assert "not available" in exc_info.value.detail async def test_no_daily_limit_returns_400(self): - """When daily_token_limit=0 (unlimited), endpoint returns 400.""" + """When daily_cost_limit=0 (unlimited), endpoint returns 400.""" with ( - patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)), + patch(f"{_MODULE}.config", _make_config(daily_cost_limit_microdollars=0)), patch(f"{_MODULE}.settings", _mock_settings()), _mock_rate_limits(daily=0), ): diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index ea0a135559..e4f29a2b65 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -165,11 +165,6 @@ _MAX_STREAM_ATTEMPTS = 3 # self-correct. The limit is generous to allow recovery attempts. _EMPTY_TOOL_CALL_LIMIT = 5 -# Cost multiplier for Opus model turns — Opus is ~5× more expensive than Sonnet -# ($15/$75 vs $3/$15 per M tokens). Applied to rate-limit counters so Opus -# turns deplete quota proportionally faster. -_OPUS_COST_MULTIPLIER = 5.0 - # User-facing error shown when the empty-tool-call circuit breaker trips. _CIRCUIT_BREAKER_ERROR_MSG = ( "AutoPilot was unable to complete the tool call " @@ -725,22 +720,20 @@ def _resolve_fallback_model() -> str | None: return _normalize_model_name(raw) -async def _resolve_model_and_multiplier( +async def _resolve_sdk_model_for_request( model: "CopilotLlmModel | None", session_id: str, -) -> tuple[str | None, float]: - """Resolve the SDK model string and rate-limit cost multiplier for a turn. +) -> str | None: + """Resolve the SDK model string for a turn. Priority (highest first): 1. Explicit per-request ``model`` tier from the frontend toggle. 2. Global config default (``_resolve_sdk_model()``). - Returns a ``(sdk_model, cost_multiplier)`` pair. - ``sdk_model`` is ``None`` when the Claude Code subscription default applies. - ``cost_multiplier`` is 5.0 for Opus, 1.0 otherwise. + Returns ``None`` when the Claude Code subscription default applies. + Rate-limit accounting no longer applies a multiplier — the real turn + cost (reported by the SDK) already reflects model-pricing differences. """ - sdk_model = _resolve_sdk_model() - if model == "advanced": sdk_model = _normalize_model_name(config.advanced_model) logger.info( @@ -748,7 +741,7 @@ async def _resolve_model_and_multiplier( session_id[:12] if session_id else "?", sdk_model, ) - return sdk_model, _OPUS_COST_MULTIPLIER + return sdk_model if model == "standard": # Reset to config default — respects subscription mode (None = CLI default). @@ -758,13 +751,9 @@ async def _resolve_model_and_multiplier( session_id[:12] if session_id else "?", sdk_model or "subscription-default", ) - return sdk_model, 1.0 + return sdk_model - # No per-request override; derive multiplier from final resolved model. - cost_multiplier = ( - _OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0 - ) - return sdk_model, cost_multiplier + return _resolve_sdk_model() _MAX_TRANSIENT_BACKOFF_SECONDS = 30 @@ -2895,7 +2884,6 @@ async def stream_chat_completion_sdk( # Defaults ensure the finally block can always reference these safely even when # an early return (e.g. sdk_cwd error) skips their normal assignment below. sdk_model: str | None = None - model_cost_multiplier: float = 1.0 # Make sure there is no more code between the lock acquisition and try-block. try: @@ -3012,10 +3000,8 @@ async def stream_chat_completion_sdk( mcp_server = create_copilot_mcp_server(use_e2b=use_e2b) - # Resolve model and cost multiplier (request tier → config default). - sdk_model, model_cost_multiplier = await _resolve_model_and_multiplier( - model, session_id - ) + # Resolve model (request tier → config default). + sdk_model = await _resolve_sdk_model_for_request(model, session_id) # Track SDK-internal compaction (PreCompact hook → start, next msg → end) compaction = CompactionTracker() @@ -3813,7 +3799,6 @@ async def stream_chat_completion_sdk( cost_usd=turn_cost_usd, model=sdk_model or config.model, provider="anthropic", - model_cost_multiplier=model_cost_multiplier, ) # --- Persist session messages --- diff --git a/autogpt_platform/backend/backend/copilot/token_tracking.py b/autogpt_platform/backend/backend/copilot/token_tracking.py index 19406ced93..f5ace5e749 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking.py @@ -1,9 +1,9 @@ -"""Shared token-usage persistence and rate-limit recording. +"""Shared usage persistence and rate-limit recording. Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to: 1. Append a ``Usage`` record to the session. - 2. Log the turn's token counts. - 3. Record weighted usage in Redis for rate-limiting. + 2. Log the turn's token counts and cost. + 3. Record the real generation cost in Redis for rate-limiting. 4. Write a PlatformCostLog entry for admin cost tracking. This module extracts that common logic so both paths stay in sync. @@ -19,7 +19,7 @@ from backend.data.db_accessors import platform_cost_db from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars from .model import ChatSession, Usage -from .rate_limit import record_token_usage +from .rate_limit import record_cost_usage logger = logging.getLogger(__name__) @@ -96,9 +96,14 @@ async def persist_and_record_usage( cost_usd: float | str | None = None, model: str | None = None, provider: str = "open_router", - model_cost_multiplier: float = 1.0, ) -> int: - """Persist token usage to session and record for rate limiting. + """Persist token usage to session and record generation cost for rate limiting. + + Rate-limit counters are charged in microdollars against the provider's + reported cost (``cost_usd``), so cache discounts and cross-model pricing + differences are already reflected. When cost is unknown the turn is + logged but the rate-limit counter is left alone — the caller logs an + error at the point the absence is detected. Args: session: The chat session to append usage to (may be None on error). @@ -108,11 +113,11 @@ async def persist_and_record_usage( cache_read_tokens: Tokens served from prompt cache (Anthropic only). cache_creation_tokens: Tokens written to prompt cache (Anthropic only). log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]"). - cost_usd: Optional cost for logging (float from SDK, str otherwise). + cost_usd: Real generation cost for the turn (float from SDK or parsed + from OpenRouter usage.cost). ``None`` means the provider did not + report a cost and rate limiting is skipped for this turn. + model: Model identifier for cost log attribution. provider: Cost provider name (e.g. "anthropic", "open_router"). - model_cost_multiplier: Relative model cost factor for rate limiting - (1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so - more expensive models deplete the rate limit proportionally faster. Returns: The computed total_tokens (prompt + completion; cache excluded). @@ -156,37 +161,51 @@ async def persist_and_record_usage( else: logger.info( f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens}," - f" total={total_tokens}" + f" total={total_tokens}, cost_usd={cost_usd}" ) - if user_id: + cost_float: float | None = None + if cost_usd is not None: try: - await record_token_usage( - user_id=user_id, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - cache_read_tokens=cache_read_tokens, - cache_creation_tokens=cache_creation_tokens, - model_cost_multiplier=model_cost_multiplier, + val = float(cost_usd) + except (ValueError, TypeError): + logger.error( + "%s cost_usd is not numeric: %r — rate limit skipped", + log_prefix, + cost_usd, ) - except Exception as usage_err: - logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err) + else: + if not math.isfinite(val): + logger.error( + "%s cost_usd is non-finite: %r — rate limit skipped", + log_prefix, + val, + ) + elif val < 0: + logger.warning( + "%s cost_usd %s is negative — skipping rate-limit + cost log", + log_prefix, + val, + ) + else: + cost_float = val + + cost_microdollars = usd_to_microdollars(cost_float) + + if user_id and cost_microdollars is not None and cost_microdollars > 0: + # record_cost_usage() owns its fail-open handling for Redis/network + # errors. Don't wrap with a broad except here — unexpected accounting + # bugs should surface instead of being silently logged as warnings. + await record_cost_usage( + user_id=user_id, + cost_microdollars=cost_microdollars, + ) # Log to PlatformCostLog for admin cost dashboard. # Include entries where cost_usd is set even if token count is 0 # (e.g. fully-cached Anthropic responses where only cache tokens # accumulate a charge without incrementing total_tokens). - if user_id and (total_tokens > 0 or cost_usd is not None): - cost_float = None - if cost_usd is not None: - try: - val = float(cost_usd) - if math.isfinite(val) and val >= 0: - cost_float = val - except (ValueError, TypeError): - pass - - cost_microdollars = usd_to_microdollars(cost_float) + if user_id and (total_tokens > 0 or cost_float is not None): session_id = session.session_id if session else None if cost_float is not None: diff --git a/autogpt_platform/backend/backend/copilot/token_tracking_test.py b/autogpt_platform/backend/backend/copilot/token_tracking_test.py index 11757ce541..ff5957e1f5 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking_test.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking_test.py @@ -37,7 +37,7 @@ class TestTotalTokens: async def test_returns_prompt_plus_completion(self): """total_tokens = prompt + completion (cache excluded from total).""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -63,7 +63,7 @@ class TestTotalTokens: async def test_cache_tokens_excluded_from_total(self): """Cache tokens are stored separately and not added to total_tokens.""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -81,7 +81,7 @@ class TestTotalTokens: async def test_baseline_path_no_cache(self): """Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion.""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -97,7 +97,7 @@ class TestTotalTokens: async def test_sdk_path_with_cache(self): """SDK (Anthropic) path passes cache tokens; total still = prompt + completion.""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -123,7 +123,7 @@ class TestSessionPersistence: async def test_appends_usage_to_session(self): session = _make_session() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): await persist_and_record_usage( @@ -144,7 +144,7 @@ class TestSessionPersistence: async def test_appends_cache_breakdown_to_session(self): session = _make_session() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): await persist_and_record_usage( @@ -163,7 +163,7 @@ class TestSessionPersistence: async def test_multiple_turns_append_multiple_records(self): session = _make_session() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): await persist_and_record_usage( @@ -178,7 +178,7 @@ class TestSessionPersistence: async def test_none_session_does_not_raise(self): """When session is None (e.g. error path), no exception should be raised.""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -210,10 +210,11 @@ class TestSessionPersistence: class TestRateLimitRecording: @pytest.mark.asyncio - async def test_calls_record_token_usage_when_user_id_present(self): + async def test_calls_record_cost_usage_when_cost_and_user_id_present(self): + """Rate-limit counter is charged with the real provider cost (microdollars).""" mock_record = AsyncMock() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new=mock_record, ): await persist_and_record_usage( @@ -223,22 +224,35 @@ class TestRateLimitRecording: completion_tokens=50, cache_read_tokens=1000, cache_creation_tokens=200, + cost_usd=0.0123, ) mock_record.assert_awaited_once_with( user_id="user-abc", - prompt_tokens=100, - completion_tokens=50, - cache_read_tokens=1000, - cache_creation_tokens=200, - model_cost_multiplier=1.0, + cost_microdollars=12_300, ) + @pytest.mark.asyncio + async def test_skips_record_when_cost_is_missing(self): + """Without a provider cost we have no authoritative figure to charge.""" + mock_record = AsyncMock() + with patch( + "backend.copilot.token_tracking.record_cost_usage", + new=mock_record, + ): + await persist_and_record_usage( + session=None, + user_id="user-abc", + prompt_tokens=100, + completion_tokens=50, + ) + mock_record.assert_not_awaited() + @pytest.mark.asyncio async def test_skips_record_when_user_id_is_none(self): """Anonymous sessions should not create Redis keys.""" mock_record = AsyncMock() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new=mock_record, ): await persist_and_record_usage( @@ -246,32 +260,38 @@ class TestRateLimitRecording: user_id=None, prompt_tokens=100, completion_tokens=50, + cost_usd=0.001, ) mock_record.assert_not_awaited() @pytest.mark.asyncio - async def test_record_failure_does_not_raise(self): - """A Redis error in record_token_usage should be swallowed (fail-open).""" - mock_record = AsyncMock(side_effect=ConnectionError("Redis down")) + async def test_record_usage_bubbles_unexpected_error(self): + """Unexpected errors from record_cost_usage must propagate. + + record_cost_usage() owns its own (RedisError, ConnectionError, OSError) + fail-open handling. Anything else is a real accounting bug and + should not be silently swallowed at this layer. + """ + mock_record = AsyncMock(side_effect=RuntimeError("boom")) with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new=mock_record, ): - # Should not raise - total = await persist_and_record_usage( - session=None, - user_id="user-xyz", - prompt_tokens=100, - completion_tokens=50, - ) - assert total == 150 + with pytest.raises(RuntimeError, match="boom"): + await persist_and_record_usage( + session=None, + user_id="user-xyz", + prompt_tokens=100, + completion_tokens=50, + cost_usd=0.002, + ) @pytest.mark.asyncio - async def test_skips_record_when_zero_tokens(self): - """Returns 0 before calling record_token_usage when tokens are zero.""" + async def test_skips_record_when_zero_tokens_and_no_cost(self): + """Returns 0 before calling record_cost_usage when there is nothing to record.""" mock_record = AsyncMock() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new=mock_record, ): await persist_and_record_usage( @@ -295,7 +315,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -336,7 +356,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -369,7 +389,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -394,7 +414,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -423,7 +443,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -452,7 +472,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -479,7 +499,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -509,7 +529,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -545,7 +565,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index c341666cdb..1e29ff4102 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -42,8 +42,8 @@ class Flag(str, Enum): CHAT = "chat" CHAT_MODE_OPTION = "chat-mode-option" COPILOT_SDK = "copilot-sdk" - COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit" - COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit" + COPILOT_DAILY_COST_LIMIT = "copilot-daily-cost-limit-microdollars" + COPILOT_WEEKLY_COST_LIMIT = "copilot-weekly-cost-limit-microdollars" STRIPE_PRICE_PRO = "stripe-price-id-pro" STRIPE_PRICE_BUSINESS = "stripe-price-id-business" GRAPHITI_MEMORY = "graphiti-memory" diff --git a/autogpt_platform/backend/snapshots/get_rate_limit b/autogpt_platform/backend/snapshots/get_rate_limit index 5bae448ba2..3ac1b94222 100644 --- a/autogpt_platform/backend/snapshots/get_rate_limit +++ b/autogpt_platform/backend/snapshots/get_rate_limit @@ -1,9 +1,9 @@ { - "daily_token_limit": 2500000, - "daily_tokens_used": 500000, + "daily_cost_limit_microdollars": 2500000, + "daily_cost_used_microdollars": 500000, "tier": "FREE", "user_email": "target@example.com", "user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c", - "weekly_token_limit": 12500000, - "weekly_tokens_used": 3000000 + "weekly_cost_limit_microdollars": 12500000, + "weekly_cost_used_microdollars": 3000000 } diff --git a/autogpt_platform/backend/snapshots/reset_user_usage_daily_and_weekly b/autogpt_platform/backend/snapshots/reset_user_usage_daily_and_weekly index c73be30be5..b5361be34a 100644 --- a/autogpt_platform/backend/snapshots/reset_user_usage_daily_and_weekly +++ b/autogpt_platform/backend/snapshots/reset_user_usage_daily_and_weekly @@ -1,9 +1,9 @@ { - "daily_token_limit": 2500000, - "daily_tokens_used": 0, + "daily_cost_limit_microdollars": 2500000, + "daily_cost_used_microdollars": 0, "tier": "FREE", "user_email": "target@example.com", "user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c", - "weekly_token_limit": 12500000, - "weekly_tokens_used": 0 + "weekly_cost_limit_microdollars": 12500000, + "weekly_cost_used_microdollars": 0 } diff --git a/autogpt_platform/backend/snapshots/reset_user_usage_daily_only b/autogpt_platform/backend/snapshots/reset_user_usage_daily_only index 5b205a8bfb..256d8e893d 100644 --- a/autogpt_platform/backend/snapshots/reset_user_usage_daily_only +++ b/autogpt_platform/backend/snapshots/reset_user_usage_daily_only @@ -1,9 +1,9 @@ { - "daily_token_limit": 2500000, - "daily_tokens_used": 0, + "daily_cost_limit_microdollars": 2500000, + "daily_cost_used_microdollars": 0, "tier": "FREE", "user_email": "target@example.com", "user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c", - "weekly_token_limit": 12500000, - "weekly_tokens_used": 3000000 + "weekly_cost_limit_microdollars": 12500000, + "weekly_cost_used_microdollars": 3000000 } diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/components/UsageBar.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/components/UsageBar.tsx index de95cf0e47..442ebf43bc 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/components/UsageBar.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/components/UsageBar.tsx @@ -1,10 +1,6 @@ "use client"; -export function formatTokens(tokens: number): string { - if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`; - if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(0)}K`; - return tokens.toString(); -} +import { formatMicrodollarsAsUsd } from "@/app/(platform)/copilot/components/usageHelpers"; export function UsageBar({ used, limit }: { used: number; limit: number }) { if (limit === 0) { @@ -17,8 +13,8 @@ export function UsageBar({ used, limit }: { used: number; limit: number }) { return (
- {formatTokens(used)} used - {formatTokens(limit)} limit + {formatMicrodollarsAsUsd(used)} spent + {formatMicrodollarsAsUsd(limit)} limit
{ + it('renders "Unlimited" when limit is 0', () => { + render(); + expect(screen.getByText("Unlimited")).toBeDefined(); + }); + + it("renders spent + limit in USD", () => { + render(); + expect(screen.getByText("$1.50 spent")).toBeDefined(); + expect(screen.getByText("$10.00 limit")).toBeDefined(); + }); + + it("renders the computed percentage", () => { + render(); + expect(screen.getByText("50.0% used")).toBeDefined(); + }); + + it("clamps percentage at 100% when over limit", () => { + render(); + expect(screen.getByText("100.0% used")).toBeDefined(); + }); + + it("clamps percentage at 0% for negative used", () => { + render(); + expect(screen.getByText("0.0% used")).toBeDefined(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/RateLimitDisplay.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/RateLimitDisplay.tsx index b216745c35..024b819699 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/RateLimitDisplay.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/RateLimitDisplay.tsx @@ -88,8 +88,9 @@ export function RateLimitDisplay({ } const nothingToReset = resetWeekly - ? data.daily_tokens_used === 0 && data.weekly_tokens_used === 0 - : data.daily_tokens_used === 0; + ? data.daily_cost_used_microdollars === 0 && + data.weekly_cost_used_microdollars === 0 + : data.daily_cost_used_microdollars === 0; return (
@@ -133,17 +134,17 @@ export function RateLimitDisplay({
-

Daily Usage

+

Daily Spend

-

Weekly Usage

+

Weekly Spend

diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitDisplay.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitDisplay.test.tsx index 5425a14ff2..08b5db312b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitDisplay.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitDisplay.test.tsx @@ -30,10 +30,10 @@ function makeData( return { user_id: "user-abc-123", user_email: "alice@example.com", - daily_token_limit: 10000, - weekly_token_limit: 50000, - daily_tokens_used: 2500, - weekly_tokens_used: 10000, + daily_cost_limit_microdollars: 10_000_000, + weekly_cost_limit_microdollars: 50_000_000, + daily_cost_used_microdollars: 2_500_000, + weekly_cost_used_microdollars: 10_000_000, tier: "FREE", ...overrides, }; @@ -113,8 +113,8 @@ describe("RateLimitDisplay", () => { it("renders daily and weekly usage sections", () => { render(); - expect(screen.getByText("Daily Usage")).toBeDefined(); - expect(screen.getByText("Weekly Usage")).toBeDefined(); + expect(screen.getByText("Daily Spend")).toBeDefined(); + expect(screen.getByText("Weekly Spend")).toBeDefined(); }); it("renders reset scope dropdown and reset button", () => { @@ -126,7 +126,7 @@ describe("RateLimitDisplay", () => { it("disables reset button when nothing to reset", () => { render( , ); @@ -137,7 +137,7 @@ describe("RateLimitDisplay", () => { it("enables reset button when there is usage to reset", () => { render( , ); @@ -174,7 +174,7 @@ describe("RateLimitDisplay", () => { render( , ); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitManager.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitManager.test.tsx index ab996748f1..8435e6dc6d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitManager.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitManager.test.tsx @@ -174,10 +174,10 @@ describe("RateLimitManager", () => { rateLimitData: { user_id: "user-123", user_email: "alice@example.com", - daily_token_limit: 10000, - weekly_token_limit: 50000, - daily_tokens_used: 2500, - weekly_tokens_used: 10000, + daily_cost_limit_microdollars: 10_000_000, + weekly_cost_limit_microdollars: 50_000_000, + daily_cost_used_microdollars: 2_500_000, + weekly_cost_used_microdollars: 10_000_000, tier: "FREE", }, }); @@ -197,10 +197,10 @@ describe("RateLimitManager", () => { rateLimitData: { user_id: "user-123", user_email: "alice@example.com", - daily_token_limit: 10000, - weekly_token_limit: 50000, - daily_tokens_used: 2500, - weekly_tokens_used: 10000, + daily_cost_limit_microdollars: 10_000_000, + weekly_cost_limit_microdollars: 50_000_000, + daily_cost_used_microdollars: 2_500_000, + weekly_cost_used_microdollars: 10_000_000, tier: "FREE", }, }); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/useRateLimitManager.test.ts b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/useRateLimitManager.test.ts index d09a74b507..523af7514b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/useRateLimitManager.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/useRateLimitManager.test.ts @@ -28,10 +28,10 @@ function makeRateLimitResponse(overrides = {}) { return { user_id: "user-123", user_email: "alice@example.com", - daily_token_limit: 10000, - weekly_token_limit: 50000, - daily_tokens_used: 2500, - weekly_tokens_used: 10000, + daily_cost_limit_microdollars: 10_000_000, + weekly_cost_limit_microdollars: 50_000_000, + daily_cost_used_microdollars: 2_500_000, + weekly_cost_used_microdollars: 10_000_000, tier: "FREE", ...overrides, }; @@ -229,8 +229,12 @@ describe("useRateLimitManager", () => { }); it("handleReset calls reset endpoint and updates data", async () => { - const initial = makeRateLimitResponse({ daily_tokens_used: 5000 }); - const after = makeRateLimitResponse({ daily_tokens_used: 0 }); + const initial = makeRateLimitResponse({ + daily_cost_used_microdollars: 5_000_000, + }); + const after = makeRateLimitResponse({ + daily_cost_used_microdollars: 0, + }); mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data: initial }); mockPostV2ResetUserRateLimitUsage.mockResolvedValue({ status: 200, @@ -338,7 +342,9 @@ describe("useRateLimitManager", () => { }); it("handleReset throws when endpoint returns non-200 status", async () => { - const initial = makeRateLimitResponse({ daily_tokens_used: 5000 }); + const initial = makeRateLimitResponse({ + daily_cost_used_microdollars: 5_000_000, + }); mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data: initial }); mockPostV2ResetUserRateLimitUsage.mockResolvedValue({ status: 500 }); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx index 158d0b2392..c3ac603073 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx @@ -1,6 +1,6 @@ "use client"; -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; import { toast } from "@/components/molecules/Toast/use-toast"; import useCredits from "@/hooks/useCredits"; @@ -125,7 +125,7 @@ export function CopilotPage() { isError: usageError, } = useGetV2GetCopilotUsage({ query: { - select: (res) => res.data as CoPilotUsageStatus, + select: (res) => res.data as CoPilotUsagePublic, refetchInterval: 30000, staleTime: 10000, }, @@ -258,9 +258,7 @@ export function CopilotPage() { resetCost={resetCost ?? 0} resetMessage={rateLimitMessage ?? ""} isWeeklyExhausted={ - hasUsage && - usage.weekly.limit > 0 && - usage.weekly.used >= usage.weekly.limit + hasUsage && !!usage.weekly && usage.weekly.percent_used >= 100 } hasInsufficientCredits={hasInsufficientCredits} isBillingEnabled={isBillingEnabled} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/CopilotPage.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/CopilotPage.test.tsx index 71791b5694..bef9a2a848 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/CopilotPage.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/CopilotPage.test.tsx @@ -39,13 +39,23 @@ vi.mock("@/components/ui/sidebar", () => ({ ), })); -// Mock hooks that hit the network +// Mock hooks that hit the network. Exercise the `select` callback so its +// line counts as covered alongside the rest of the options. vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ - useGetV2GetCopilotUsage: () => ({ - data: undefined, - isSuccess: false, - isError: false, - }), + useGetV2GetCopilotUsage: (opts: { + query?: { select?: (r: { data: unknown }) => unknown }; + }) => { + const data = { + daily: null, + weekly: null, + tier: "FREE", + reset_cost: 0, + }; + if (typeof opts?.query?.select === "function") { + opts.query.select({ data }); + } + return { data: undefined, isSuccess: false, isError: false }; + }, })); vi.mock("@/hooks/useCredits", () => ({ default: () => ({ credits: null, fetchCredits: vi.fn() }), diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsageLimits.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsageLimits.tsx index 1420e626b3..711c36c26e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsageLimits.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsageLimits.tsx @@ -1,4 +1,4 @@ -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; import useCredits from "@/hooks/useCredits"; import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; @@ -14,9 +14,9 @@ import { UsagePanelContent } from "./UsagePanelContent"; export { UsagePanelContent, formatResetTime } from "./UsagePanelContent"; export function UsageLimits() { - const { data: usage, isLoading } = useGetV2GetCopilotUsage({ + const { data: usage, isSuccess } = useGetV2GetCopilotUsage({ query: { - select: (res) => res.data as CoPilotUsageStatus, + select: (res) => res.data as CoPilotUsagePublic, refetchInterval: 30000, staleTime: 10000, }, @@ -28,8 +28,8 @@ export function UsageLimits() { const hasInsufficientCredits = credits !== null && resetCost != null && credits < resetCost; - if (isLoading || !usage?.daily || !usage?.weekly) return null; - if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null; + if (!isSuccess || !usage) return null; + if (!usage.daily && !usage.weekly) return null; return ( diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx index 91187816da..9a1c0d1c87 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx @@ -1,4 +1,4 @@ -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import { Button } from "@/components/atoms/Button/Button"; import Link from "next/link"; import { formatCents, formatResetTime } from "../usageHelpers"; @@ -8,22 +8,17 @@ export { formatResetTime }; function UsageBar({ label, - used, - limit, + percentUsed, resetsAt, }: { label: string; - used: number; - limit: number; + percentUsed: number; resetsAt: Date | string; }) { - if (limit <= 0) return null; - - const rawPercent = (used / limit) * 100; - const percent = Math.min(100, Math.round(rawPercent)); + const percent = Math.min(100, Math.max(0, Math.round(percentUsed))); const isHigh = percent >= 80; const percentLabel = - used > 0 && percent === 0 ? "<1% used" : `${percent}% used`; + percentUsed > 0 && percent === 0 ? "<1% used" : `${percent}% used`; return (
@@ -38,10 +33,15 @@ function UsageBar({
0 ? 1 : 0, percent)}%` }} + style={{ width: `${Math.max(percent > 0 ? 1 : 0, percent)}%` }} />
@@ -79,21 +79,19 @@ export function UsagePanelContent({ isBillingEnabled = false, onCreditChange, }: { - usage: CoPilotUsageStatus; + usage: CoPilotUsagePublic; showBillingLink?: boolean; hasInsufficientCredits?: boolean; isBillingEnabled?: boolean; onCreditChange?: () => void; }) { - const hasDailyLimit = usage.daily.limit > 0; - const hasWeeklyLimit = usage.weekly.limit > 0; - const isDailyExhausted = - hasDailyLimit && usage.daily.used >= usage.daily.limit; - const isWeeklyExhausted = - hasWeeklyLimit && usage.weekly.used >= usage.weekly.limit; + const daily = usage.daily; + const weekly = usage.weekly; + const isDailyExhausted = !!daily && daily.percent_used >= 100; + const isWeeklyExhausted = !!weekly && weekly.percent_used >= 100; const resetCost = usage.reset_cost ?? 0; - if (!hasDailyLimit && !hasWeeklyLimit) { + if (!daily && !weekly) { return (
No usage limits configured
); @@ -113,20 +111,18 @@ export function UsagePanelContent({ {tierLabel} plan )}
- {hasDailyLimit && ( + {daily && ( )} - {hasWeeklyLimit && ( + {weekly && ( )} {isDailyExhausted && diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsageLimits.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsageLimits.test.tsx index 9c7a78599f..67595dceec 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsageLimits.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsageLimits.test.tsx @@ -2,10 +2,19 @@ import { render, screen, cleanup } from "@/tests/integrations/test-utils"; import { afterEach, describe, expect, it, vi } from "vitest"; import { UsageLimits } from "../UsageLimits"; -// Mock the generated Orval hook +// Mock the generated Orval hook, exercising the `select` callback so its +// line counts as covered alongside the rest of the options. const mockUseGetV2GetCopilotUsage = vi.fn(); vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ - useGetV2GetCopilotUsage: (opts: unknown) => mockUseGetV2GetCopilotUsage(opts), + useGetV2GetCopilotUsage: (opts: { + query?: { select?: (r: { data: unknown }) => unknown }; + }) => { + const ret = mockUseGetV2GetCopilotUsage(opts) as { data?: unknown }; + if (ret?.data !== undefined && typeof opts?.query?.select === "function") { + opts.query.select({ data: ret.data }); + } + return ret; + }, })); // Mock Popover to render children directly (Radix portals don't work in happy-dom) @@ -27,22 +36,24 @@ afterEach(() => { }); function makeUsage({ - dailyUsed = 500, - dailyLimit = 10000, - weeklyUsed = 2000, - weeklyLimit = 50000, + dailyPercent = 5, + weeklyPercent = 4, tier = "FREE", }: { - dailyUsed?: number; - dailyLimit?: number; - weeklyUsed?: number; - weeklyLimit?: number; + dailyPercent?: number | null; + weeklyPercent?: number | null; tier?: string; } = {}) { - const future = new Date(Date.now() + 3600 * 1000); // 1h from now + const future = new Date(Date.now() + 3600 * 1000).toISOString(); return { - daily: { used: dailyUsed, limit: dailyLimit, resets_at: future }, - weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future }, + daily: + dailyPercent === null + ? null + : { percent_used: dailyPercent, resets_at: future }, + weekly: + weeklyPercent === null + ? null + : { percent_used: weeklyPercent, resets_at: future }, tier, }; } @@ -51,7 +62,7 @@ describe("UsageLimits", () => { it("renders nothing while loading", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ data: undefined, - isLoading: true, + isSuccess: false, }); const { container } = render(); expect(container.innerHTML).toBe(""); @@ -59,8 +70,8 @@ describe("UsageLimits", () => { it("renders nothing when no limits are configured", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ - data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }), - isLoading: false, + data: makeUsage({ dailyPercent: null, weeklyPercent: null }), + isSuccess: true, }); const { container } = render(); expect(container.innerHTML).toBe(""); @@ -69,16 +80,16 @@ describe("UsageLimits", () => { it("renders the usage button when limits exist", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ data: makeUsage(), - isLoading: false, + isSuccess: true, }); render(); expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined(); }); - it("displays daily and weekly usage percentages", () => { + it("displays daily and weekly percentage", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ - data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }), - isLoading: false, + data: makeUsage({ dailyPercent: 50, weeklyPercent: 4 }), + isSuccess: true, }); render(); @@ -88,14 +99,10 @@ describe("UsageLimits", () => { expect(screen.getByText("Usage limits")).toBeDefined(); }); - it("shows only weekly bar when daily limit is 0", () => { + it("shows only weekly bar when daily is null", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ - data: makeUsage({ - dailyLimit: 0, - weeklyUsed: 25000, - weeklyLimit: 50000, - }), - isLoading: false, + data: makeUsage({ dailyPercent: null, weeklyPercent: 50 }), + isSuccess: true, }); render(); @@ -103,20 +110,22 @@ describe("UsageLimits", () => { expect(screen.queryByText("Today")).toBeNull(); }); - it("caps percentage at 100% when over limit", () => { + it("caps bar width at 100% when over limit", () => { + // 150% exercises the clamp — 100% exactly is merely exhausted, not over. mockUseGetV2GetCopilotUsage.mockReturnValue({ - data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }), - isLoading: false, + data: makeUsage({ dailyPercent: 150 }), + isSuccess: true, }); render(); - expect(screen.getByText("100% used")).toBeDefined(); + const dailyBar = screen.getByRole("progressbar", { name: /today usage/i }); + expect(dailyBar.getAttribute("aria-valuenow")).toBe("100"); }); it("displays the user tier label", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ data: makeUsage({ tier: "PRO" }), - isLoading: false, + isSuccess: true, }); render(); @@ -126,7 +135,7 @@ describe("UsageLimits", () => { it("shows learn more link to credits page", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ data: makeUsage(), - isLoading: false, + isSuccess: true, }); render(); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsagePanelContentRender.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsagePanelContentRender.test.tsx index 9230663381..db2d4241a8 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsagePanelContentRender.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsagePanelContentRender.test.tsx @@ -6,7 +6,7 @@ import { } from "@/tests/integrations/test-utils"; import { afterEach, describe, expect, it, vi } from "vitest"; import { UsagePanelContent } from "../UsagePanelContent"; -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; const mockResetUsage = vi.fn(); vi.mock("../../../hooks/useResetRateLimit", () => ({ @@ -20,36 +20,38 @@ afterEach(() => { function makeUsage( overrides: Partial<{ - dailyUsed: number; - dailyLimit: number; - weeklyUsed: number; - weeklyLimit: number; + dailyPercent: number | null; + weeklyPercent: number | null; tier: string; resetCost: number; }> = {}, -): CoPilotUsageStatus { +): CoPilotUsagePublic { const { - dailyUsed = 500, - dailyLimit = 10000, - weeklyUsed = 2000, - weeklyLimit = 50000, + dailyPercent = 5, + weeklyPercent = 4, tier = "FREE", resetCost = 100, } = overrides; - const future = new Date(Date.now() + 3600 * 1000); + const future = new Date(Date.now() + 3600 * 1000).toISOString(); return { - daily: { used: dailyUsed, limit: dailyLimit, resets_at: future }, - weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future }, + daily: + dailyPercent === null + ? null + : { percent_used: dailyPercent, resets_at: future }, + weekly: + weeklyPercent === null + ? null + : { percent_used: weeklyPercent, resets_at: future }, tier, reset_cost: resetCost, - } as CoPilotUsageStatus; + } as CoPilotUsagePublic; } describe("UsagePanelContent", () => { - it("renders 'No usage limits configured' when both limits are zero", () => { + it("renders 'No usage limits configured' when both windows are null", () => { render( , ); expect(screen.getByText("No usage limits configured")).toBeDefined(); @@ -58,11 +60,7 @@ describe("UsagePanelContent", () => { it("renders the reset button when daily limit is exhausted", () => { render( , ); expect(screen.getByText(/Reset daily limit/)).toBeDefined(); @@ -72,10 +70,8 @@ describe("UsagePanelContent", () => { render( , @@ -86,11 +82,7 @@ describe("UsagePanelContent", () => { it("calls resetUsage when the reset button is clicked", () => { render( , ); fireEvent.click(screen.getByText(/Reset daily limit/)); @@ -100,15 +92,21 @@ describe("UsagePanelContent", () => { it("renders 'Add credits' link when insufficient credits", () => { render( , ); expect(screen.getByText("Add credits to reset")).toBeDefined(); }); + + it("renders percent used in the usage bar", () => { + render(); + expect(screen.getByText("25% used")).toBeDefined(); + }); + + it("renders '<1% used' when usage is greater than 0 but rounds to 0", () => { + render(); + expect(screen.getByText("<1% used")).toBeDefined(); + }); }); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/__tests__/usageHelpers.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/__tests__/usageHelpers.test.ts new file mode 100644 index 0000000000..eecdb70245 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/__tests__/usageHelpers.test.ts @@ -0,0 +1,76 @@ +import { describe, expect, it } from "vitest"; +import { + formatCents, + formatMicrodollarsAsUsd, + formatResetTime, +} from "../usageHelpers"; + +describe("formatCents", () => { + it("formats whole dollars", () => { + expect(formatCents(500)).toBe("$5.00"); + }); + + it("formats zero", () => { + expect(formatCents(0)).toBe("$0.00"); + }); + + it("formats fractional cents", () => { + expect(formatCents(1999)).toBe("$19.99"); + }); +}); + +describe("formatMicrodollarsAsUsd", () => { + it("formats zero as $0.00", () => { + expect(formatMicrodollarsAsUsd(0)).toBe("$0.00"); + }); + + it("formats whole dollar amounts", () => { + expect(formatMicrodollarsAsUsd(1_500_000)).toBe("$1.50"); + }); + + it("formats amounts that round to $0.00 but are > 0 as <$0.01", () => { + expect(formatMicrodollarsAsUsd(999)).toBe("<$0.01"); + }); + + it("formats exactly one cent as $0.01", () => { + expect(formatMicrodollarsAsUsd(10_000)).toBe("$0.01"); + }); + + it("formats negative input with toFixed semantics (no special case)", () => { + // Negative should never come from the backend, but the helper is + // safe — it simply passes through `toFixed`. + expect(formatMicrodollarsAsUsd(-1_500_000)).toBe("$-1.50"); + }); + + it("formats very large values without truncating", () => { + expect(formatMicrodollarsAsUsd(1_234_567_890)).toBe("$1234.57"); + }); +}); + +describe("formatResetTime", () => { + it("returns 'now' when reset time is in the past", () => { + const now = new Date("2026-04-21T12:00:00Z"); + const past = new Date("2026-04-21T11:59:00Z"); + expect(formatResetTime(past, now)).toBe("now"); + }); + + it("renders sub-hour resets as minutes", () => { + const now = new Date("2026-04-21T12:00:00Z"); + const future = new Date("2026-04-21T12:15:00Z"); + expect(formatResetTime(future, now)).toBe("in 15m"); + }); + + it("renders same-day resets as 'Xh Ym'", () => { + const now = new Date("2026-04-21T12:00:00Z"); + const future = new Date("2026-04-21T14:30:00Z"); + expect(formatResetTime(future, now)).toBe("in 2h 30m"); + }); + + it("renders future-day resets as a localized date string", () => { + const now = new Date("2026-04-21T12:00:00Z"); + const future = new Date("2026-04-24T12:00:00Z"); + // Not asserting exact format (localized), just that it's not the + // minute/hour form. + expect(formatResetTime(future, now)).not.toMatch(/^in \d/); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts index 599442075f..f25df85e9b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts @@ -2,6 +2,12 @@ export function formatCents(cents: number): string { return `$${(cents / 100).toFixed(2)}`; } +export function formatMicrodollarsAsUsd(microdollars: number): string { + const dollars = microdollars / 1_000_000; + if (microdollars > 0 && dollars < 0.01) return "<$0.01"; + return `$${dollars.toFixed(2)}`; +} + export function formatResetTime( resetsAt: Date | string, now: Date = new Date(), diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx index 939ec5403f..fc6e26424d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx @@ -1,6 +1,6 @@ "use client"; -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; import { @@ -42,9 +42,9 @@ export function BriefingTabContent({ activeTab, agents }: Props) { } function UsageSection() { - const { data: usage } = useGetV2GetCopilotUsage({ + const { data: usage, isSuccess } = useGetV2GetCopilotUsage({ query: { - select: (res) => res.data as CoPilotUsageStatus, + select: (res) => res.data as CoPilotUsagePublic, refetchInterval: 30000, staleTime: 10000, }, @@ -56,7 +56,8 @@ function UsageSection() { const hasInsufficientCredits = credits !== null && resetCost != null && credits < resetCost; - if (!usage?.daily || !usage?.weekly) return null; + if (!isSuccess || !usage) return null; + if (!usage.daily && !usage.weekly) return null; return (
@@ -80,19 +81,17 @@ function UsageSection() { )}
- {usage.daily.limit > 0 && ( + {usage.daily && ( )} - {usage.weekly.limit > 0 && ( + {usage.weekly && ( )} @@ -244,14 +243,12 @@ function UsageFooter({ hasInsufficientCredits, onCreditChange, }: { - usage: CoPilotUsageStatus; + usage: CoPilotUsagePublic; hasInsufficientCredits: boolean; onCreditChange?: () => void; }) { - const isDailyExhausted = - usage.daily.limit > 0 && usage.daily.used >= usage.daily.limit; - const isWeeklyExhausted = - usage.weekly.limit > 0 && usage.weekly.used >= usage.weekly.limit; + const isDailyExhausted = !!usage.daily && usage.daily.percent_used >= 100; + const isWeeklyExhausted = !!usage.weekly && usage.weekly.percent_used >= 100; const resetCost = usage.reset_cost ?? 0; const { resetUsage, isPending } = useResetRateLimit({ onCreditChange }); @@ -294,22 +291,17 @@ function UsageFooter({ function UsageMeter({ label, - used, - limit, + percentUsed, resetsAt, }: { label: string; - used: number; - limit: number; + percentUsed: number; resetsAt: Date | string; }) { - if (limit <= 0) return null; - - const rawPercent = (used / limit) * 100; - const percent = Math.min(100, Math.round(rawPercent)); + const percent = Math.min(100, Math.max(0, Math.round(percentUsed))); const isHigh = percent >= 80; const percentLabel = - used > 0 && percent === 0 ? "<1% used" : `${percent}% used`; + percentUsed > 0 && percent === 0 ? "<1% used" : `${percent}% used`; return (
@@ -323,20 +315,20 @@ function UsageMeter({
0 ? 1 : 0, percent)}%` }} + style={{ width: `${Math.max(percent > 0 ? 1 : 0, percent)}%` }} />
-
- - {used.toLocaleString()} / {limit.toLocaleString()} - - - Resets {formatResetTime(resetsAt)} - -
+ + Resets {formatResetTime(resetsAt)} +
); } diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/__tests__/BriefingTabContent.test.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/__tests__/BriefingTabContent.test.tsx new file mode 100644 index 0000000000..5dbb3bab17 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/__tests__/BriefingTabContent.test.tsx @@ -0,0 +1,212 @@ +import { render, screen, cleanup } from "@/tests/integrations/test-utils"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { BriefingTabContent } from "../BriefingTabContent"; + +const mockUseGetV2GetCopilotUsage = vi.fn(); +vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ + useGetV2GetCopilotUsage: (opts: { + query?: { select?: (r: { data: unknown }) => unknown }; + }) => { + const ret = mockUseGetV2GetCopilotUsage(opts) as { data?: unknown }; + // Exercise the `select` callback so its line counts as covered. + if (ret?.data !== undefined && typeof opts?.query?.select === "function") { + opts.query.select({ data: ret.data }); + } + return ret; + }, +})); + +const mockUseGetFlag = vi.fn(); +vi.mock("@/services/feature-flags/use-get-flag", async () => { + const actual = await vi.importActual< + typeof import("@/services/feature-flags/use-get-flag") + >("@/services/feature-flags/use-get-flag"); + return { + ...actual, + useGetFlag: (flag: unknown) => mockUseGetFlag(flag), + }; +}); + +const mockUseCredits = vi.fn(); +vi.mock("@/hooks/useCredits", () => ({ + default: (opts: unknown) => mockUseCredits(opts), +})); + +const mockResetUsage = vi.fn(); +vi.mock("@/app/(platform)/copilot/hooks/useResetRateLimit", () => ({ + useResetRateLimit: () => ({ + resetUsage: mockResetUsage, + isPending: false, + }), +})); + +afterEach(() => { + cleanup(); + mockUseGetV2GetCopilotUsage.mockReset(); + mockUseGetFlag.mockReset(); + mockUseCredits.mockReset(); + mockResetUsage.mockReset(); +}); + +function makeUsage({ + dailyPercent = 5, + weeklyPercent = 4, + tier = "FREE", + resetCost = 500, +}: { + dailyPercent?: number | null; + weeklyPercent?: number | null; + tier?: string; + resetCost?: number; +} = {}) { + const future = new Date(Date.now() + 3600 * 1000).toISOString(); + return { + daily: + dailyPercent === null + ? null + : { percent_used: dailyPercent, resets_at: future }, + weekly: + weeklyPercent === null + ? null + : { percent_used: weeklyPercent, resets_at: future }, + tier, + reset_cost: resetCost, + }; +} + +describe("BriefingTabContent — UsageSection", () => { + it("renders nothing when usage fetch has not succeeded", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: undefined, + isSuccess: false, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + const { container } = render( + , + ); + expect(container.innerHTML).toBe(""); + }); + + it("renders nothing when both windows are null (no limits configured)", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: null, weeklyPercent: null }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + const { container } = render( + , + ); + expect(container.innerHTML).toBe(""); + }); + + it("renders tier badge + daily+weekly meters at normal usage", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: 12, weeklyPercent: 4, tier: "PRO" }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(true); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + render(); + + expect(screen.getByText("Usage limits")).toBeDefined(); + expect(screen.getByText("Pro plan")).toBeDefined(); + expect(screen.getByText("12% used")).toBeDefined(); + expect(screen.getByText("4% used")).toBeDefined(); + expect(screen.getByText("Today")).toBeDefined(); + expect(screen.getByText("This week")).toBeDefined(); + expect(screen.getByText("Manage billing")).toBeDefined(); + }); + + it("shows reset button when daily limit is exhausted and user has credits", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: 100, weeklyPercent: 40, resetCost: 500 }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(true); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + render(); + + expect(screen.getByText(/Reset daily limit/)).toBeDefined(); + }); + + it("shows 'Add credits' CTA when daily exhausted but user lacks credits", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: 100, weeklyPercent: 40, resetCost: 500 }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(true); + mockUseCredits.mockReturnValue({ credits: 10, fetchCredits: vi.fn() }); + render(); + + expect(screen.getByText("Add credits to reset")).toBeDefined(); + expect(screen.queryByText(/Reset daily limit/)).toBeNull(); + }); + + it("hides reset CTAs when the weekly limit is also exhausted", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ + dailyPercent: 100, + weeklyPercent: 100, + resetCost: 500, + }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(true); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + render(); + + expect(screen.queryByText(/Reset daily limit/)).toBeNull(); + expect(screen.queryByText("Add credits to reset")).toBeNull(); + }); + + it("renders <1% used when percent is >0 but rounds to 0", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: 0.4, weeklyPercent: 0 }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + render(); + + expect(screen.getByText("<1% used")).toBeDefined(); + }); + + it("dispatches to ExecutionListSection for running/attention/completed tabs", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: undefined, + isSuccess: false, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + + for (const tab of ["running", "attention", "completed"] as const) { + const { unmount } = render( + , + ); + // Empty list -> EmptyMessage renders for each of the execution tabs. + expect( + screen.getByText(/No agents|No recently completed/i), + ).toBeDefined(); + unmount(); + } + }); + + it("dispatches to AgentListSection for listening/scheduled/idle tabs", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: undefined, + isSuccess: false, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + + for (const tab of ["listening", "scheduled", "idle"] as const) { + const { unmount } = render( + , + ); + expect(screen.getByText(/No/i)).toBeDefined(); + unmount(); + } + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/page.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/page.tsx index fb565c048b..f6f9398721 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/page.tsx @@ -13,7 +13,7 @@ import { RefundModal } from "./RefundModal"; import { SubscriptionTierSection } from "./components/SubscriptionTierSection/SubscriptionTierSection"; import { CreditTransaction } from "@/lib/autogpt-server-api"; import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits"; -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; import { @@ -27,16 +27,16 @@ import { function CoPilotUsageSection() { const router = useRouter(); - const { data: usage, isLoading } = useGetV2GetCopilotUsage({ + const { data: usage, isSuccess } = useGetV2GetCopilotUsage({ query: { - select: (res) => res.data as CoPilotUsageStatus, + select: (res) => res.data as CoPilotUsagePublic, refetchInterval: 30000, staleTime: 10000, }, }); - if (isLoading || !usage?.daily || !usage?.weekly) return null; - if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null; + if (!isSuccess || !usage) return null; + if (!usage.daily && !usage.weekly) return null; return (
diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index f20f34a805..9103d6f475 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1793,7 +1793,7 @@ } }, "429": { - "description": "Token rate-limit or call-frequency cap exceeded" + "description": "Cost rate-limit or call-frequency cap exceeded" } } } @@ -1879,14 +1879,14 @@ "get": { "tags": ["v2", "chat", "chat"], "summary": "Get Copilot Usage", - "description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.\nGlobal defaults sourced from LaunchDarkly (falling back to config).\nIncludes the user's rate-limit tier.", + "description": "Get CoPilot usage status for the authenticated user.\n\nReturns the percentage of the daily/weekly allowance used — not the\nraw spend or cap — so clients cannot derive per-turn cost or platform\nmargins. Global defaults sourced from LaunchDarkly (falling back to\nconfig). Includes the user's rate-limit tier.", "operationId": "getV2GetCopilotUsage", "responses": { "200": { "description": "Successful Response", "content": { "application/json": { - "schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" } + "schema": { "$ref": "#/components/schemas/CoPilotUsagePublic" } } } }, @@ -1901,7 +1901,7 @@ "post": { "tags": ["v2", "chat", "chat"], "summary": "Reset Copilot Usage", - "description": "Reset the daily CoPilot rate limit by spending credits.\n\nAllows users who have hit their daily token limit to spend credits\nto reset their daily usage counter and continue working.\nReturns 400 if the feature is disabled or the user is not over the limit.\nReturns 402 if the user has insufficient credits.", + "description": "Reset the daily CoPilot rate limit by spending credits.\n\nAllows users who have hit their daily cost limit to spend credits\nto reset their daily usage counter and continue working.\nReturns 400 if the feature is disabled or the user is not over the limit.\nReturns 402 if the user has insufficient credits.", "operationId": "postV2ResetCopilotUsage", "responses": { "200": { @@ -9211,10 +9211,22 @@ "title": "ClarifyingQuestion", "description": "A question that needs user clarification." }, - "CoPilotUsageStatus": { + "CoPilotUsagePublic": { "properties": { - "daily": { "$ref": "#/components/schemas/UsageWindow" }, - "weekly": { "$ref": "#/components/schemas/UsageWindow" }, + "daily": { + "anyOf": [ + { "$ref": "#/components/schemas/UsageWindowPublic" }, + { "type": "null" } + ], + "description": "Null when no daily cap is configured (unlimited)." + }, + "weekly": { + "anyOf": [ + { "$ref": "#/components/schemas/UsageWindowPublic" }, + { "type": "null" } + ], + "description": "Null when no weekly cap is configured (unlimited)." + }, "tier": { "$ref": "#/components/schemas/SubscriptionTier", "default": "FREE" @@ -9227,9 +9239,8 @@ } }, "type": "object", - "required": ["daily", "weekly"], - "title": "CoPilotUsageStatus", - "description": "Current usage status for a user across all windows." + "title": "CoPilotUsagePublic", + "description": "Current usage status for a user — public (client-safe) shape." }, "ContentType": { "type": "string", @@ -12997,8 +13008,8 @@ "description": "Credit balance after charge (in cents)" }, "usage": { - "$ref": "#/components/schemas/CoPilotUsageStatus", - "description": "Updated usage status after reset" + "$ref": "#/components/schemas/CoPilotUsagePublic", + "description": "Updated usage status after reset (percentages only)" } }, "type": "object", @@ -14259,7 +14270,7 @@ "type": "string", "enum": ["FREE", "PRO", "BUSINESS", "ENTERPRISE"], "title": "SubscriptionTier", - "description": "Subscription tiers with increasing token allowances.\n\nMirrors the ``SubscriptionTier`` enum in ``schema.prisma``.\nOnce ``prisma generate`` is run, this can be replaced with::\n\n from prisma.enums import SubscriptionTier" + "description": "Subscription tiers with increasing cost allowances.\n\nMirrors the ``SubscriptionTier`` enum in ``schema.prisma``.\nOnce ``prisma generate`` is run, this can be replaced with::\n\n from prisma.enums import SubscriptionTier" }, "SubscriptionTierRequest": { "properties": { @@ -15886,13 +15897,14 @@ "required": ["timezone"], "title": "UpdateTimezoneRequest" }, - "UsageWindow": { + "UsageWindowPublic": { "properties": { - "used": { "type": "integer", "title": "Used" }, - "limit": { - "type": "integer", - "title": "Limit", - "description": "Maximum tokens allowed in this window. 0 means unlimited." + "percent_used": { + "type": "number", + "maximum": 100.0, + "minimum": 0.0, + "title": "Percent Used", + "description": "Percentage of the window's allowance used (0-100). Clamped at 100 when over the cap." }, "resets_at": { "type": "string", @@ -15901,9 +15913,9 @@ } }, "type": "object", - "required": ["used", "limit", "resets_at"], - "title": "UsageWindow", - "description": "Usage within a single time window." + "required": ["percent_used", "resets_at"], + "title": "UsageWindowPublic", + "description": "Public view of a usage window — only the percentage and reset time.\n\nHides the raw spend and the cap so clients cannot derive per-turn cost\nor reverse-engineer platform margins. ``percent_used`` is capped at 100." }, "UserCostSummary": { "properties": { @@ -16144,31 +16156,31 @@ "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "User Email" }, - "daily_token_limit": { + "daily_cost_limit_microdollars": { "type": "integer", - "title": "Daily Token Limit" + "title": "Daily Cost Limit Microdollars" }, - "weekly_token_limit": { + "weekly_cost_limit_microdollars": { "type": "integer", - "title": "Weekly Token Limit" + "title": "Weekly Cost Limit Microdollars" }, - "daily_tokens_used": { + "daily_cost_used_microdollars": { "type": "integer", - "title": "Daily Tokens Used" + "title": "Daily Cost Used Microdollars" }, - "weekly_tokens_used": { + "weekly_cost_used_microdollars": { "type": "integer", - "title": "Weekly Tokens Used" + "title": "Weekly Cost Used Microdollars" }, "tier": { "$ref": "#/components/schemas/SubscriptionTier" } }, "type": "object", "required": [ "user_id", - "daily_token_limit", - "weekly_token_limit", - "daily_tokens_used", - "weekly_tokens_used", + "daily_cost_limit_microdollars", + "weekly_cost_limit_microdollars", + "daily_cost_used_microdollars", + "weekly_cost_used_microdollars", "tier" ], "title": "UserRateLimitResponse" From f238c153a5bb445a99d1cd71228783584db08e39 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 16:27:01 +0700 Subject: [PATCH 3/4] fix(backend/copilot): release session cluster lock on completion (#12867) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes a bug where a chat session gets silently stuck after the user presses Stop mid-turn. **Root cause:** the cancel endpoint marks the session `failed` after polling 5s, but the cluster lock held by the still-running task is only released by `on_run_done` when the task actually finishes. If the task hangs past the 5s poll (slow LLM call, agent-browser step, etc.), the lock lingers for up to 5 min — `stream_chat_post`'s `is_turn_in_flight` check sees the flipped meta (`failed`) and enqueues a new turn, but the run handler sees the stale lock and drops the user's message at `manager.py:379` (`reject+requeue=False`). The new SSE stream hangs until its 60s idle timeout. ### Fix Two cooperating changes: 1. **`mark_session_completed` force-releases the cluster lock** in the same transaction that flips status to `completed`/`failed`. Unconditional delete — by the time we're declaring the session dead, we don't care who the current lock holder is; the lock has to go so the next enqueued turn can acquire. This is what closes the stuck-session window. 2. **`ClusterLock.release()` is now owner-checked** (Lua CAS — `GET == token ? DEL : noop` atomically). Force-release means another pod may legitimately own the key by the time the original task's `on_run_done` eventually fires. Without the CAS, that late `release()` would wipe the successor's lock. With it, the late `release()` is a safe no-op when the owner has changed. Together: prompt release on completion (via force-delete) + safe cleanup when on_run_done catches up (via CAS). That re-syncs the API-level `is_turn_in_flight` check with the actual lock state, so the contention window disappears. No changes to the worker-level contention handler: `stream_chat_post` already queues incoming messages into the pending buffer when a turn is in flight (via `queue_pending_for_http`). With these fixes, the worker never sees contention in the common case; if it does (true multi-pod race), the pre-existing `reject+requeue=False` behaviour still applies — we'll revisit that path with its own PR if it becomes a production symptom. ### Verification - Reproduced the original stuck-session symptom locally (Stop mid-turn → send new message → backend logs `Session … already running on pod …`, user message silently lost, SSE stream idle 60s then closes). - After the fix: cancel → new message → turn starts normally (lock released by `mark_session_completed`). - `poetry run pyright` — 0 errors on edited files. - `pytest backend/copilot/stream_registry_test.py backend/executor/cluster_lock_test.py` — 33 passed (includes the successor-not-wiped test). ## Changes - `autogpt_platform/backend/backend/copilot/executor/utils.py` — extract `get_session_lock_key(session_id)` helper so the lock-key format has a single source of truth. - `autogpt_platform/backend/backend/copilot/executor/manager.py` — use the helper where the cluster lock is created. - `autogpt_platform/backend/backend/copilot/stream_registry.py` — `mark_session_completed` deletes the lock key after the atomic status swap (force-release). - `autogpt_platform/backend/backend/executor/cluster_lock.py` — `ClusterLock.release()` (sync + async) uses a Lua CAS to only delete when `GET == token`, protecting against wiping a successor after a force-release. ## Test plan - [ ] Send a message in /copilot that triggers a long turn (e.g. `run_agent`), press Stop before it finishes, then send another message. Expect: new turn starts promptly (no 5-min wait for lock TTL). - [ ] Happy path regression — send a normal message, verify turn completes and the session lock key is deleted after completion. - [ ] Successor protection — unit test `test_release_does_not_wipe_successor_lock` covers: A acquires, external DEL, B acquires, A.release() is a no-op, B's lock intact. --- .../backend/copilot/executor/manager.py | 3 +- .../backend/backend/copilot/executor/utils.py | 6 + .../backend/copilot/stream_registry.py | 11 +- .../backend/copilot/stream_registry_test.py | 114 ++++++++++++++++++ .../backend/backend/executor/cluster_lock.py | 31 ++++- .../backend/executor/cluster_lock_test.py | 27 +++++ 6 files changed, 185 insertions(+), 7 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/executor/manager.py b/autogpt_platform/backend/backend/copilot/executor/manager.py index da113ccc50..02a2913883 100644 --- a/autogpt_platform/backend/backend/copilot/executor/manager.py +++ b/autogpt_platform/backend/backend/copilot/executor/manager.py @@ -34,6 +34,7 @@ from .utils import ( CancelCoPilotEvent, CoPilotExecutionEntry, create_copilot_queue_config, + get_session_lock_key, ) logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]") @@ -366,7 +367,7 @@ class CoPilotExecutor(AppProcess): # Try to acquire cluster-wide lock cluster_lock = ClusterLock( redis=redis.get_redis(), - key=f"copilot:session:{session_id}:lock", + key=get_session_lock_key(session_id), owner_id=self.executor_id, timeout=settings.config.cluster_lock_timeout, ) diff --git a/autogpt_platform/backend/backend/copilot/executor/utils.py b/autogpt_platform/backend/backend/copilot/executor/utils.py index b96e1821a1..a2b051d82b 100644 --- a/autogpt_platform/backend/backend/copilot/executor/utils.py +++ b/autogpt_platform/backend/backend/copilot/executor/utils.py @@ -82,6 +82,12 @@ COPILOT_CANCEL_EXCHANGE = Exchange( ) COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue" + +def get_session_lock_key(session_id: str) -> str: + """Redis key for the per-session cluster lock held by the executing pod.""" + return f"copilot:session:{session_id}:lock" + + # CoPilot operations can include extended thinking and agent generation # which may take 30+ minutes to complete COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index f4a26b7008..424964e075 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -35,7 +35,7 @@ from backend.data.redis_client import get_redis_async from backend.data.redis_helpers import hash_compare_and_set from .config import ChatConfig -from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS +from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS, get_session_lock_key from .response_model import ( ResponseType, StreamBaseResponse, @@ -851,6 +851,15 @@ async def mark_session_completed( logger.debug(f"Session {session_id} already completed/failed, skipping") return False + # Force-release the executor's cluster lock so the next enqueued turn can + # acquire it immediately. The lock holder's on_run_done will also release + # (idempotent delete); doing it here unblocks cases where the task hangs + # past the cancel timeout or a pod crash leaves the lock orphaned. + try: + await redis.delete(get_session_lock_key(session_id)) + except RedisError as e: + logger.warning(f"Failed to release cluster lock for session {session_id}: {e}") + if error_message and not skip_error_publish: try: await publish_chunk(turn_id, StreamError(errorText=error_message)) diff --git a/autogpt_platform/backend/backend/copilot/stream_registry_test.py b/autogpt_platform/backend/backend/copilot/stream_registry_test.py index 28ec199025..db26a5f524 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry_test.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry_test.py @@ -4,8 +4,10 @@ import asyncio from unittest.mock import AsyncMock, patch import pytest +from redis.exceptions import RedisError from backend.copilot import stream_registry +from backend.copilot.executor.utils import get_session_lock_key @pytest.fixture(autouse=True) @@ -221,3 +223,115 @@ async def test_stream_and_publish_consumer_break_then_aclose_releases_inner(): await wrapper.aclose() assert inner_finally_ran.is_set() + + +# --------------------------------------------------------------------------- +# mark_session_completed: the atomic meta flip to completed/failed must also +# release the per-session cluster lock, so the next enqueued turn's run +# handler can acquire it without waiting for the TTL (5 min default). +# --------------------------------------------------------------------------- + + +class _FakeRedis: + """Minimal async-Redis fake: only the calls mark_session_completed makes.""" + + def __init__(self, meta: dict[str, str]): + self._meta = dict(meta) + self.deleted_keys: list[str] = [] + self.delete = AsyncMock(side_effect=self._record_delete) + + async def _record_delete(self, *keys: str): + self.deleted_keys.extend(keys) + for k in keys: + self._meta.pop(k, None) + return len(keys) + + async def hgetall(self, _key: str): + return dict(self._meta) + + +@pytest.mark.asyncio +async def test_mark_session_completed_releases_cluster_lock_on_success(): + """CAS swap must be followed by a DELETE on the session's lock key so a + stuck-because-of-stale-lock session becomes immediately claimable.""" + fake_redis = _FakeRedis({"status": "running", "turn_id": "turn-1"}) + + with ( + patch.object( + stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis) + ), + patch.object( + stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=True) + ), + patch.object(stream_registry, "publish_chunk", new=AsyncMock()), + patch.object( + stream_registry.chat_db(), + "set_turn_duration", + new=AsyncMock(), + create=True, + ), + ): + result = await stream_registry.mark_session_completed("sess-1") + + assert result is True + assert get_session_lock_key("sess-1") in fake_redis.deleted_keys + + +@pytest.mark.asyncio +async def test_mark_session_completed_skips_lock_release_when_already_completed(): + """CAS failure = someone else completed the session first; we must not + delete their already-released lock, and we must NOT publish StreamFinish + twice (the winning caller already published it).""" + fake_redis = _FakeRedis({"status": "completed", "turn_id": "turn-1"}) + publish_mock = AsyncMock() + + with ( + patch.object( + stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis) + ), + patch.object( + stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=False) + ), + patch.object(stream_registry, "publish_chunk", new=publish_mock), + ): + result = await stream_registry.mark_session_completed("sess-1") + + assert result is False + assert get_session_lock_key("sess-1") not in fake_redis.deleted_keys + assert not any( + isinstance(call.args[1], stream_registry.StreamFinish) + for call in publish_mock.call_args_list + ), "StreamFinish must NOT be re-published on the CAS-no-op branch" + + +@pytest.mark.asyncio +async def test_mark_session_completed_survives_lock_release_redis_error(): + """A Redis hiccup during lock DELETE must not prevent the StreamFinish + publish — the client's SSE stream would otherwise hang on the stale meta + status while Redis recovers.""" + fake_redis = _FakeRedis({"status": "running", "turn_id": "turn-1"}) + fake_redis.delete = AsyncMock(side_effect=RedisError("boom")) + publish_mock = AsyncMock() + + with ( + patch.object( + stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis) + ), + patch.object( + stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=True) + ), + patch.object(stream_registry, "publish_chunk", new=publish_mock), + patch.object( + stream_registry.chat_db(), + "set_turn_duration", + new=AsyncMock(), + create=True, + ), + ): + result = await stream_registry.mark_session_completed("sess-1") + + assert result is True + assert any( + isinstance(call.args[1], stream_registry.StreamFinish) + for call in publish_mock.call_args_list + ), "StreamFinish must still be published even if lock DELETE raises" diff --git a/autogpt_platform/backend/backend/executor/cluster_lock.py b/autogpt_platform/backend/backend/executor/cluster_lock.py index 0732c3f6de..9fe8b744c4 100644 --- a/autogpt_platform/backend/backend/executor/cluster_lock.py +++ b/autogpt_platform/backend/backend/executor/cluster_lock.py @@ -4,7 +4,7 @@ import asyncio import logging import threading import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from redis import Redis @@ -12,6 +12,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Lua CAS release: only delete the key if the stored value still matches our +# owner_id. Returns 1 on delete, 0 on no-op. This makes release() safe against +# the race where an external caller (e.g. mark_session_completed's force-release) +# deletes our key and a new owner acquires it before our release() fires — without +# the CAS guard, release() would wipe the successor's valid lock. +_RELEASE_LUA = ( + "if redis.call('get', KEYS[1]) == ARGV[1] then " + "return redis.call('del', KEYS[1]) " + "else return 0 end" +) + class ClusterLock: """Simple Redis-based distributed lock for preventing duplicate execution.""" @@ -116,13 +127,18 @@ class ClusterLock: return False def release(self): - """Release the lock.""" + """Release the lock. + + Owner-checked: only deletes the Redis key if the stored value still + matches our owner_id. Prevents wiping a successor's lock when the + original key was force-released externally and re-acquired. + """ with self._refresh_lock: if self._last_refresh == 0: return try: - self.redis.delete(self.key) + self.redis.eval(_RELEASE_LUA, 1, self.key, self.owner_id) except Exception: pass @@ -237,13 +253,18 @@ class AsyncClusterLock: return False async def release(self): - """Release the lock.""" + """Release the lock. + + Owner-checked: only deletes the Redis key if the stored value still + matches our owner_id. Prevents wiping a successor's lock when the + original key was force-released externally and re-acquired. + """ async with self._refresh_lock: if self._last_refresh == 0: return try: - await self.redis.delete(self.key) + await cast(Any, self.redis.eval(_RELEASE_LUA, 1, self.key, self.owner_id)) except Exception: pass diff --git a/autogpt_platform/backend/backend/executor/cluster_lock_test.py b/autogpt_platform/backend/backend/executor/cluster_lock_test.py index c5d8965f0f..5491c51cad 100644 --- a/autogpt_platform/backend/backend/executor/cluster_lock_test.py +++ b/autogpt_platform/backend/backend/executor/cluster_lock_test.py @@ -108,6 +108,33 @@ class TestClusterLockBasic: new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60) assert new_lock.try_acquire() == new_owner_id + def test_release_does_not_wipe_successor_lock(self, redis_client, lock_key): + """Releasing after external delete+reacquire must NOT delete successor. + + Race: an external caller force-deletes the lock key, a new owner + acquires it, then the original ClusterLock.release() runs. Owner-checked + release must leave the successor's key intact. + """ + owner_a = str(uuid.uuid4()) + owner_b = str(uuid.uuid4()) + + lock_a = ClusterLock(redis_client, lock_key, owner_a, timeout=60) + assert lock_a.try_acquire() == owner_a + + # External force-release (e.g. mark_session_completed). + redis_client.delete(lock_key) + + # Successor acquires the same key. + lock_b = ClusterLock(redis_client, lock_key, owner_b, timeout=60) + assert lock_b.try_acquire() == owner_b + + # Original releases — must be a no-op on Redis because value != owner_a. + lock_a.release() + + # Successor's lock is still intact. + assert redis_client.exists(lock_key) == 1 + assert redis_client.get(lock_key).decode("utf-8") == owner_b + class TestClusterLockRefresh: """Lock refresh and TTL management.""" From e17e9f13c4c6832eb6bfa869534181fe37b8fa6c Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 16:34:10 +0700 Subject: [PATCH 4/4] fix(backend/copilot): reduce SDK + baseline prompt cache waste (#12866) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Four cost-reduction changes for the copilot feature. Consolidated into one PR at user request; each commit is self-contained and bisectable. ### 1. SDK: full cross-user cache on every turn (CLI 2.1.116 bump) Previous behavior: CLI 2.1.97 crashed when `excludeDynamicSections=True` was combined with `--resume`, so the code fell back to a raw `system_prompt` string on resume, losing Claude Code's default prompt and all cache markers. Every Turn 2+ of an SDK session wrote ~33K tokens to cache instead of reading. Fix: install `@anthropic-ai/claude-code@2.1.116` in the backend Docker image and point the SDK at it via `CHAT_CLAUDE_AGENT_CLI_PATH=/usr/bin/claude`. CLI 2.1.98+ fixes the crash, so we can use the preset with `exclude_dynamic_sections=True` on every turn — Turn 1, 2, 3+ all share the same static prefix and hit the **cross-user** prompt cache. **Local dev requirement:** if `CHAT_CLAUDE_AGENT_CLI_PATH` is unset, the bundled 2.1.97 fallback will crash on `--resume`. Install the CLI globally (`npm install -g @anthropic-ai/claude-code@2.1.116`) or set the env var. ### 2. Baseline: add `cache_control` markers (commit `756b3ecd9` + follow-ups) Baseline path had zero `cache_control` across `backend/copilot/**`. Every turn was full uncached input (~18.6K tokens, ~$0.058). Two ephemeral markers — on the system message (content-blocks form) and the last tool schema — plus `anthropic-beta: prompt-caching-2024-07-31` via `extra_headers` as defense-in-depth. Helpers split into `_mark_tools_*` (precomputed once per session) and `_mark_system_*` (per-round, O(1)). Repeat hellos: ~$0.058 → ~$0.006. ### 3. Drop `get_baseline_supplement()` (commit `6e6c4d791`) `_generate_tool_documentation()` emitted ~4.3K tokens of `(tool_name, description)` pairs that exactly duplicated the tools array already in the same request. Deleted. `SHARED_TOOL_NOTES` (cross-tool workflow rules) is preserved. Baseline "hello" input: ~18.7K → ~14.4K tokens. ### 4. Langfuse "CoPilot Prompt" v26 (published under `review` label) Separate, out-of-repo change. v25 had three duplicate "Example Response" blocks + a 10-step "Internal Reasoning Process" section. v26 collapses to one example + bullet-form reasoning. Char count 20,481 → 7,075 (rough 4 chars/token → ~5,100 → ~1,770 tokens). - v26 is published with label `review` (NOT `production`); v25 remains active. - Promote via `mcp__langfuse__updatePromptLabels(name="CoPilot Prompt", version=26, newLabels=["production"])` after smoke-test. - Rollback: relabel v25 `production`. ## Test plan - [x] Unit tests for `_build_system_prompt_value` (fresh vs resumed turns emit identical preset dict) - [x] SDK compat tests pass including `test_bundled_cli_version_is_known_good_against_openrouter` - [x] `cli_openrouter_compat_test.py` passes against CLI 2.1.116 (locally verified with `CHAT_CLAUDE_AGENT_CLI_PATH=/opt/homebrew/bin/claude`) - [x] 8 new `_mark_*` unit tests + identity regression test for `_fresh_*` helpers - [x] `SHARED_TOOL_NOTES` public-constant test passes; 5 old tool-docs tests removed - [ ] **Manual cost verification (commit 1):** send two consecutive SDK turns; Turn 2 and Turn 3 should both show `cacheReadTokens` ≈ 33K (full cross-user cache hits). - [ ] **Manual cost verification (commit 2):** send two "hello" turns on baseline <5 min apart; Turn 2 reports `cacheReadTokens` ≈ 18K and cost ≈ $0.006. - [ ] **Regression sweep for commit 3:** one turn per tool family — `search_agents`, `run_agent`, `add_memory`/`forget_memory`/`search_memory`, `search_docs`, `read_workspace_file` — to verify no tool-selection regression from dropping the prose tool docs. - [ ] **Langfuse v26 smoke test:** 5-10 varied turns after relabelling to `production`; compare responses vs v25 for regression on persona, concision, capability-gap handling, credential security flows. ## Deployment notes - Production Docker image now installs CLI 2.1.116 (~20 MB added). - `CHAT_CLAUDE_AGENT_CLI_PATH=/usr/bin/claude` set in the Dockerfile; runtime can override via env. - First deploy after this merge needs a fresh image rebuild to pick up the new CLI. --- .../backend/copilot/baseline/service.py | 251 ++++++++++++-- .../copilot/baseline/service_unit_test.py | 309 +++++++++++++++++- .../backend/backend/copilot/config.py | 12 + .../backend/backend/copilot/prompting.py | 55 +--- .../backend/copilot/sdk/sdk_compat_test.py | 23 +- .../backend/backend/copilot/sdk/service.py | 46 +-- .../backend/copilot/sdk/service_test.py | 100 ++---- autogpt_platform/backend/poetry.lock | 20 +- autogpt_platform/backend/pyproject.toml | 2 +- 9 files changed, 622 insertions(+), 196 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 8a26002e25..4e495264c8 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -15,7 +15,7 @@ import re import shutil import tempfile import uuid -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator, Mapping, Sequence from dataclasses import dataclass, field from functools import partial from typing import TYPE_CHECKING, Any, cast @@ -47,7 +47,7 @@ from backend.copilot.pending_messages import ( drain_pending_messages, format_pending_as_user_message, ) -from backend.copilot.prompting import get_baseline_supplement, get_graphiti_supplement +from backend.copilot.prompting import SHARED_TOOL_NOTES, get_graphiti_supplement from backend.copilot.response_model import ( StreamBaseResponse, StreamError, @@ -168,12 +168,37 @@ def _extract_usage_cost(usage: CompletionUsage) -> float | None: def _extract_cache_creation_tokens(ptd: PromptTokensDetails) -> int: - """Read Anthropic's ``cache_creation_input_tokens`` off an OpenAI - ``PromptTokensDetails`` — it's a provider-specific extra, not in the - typed model, so we read it via ``model_extra`` rather than - ``getattr``. + """Return cache-write token count from an OpenAI-compatible + ``PromptTokensDetails``, handling provider-specific field names and + SDK-version shape differences. + + Two shapes we care about: + + - **OpenRouter** (our primary baseline provider) streams the cache-write + count as ``cache_write_tokens``. Newer ``openai-python`` versions + declare this as a typed attribute on ``PromptTokensDetails``; older + versions expose it only in ``model_extra``. Verified empirically: + cold-cache request returns ``cache_write_tokens`` > 0, warm-cache + request returns ``cached_tokens`` > 0 and ``cache_write_tokens`` = 0. + - **Direct Anthropic API** uses ``cache_creation_input_tokens`` — + never a typed attribute on the OpenAI SDK, always lives in + ``model_extra``. + + Lookup order: typed attr → ``model_extra`` (OpenRouter) → ``model_extra`` + (Anthropic-native). ``getattr`` handles both the typed-attr case + (newer SDK) and the no-such-attr case (older SDK) — we can't only use + ``model_extra`` because when the field is typed it's filtered out of + ``model_extra``, leaving us at 0 on the modern happy path. """ - return int((ptd.model_extra or {}).get("cache_creation_input_tokens") or 0) + typed_val = getattr(ptd, "cache_write_tokens", None) + if typed_val: + return int(typed_val) + extras = ptd.model_extra or {} + return int( + extras.get("cache_write_tokens") + or extras.get("cache_creation_input_tokens") + or 0 + ) async def _prepare_baseline_attachments( @@ -327,6 +352,137 @@ class _BaselineStreamState: # block only appends the *new* assistant text (avoiding duplication of # round-1 text when round-1 entries were cleared from session_messages). _flushed_assistant_text_len: int = 0 + # Memoised system-message dict with cache_control applied. The system + # prompt is static within a session, so we build it once on the first + # LLM round and reuse the same dict on subsequent rounds — avoiding + # an O(N) dict-copy of the growing ``messages`` list on every tool-call + # iteration. ``None`` means "not yet computed" (or the first message + # wasn't a system role, so no marking applies). + cached_system_message: dict[str, Any] | None = None + + +def _is_anthropic_model(model: str) -> bool: + """Return True if *model* routes to Anthropic (native or via OpenRouter). + + Cache-control markers on message content + the ``anthropic-beta`` header + are Anthropic-specific. OpenAI rejects the unknown ``cache_control`` + field with a 400 ("Extra inputs are not permitted") and Grok / other + providers behave similarly. OpenRouter strips unknown headers but + passes through ``cache_control`` on the body regardless of provider — + which would also fail when OpenRouter routes to a non-Anthropic model. + + Examples that return True: + - ``anthropic/claude-sonnet-4-6`` (OpenRouter route) + - ``claude-3-5-sonnet-20241022`` (direct Anthropic API) + - ``anthropic.claude-3-5-sonnet`` (Bedrock-style) + + False for ``openai/gpt-4o``, ``google/gemini-2.5-pro``, ``xai/grok-4`` + etc. + """ + lowered = model.lower() + return "claude" in lowered or lowered.startswith("anthropic") + + +def _fresh_ephemeral_cache_control() -> dict[str, str]: + """Return a FRESH ephemeral ``cache_control`` dict each call. + + The ``ttl`` is sourced from :attr:`ChatConfig.baseline_prompt_cache_ttl` + (default ``1h``) so the static prefix stays warm across many users' + requests in the same workspace cache. Anthropic caches are keyed + per-workspace, so every copilot user reading the same system prompt + hits the same cached entry. + + Using a shared module-level dict would let any downstream mutation + (e.g. the OpenAI SDK normalising fields in-place) poison every future + request's marker. Construction is O(1) so the safety margin is free. + """ + return {"type": "ephemeral", "ttl": config.baseline_prompt_cache_ttl} + + +def _fresh_anthropic_caching_headers() -> dict[str, str]: + """Return a FRESH ``extra_headers`` dict requesting the Anthropic + prompt-caching beta. + + Same reasoning as :func:`_fresh_ephemeral_cache_control`: never hand a + shared module-level dict to third-party SDKs. OpenRouter auto-forwards + cache_control for Anthropic routes without this header, but passing it + makes the intent unambiguous on-wire and is a no-op for non-Anthropic + providers (unknown headers are dropped). + """ + return {"anthropic-beta": "prompt-caching-2024-07-31"} + + +def _mark_tools_with_cache_control( + tools: Sequence[Mapping[str, Any]], +) -> list[dict[str, Any]]: + """Return a copy of *tools* with ``cache_control`` on the last entry. + + Marking the last tool is a cache breakpoint that covers the whole tool + schema block as a cacheable prefix segment. Extracted from + :func:`_mark_system_message_with_cache_control` so callers can precompute + the marked tool list once per session — the tool set is static within a + request and the ~43 dict-copies would otherwise run on every LLM round + in the tool-call loop. + + **Only call this for Anthropic model routes.** Non-Anthropic providers + (OpenAI, Grok, Gemini) reject the unknown ``cache_control`` field with + a 400 schema validation error. Gate via :func:`_is_anthropic_model`. + """ + cached: list[dict[str, Any]] = [dict(t) for t in tools] + if cached: + cached[-1] = { + **cached[-1], + "cache_control": _fresh_ephemeral_cache_control(), + } + return cached + + +def _build_cached_system_message( + system_message: Mapping[str, Any], +) -> dict[str, Any]: + """Return a copy of *system_message* with ``cache_control`` applied. + + Anthropic's cache uses prefix-match with up to 4 explicit breakpoints. + Combined with the last-tool marker this gives two cache segments — the + system block alone, and system+all-tools — so requests that share only + the system prefix still get a partial cache hit. + + The system message is rebuilt via spread (``{**original, ...}``) so any + unknown fields the caller set (e.g. ``name``) survive the transformation. + Non-Anthropic models silently ignore the markers. + + Returns the original dict (shallow-copied) unchanged when the content + shape is unsupported (missing / non-string / empty) — callers should + splice it into the message list as-is in that case. + """ + sys_copy = dict(system_message) + sys_content = sys_copy.get("content") + if isinstance(sys_content, str) and sys_content: + sys_copy["content"] = [ + { + "type": "text", + "text": sys_content, + "cache_control": _fresh_ephemeral_cache_control(), + } + ] + return sys_copy + + +def _mark_system_message_with_cache_control( + messages: Sequence[Mapping[str, Any]], +) -> list[dict[str, Any]]: + """Return a copy of *messages* with ``cache_control`` on the system block. + + Thin wrapper around :func:`_build_cached_system_message` that preserves + the original list shape. Prefer the memoised path in + ``_baseline_llm_caller`` (which builds the cached system dict once per + session) for hot-loop callers; this function is retained for call sites + outside the tool-call loop where per-call copying is acceptable. + """ + cached_messages: list[dict[str, Any]] = [dict(m) for m in messages] + if cached_messages and cached_messages[0].get("role") == "system": + cached_messages[0] = _build_cached_system_message(cached_messages[0]) + return cached_messages async def _baseline_llm_caller( @@ -347,28 +503,51 @@ async def _baseline_llm_caller( round_text = "" try: client = _get_openai_client() - typed_messages = cast(list[ChatCompletionMessageParam], messages) - # extra_body `usage.include=true` asks OpenRouter to embed the real - # generation cost into the final usage chunk. Without this we only get - # token counts and have no authoritative cost for rate limiting. - if tools: - typed_tools = cast(list[ChatCompletionToolParam], tools) - response = await client.chat.completions.create( - model=state.model, - messages=typed_messages, - tools=typed_tools, - stream=True, - stream_options={"include_usage": True}, - extra_body=_OPENROUTER_INCLUDE_USAGE_COST, - ) + # Cache markers are Anthropic-specific. For OpenAI/Grok/other + # providers, leaving them on would trigger a 400 ("Extra inputs + # are not permitted" on cache_control). Tools were precomputed + # in stream_chat_completion_baseline via _mark_tools_with_cache_control + # (only when the model was Anthropic), so on non-Anthropic routes + # tools ship without cache_control on the last entry too. + # + # `extra_body` `usage.include=true` asks OpenRouter to embed the real + # generation cost into the final usage chunk — required by the + # cost-based rate limiter in routes.py. Separate from the Anthropic + # caching headers, always sent. + is_anthropic = _is_anthropic_model(state.model) + if is_anthropic: + # Build the cached system dict once per session and splice it in + # on each round. The full ``messages`` list grows with every + # tool call, so copying the entire list just to mutate index 0 + # scales with conversation length (sentry flagged this); this + # splice touches only list slots, not message contents. + if ( + state.cached_system_message is None + and messages + and messages[0].get("role") == "system" + ): + state.cached_system_message = _build_cached_system_message(messages[0]) + if state.cached_system_message is not None and messages: + final_messages = [state.cached_system_message, *messages[1:]] + else: + final_messages = messages + extra_headers = _fresh_anthropic_caching_headers() else: - response = await client.chat.completions.create( - model=state.model, - messages=typed_messages, - stream=True, - stream_options={"include_usage": True}, - extra_body=_OPENROUTER_INCLUDE_USAGE_COST, - ) + final_messages = messages + extra_headers = None + typed_messages = cast(list[ChatCompletionMessageParam], final_messages) + create_kwargs: dict[str, Any] = { + "model": state.model, + "messages": typed_messages, + "stream": True, + "stream_options": {"include_usage": True}, + "extra_body": _OPENROUTER_INCLUDE_USAGE_COST, + } + if extra_headers: + create_kwargs["extra_headers"] = extra_headers + if tools: + create_kwargs["tools"] = cast(list[ChatCompletionToolParam], list(tools)) + response = await client.chat.completions.create(**create_kwargs) tool_calls_by_index: dict[int, dict[str, str]] = {} # Iterate under an inner try/finally so early exits (cancel, tool-call @@ -1170,7 +1349,7 @@ async def stream_chat_completion_baseline( graphiti_enabled = await is_enabled_for_user(user_id) graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else "" - system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement + system_prompt = base_system_prompt + SHARED_TOOL_NOTES + graphiti_supplement # Warm context: pre-load relevant facts from Graphiti on first turn. # Use the pre-drain count so pending messages drained at turn start @@ -1320,6 +1499,18 @@ async def stream_chat_completion_baseline( if permissions is not None: tools = _filter_tools_by_permissions(tools, permissions) + # Pre-mark cache_control on the last tool schema once per session. The + # tool set is static within a request, so doing this here (instead of in + # _baseline_llm_caller) avoids re-copying ~43 tool dicts on every LLM + # round of the tool-call loop. + # + # Only apply to Anthropic routes — OpenAI/Grok/other providers would + # 400 on the unknown ``cache_control`` field inside tool definitions. + if _is_anthropic_model(active_model): + tools = cast( + list[ChatCompletionToolParam], _mark_tools_with_cache_control(tools) + ) + # Propagate execution context so tool handlers can read session-level flags. set_execution_context( user_id, @@ -1707,6 +1898,8 @@ async def stream_chat_completion_baseline( prompt_tokens=billed_prompt, completion_tokens=state.turn_completion_tokens, total_tokens=billed_prompt + state.turn_completion_tokens, + cache_read_tokens=state.turn_cache_read_tokens, + cache_creation_tokens=state.turn_cache_creation_tokens, ) yield StreamFinish() diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index e21618c367..4e70767426 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -13,7 +13,14 @@ from backend.copilot.baseline.service import ( _baseline_conversation_updater, _baseline_llm_caller, _BaselineStreamState, + _build_cached_system_message, _compress_session_messages, + _extract_cache_creation_tokens, + _fresh_anthropic_caching_headers, + _fresh_ephemeral_cache_control, + _is_anthropic_model, + _mark_system_message_with_cache_control, + _mark_tools_with_cache_control, ) from backend.copilot.model import ChatMessage from backend.copilot.transcript_builder import TranscriptBuilder @@ -605,11 +612,18 @@ def _make_usage_chunk( chunk.usage.model_extra = usage_extras if cached_tokens is not None or cache_creation_input_tokens is not None: - ptd = MagicMock() - ptd.cached_tokens = cached_tokens or 0 - ptd.model_extra = { - "cache_creation_input_tokens": cache_creation_input_tokens or 0 - } + # Build a real ``PromptTokensDetails`` so ``getattr(ptd, + # "cache_write_tokens", None)`` returns ``None`` on this SDK version + # (rather than a truthy MagicMock attribute) and the extraction + # helper's typed-attr vs model_extra fallback resolves correctly. + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": cached_tokens or 0}) + if cache_creation_input_tokens is not None: + if ptd.model_extra is None: + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_creation_input_tokens"] = cache_creation_input_tokens chunk.usage.prompt_tokens_details = ptd else: chunk.usage.prompt_tokens_details = None @@ -1209,3 +1223,288 @@ class TestMidLoopPendingFlushOrdering: assert assistant_msgs[1].tool_calls is None # Crucially: only 2 assistant messages, not 3 (no duplicate) assert len(assistant_msgs) == 2 + + +class TestApplyPromptCacheMarkers: + """Tests for _apply_prompt_cache_markers — Anthropic ephemeral + cache_control markers on baseline OpenRouter requests.""" + + def test_system_message_converted_to_content_blocks(self): + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ] + + cached_messages = _mark_system_message_with_cache_control(messages) + + assert cached_messages[0]["role"] == "system" + assert cached_messages[0]["content"] == [ + { + "type": "text", + "text": "You are helpful.", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + # User message must be untouched. + assert cached_messages[1] == {"role": "user", "content": "hello"} + + def test_system_message_preserves_unknown_fields(self): + # Future-proofing: a system message with extra keys (e.g. "name") must + # keep them after the content-blocks conversion. + messages = [ + {"role": "system", "content": "sys", "name": "developer"}, + ] + + cached_messages = _mark_system_message_with_cache_control(messages) + + assert cached_messages[0]["name"] == "developer" + assert cached_messages[0]["role"] == "system" + + def test_last_tool_gets_cache_control(self): + tools = [ + {"type": "function", "function": {"name": "a"}}, + {"type": "function", "function": {"name": "b"}}, + ] + + cached_tools = _mark_tools_with_cache_control(tools) + + assert "cache_control" not in cached_tools[0] + assert cached_tools[-1]["cache_control"] == { + "type": "ephemeral", + "ttl": "1h", + } + # Last tool's other fields preserved. + assert cached_tools[-1]["function"] == {"name": "b"} + + def test_does_not_mutate_input(self): + messages = [{"role": "system", "content": "sys"}] + tools = [{"type": "function", "function": {"name": "a"}}] + + _mark_system_message_with_cache_control(messages) + _mark_tools_with_cache_control(tools) + + assert messages == [{"role": "system", "content": "sys"}] + assert tools == [{"type": "function", "function": {"name": "a"}}] + + def test_no_system_message_safe(self): + messages = [{"role": "user", "content": "hi"}] + cached_messages = _mark_system_message_with_cache_control(messages) + assert cached_messages == messages + + def test_empty_tools_safe(self): + assert _mark_tools_with_cache_control([]) == [] + + def test_non_string_system_content_left_untouched(self): + # If the content is already a list of blocks (e.g. caller pre-marked), + # the helper must not overwrite it. + pre_marked = [ + { + "type": "text", + "text": "sys", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + messages = [{"role": "system", "content": pre_marked}] + cached_messages = _mark_system_message_with_cache_control(messages) + assert cached_messages[0]["content"] == pre_marked + + def test_is_anthropic_model_matches_claude_and_anthropic_prefix(self): + assert _is_anthropic_model("anthropic/claude-sonnet-4-6") + assert _is_anthropic_model("claude-3-5-sonnet-20241022") + assert _is_anthropic_model("anthropic.claude-3-5-sonnet-20241022-v2:0") + assert _is_anthropic_model("ANTHROPIC/Claude-Opus") # case insensitive + + def test_is_anthropic_model_rejects_other_providers(self): + assert not _is_anthropic_model("openai/gpt-4o") + assert not _is_anthropic_model("openai/gpt-5") + assert not _is_anthropic_model("google/gemini-2.5-pro") + assert not _is_anthropic_model("xai/grok-4") + assert not _is_anthropic_model("meta-llama/llama-3.3-70b-instruct") + + def test_cache_control_uses_configured_ttl(self, monkeypatch): + """TTL comes from ChatConfig.baseline_prompt_cache_ttl — defaults + to 1h so the static prefix (system + tools) stays warm across + workspace users past the 5-min default window.""" + from backend.copilot.baseline import service as bsvc + + assert bsvc.config.baseline_prompt_cache_ttl == "1h" + cc = bsvc._fresh_ephemeral_cache_control() + assert cc == {"type": "ephemeral", "ttl": "1h"} + monkeypatch.setattr(bsvc.config, "baseline_prompt_cache_ttl", "5m") + assert bsvc._fresh_ephemeral_cache_control() == { + "type": "ephemeral", + "ttl": "5m", + } + + def test_fresh_helpers_return_distinct_objects(self): + """Regression guard: the `_fresh_*` helpers must return a NEW dict + on every call. A future refactor returning a module-level constant + would silently reintroduce the shared-mutable-state bug flagged + during earlier review cycles.""" + assert _fresh_ephemeral_cache_control() is not _fresh_ephemeral_cache_control() + assert ( + _fresh_anthropic_caching_headers() is not _fresh_anthropic_caching_headers() + ) + + def test_extract_cache_creation_tokens_openrouter_typed_attr(self): + """Newer ``openai-python`` declares ``cache_write_tokens`` as a + typed attribute on ``PromptTokensDetails`` — it no longer lands in + ``model_extra``. Verified empirically against the production + openai==1.113 installed in this venv: OpenRouter streaming + response populates ``ptd.cache_write_tokens`` directly while + ``ptd.model_extra`` is ``{}``. + """ + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate( + { + "audio_tokens": 0, + "cached_tokens": 0, + "cache_write_tokens": 4432, + "video_tokens": 0, + } + ) + assert getattr(ptd, "cache_write_tokens", None) == 4432 + assert _extract_cache_creation_tokens(ptd) == 4432 + + def test_extract_cache_creation_tokens_openrouter_model_extra(self): + """Older SDKs that don't yet declare ``cache_write_tokens`` as a + typed field leave it in ``model_extra`` — the helper must still + find it there.""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + # Force the value into model_extra (simulates the old SDK shape + # where the field wasn't typed yet). + if ptd.model_extra is None: + # Pydantic v2 sometimes exposes __pydantic_extra__ as None when + # extras are disabled; initialise to a dict to mutate safely. + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_write_tokens"] = 7777 + assert _extract_cache_creation_tokens(ptd) == 7777 + + def test_extract_cache_creation_tokens_anthropic_native_field(self): + """Direct Anthropic API uses ``cache_creation_input_tokens`` — + falls through as the final path when neither + ``cache_write_tokens`` typed attr nor model_extra entry exists.""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + if ptd.model_extra is None: + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_creation_input_tokens"] = 2048 + assert _extract_cache_creation_tokens(ptd) == 2048 + + def test_extract_cache_creation_tokens_absent(self): + """Neither provider field present → 0 (non-Anthropic routes or + cache-miss responses).""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + assert _extract_cache_creation_tokens(ptd) == 0 + + def test_build_cached_system_message_applies_cache_control(self): + """The single-message helper wraps the string content in a text block + with an ephemeral cache_control marker.""" + out = _build_cached_system_message({"role": "system", "content": "hi"}) + assert out["role"] == "system" + assert out["content"] == [ + { + "type": "text", + "text": "hi", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + + def test_build_cached_system_message_preserves_extra_fields(self): + """Unknown keys (e.g. ``name``) survive the transformation.""" + out = _build_cached_system_message( + {"role": "system", "content": "sys", "name": "dev"} + ) + assert out["name"] == "dev" + assert out["role"] == "system" + + def test_build_cached_system_message_non_string_passthrough(self): + """Pre-marked list content is returned as-is (shallow-copied).""" + pre_marked = [ + { + "type": "text", + "text": "sys", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + out = _build_cached_system_message({"role": "system", "content": pre_marked}) + assert out["content"] is pre_marked + + @pytest.mark.asyncio + async def test_baseline_llm_caller_memoises_cached_system_message(self): + """The cached system dict is built once and reused across rounds. + + Guards against the perf regression where the entire (growing) + ``messages`` list was copied on every tool-call iteration just to + mark the static system prompt. + """ + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + chunk = _make_usage_chunk(prompt_tokens=10, completion_tokens=5) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + side_effect=[_make_stream_mock(chunk), _make_stream_mock(chunk)] + ) + + messages: list[dict] = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hi"}, + ] + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller(messages=messages, tools=[], state=state) + first_cached = state.cached_system_message + assert first_cached is not None + # Simulate the tool-call loop growing ``messages`` between rounds. + messages.append({"role": "assistant", "content": "ok"}) + messages.append({"role": "user", "content": "follow up"}) + await _baseline_llm_caller(messages=messages, tools=[], state=state) + + # Same dict instance reused — not rebuilt per round. + assert state.cached_system_message is first_cached + + # Second call's first message is the memoised system dict (not a new copy). + second_call_messages = mock_client.chat.completions.create.call_args_list[1][1][ + "messages" + ] + assert second_call_messages[0] is first_cached + # And the tail messages were spliced in, not re-copied. + assert second_call_messages[1] is messages[1] + assert second_call_messages[-1] is messages[-1] + + @pytest.mark.asyncio + async def test_baseline_llm_caller_skips_memoisation_for_non_anthropic(self): + """Non-Anthropic routes pass messages through unmodified — no cache + dict is built, no list splicing happens.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk(prompt_tokens=10, completion_tokens=5) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + messages: list[dict] = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller(messages=messages, tools=[], state=state) + + assert state.cached_system_message is None + # The exact same list object reaches the provider (no copy needed). + call_messages = mock_client.chat.completions.create.call_args[1]["messages"] + assert call_messages is messages diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index 3277854172..1080921fd8 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -225,6 +225,18 @@ class ChatConfig(BaseSettings): "from the prefix. Set to False to fall back to passing the system " "prompt as a raw string.", ) + baseline_prompt_cache_ttl: str = Field( + default="1h", + description="TTL for the ephemeral prompt-cache markers on the baseline " + "OpenRouter path. Anthropic supports only `5m` (default, 1.25x input " + "price for the write) or `1h` (2x input price for the write). 1h is " + "strictly cheaper overall when the static prefix gets >7 reads per " + "write-window; since the system prompt + tools array is identical " + "across all users in our workspace, 1h is the default so cross-user " + "reads amortise the higher write cost. Anthropic has no longer " + "(24h, permanent) TTL option — see " + "https://platform.claude.com/docs/en/build-with-claude/prompt-caching.", + ) claude_agent_cli_path: str | None = Field( default=None, description="Optional explicit path to a Claude Code CLI binary. " diff --git a/autogpt_platform/backend/backend/copilot/prompting.py b/autogpt_platform/backend/backend/copilot/prompting.py index 2f52bd460d..399d31c1cc 100644 --- a/autogpt_platform/backend/backend/copilot/prompting.py +++ b/autogpt_platform/backend/backend/copilot/prompting.py @@ -8,10 +8,12 @@ handling the distinction between: from functools import cache -from backend.copilot.tools import TOOL_REGISTRY - -# Shared technical notes that apply to both SDK and baseline modes -_SHARED_TOOL_NOTES = """\ +# Workflow rules appended to the system prompt on every copilot turn +# (baseline appends directly; SDK appends via the storage-supplement +# template). These are cross-tool rules (file sharing, @@agptfile: refs, +# tool-discovery priority, sub-agent etiquette) that don't belong on any +# individual tool schema. +SHARED_TOOL_NOTES = """\ ### Sharing files After `write_workspace_file`, embed the `download_url` in Markdown: @@ -261,7 +263,7 @@ When a tool output contains ``, the full output is in workspace storage (NOT on the local filesystem). To access it: - Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections. - To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy. -{_SHARED_TOOL_NOTES}{extra_notes}""" +{SHARED_TOOL_NOTES}{extra_notes}""" # Pre-built supplements for common environments @@ -312,35 +314,6 @@ def _get_cloud_sandbox_supplement() -> str: ) -def _generate_tool_documentation() -> str: - """Auto-generate tool documentation from TOOL_REGISTRY. - - NOTE: This is ONLY used in baseline mode (direct OpenAI API). - SDK mode doesn't need it since Claude gets tool schemas automatically. - - This generates a complete list of available tools with their descriptions, - ensuring the documentation stays in sync with the actual tool implementations. - All workflow guidance is now embedded in individual tool descriptions. - - Only documents tools that are available in the current environment - (checked via tool.is_available property). - """ - docs = "\n## AVAILABLE TOOLS\n\n" - - # Sort tools alphabetically for consistent output - # Filter by is_available to match get_available_tools() behavior - for name in sorted(TOOL_REGISTRY.keys()): - tool = TOOL_REGISTRY[name] - if not tool.is_available: - continue - schema = tool.as_openai_tool() - desc = schema["function"].get("description", "No description available") - # Format as bullet list with tool name in code style - docs += f"- **`{name}`**: {desc}\n" - - return docs - - _USER_FOLLOW_UP_NOTE = """ # `` blocks in tool output @@ -438,17 +411,3 @@ You have access to persistent temporal memory tools that remember facts across s - group_id is handled automatically by the system — never set it yourself. - When storing, be specific about operational rules and instructions (e.g., "CC Sarah on client communications" not just "Sarah is the assistant"). """ - - -def get_baseline_supplement() -> str: - """Get the supplement for baseline mode (direct OpenAI API). - - Baseline mode INCLUDES auto-generated tool documentation because the - direct API doesn't automatically provide tool schemas to Claude. - Also includes shared technical notes (but NOT SDK-specific environment details). - - Returns: - The supplement string to append to the system prompt - """ - tool_docs = _generate_tool_documentation() - return tool_docs + _SHARED_TOOL_NOTES diff --git a/autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py b/autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py index 5d132aa94d..7cf8af3396 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py @@ -94,21 +94,23 @@ def test_agent_options_accepts_required_fields(): def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_sections(): """Verify ClaudeAgentOptions accepts the exact preset dict _build_system_prompt_value produces. - The production code always includes ``exclude_dynamic_sections=True`` in the preset - dict. This compat test mirrors that exact shape so any SDK version that starts - rejecting unknown keys will be caught here rather than at runtime. + The Turn 1 (non-resume) code path includes ``exclude_dynamic_sections=True`` in + the preset dict for cross-user caching. This compat test mirrors that exact + shape so any SDK version that starts rejecting unknown keys will be caught + here rather than at runtime. """ from claude_agent_sdk import ClaudeAgentOptions from claude_agent_sdk.types import SystemPromptPreset from .service import _build_system_prompt_value - # Call the production helper directly so this test is tied to the real - # dict shape rather than a hand-rolled copy. preset = _build_system_prompt_value("custom system prompt", cross_user_cache=True) assert isinstance( preset, dict ), "_build_system_prompt_value must return a dict when caching is on" + assert preset.get("exclude_dynamic_sections") is True, ( + "Turn 1 must strip dynamic sections to keep the prefix cacheable " "cross-user" + ) sdk_preset = cast(SystemPromptPreset, preset) opts = ClaudeAgentOptions(system_prompt=sdk_preset) @@ -116,8 +118,9 @@ def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_section def test_build_system_prompt_value_returns_plain_string_when_cross_user_cache_off(): - """When cross_user_cache=False (e.g. on --resume turns), the helper must return - a plain string so the preset+resume crash is avoided.""" + """When cross_user_cache=False (feature flag disabled globally), the + helper returns a plain string; the CLI will receive --system-prompt + (replace-mode) and skip the preset entirely.""" from .service import _build_system_prompt_value result = _build_system_prompt_value("my prompt", cross_user_cache=False) @@ -262,6 +265,12 @@ _KNOWN_GOOD_BUNDLED_CLI_VERSIONS: frozenset[str] = frozenset( "2.1.97", # claude-agent-sdk 0.1.58 -- OpenRouter-safe only with # CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 (injected by # build_sdk_env() in env.py). + "2.1.116", # claude-agent-sdk 0.1.64 -- first bundled version that + # fixes the --resume + excludeDynamicSections=True crash + # (introduced in 2.1.98), unlocking cross-user prompt + # cache reads on every resumed SDK turn. Still requires + # CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1. Verified + # OpenRouter-safe via cli_openrouter_compat_test.py. } ) diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index e4f29a2b65..8fe8aa12df 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -836,16 +836,25 @@ def _is_fallback_stderr(line: str) -> bool: def _build_system_prompt_value( system_prompt: str, + *, cross_user_cache: bool, ) -> str | SystemPromptPreset: """Build the ``system_prompt`` argument for :class:`ClaudeAgentOptions`. When *cross_user_cache* is enabled, returns a :class:`SystemPromptPreset` - dict so the Claude Code default prompt becomes a cacheable prefix shared - across all users; our custom *system_prompt* is appended after it. + with ``exclude_dynamic_sections=True`` so every turn — Turn 1 *and* + resumed turns — shares the same static prefix and hits the cross-user + prompt cache. Our custom *system_prompt* is appended after the preset. - When disabled (or if the SDK is too old to support ``SystemPromptPreset``), - the raw *system_prompt* string is returned unchanged. + Requires CLI ≥ 2.1.98 (older CLIs crash when ``excludeDynamicSections`` + is combined with ``--resume``). The SDK bundles CLI 2.1.116 at + ``claude-agent-sdk >= 0.1.64``, so the pin in ``pyproject.toml`` is + the single source of truth — no external install needed. + + When *cross_user_cache* is disabled, the raw *system_prompt* string is + returned. Note this causes the CLI to REPLACE its built-in prompt via + ``--system-prompt`` (vs ``--append-system-prompt`` for the preset), + which loses Claude Code's default prompt and its cache markers entirely. An empty *system_prompt* is accepted: the preset dict will have ``append: ""`` which the SDK treats as no custom suffix. @@ -3036,15 +3045,17 @@ async def stream_chat_completion_sdk( sid, ) - # Use SystemPromptPreset for cross-user prompt caching. - # WORKAROUND: CLI 2.1.97 (sdk 0.1.58) exits code 1 when - # excludeDynamicSections=True is in the initialize request AND - # --resume is active. Disable the preset on resumed turns. - # Turn 1 still gets the preset (no --resume). - _cross_user = config.claude_agent_cross_user_prompt_cache and not use_resume + # Use SystemPromptPreset with exclude_dynamic_sections=True on + # every turn — including resumed ones — so all turns share the + # same static prefix and hit the cross-user prompt cache. + # + # Requires CLI ≥ 2.1.98 (older CLIs crash when excludeDynamicSections + # is combined with --resume). claude-agent-sdk >= 0.1.64 bundles + # CLI 2.1.116, so the pin in pyproject.toml is sufficient — no + # external install or env-var override needed. system_prompt_value = _build_system_prompt_value( system_prompt, - cross_user_cache=_cross_user, + cross_user_cache=config.claude_agent_cross_user_prompt_cache, ) sdk_options_kwargs: dict[str, Any] = { @@ -3401,15 +3412,12 @@ async def stream_chat_completion_sdk( # fail with "Session ID already in use". sdk_options_kwargs_retry.pop("resume", None) sdk_options_kwargs_retry.pop("session_id", None) - # Recompute system_prompt for retry — ctx.use_resume may have - # changed (context reduction enabled --resume). CLI 2.1.97 - # crashes when excludeDynamicSections=True is combined with - # --resume, so disable the cross-user preset on resumed turns. - _cross_user_retry = ( - config.claude_agent_cross_user_prompt_cache and not ctx.use_resume - ) + # Recompute system_prompt for retry — the preset is safe on + # every turn (requires CLI ≥ 2.1.98, installed in the Docker + # image and configured via CHAT_CLAUDE_AGENT_CLI_PATH). sdk_options_kwargs_retry["system_prompt"] = _build_system_prompt_value( - system_prompt, cross_user_cache=_cross_user_retry + system_prompt, + cross_user_cache=config.claude_agent_cross_user_prompt_cache, ) state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs # Retry intentionally omits prior_messages (transcript+gap context) and diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_test.py index f7ebe766f6..d47f67252a 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_test.py @@ -177,70 +177,18 @@ class TestPromptSupplement: assert "## Tool notes" in local_supplement assert "## Tool notes" in e2b_supplement - def test_baseline_supplement_includes_tool_docs(self): - """Baseline mode MUST include tool documentation (direct API needs it).""" - from backend.copilot.prompting import get_baseline_supplement + def test_baseline_supplement_has_shared_notes_no_tool_list(self): + """Baseline now relies on the OpenAI tools array for schemas and only + appends SHARED_TOOL_NOTES (workflow rules not present in any schema). + The old auto-generated ``## AVAILABLE TOOLS`` list is gone — it was + ~4.3K tokens of pure duplication of the tools array.""" + from backend.copilot.prompting import SHARED_TOOL_NOTES - supplement = get_baseline_supplement() - - # MUST have tool list section - assert "## AVAILABLE TOOLS" in supplement - - # Should NOT have environment-specific notes (SDK-only) - assert "## Tool notes" not in supplement - - def test_baseline_supplement_includes_key_tools(self): - """Baseline supplement should document all essential tools.""" - from backend.copilot.prompting import get_baseline_supplement - from backend.copilot.tools import TOOL_REGISTRY - - docs = get_baseline_supplement() - - # Core agent workflow tools (always available) - assert "`create_agent`" in docs - assert "`run_agent`" in docs - assert "`find_library_agent`" in docs - assert "`edit_agent`" in docs - - # MCP integration (always available) - assert "`run_mcp_tool`" in docs - - # Folder management (always available) - assert "`create_folder`" in docs - - # Browser tools only if available (Playwright may not be installed in CI) - if ( - TOOL_REGISTRY.get("browser_navigate") - and TOOL_REGISTRY["browser_navigate"].is_available - ): - assert "`browser_navigate`" in docs - - def test_baseline_supplement_includes_workflows(self): - """Baseline supplement should include workflow guidance in tool descriptions.""" - from backend.copilot.prompting import get_baseline_supplement - - docs = get_baseline_supplement() - - # Workflows are now in individual tool descriptions (not separate sections) - # Check that key workflow concepts appear in tool descriptions - assert "agent_json" in docs or "find_block" in docs - assert "run_mcp_tool" in docs - - def test_baseline_supplement_completeness(self): - """All available tools from TOOL_REGISTRY should appear in baseline supplement.""" - from backend.copilot.prompting import get_baseline_supplement - from backend.copilot.tools import TOOL_REGISTRY - - docs = get_baseline_supplement() - - # Verify each available registered tool is documented - # (matches _generate_tool_documentation which filters by is_available) - for tool_name, tool in TOOL_REGISTRY.items(): - if not tool.is_available: - continue - assert ( - f"`{tool_name}`" in docs - ), f"Tool '{tool_name}' missing from baseline supplement" + assert "## AVAILABLE TOOLS" not in SHARED_TOOL_NOTES + # Keep the high-value workflow rules that are NOT in any tool schema. + assert "@@agptfile:" in SHARED_TOOL_NOTES + assert "Tool Discovery Priority" in SHARED_TOOL_NOTES + assert "run_sub_session" in SHARED_TOOL_NOTES def test_pause_task_scheduled_before_transcript_upload(self): """Pause is scheduled as a background task before transcript upload begins. @@ -284,21 +232,6 @@ class TestPromptSupplement: # concurrently during upload's first yield. The ordering guarantee is # that create_task is CALLED before upload is AWAITED (see source order). - def test_baseline_supplement_no_duplicate_tools(self): - """No tool should appear multiple times in baseline supplement.""" - from backend.copilot.prompting import get_baseline_supplement - from backend.copilot.tools import TOOL_REGISTRY - - docs = get_baseline_supplement() - - # Count occurrences of each available tool in the entire supplement - for tool_name, tool in TOOL_REGISTRY.items(): - if not tool.is_available: - continue - # Count how many times this tool appears as a bullet point - count = docs.count(f"- **`{tool_name}`**") - assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)" - # --------------------------------------------------------------------------- # _cleanup_sdk_tool_results — orchestration + rate-limiting @@ -700,6 +633,17 @@ class TestSystemPromptPreset: assert result["append"] == "" assert result["exclude_dynamic_sections"] is True + def test_resume_and_fresh_share_the_same_static_prefix(self): + """Every turn (fresh + --resume) must emit the same preset dict + so the cross-user cache prefix match works on all turns. This + relies on CLI ≥ 2.1.98 (installed in the Docker image); older + CLIs would crash on --resume + excludeDynamicSections=True.""" + fresh = _build_system_prompt_value("sys", cross_user_cache=True) + resumed = _build_system_prompt_value("sys", cross_user_cache=True) + assert fresh == resumed + assert isinstance(fresh, dict) + assert fresh.get("exclude_dynamic_sections") is True + def test_default_config_is_enabled(self, _clean_config_env): """The default value for claude_agent_cross_user_prompt_cache is True.""" cfg = cfg_mod.ChatConfig( diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index 03c93c286a..a9aafef96f 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "agentmail" @@ -909,18 +909,18 @@ files = [ [[package]] name = "claude-agent-sdk" -version = "0.1.58" +version = "0.1.64" description = "Python SDK for Claude Code" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "claude_agent_sdk-0.1.58-py3-none-macosx_11_0_arm64.whl", hash = "sha256:69197950809754c4f06bba8261f2d99c3f9605b6cc1c13d3409d0eb82fb4ee64"}, - {file = "claude_agent_sdk-0.1.58-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:75d60883fc5e2070bccd8d9b19505fe16af8e049120c03821e9dc8c826cca434"}, - {file = "claude_agent_sdk-0.1.58-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:7bf4eb0f00ec944a7b63eb94788f120dfb0460c348a525235c7d6641805acc1d"}, - {file = "claude_agent_sdk-0.1.58-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:650d298a3d3c0dcdde4b5f1dbf52f472ff0b0ec82987b27ffa2a4e0e72928408"}, - {file = "claude_agent_sdk-0.1.58-py3-none-win_amd64.whl", hash = "sha256:2c2130a7ffe06ed4f88d56b217a5091c91c9bcb1a69cfd94d5dcf0d2946d8c55"}, - {file = "claude_agent_sdk-0.1.58.tar.gz", hash = "sha256:77bee8fd60be033cb870def46c2ab1625a512fa8a3de4ff8d766664ffb16d6a6"}, + {file = "claude_agent_sdk-0.1.64-py3-none-macosx_11_0_arm64.whl", hash = "sha256:4cf47a9e40c0a683a05afff4fac1e3d5ea7965b1e9f72a8e266c8d2efbf65904"}, + {file = "claude_agent_sdk-0.1.64-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:7fe765c6482c74bc6b0b4491ad3bddd1349c25f4cdf4483191c68ea9c1336825"}, + {file = "claude_agent_sdk-0.1.64-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:605eebf46e7590e4f878572c2743954fba3f3530dfd99e10ff3b8b41a9fee757"}, + {file = "claude_agent_sdk-0.1.64-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:bbb1373ee0b4494e2db24aa10d312d22b86895b4b8f18eb5b58f99f14d827237"}, + {file = "claude_agent_sdk-0.1.64-py3-none-win_amd64.whl", hash = "sha256:453fa251e2a4aeed580c72d4c7b2cb98fc8d8d26012798126f5cb11a9829cd71"}, + {file = "claude_agent_sdk-0.1.64.tar.gz", hash = "sha256:147e513cb45095b57c37d74b8d01dd41b5f3ec7f70e408edce43a6590159c27d"}, ] [package.dependencies] @@ -930,6 +930,8 @@ typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} [package.extras] dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"] +examples = ["asyncpg (>=0.27.0)", "boto3 (>=1.28.0)", "fakeredis (>=2.20.0)", "moto[s3] (>=5.0.0)", "redis (>=4.2.0)"] +otel = ["opentelemetry-api (>=1.20.0)"] [[package]] name = "cleo" @@ -8929,4 +8931,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "c4cc6a0a26869a167ce182b178224554135d89d8ffa4605257d17b3f495cdf59" +content-hash = "529e1acbb1213421ef617f9dab309787cf81ea5d787eeffebc1bd38a42daf976" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index ea81390d81..6e7003a65d 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -18,7 +18,7 @@ apscheduler = "^3.11.1" autogpt-libs = { path = "../autogpt_libs", develop = true } bleach = { extras = ["css"], version = "^6.2.0" } cachetools = "^5.5.0" -claude-agent-sdk = "0.1.58" # latest stable; bundled CLI 2.1.97 -- CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 env var strips the broken context-management beta. See sdk_compat_test.py. +claude-agent-sdk = "^0.1.64" # bundled CLI 2.1.116 -- 2.1.98+ fixes the --resume + excludeDynamicSections crash that used to force a per-turn 33K cache write. CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 env var strips the broken context-management beta. See sdk_compat_test.py. click = "^8.2.0" cryptography = "^46.0" discord-py = "^2.5.2"