diff --git a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py index 2d04bf20aa..563ff323a9 100644 --- a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py @@ -55,12 +55,12 @@ def test_get_subscription_status_pro( mock_user = Mock() mock_user.subscription_tier = SubscriptionTier.PRO - mock_price = Mock() - mock_price.unit_amount = 1999 # $19.99 - async def mock_price_id(tier: SubscriptionTier) -> str | None: return "price_pro" if tier == SubscriptionTier.PRO else None + async def mock_stripe_price_amount(price_id: str) -> int: + return 1999 if price_id == "price_pro" else 0 + mocker.patch( "backend.api.features.v1.get_user_by_id", new_callable=AsyncMock, @@ -71,8 +71,8 @@ def test_get_subscription_status_pro( side_effect=mock_price_id, ) mocker.patch( - "backend.api.features.v1.stripe.Price.retrieve", - return_value=mock_price, + "backend.api.features.v1._get_stripe_price_amount", + side_effect=mock_stripe_price_amount, ) response = client.get("/credits/subscription") diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index 7737f864f1..6a9718f1f6 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -729,6 +729,21 @@ def _validate_checkout_redirect_url(url: str) -> bool: ) +@cached(ttl_seconds=300, maxsize=32) +async def _get_stripe_price_amount(price_id: str) -> int: + """Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes. + + Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on + every GET /credits/subscription page load and reduces quota consumption. + """ + try: + price = await run_in_threadpool(stripe.Price.retrieve, price_id) + return price.unit_amount or 0 + except stripe.StripeError: + logger.warning("Failed to retrieve Stripe price %s — returning 0", price_id) + return 0 + + @v1_router.get( path="/credits/subscription", summary="Get subscription tier, current cost, and all tier costs", @@ -747,15 +762,16 @@ async def get_subscription_status( *[get_subscription_price_id(t) for t in paid_tiers] ) - tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0} - for t, price_id in zip(paid_tiers, price_ids): - cost = 0 - if price_id: - try: - price = await run_in_threadpool(stripe.Price.retrieve, price_id) - cost = price.unit_amount or 0 - except stripe.StripeError: - pass + tier_costs: dict[str, int] = { + SubscriptionTier.FREE.value: 0, + SubscriptionTier.ENTERPRISE.value: 0, + } + + async def _cost(pid: str | None) -> int: + return await _get_stripe_price_amount(pid) if pid else 0 + + costs = await asyncio.gather(*[_cost(pid) for pid in price_ids]) + for t, cost in zip(paid_tiers, costs): tier_costs[t.value] = cost return SubscriptionStatusResponse( diff --git a/autogpt_platform/backend/backend/data/credit_subscription_test.py b/autogpt_platform/backend/backend/data/credit_subscription_test.py index 29babc4bdc..abdea94092 100644 --- a/autogpt_platform/backend/backend/data/credit_subscription_test.py +++ b/autogpt_platform/backend/backend/data/credit_subscription_test.py @@ -480,8 +480,8 @@ async def test_create_subscription_checkout_no_price_raises(): @pytest.mark.asyncio -async def test_sync_subscription_from_stripe_unknown_price_preserves_current_tier(): - """Unknown price_id should preserve the current tier (no DB write).""" +async def test_sync_subscription_from_stripe_unknown_price_id_preserves_current_tier(): + """Unknown price_id should preserve the current tier, not default to FREE (no DB write).""" mock_user = _make_user(tier=SubscriptionTier.PRO) stripe_sub = { "customer": "cus_123", @@ -511,8 +511,8 @@ async def test_sync_subscription_from_stripe_unknown_price_preserves_current_tie @pytest.mark.asyncio -async def test_sync_subscription_from_stripe_none_ld_price_preserves_current_tier(): - """When LD returns None for price IDs, the current tier should be preserved.""" +async def test_sync_subscription_from_stripe_unconfigured_ld_price_preserves_current_tier(): + """When LD flags are unconfigured (None price IDs), the current tier should be preserved, not defaulted to FREE.""" mock_user = _make_user(tier=SubscriptionTier.PRO) stripe_sub = { "customer": "cus_123", diff --git a/autogpt_platform/backend/backend/data/platform_cost_test.py b/autogpt_platform/backend/backend/data/platform_cost_test.py index dacd2c42ea..4a2372628b 100644 --- a/autogpt_platform/backend/backend/data/platform_cost_test.py +++ b/autogpt_platform/backend/backend/data/platform_cost_test.py @@ -35,7 +35,6 @@ class TestUsdToMicrodollars: assert usd_to_microdollars(1.0) == 1_000_000 - class TestMaskEmail: def test_typical_email(self): assert _mask_email("user@example.com") == "us***@example.com" diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx index 774fe01ed9..e3351e4458 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx @@ -1,6 +1,14 @@ "use client"; import { useState } from "react"; import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/__legacy__/ui/dialog"; import { useSubscriptionTierSection } from "./useSubscriptionTierSection"; type TierInfo = { @@ -15,31 +23,43 @@ const TIERS: TierInfo[] = [ key: "FREE", label: "Free", multiplier: "1x", - description: "Base rate limits", + description: "Base AutoPilot capacity with standard rate limits", }, { key: "PRO", label: "Pro", multiplier: "5x", - description: "5x more AutoPilot capacity", + description: "5x AutoPilot capacity — run 5× more tasks per day/week", }, { key: "BUSINESS", label: "Business", multiplier: "20x", - description: "20x more AutoPilot capacity", + description: "20x AutoPilot capacity — ideal for teams and heavy workloads", }, ]; -function formatCost(cents: number): string { - if (cents === 0) return "Free"; +const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; + +function formatCost(cents: number, tierKey: string): string { + if (tierKey === "FREE") return "Free"; + if (cents === 0) return "Pricing available soon"; return `$${(cents / 100).toFixed(2)}/mo`; } export function SubscriptionTierSection() { - const { subscription, isLoading, error, isPending, changeTier } = - useSubscriptionTierSection(); - const [tierError, setTierError] = useState(null); + const { + subscription, + isLoading, + error, + tierError, + isPending, + pendingTier, + changeTier, + } = useSubscriptionTierSection(); + const [confirmDowngradeTo, setConfirmDowngradeTo] = useState( + null, + ); if (isLoading) return null; @@ -47,7 +67,10 @@ export function SubscriptionTierSection() { return (

Subscription Plan

-

+

{error}

@@ -56,10 +79,40 @@ export function SubscriptionTierSection() { if (!subscription) return null; - async function handleTierChange(tierKey: string) { - setTierError(null); - const err = await changeTier(tierKey); - if (err) setTierError(err); + const currentTier = subscription.tier; + + if (currentTier === "ENTERPRISE") { + return ( +
+

Subscription Plan

+
+

+ Enterprise Plan +

+

+ Your Enterprise plan is managed by your administrator. Contact your + account team for changes. +

+
+
+ ); + } + + function handleTierChange(tierKey: string) { + const currentIdx = TIER_ORDER.indexOf(currentTier); + const targetIdx = TIER_ORDER.indexOf(tierKey); + if (targetIdx < currentIdx) { + setConfirmDowngradeTo(tierKey); + return; + } + changeTier(tierKey); + } + + async function confirmDowngrade() { + if (!confirmDowngradeTo) return; + const tier = confirmDowngradeTo; + setConfirmDowngradeTo(null); + await changeTier(tier); } return ( @@ -67,24 +120,28 @@ export function SubscriptionTierSection() {

Subscription Plan

{tierError && ( -

+

{tierError}

)}
{TIERS.map((tier) => { - const isCurrent = subscription.tier === tier.key; + const isCurrent = currentTier === tier.key; const cost = subscription.tier_costs[tier.key] ?? 0; - const currentTierOrder = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; - const currentIdx = currentTierOrder.indexOf(subscription.tier); - const targetIdx = currentTierOrder.indexOf(tier.key); + const currentIdx = TIER_ORDER.indexOf(currentTier); + const targetIdx = TIER_ORDER.indexOf(tier.key); const isUpgrade = targetIdx > currentIdx; const isDowngrade = targetIdx < currentIdx; + const isThisPending = pendingTier === tier.key; return (
-

{formatCost(cost)}

+

+ {formatCost(cost, tier.key)} +

{tier.multiplier} rate limits

@@ -115,7 +174,7 @@ export function SubscriptionTierSection() { disabled={isPending} onClick={() => handleTierChange(tier.key)} > - {isPending + {isThisPending ? "Updating..." : isUpgrade ? `Upgrade to ${tier.label}` @@ -129,12 +188,40 @@ export function SubscriptionTierSection() { })}
- {subscription.tier !== "FREE" && ( + {currentTier !== "FREE" && (

Your subscription is managed through Stripe. Changes take effect immediately.

)} + + !open && setConfirmDowngradeTo(null)} + > + + + Confirm Downgrade + + {confirmDowngradeTo === "FREE" + ? "Downgrading to Free will cancel your current Stripe subscription immediately and remove your paid-tier rate limit increases." + : `Switching to ${confirmDowngradeTo} will take effect immediately.`}{" "} + Are you sure? + + + + + + + +
); } diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts index b0fe635b72..5ccc37a3d2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts @@ -1,13 +1,22 @@ +import { useEffect, useRef, useState } from "react"; +import { useSearchParams } from "next/navigation"; import { useGetSubscriptionStatus, useUpdateSubscriptionTier, } from "@/app/api/__generated__/endpoints/credits/credits"; import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse"; import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier"; +import { useToast } from "@/components/molecules/Toast/use-toast"; export type SubscriptionStatus = SubscriptionStatusResponse; export function useSubscriptionTierSection() { + const searchParams = useSearchParams(); + const subscriptionStatus = searchParams.get("subscription"); + const { toast } = useToast(); + const toastShownRef = useRef(false); + const [tierError, setTierError] = useState(null); + const { data: subscription, isLoading, @@ -17,11 +26,28 @@ export function useSubscriptionTierSection() { query: { select: (data) => (data.status === 200 ? data.data : null) }, }); - const error = queryError ? "Failed to load subscription info" : null; + const fetchError = queryError ? "Failed to load subscription info" : null; - const { mutateAsync: doUpdateTier, isPending } = useUpdateSubscriptionTier(); + const { + mutateAsync: doUpdateTier, + isPending, + variables, + } = useUpdateSubscriptionTier(); - async function changeTier(tier: string): Promise { + useEffect(() => { + if (subscriptionStatus === "success" && !toastShownRef.current) { + toastShownRef.current = true; + refetch(); + toast({ + title: "Subscription upgraded", + description: + "Your plan has been updated. It may take a moment to reflect.", + }); + } + }, [subscriptionStatus, refetch, toast]); + + async function changeTier(tier: string) { + setTierError(null); try { const successUrl = `${window.location.origin}${window.location.pathname}?subscription=success`; const cancelUrl = `${window.location.origin}${window.location.pathname}?subscription=cancelled`; @@ -34,22 +60,26 @@ export function useSubscriptionTierSection() { }); if (result.status === 200 && result.data.url) { window.location.href = result.data.url; - return null; + return; } await refetch(); - return null; } catch (e: unknown) { const msg = e instanceof Error ? e.message : "Failed to change subscription tier"; - return msg; + setTierError(msg); } } + const pendingTier = + isPending && variables?.data?.tier ? variables.data.tier : null; + return { subscription: subscription ?? null, isLoading, - error, + error: fetchError, + tierError, isPending, + pendingTier, changeTier, }; }