Compare commits

..

19 Commits

Author SHA1 Message Date
Zamil Majdy
f8ca9cba85 test: update E2E screenshots for PR #12727 (final test run) 2026-04-16 20:18:37 +07:00
majdyz
d02f245c7b test: add E2E screenshots for PR #12727 billing tier tests 2026-04-14 15:31:47 +07:00
Zamil Majdy
b28c0ac072 test: add E2E screenshots for PR #12727 2026-04-10 01:22:43 +07:00
Zamil Majdy
9ec44dd109 test(backend): add route-level tests for subscription API endpoints
Tests for GET/POST /credits/subscription covering:
- GET returns current tier (PRO, FREE default when None)
- POST FREE skips Stripe when payment disabled
- POST PRO sets tier directly for beta users (payment disabled)
- POST paid tier rejects missing success_url/cancel_url with 422
- POST paid tier creates Stripe Checkout Session and returns URL
- POST FREE with payment enabled cancels active Stripe subscription
2026-04-10 00:19:06 +07:00
Zamil Majdy
bfb82b6246 fix(platform): address reviewer feedback on subscription endpoint
- Remove useCallback from changeTier (not needed per project guidelines)
- Block self-service tier changes for ENTERPRISE users (admin-managed)
- Preserve current tier on unrecognized Stripe price_id instead of
  defaulting to FREE (prevents accidental downgrades during price migration)
2026-04-10 00:08:54 +07:00
Zamil Majdy
63210770ce test(backend): add tests for get_subscription_price_id to improve coverage 2026-04-09 23:54:02 +07:00
Zamil Majdy
68b51ae2d3 test(backend): add coverage for sync_subscription_from_stripe edge cases
Tests for:
- Unknown/mismatched Stripe price_id defaults to FREE (not early return)
- None from LaunchDarkly price flags defaults to FREE
- BUSINESS tier mapping
- StripeError during cancel_stripe_subscription is logged, not raised
2026-04-09 23:52:16 +07:00
Zamil Majdy
63ff214563 fix(backend): default to FREE tier on unknown Stripe price ID in webhook sync
When sync_subscription_from_stripe encounters an unrecognized price_id
(e.g. LD flags unconfigured or price changed), it no longer returns early
leaving the user on a stale tier. Instead it defaults to FREE and logs a
warning, keeping the DB state consistent with Stripe's subscription status.

Also guard against None pro_price/biz_price from LaunchDarkly before
comparison to avoid silent mismatches.
2026-04-09 23:41:51 +07:00
Zamil Majdy
0d89f7bb33 fix(backend): handle customer.subscription.created webhook event
Add customer.subscription.created to the sync handler so user tier is
upgraded immediately when the subscription is first created (not just on
subsequent updates/deletions).
2026-04-09 23:39:16 +07:00
Zamil Majdy
4eabc48053 fix(backend): fix migration conflict with dev's SubscriptionTier migration
dev branch already creates SubscriptionTier enum and subscriptionTier column in
20260326200000_add_rate_limit_tier. Remove duplicate DDL from our migration and
only add SUBSCRIPTION to CreditTransactionType using IF NOT EXISTS guard.
2026-04-09 23:24:12 +07:00
Zamil Majdy
101504ce0b fix(platform): cancel Stripe subscription when downgrading to FREE tier
Add cancel_stripe_subscription() which lists and cancels all active Stripe
subscriptions for the customer, preventing continued billing after downgrade.
Call it from update_subscription_tier() when tier == FREE and payment is
enabled. Add two unit tests covering active and empty subscription scenarios.
2026-04-09 23:21:27 +07:00
Zamil Majdy
e73b5b3692 fix(backend): validate success_url/cancel_url for paid Stripe checkout
Add upfront 422 validation when upgrading to a paid tier without providing
redirect URLs. Also catch stripe.StripeError alongside ValueError to return
a proper 422 instead of a 500 on Stripe API errors.
2026-04-09 23:18:16 +07:00
Zamil Majdy
611a00d930 fix(backend): resolve dev merge conflict and remove credit-based subscription cost
Remove get_subscription_cost (referenced deleted flags SUBSCRIPTION_COST_PRO/BUSINESS).
Subscription pricing is now handled by Stripe. Add GRAPHITI_MEMORY flag from dev.
2026-04-09 23:14:15 +07:00
Zamil Majdy
8d31bdb2dc fix(platform): address remaining review comments on subscription billing
- Remove `# type: ignore[attr-defined]` suppressors from `set_auto_top_up`
  and `set_subscription_tier` — pyright resolves `CachedFunction.cache_delete`
  through the import boundary without the suppressor
- Add `max(0, ...)` guard to `get_subscription_cost` to prevent negative
  LaunchDarkly flag values from yielding negative costs
- Change `SubscriptionTierRequest.tier` from `str` to
  `Literal["FREE", "PRO", "BUSINESS"]` so Pydantic rejects ENTERPRISE and
  any unknown tier with a 422 at the schema layer
- Move `SubscriptionTier` and feature-flag imports from local function scope
  to module-level in v1.py (top-level imports policy)
- Fix `test_sync_subscription_from_stripe_active` mock to use a proper async
  `side_effect` function instead of calling an `AsyncMock` inline
2026-04-09 23:06:40 +07:00
Zamil Majdy
2e64f3add7 feat(frontend): redirect to Stripe checkout when upgrading subscription
POST /credits/subscription now returns {url} when Stripe checkout is needed.
Redirect user to Stripe on non-empty URL, refresh tier on empty URL (beta/FREE).
Remove credit-based tier validation; Stripe handles payment gating.
2026-04-09 22:58:58 +07:00
Zamil Majdy
4942249a60 fix(platform): resolve merge conflicts with dev branch
Merges latest dev branch changes into feat/subscription-tier-billing.
Updates credit_subscription_test.py to match new Stripe-based implementation.
2026-04-09 22:51:06 +07:00
Zamil Majdy
70d53a0926 fix(platform): address round-2 review comments on subscription billing
- Wrap ensure_subscription_paid in spend_credits with try/except (fails open like check_rate_limit)
- Invalidate get_user_by_id cache in set_auto_top_up to prevent stale auto top-up data
- Block ENTERPRISE tier self-service upgrades from POST /credits/subscription API
2026-04-09 20:19:10 +07:00
Zamil Majdy
642c72e5e5 fix(platform): address review comments on subscription billing
- Format error messages as \$X.XX/mo instead of raw cents
- Move get_feature_flag_value import to module level in credit.py
- Add explicit operation_id to subscription FastAPI routes
- Pass autoTopUpConfig as prop to SubscriptionTierSection (avoid duplicate fetch)
- Display fetch error in SubscriptionTierSection instead of silent null
- Add cache hit comment to rate_limit.py hot path
- Add tests: idempotency, free tier no-op, beta grant offset, tier upgrade validation
2026-04-09 20:14:11 +07:00
Zamil Majdy
ba7929205d feat(platform): add subscription tier billing with lazy credit deduction
- Add SubscriptionTier enum (FREE/PRO/BUSINESS/ENTERPRISE) to schema
- Add SUBSCRIPTION CreditTransactionType for monthly charges
- Lazy monthly deduction via ensure_subscription_paid() — idempotent,
  called from spend_credits() and rate-limit checks
- BetaUserCredit grant includes subscription offset so beta usage credits
  are not reduced by subscription cost
- Auto top-up enforced >= subscription cost on tier upgrade and config update
- Subscription cost configurable via LaunchDarkly (subscription-cost-pro,
  subscription-cost-business); 0 = feature off, no separate flag needed
- New endpoints: GET/POST /credits/subscription for tier management
- No proration: full month charged on upgrade, downgrade takes next cycle
- Frontend: SubscriptionTierSection component on billing page with tier
  cards, upgrade/downgrade flow, and auto top-up guard
2026-04-09 19:58:01 +07:00
45 changed files with 1223 additions and 187 deletions

View File

@@ -0,0 +1,266 @@
"""Tests for subscription tier API endpoints."""
from unittest.mock import AsyncMock, Mock
import fastapi
import fastapi.testclient
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from prisma.enums import SubscriptionTier
from .v1 import v1_router
app = fastapi.FastAPI()
app.include_router(v1_router)
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
def setup_auth(app: fastapi.FastAPI):
def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]:
return {"sub": TEST_USER_ID, "role": "user", "email": "test@example.com"}
app.dependency_overrides[get_jwt_payload] = override_get_jwt_payload
def teardown_auth(app: fastapi.FastAPI):
app.dependency_overrides.clear()
def test_get_subscription_status_pro(
mocker: pytest_mock.MockFixture,
) -> None:
"""GET /credits/subscription returns PRO tier for a PRO user."""
setup_auth(app)
try:
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
response = client.get("/credits/subscription")
assert response.status_code == 200
data = response.json()
assert data["tier"] == "PRO"
assert "monthly_cost" in data
assert "tier_costs" in data
finally:
teardown_auth(app)
def test_get_subscription_status_defaults_to_free(
mocker: pytest_mock.MockFixture,
) -> None:
"""GET /credits/subscription when subscription_tier is None defaults to FREE."""
setup_auth(app)
try:
mock_user = Mock()
mock_user.subscription_tier = None
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
response = client.get("/credits/subscription")
assert response.status_code == 200
data = response.json()
assert data["tier"] == SubscriptionTier.FREE.value
finally:
teardown_auth(app)
def test_update_subscription_tier_free_no_payment(
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/subscription to FREE tier when payment disabled skips Stripe."""
setup_auth(app)
try:
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
async def mock_feature_disabled(*args, **kwargs):
return False
async def mock_set_tier(*args, **kwargs):
pass
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_disabled,
)
mocker.patch(
"backend.api.features.v1.set_subscription_tier",
side_effect=mock_set_tier,
)
response = client.post("/credits/subscription", json={"tier": "FREE"})
assert response.status_code == 200
assert response.json()["url"] == ""
finally:
teardown_auth(app)
def test_update_subscription_tier_paid_beta_user(
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/subscription for paid tier when payment disabled sets tier directly."""
setup_auth(app)
try:
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.FREE
async def mock_feature_disabled(*args, **kwargs):
return False
async def mock_set_tier(*args, **kwargs):
pass
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_disabled,
)
mocker.patch(
"backend.api.features.v1.set_subscription_tier",
side_effect=mock_set_tier,
)
response = client.post("/credits/subscription", json={"tier": "PRO"})
assert response.status_code == 200
assert response.json()["url"] == ""
finally:
teardown_auth(app)
def test_update_subscription_tier_paid_requires_urls(
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/subscription for paid tier without success/cancel URLs returns 422."""
setup_auth(app)
try:
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.FREE
async def mock_feature_enabled(*args, **kwargs):
return True
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
response = client.post("/credits/subscription", json={"tier": "PRO"})
assert response.status_code == 422
finally:
teardown_auth(app)
def test_update_subscription_tier_creates_checkout(
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/subscription creates Stripe Checkout Session for paid upgrade."""
setup_auth(app)
try:
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.FREE
async def mock_feature_enabled(*args, **kwargs):
return True
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
return_value="https://checkout.stripe.com/pay/cs_test_abc",
)
response = client.post(
"/credits/subscription",
json={
"tier": "PRO",
"success_url": "https://app.example.com/success",
"cancel_url": "https://app.example.com/cancel",
},
)
assert response.status_code == 200
assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_abc"
finally:
teardown_auth(app)
def test_update_subscription_tier_free_with_payment_cancels_stripe(
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to FREE cancels active Stripe subscription when payment is enabled."""
setup_auth(app)
try:
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
async def mock_feature_enabled(*args, **kwargs):
return True
mock_cancel = mocker.patch(
"backend.api.features.v1.cancel_stripe_subscription",
new_callable=AsyncMock,
)
async def mock_set_tier(*args, **kwargs):
pass
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.set_subscription_tier",
side_effect=mock_set_tier,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
response = client.post("/credits/subscription", json={"tier": "FREE"})
assert response.status_code == 200
mock_cancel.assert_awaited_once()
finally:
teardown_auth(app)

View File

@@ -5,7 +5,7 @@ import time
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Any, Sequence, get_args
from typing import Annotated, Any, Literal, Sequence, get_args
import pydantic
import stripe
@@ -24,6 +24,7 @@ from fastapi import (
UploadFile,
)
from fastapi.concurrency import run_in_threadpool
from prisma.enums import SubscriptionTier
from pydantic import BaseModel
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
@@ -50,9 +51,13 @@ from backend.data.credit import (
RefundRequest,
TransactionHistory,
UserCredit,
cancel_stripe_subscription,
create_subscription_checkout,
get_auto_top_up,
get_user_credit_model,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
)
from backend.data.graph import GraphSettings
from backend.data.model import CredentialsMetaInput, UserOnboarding
@@ -661,9 +666,12 @@ async def configure_user_auto_top_up(
raise HTTPException(status_code=422, detail=str(e))
raise
await set_auto_top_up(
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
)
try:
await set_auto_top_up(
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
return "Auto top-up settings updated"
@@ -679,6 +687,98 @@ async def get_user_auto_top_up(
return await get_auto_top_up(user_id)
class SubscriptionTierRequest(BaseModel):
tier: Literal["FREE", "PRO", "BUSINESS"]
success_url: str = ""
cancel_url: str = ""
class SubscriptionCheckoutResponse(BaseModel):
url: str
class SubscriptionStatusResponse(BaseModel):
tier: str
monthly_cost: int
tier_costs: dict[str, int]
@v1_router.get(
path="/credits/subscription",
summary="Get subscription tier, current cost, and all tier costs",
operation_id="getSubscriptionStatus",
tags=["credits"],
dependencies=[Security(requires_user)],
)
async def get_subscription_status(
user_id: Annotated[str, Security(get_user_id)],
) -> SubscriptionStatusResponse:
user = await get_user_by_id(user_id)
tier = user.subscription_tier or SubscriptionTier.FREE
return SubscriptionStatusResponse(
tier=tier.value,
monthly_cost=0,
tier_costs={"FREE": 0, "PRO": 0, "BUSINESS": 0, "ENTERPRISE": 0},
)
@v1_router.post(
path="/credits/subscription",
summary="Start a Stripe Checkout session to upgrade subscription tier",
operation_id="updateSubscriptionTier",
tags=["credits"],
dependencies=[Security(requires_user)],
)
async def update_subscription_tier(
request: SubscriptionTierRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> SubscriptionCheckoutResponse:
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
tier = SubscriptionTier(request.tier)
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
user = await get_user_by_id(user_id)
if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE:
raise HTTPException(
status_code=403,
detail="ENTERPRISE subscription changes must be managed by an administrator",
)
payment_enabled = await is_feature_enabled(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
if tier == SubscriptionTier.FREE:
if payment_enabled:
await cancel_stripe_subscription(user_id)
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
# Beta users (payment not enabled) → update tier directly without Stripe.
if not payment_enabled:
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
# Paid upgrade → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
raise HTTPException(
status_code=422,
detail="success_url and cancel_url are required for paid tier upgrades",
)
try:
url = await create_subscription_checkout(
user_id=user_id,
tier=tier,
success_url=request.success_url,
cancel_url=request.cancel_url,
)
except (ValueError, stripe.StripeError) as e:
raise HTTPException(status_code=422, detail=str(e))
return SubscriptionCheckoutResponse(url=url)
@v1_router.post(
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
)
@@ -709,6 +809,13 @@ async def stripe_webhook(request: Request):
):
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
if event["type"] in (
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
):
await sync_subscription_from_stripe(event["data"]["object"])
if event["type"] == "charge.dispute.created":
await UserCredit().handle_dispute(event["data"]["object"])

View File

@@ -53,7 +53,6 @@ from backend.copilot.response_model import (
)
from backend.copilot.service import (
_build_system_prompt,
_get_anthropic_client,
_get_openai_client,
_update_title_async,
config,
@@ -84,8 +83,6 @@ from backend.util.tool_call_loop import (
)
if TYPE_CHECKING:
from langfuse.openai import AsyncOpenAI as LangfuseAsyncOpenAI
from backend.copilot.permissions import CopilotPermissions
logger = logging.getLogger(__name__)
@@ -232,23 +229,6 @@ def _resolve_baseline_model(mode: CopilotMode | None) -> str:
return config.model
def _is_anthropic_model(model: str) -> bool:
"""Return True if *model* should be routed to the Anthropic API directly."""
return model.startswith("claude-") or model.startswith("anthropic/")
def _get_baseline_client(model: str) -> "LangfuseAsyncOpenAI":
"""Return the right OpenAI-compatible client for *model*.
Anthropic models are sent directly to the Anthropic API when an
``ANTHROPIC_API_KEY`` is configured; everything else goes through
OpenRouter.
"""
if _is_anthropic_model(model) and config.anthropic_api_key:
return _get_anthropic_client()
return _get_openai_client()
# Tag pairs to strip from baseline streaming output. Different models use
# different tag names for their internal reasoning (Claude uses <thinking>,
# Gemini uses <internal_reasoning>, etc.).
@@ -379,7 +359,7 @@ async def _baseline_llm_caller(
round_text = ""
response = None # initialized before try so finally block can access it
try:
client = _get_baseline_client(state.model)
client = _get_openai_client()
typed_messages = cast(list[ChatCompletionMessageParam], messages)
if tools:
typed_tools = cast(list[ChatCompletionToolParam], tools)
@@ -749,7 +729,7 @@ async def _compress_session_messages(
result = await compress_context(
messages=messages_dict,
model=model,
client=_get_baseline_client(model),
client=_get_openai_client(),
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)

View File

@@ -828,94 +828,3 @@ class TestBaselineCostExtraction:
# response was never assigned so cost extraction must not raise
assert state.cost_usd is None
class TestGetBaselineClient:
"""Tests for _get_baseline_client routing logic."""
def test_anthropic_model_uses_anthropic_client(self):
from backend.copilot.baseline.service import _get_baseline_client
mock_anthropic = MagicMock()
mock_openai = MagicMock()
with (
patch(
"backend.copilot.baseline.service._get_anthropic_client",
return_value=mock_anthropic,
),
patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_openai,
),
patch(
"backend.copilot.baseline.service.config",
anthropic_api_key="sk-ant-test",
),
):
client = _get_baseline_client("claude-sonnet-4-20250514")
assert client is mock_anthropic
def test_openrouter_model_uses_openai_client(self):
from backend.copilot.baseline.service import _get_baseline_client
mock_anthropic = MagicMock()
mock_openai = MagicMock()
with (
patch(
"backend.copilot.baseline.service._get_anthropic_client",
return_value=mock_anthropic,
),
patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_openai,
),
patch(
"backend.copilot.baseline.service.config",
anthropic_api_key="sk-ant-test",
),
):
client = _get_baseline_client("openai/gpt-4o-mini")
assert client is mock_openai
def test_anthropic_model_without_key_falls_back_to_openrouter(self):
from backend.copilot.baseline.service import _get_baseline_client
mock_anthropic = MagicMock()
mock_openai = MagicMock()
with (
patch(
"backend.copilot.baseline.service._get_anthropic_client",
return_value=mock_anthropic,
),
patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_openai,
),
patch(
"backend.copilot.baseline.service.config",
anthropic_api_key=None,
),
):
client = _get_baseline_client("claude-sonnet-4-20250514")
assert client is mock_openai
class TestIsAnthropicModel:
"""Tests for _is_anthropic_model helper."""
def test_claude_prefix(self):
from backend.copilot.baseline.service import _is_anthropic_model
assert _is_anthropic_model("claude-sonnet-4-20250514") is True
assert _is_anthropic_model("claude-opus-4-20250514") is True
def test_anthropic_slash_prefix(self):
from backend.copilot.baseline.service import _is_anthropic_model
assert _is_anthropic_model("anthropic/claude-sonnet-4") is True
def test_non_anthropic(self):
from backend.copilot.baseline.service import _is_anthropic_model
assert _is_anthropic_model("openai/gpt-4o-mini") is False
assert _is_anthropic_model("google/gemini-2.5-flash") is False

View File

@@ -8,8 +8,6 @@ from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
ANTHROPIC_BASE_URL = "https://api.anthropic.com/v1"
# Per-request routing mode for a single chat turn.
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
# - 'extended_thinking': route to the Claude Agent SDK path with the default
@@ -24,11 +22,11 @@ class ChatConfig(BaseSettings):
# OpenAI API Configuration
model: str = Field(
default="claude-opus-4-20250514",
default="anthropic/claude-opus-4.6",
description="Default model for extended thinking mode",
)
fast_model: str = Field(
default="claude-sonnet-4-20250514",
default="anthropic/claude-sonnet-4",
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
)
title_model: str = Field(
@@ -40,10 +38,6 @@ class ChatConfig(BaseSettings):
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
anthropic_api_key: str | None = Field(
default=None,
description="Anthropic API key for direct Anthropic API access (baseline path)",
)
base_url: str | None = Field(
default=OPENROUTER_BASE_URL,
description="Base URL for API (e.g., for OpenRouter)",
@@ -285,14 +279,6 @@ class ChatConfig(BaseSettings):
# would pair it with the OpenRouter base_url, causing auth failures.
return v
@field_validator("anthropic_api_key", mode="before")
@classmethod
def get_anthropic_api_key(cls, v):
"""Get Anthropic API key from environment if not provided."""
if not v:
v = os.getenv("ANTHROPIC_API_KEY")
return v
@field_validator("base_url", mode="before")
@classmethod
def get_base_url(cls, v):

View File

@@ -14,7 +14,6 @@ _ENV_VARS_TO_CLEAR = (
"CHAT_API_KEY",
"OPEN_ROUTER_API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"CHAT_BASE_URL",
"OPENROUTER_BASE_URL",
"OPENAI_BASE_URL",
@@ -71,38 +70,6 @@ class TestOpenrouterActive:
assert cfg.openrouter_active is False
class TestAnthropicApiKey:
"""Tests for the anthropic_api_key field and validator."""
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test")
cfg = ChatConfig()
assert cfg.anthropic_api_key == "sk-ant-test"
def test_none_when_not_set(self):
cfg = ChatConfig()
assert cfg.anthropic_api_key is None
def test_explicit_value_overrides_env(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "from-env")
cfg = ChatConfig(anthropic_api_key="explicit")
assert cfg.anthropic_api_key == "explicit"
class TestDefaultModelNames:
"""Default model names should use direct Anthropic IDs (not OpenRouter format)."""
def test_default_model_is_direct_anthropic(self):
cfg = ChatConfig()
assert "/" not in cfg.model
assert cfg.model.startswith("claude-")
def test_fast_model_is_direct_anthropic(self):
cfg = ChatConfig()
assert "/" not in cfg.fast_model
assert cfg.fast_model.startswith("claude-")
class TestE2BActive:
"""Tests for the e2b_active property — single source of truth for E2B usage."""

View File

@@ -21,7 +21,7 @@ from backend.data.understanding import format_understanding_for_prompt
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.settings import AppEnvironment, Settings
from .config import ANTHROPIC_BASE_URL, ChatConfig
from .config import ChatConfig
from .model import (
ChatSessionInfo,
get_chat_session,
@@ -35,7 +35,6 @@ config = ChatConfig()
settings = Settings()
_client: LangfuseAsyncOpenAI | None = None
_anthropic_client: LangfuseAsyncOpenAI | None = None
_langfuse = None
@@ -46,16 +45,6 @@ def _get_openai_client() -> LangfuseAsyncOpenAI:
return _client
def _get_anthropic_client() -> LangfuseAsyncOpenAI:
"""Return an OpenAI-compatible client pointed at the Anthropic API."""
global _anthropic_client
if _anthropic_client is None:
_anthropic_client = LangfuseAsyncOpenAI(
api_key=config.anthropic_api_key, base_url=ANTHROPIC_BASE_URL
)
return _anthropic_client
def _get_langfuse():
global _langfuse
if _langfuse is None:

View File

@@ -10,6 +10,7 @@ from prisma.enums import (
CreditTransactionType,
NotificationType,
OnboardingStep,
SubscriptionTier,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
@@ -31,7 +32,7 @@ from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.util.exceptions import InsufficientBalanceError
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled
from backend.util.json import SafeJson, dumps
from backend.util.models import Pagination
from backend.util.retry import func_retry
@@ -1144,10 +1145,12 @@ class BetaUserCredit(UserCredit):
if (snapshot_time.year, snapshot_time.month) == (cur_time.year, cur_time.month):
return balance
target = self.num_user_credits_refill
try:
balance, _ = await self._add_transaction(
user_id=user_id,
amount=max(self.num_user_credits_refill - balance, 0),
amount=max(target - balance, 0),
transaction_type=CreditTransactionType.GRANT,
transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
metadata=SafeJson({"reason": "Monthly credit refill"}),
@@ -1250,6 +1253,33 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig):
where={"id": user_id},
data={"topUpConfig": SafeJson(config.model_dump())},
)
get_user_by_id.cache_delete(user_id)
async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
"""Set the user's subscription tier (used by webhook and admin flows)."""
await User.prisma().update(
where={"id": user_id},
data={"subscriptionTier": tier},
)
get_user_by_id.cache_delete(user_id)
async def cancel_stripe_subscription(user_id: str) -> None:
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
customer_id = await get_stripe_customer_id(user_id)
subscriptions = stripe.Subscription.list(
customer=customer_id, status="active", limit=10
)
for sub in subscriptions.auto_paging_iter():
try:
stripe.Subscription.cancel(sub["id"])
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: failed to cancel sub %s for user %s",
sub["id"],
user_id,
)
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
@@ -1261,6 +1291,78 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
return AutoTopUpConfig.model_validate(user.top_up_config)
async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
"""Return Stripe Price ID for a tier from LaunchDarkly. None = not configured."""
flag_map = {
SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO,
SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS,
}
flag = flag_map.get(tier)
if flag is None:
return None
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
return price_id if isinstance(price_id, str) and price_id else None
async def create_subscription_checkout(
user_id: str,
tier: SubscriptionTier,
success_url: str,
cancel_url: str,
) -> str:
"""Create a Stripe Checkout Session for a subscription. Returns the redirect URL."""
price_id = await get_subscription_price_id(tier)
if not price_id:
raise ValueError(f"Subscription not available for tier {tier.value}")
customer_id = await get_stripe_customer_id(user_id)
session = stripe.checkout.Session.create(
customer=customer_id,
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
success_url=success_url,
cancel_url=cancel_url,
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
)
return session.url or ""
async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
"""Update User.subscriptionTier from a Stripe subscription object."""
customer_id = stripe_subscription["customer"]
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"sync_subscription_from_stripe: no user for customer %s", customer_id
)
return
status = stripe_subscription.get("status", "")
if status in ("active", "trialing"):
price_id = ""
items = stripe_subscription.get("items", {}).get("data", [])
if items:
price_id = items[0].get("price", {}).get("id", "")
pro_price = await get_subscription_price_id(SubscriptionTier.PRO)
biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS)
if price_id and pro_price and price_id == pro_price:
tier = SubscriptionTier.PRO
elif price_id and biz_price and price_id == biz_price:
tier = SubscriptionTier.BUSINESS
else:
# Unknown or unconfigured price ID — preserve the user's current tier
# rather than defaulting to FREE. This prevents accidental downgrades
# during a price migration or when LD flags are not yet configured.
logger.warning(
"sync_subscription_from_stripe: unknown price %s for customer %s,"
" preserving current tier",
price_id,
customer_id,
)
return
else:
tier = SubscriptionTier.FREE
await set_subscription_tier(user.id, tier)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,

View File

@@ -0,0 +1,360 @@
"""
Tests for Stripe-based subscription tier billing.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import SubscriptionTier
from prisma.models import User
from backend.data.credit import (
cancel_stripe_subscription,
create_subscription_checkout,
set_subscription_tier,
sync_subscription_from_stripe,
)
@pytest.mark.asyncio
async def test_set_subscription_tier_updates_db():
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(update=AsyncMock()),
) as mock_prisma,
patch("backend.data.credit.get_user_by_id"),
):
await set_subscription_tier("user-1", SubscriptionTier.PRO)
mock_prisma.return_value.update.assert_awaited_once_with(
where={"id": "user-1"},
data={"subscriptionTier": SubscriptionTier.PRO},
)
@pytest.mark.asyncio
async def test_set_subscription_tier_downgrade():
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(update=AsyncMock()),
),
patch("backend.data.credit.get_user_by_id"),
):
# Downgrade to FREE should not raise
await set_subscription_tier("user-1", SubscriptionTier.FREE)
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_active():
mock_user = MagicMock(spec=User)
mock_user.id = "user-1"
stripe_sub = {
"customer": "cus_123",
"status": "active",
"items": {"data": [{"price": {"id": "price_pro_monthly"}}]},
}
async def mock_price_id(tier: SubscriptionTier) -> str | None:
if tier == SubscriptionTier.PRO:
return "price_pro_monthly"
if tier == SubscriptionTier.BUSINESS:
return "price_biz_monthly"
return None
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.get_subscription_price_id",
side_effect=mock_price_id,
),
patch(
"backend.data.credit.set_subscription_tier", new_callable=AsyncMock
) as mock_set,
):
await sync_subscription_from_stripe(stripe_sub)
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.PRO)
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_cancelled():
mock_user = MagicMock(spec=User)
mock_user.id = "user-1"
stripe_sub = {
"customer": "cus_123",
"status": "canceled",
"items": {"data": []},
}
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.set_subscription_tier", new_callable=AsyncMock
) as mock_set,
):
await sync_subscription_from_stripe(stripe_sub)
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.FREE)
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_unknown_customer():
stripe_sub = {
"customer": "cus_unknown",
"status": "active",
"items": {"data": []},
}
with patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=None)),
):
# Should not raise even if user not found
await sync_subscription_from_stripe(stripe_sub)
@pytest.mark.asyncio
async def test_cancel_stripe_subscription_cancels_active():
mock_sub = {"id": "sub_abc123"}
mock_subscriptions = MagicMock()
mock_subscriptions.auto_paging_iter.return_value = iter([mock_sub])
with (
patch(
"backend.data.credit.get_stripe_customer_id",
new_callable=AsyncMock,
return_value="cus_123",
),
patch(
"backend.data.credit.stripe.Subscription.list",
return_value=mock_subscriptions,
),
patch("backend.data.credit.stripe.Subscription.cancel") as mock_cancel,
):
await cancel_stripe_subscription("user-1")
mock_cancel.assert_called_once_with("sub_abc123")
@pytest.mark.asyncio
async def test_cancel_stripe_subscription_no_active():
mock_subscriptions = MagicMock()
mock_subscriptions.auto_paging_iter.return_value = iter([])
with (
patch(
"backend.data.credit.get_stripe_customer_id",
new_callable=AsyncMock,
return_value="cus_123",
),
patch(
"backend.data.credit.stripe.Subscription.list",
return_value=mock_subscriptions,
),
patch("backend.data.credit.stripe.Subscription.cancel") as mock_cancel,
):
await cancel_stripe_subscription("user-1")
mock_cancel.assert_not_called()
@pytest.mark.asyncio
async def test_create_subscription_checkout_returns_url():
mock_session = MagicMock()
mock_session.url = "https://checkout.stripe.com/pay/cs_test_abc123"
with (
patch(
"backend.data.credit.get_subscription_price_id",
new_callable=AsyncMock,
return_value="price_pro_monthly",
),
patch(
"backend.data.credit.get_stripe_customer_id",
new_callable=AsyncMock,
return_value="cus_123",
),
patch("stripe.checkout.Session.create", return_value=mock_session),
):
url = await create_subscription_checkout(
user_id="user-1",
tier=SubscriptionTier.PRO,
success_url="https://app.example.com/success",
cancel_url="https://app.example.com/cancel",
)
assert url == "https://checkout.stripe.com/pay/cs_test_abc123"
@pytest.mark.asyncio
async def test_create_subscription_checkout_no_price_raises():
with patch(
"backend.data.credit.get_subscription_price_id",
new_callable=AsyncMock,
return_value=None,
):
with pytest.raises(ValueError, match="not available"):
await create_subscription_checkout(
user_id="user-1",
tier=SubscriptionTier.PRO,
success_url="https://app.example.com/success",
cancel_url="https://app.example.com/cancel",
)
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_unknown_price_defaults_to_free():
"""Unknown price_id should default to FREE instead of returning early."""
mock_user = MagicMock(spec=User)
mock_user.id = "user-1"
stripe_sub = {
"customer": "cus_123",
"status": "active",
"items": {"data": [{"price": {"id": "price_unknown"}}]},
}
async def mock_price_id(tier: SubscriptionTier) -> str | None:
return "price_pro_monthly" if tier == SubscriptionTier.PRO else None
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.get_subscription_price_id",
side_effect=mock_price_id,
),
patch(
"backend.data.credit.set_subscription_tier", new_callable=AsyncMock
) as mock_set,
):
await sync_subscription_from_stripe(stripe_sub)
# Unknown price → preserve current tier (early return, no DB write)
mock_set.assert_not_awaited()
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_none_ld_price_defaults_to_free():
"""When LD returns None for price IDs, active subscription should default to FREE."""
mock_user = MagicMock(spec=User)
mock_user.id = "user-1"
stripe_sub = {
"customer": "cus_123",
"status": "active",
"items": {"data": [{"price": {"id": "price_pro_monthly"}}]},
}
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.get_subscription_price_id",
new_callable=AsyncMock,
return_value=None, # LD flags unconfigured
),
patch(
"backend.data.credit.set_subscription_tier", new_callable=AsyncMock
) as mock_set,
):
await sync_subscription_from_stripe(stripe_sub)
# None from LD → comparison guards prevent match → preserve current tier
mock_set.assert_not_awaited()
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_business_tier():
"""BUSINESS price_id should map to BUSINESS tier."""
mock_user = MagicMock(spec=User)
mock_user.id = "user-1"
stripe_sub = {
"customer": "cus_123",
"status": "active",
"items": {"data": [{"price": {"id": "price_biz_monthly"}}]},
}
async def mock_price_id(tier: SubscriptionTier) -> str | None:
if tier == SubscriptionTier.PRO:
return "price_pro_monthly"
if tier == SubscriptionTier.BUSINESS:
return "price_biz_monthly"
return None
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.get_subscription_price_id",
side_effect=mock_price_id,
),
patch(
"backend.data.credit.set_subscription_tier", new_callable=AsyncMock
) as mock_set,
):
await sync_subscription_from_stripe(stripe_sub)
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.BUSINESS)
@pytest.mark.asyncio
async def test_get_subscription_price_id_pro():
from backend.data.credit import get_subscription_price_id
with patch(
"backend.data.credit.get_feature_flag_value",
new_callable=AsyncMock,
return_value="price_pro_monthly",
):
price_id = await get_subscription_price_id(SubscriptionTier.PRO)
assert price_id == "price_pro_monthly"
@pytest.mark.asyncio
async def test_get_subscription_price_id_free_returns_none():
from backend.data.credit import get_subscription_price_id
price_id = await get_subscription_price_id(SubscriptionTier.FREE)
assert price_id is None
@pytest.mark.asyncio
async def test_get_subscription_price_id_empty_flag_returns_none():
from backend.data.credit import get_subscription_price_id
with patch(
"backend.data.credit.get_feature_flag_value",
new_callable=AsyncMock,
return_value="", # LD flag not set
):
price_id = await get_subscription_price_id(SubscriptionTier.BUSINESS)
assert price_id is None
@pytest.mark.asyncio
async def test_cancel_stripe_subscription_handles_stripe_error():
"""Stripe errors during cancellation should be logged, not raised."""
import stripe as stripe_mod
mock_sub = {"id": "sub_abc123"}
mock_subscriptions = MagicMock()
mock_subscriptions.auto_paging_iter.return_value = iter([mock_sub])
with (
patch(
"backend.data.credit.get_stripe_customer_id",
new_callable=AsyncMock,
return_value="cus_123",
),
patch(
"backend.data.credit.stripe.Subscription.list",
return_value=mock_subscriptions,
),
patch(
"backend.data.credit.stripe.Subscription.cancel",
side_effect=stripe_mod.StripeError("network error"),
),
):
# Should not raise — errors are logged as warnings
await cancel_stripe_subscription("user-1")

View File

@@ -71,6 +71,9 @@ class User(BaseModel):
top_up_config: Optional["AutoTopUpConfig"] = Field(
None, description="Top up configuration"
)
subscription_tier: SubscriptionTier = Field(
default=SubscriptionTier.FREE, description="User subscription tier"
)
# Notification preferences
max_emails_per_day: int = Field(default=3, description="Maximum emails per day")
@@ -103,9 +106,6 @@ class User(BaseModel):
description="User timezone (IANA timezone identifier or 'not-set')",
)
# Subscription / rate-limit tier
subscription_tier: SubscriptionTier | None = Field(default=None)
@classmethod
def from_db(cls, prisma_user: "PrismaUser") -> "User":
"""Convert a database User object to application User model."""
@@ -148,6 +148,7 @@ class User(BaseModel):
integrations=prisma_user.integrations or "",
stripe_customer_id=prisma_user.stripeCustomerId,
top_up_config=top_up_config,
subscription_tier=prisma_user.subscriptionTier or SubscriptionTier.FREE,
max_emails_per_day=prisma_user.maxEmailsPerDay or 3,
notify_on_agent_run=prisma_user.notifyOnAgentRun or True,
notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True,
@@ -160,7 +161,6 @@ class User(BaseModel):
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET,
subscription_tier=prisma_user.subscriptionTier,
)

View File

@@ -43,6 +43,8 @@ class Flag(str, Enum):
COPILOT_SDK = "copilot-sdk"
COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit"
COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit"
STRIPE_PRICE_PRO = "stripe-price-id-pro"
STRIPE_PRICE_BUSINESS = "stripe-price-id-business"
GRAPHITI_MEMORY = "graphiti-memory"

View File

@@ -0,0 +1,5 @@
-- SubscriptionTier enum and User.subscriptionTier column already created by
-- 20260326200000_add_rate_limit_tier migration. Only add SUBSCRIPTION transaction type.
-- AlterEnum
ALTER TYPE "CreditTransactionType" ADD VALUE IF NOT EXISTS 'SUBSCRIPTION';

View File

@@ -0,0 +1,30 @@
{
"pythonVersion": "3.12",
"venvPath": ".",
"venv": ".venv",
"include": ["backend"],
"ignore": [
"backend/**/*_test.py",
"backend/**/*conftest.py",
"backend/**/conftest.py",
"backend/**/_test_data.py",
"backend/**/*test_data*.py",
"backend/api/features/library/_add_to_library.py",
"backend/api/features/library/db.py",
"backend/api/features/store/db.py",
"backend/blocks/sql_query_helpers.py",
"backend/blocks/stagehand/blocks.py",
"backend/cli/oauth_tool.py",
"backend/copilot/db.py",
"backend/copilot/rate_limit.py",
"backend/data/auth/api_key.py",
"backend/data/auth/oauth.py",
"backend/data/execution.py",
"backend/data/graph.py",
"backend/data/human_review.py",
"backend/data/onboarding.py",
"backend/data/understanding.py",
"backend/data/workspace.py",
"backend/sdk/__init__.py"
]
}

View File

@@ -774,6 +774,7 @@ enum CreditTransactionType {
GRANT
REFUND
CARD_CHECK
SUBSCRIPTION
}
model CreditTransaction {

View File

@@ -95,8 +95,8 @@ export function buildReactArtifactSrcDoc(
}
</style>
<script src="${TAILWIND_CDN_URL}"></script>
<script crossorigin="anonymous" src="https://unpkg.com/react@18.3.1/umd/react.production.min.js" integrity="sha384-DGyLxAyjq0f9SPpVevD6IgztCFlnMF6oW/XQGmfe+IsZ8TqEiDrcHkMLKI6fiB/Z"></script>
<script crossorigin="anonymous" src="https://unpkg.com/react-dom@18.3.1/umd/react-dom.production.min.js" integrity="sha384-gTGxhz21lVGYNMcdJOyq01Edg0jhn/c22nsx0kyqP0TxaV5WVdsSH1fSDUf5YJj1"></script>
<script crossorigin="anonymous" src="https://unpkg.com/react@18.3.1/umd/react.production.min.js" integrity="sha384-DGyLxAyjq0f9SPpVevD6IgztCFlnMF6oW/XQGmfe+IsZ8TqEiDrcHkMLKI6fiB/Z"></script><!-- pragma: allowlist secret -->
<script crossorigin="anonymous" src="https://unpkg.com/react-dom@18.3.1/umd/react-dom.production.min.js" integrity="sha384-gTGxhz21lVGYNMcdJOyq01Edg0jhn/c22nsx0kyqP0TxaV5WVdsSH1fSDUf5YJj1"></script><!-- pragma: allowlist secret -->
</head>
<body>
<div id="root"></div>

View File

@@ -0,0 +1,140 @@
"use client";
import { useState } from "react";
import { Button } from "@/components/__legacy__/ui/button";
import { useSubscriptionTierSection } from "./useSubscriptionTierSection";
type TierInfo = {
key: string;
label: string;
multiplier: string;
description: string;
};
const TIERS: TierInfo[] = [
{
key: "FREE",
label: "Free",
multiplier: "1x",
description: "Base rate limits",
},
{
key: "PRO",
label: "Pro",
multiplier: "5x",
description: "5x more AutoPilot capacity",
},
{
key: "BUSINESS",
label: "Business",
multiplier: "20x",
description: "20x more AutoPilot capacity",
},
];
function formatCost(cents: number): string {
if (cents === 0) return "Free";
return `$${(cents / 100).toFixed(2)}/mo`;
}
export function SubscriptionTierSection() {
const { subscription, isLoading, error, isPending, changeTier } =
useSubscriptionTierSection();
const [tierError, setTierError] = useState<string | null>(null);
if (isLoading) return null;
if (error) {
return (
<div className="space-y-4">
<h3 className="text-lg font-medium">Subscription Plan</h3>
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
{error}
</p>
</div>
);
}
if (!subscription) return null;
async function handleTierChange(tierKey: string) {
setTierError(null);
const err = await changeTier(tierKey);
if (err) setTierError(err);
}
return (
<div className="space-y-4">
<h3 className="text-lg font-medium">Subscription Plan</h3>
{tierError && (
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
{tierError}
</p>
)}
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
{TIERS.map((tier) => {
const isCurrent = subscription.tier === tier.key;
const cost = subscription.tier_costs[tier.key] ?? 0;
const currentTierOrder = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
const currentIdx = currentTierOrder.indexOf(subscription.tier);
const targetIdx = currentTierOrder.indexOf(tier.key);
const isUpgrade = targetIdx > currentIdx;
const isDowngrade = targetIdx < currentIdx;
return (
<div
key={tier.key}
className={`rounded-lg border p-4 ${
isCurrent
? "border-violet-500 bg-violet-50 dark:bg-violet-900/20"
: "border-neutral-200 dark:border-neutral-700"
}`}
>
<div className="mb-2 flex items-center justify-between">
<span className="font-semibold">{tier.label}</span>
{isCurrent && (
<span className="rounded-full bg-violet-100 px-2 py-0.5 text-xs font-medium text-violet-700 dark:bg-violet-800 dark:text-violet-200">
Current
</span>
)}
</div>
<p className="mb-1 text-2xl font-bold">{formatCost(cost)}</p>
<p className="mb-1 text-sm font-medium text-neutral-600 dark:text-neutral-400">
{tier.multiplier} rate limits
</p>
<p className="mb-4 text-sm text-neutral-500 dark:text-neutral-400">
{tier.description}
</p>
{!isCurrent && (
<Button
className="w-full"
variant={isUpgrade ? "default" : "outline"}
disabled={isPending}
onClick={() => handleTierChange(tier.key)}
>
{isPending
? "Updating..."
: isUpgrade
? `Upgrade to ${tier.label}`
: isDowngrade
? `Downgrade to ${tier.label}`
: `Switch to ${tier.label}`}
</Button>
)}
</div>
);
})}
</div>
{subscription.tier !== "FREE" && (
<p className="text-sm text-neutral-500">
Your subscription is managed through Stripe. Changes take effect
immediately.
</p>
)}
</div>
);
}

View File

@@ -0,0 +1,55 @@
import {
useGetSubscriptionStatus,
useUpdateSubscriptionTier,
} from "@/app/api/__generated__/endpoints/credits/credits";
import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse";
import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier";
export type SubscriptionStatus = SubscriptionStatusResponse;
export function useSubscriptionTierSection() {
const {
data: subscription,
isLoading,
error: queryError,
refetch,
} = useGetSubscriptionStatus({
query: { select: (data) => (data.status === 200 ? data.data : null) },
});
const error = queryError ? "Failed to load subscription info" : null;
const { mutateAsync: doUpdateTier, isPending } = useUpdateSubscriptionTier();
async function changeTier(tier: string): Promise<string | null> {
try {
const successUrl = `${window.location.origin}${window.location.pathname}?subscription=success`;
const cancelUrl = `${window.location.origin}${window.location.pathname}?subscription=cancelled`;
const result = await doUpdateTier({
data: {
tier: tier as SubscriptionTierRequestTier,
success_url: successUrl,
cancel_url: cancelUrl,
},
});
if (result.status === 200 && result.data.url) {
window.location.href = result.data.url;
return null;
}
await refetch();
return null;
} catch (e: unknown) {
const msg =
e instanceof Error ? e.message : "Failed to change subscription tier";
return msg;
}
}
return {
subscription: subscription ?? null,
isLoading,
error,
isPending,
changeTier,
};
}

View File

@@ -10,6 +10,7 @@ import {
} from "@/components/molecules/Toast/use-toast";
import { RefundModal } from "./RefundModal";
import { SubscriptionTierSection } from "./components/SubscriptionTierSection/SubscriptionTierSection";
import { CreditTransaction } from "@/lib/autogpt-server-api";
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
@@ -141,6 +142,11 @@ export default function CreditsPage() {
Billing
</h1>
{/* Subscription Tier */}
<div className="mb-8">
<SubscriptionTierSection />
</div>
<div className="grid grid-cols-1 gap-8 lg:grid-cols-2">
{/* Top-up Form */}
<div className="space-y-4">

View File

@@ -2171,6 +2171,68 @@
}
}
},
"/api/credits/subscription": {
"get": {
"tags": ["v1", "credits"],
"summary": "Get subscription tier, current cost, and all tier costs",
"operationId": "getSubscriptionStatus",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SubscriptionStatusResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
},
"post": {
"tags": ["v1", "credits"],
"summary": "Start a Stripe Checkout session to upgrade subscription tier",
"operationId": "updateSubscriptionTier",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SubscriptionTierRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SubscriptionCheckoutResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/credits/transactions": {
"get": {
"tags": ["v1", "credits"],
@@ -9209,7 +9271,14 @@
},
"CreditTransactionType": {
"type": "string",
"enum": ["TOP_UP", "USAGE", "GRANT", "REFUND", "CARD_CHECK"],
"enum": [
"TOP_UP",
"USAGE",
"GRANT",
"REFUND",
"CARD_CHECK",
"SUBSCRIPTION"
],
"title": "CreditTransactionType"
},
"DeleteFileResponse": {
@@ -13622,12 +13691,54 @@
"enum": ["DRAFT", "PENDING", "APPROVED", "REJECTED"],
"title": "SubmissionStatus"
},
"SubscriptionCheckoutResponse": {
"properties": { "url": { "type": "string", "title": "Url" } },
"type": "object",
"required": ["url"],
"title": "SubscriptionCheckoutResponse"
},
"SubscriptionStatusResponse": {
"properties": {
"tier": { "type": "string", "title": "Tier" },
"monthly_cost": { "type": "integer", "title": "Monthly Cost" },
"tier_costs": {
"additionalProperties": { "type": "integer" },
"type": "object",
"title": "Tier Costs"
}
},
"type": "object",
"required": ["tier", "monthly_cost", "tier_costs"],
"title": "SubscriptionStatusResponse"
},
"SubscriptionTier": {
"type": "string",
"enum": ["FREE", "PRO", "BUSINESS", "ENTERPRISE"],
"title": "SubscriptionTier",
"description": "Subscription tiers with increasing token allowances.\n\nMirrors the ``SubscriptionTier`` enum in ``schema.prisma``.\nOnce ``prisma generate`` is run, this can be replaced with::\n\n from prisma.enums import SubscriptionTier"
},
"SubscriptionTierRequest": {
"properties": {
"tier": {
"type": "string",
"enum": ["FREE", "PRO", "BUSINESS"],
"title": "Tier"
},
"success_url": {
"type": "string",
"title": "Success Url",
"default": ""
},
"cancel_url": {
"type": "string",
"title": "Cancel Url",
"default": ""
}
},
"type": "object",
"required": ["tier"],
"title": "SubscriptionTierRequest"
},
"SuggestedGoalResponse": {
"properties": {
"type": {

View File

@@ -194,6 +194,26 @@ export default class BackendAPI {
return this._request("PATCH", "/credits");
}
getSubscription(): Promise<{
tier: string;
monthly_cost: number;
tier_costs: Record<string, number>;
}> {
return this._get("/credits/subscription");
}
setSubscriptionTier(
tier: string,
successUrl?: string,
cancelUrl?: string,
): Promise<{ url: string }> {
return this._request("POST", "/credits/subscription", {
tier,
success_url: successUrl ?? "",
cancel_url: cancelUrl ?? "",
});
}
////////////////////////////////////////
//////////////// GRAPHS ////////////////
////////////////////////////////////////

Binary file not shown.

After

Width:  |  Height:  |  Size: 191 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB