mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(platform): address remaining PR review comments on subscription billing
Backend: - Cache stripe.Price.retrieve with 5-min TTL via _get_stripe_price_amount to avoid 200-600ms Stripe round-trip on every GET /credits/subscription - Use SubscriptionTier enum .value for FREE/ENTERPRISE in tier_costs dict for consistency (instead of hardcoded strings) - Rename misleading test names: "defaults_to_FREE" → "preserves_current_tier" to reflect actual behaviour (unknown price IDs preserve tier, not reset) - Update subscription_routes_test to mock _get_stripe_price_amount instead of stripe.Price.retrieve directly, avoiding cached-result interference Frontend: - Handle ?subscription=success return from Stripe Checkout: refetch + toast - Add downgrade confirmation Dialog before cancelling paid subscription - Handle ENTERPRISE tier: render dedicated admin-managed plan card, not the FREE/PRO/BUSINESS tier cards (which would show no "Current" badge) - Track pendingTier (via variables) so only the clicked button shows "Updating..." - Show "Pricing available soon" for paid tiers with cost=0 (unconfigured LD flags) instead of misleading "Free" - Move tierError state into the hook, set via changeTier internally - Move TIER_ORDER constant to module scope (was magic array inside render body) - Add aria-current="true" to active tier card for screen reader accessibility - Add role="alert" to all error paragraph elements - Improve tier descriptions with concrete capacity values
This commit is contained in:
@@ -55,12 +55,12 @@ def test_get_subscription_status_pro(
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
mock_price = Mock()
|
||||
mock_price.unit_amount = 1999 # $19.99
|
||||
|
||||
async def mock_price_id(tier: SubscriptionTier) -> str | None:
|
||||
return "price_pro" if tier == SubscriptionTier.PRO else None
|
||||
|
||||
async def mock_stripe_price_amount(price_id: str) -> int:
|
||||
return 1999 if price_id == "price_pro" else 0
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
@@ -71,8 +71,8 @@ def test_get_subscription_status_pro(
|
||||
side_effect=mock_price_id,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.stripe.Price.retrieve",
|
||||
return_value=mock_price,
|
||||
"backend.api.features.v1._get_stripe_price_amount",
|
||||
side_effect=mock_stripe_price_amount,
|
||||
)
|
||||
|
||||
response = client.get("/credits/subscription")
|
||||
|
||||
@@ -729,6 +729,21 @@ def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=300, maxsize=32)
|
||||
async def _get_stripe_price_amount(price_id: str) -> int:
|
||||
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
|
||||
|
||||
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
|
||||
every GET /credits/subscription page load and reduces quota consumption.
|
||||
"""
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
return price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
logger.warning("Failed to retrieve Stripe price %s — returning 0", price_id)
|
||||
return 0
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits/subscription",
|
||||
summary="Get subscription tier, current cost, and all tier costs",
|
||||
@@ -747,15 +762,16 @@ async def get_subscription_status(
|
||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
||||
)
|
||||
|
||||
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
|
||||
for t, price_id in zip(paid_tiers, price_ids):
|
||||
cost = 0
|
||||
if price_id:
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
cost = price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
pass
|
||||
tier_costs: dict[str, int] = {
|
||||
SubscriptionTier.FREE.value: 0,
|
||||
SubscriptionTier.ENTERPRISE.value: 0,
|
||||
}
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return await _get_stripe_price_amount(pid) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
for t, cost in zip(paid_tiers, costs):
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
return SubscriptionStatusResponse(
|
||||
|
||||
@@ -480,8 +480,8 @@ async def test_create_subscription_checkout_no_price_raises():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_subscription_from_stripe_unknown_price_preserves_current_tier():
|
||||
"""Unknown price_id should preserve the current tier (no DB write)."""
|
||||
async def test_sync_subscription_from_stripe_unknown_price_id_preserves_current_tier():
|
||||
"""Unknown price_id should preserve the current tier, not default to FREE (no DB write)."""
|
||||
mock_user = _make_user(tier=SubscriptionTier.PRO)
|
||||
stripe_sub = {
|
||||
"customer": "cus_123",
|
||||
@@ -511,8 +511,8 @@ async def test_sync_subscription_from_stripe_unknown_price_preserves_current_tie
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_subscription_from_stripe_none_ld_price_preserves_current_tier():
|
||||
"""When LD returns None for price IDs, the current tier should be preserved."""
|
||||
async def test_sync_subscription_from_stripe_unconfigured_ld_price_preserves_current_tier():
|
||||
"""When LD flags are unconfigured (None price IDs), the current tier should be preserved, not defaulted to FREE."""
|
||||
mock_user = _make_user(tier=SubscriptionTier.PRO)
|
||||
stripe_sub = {
|
||||
"customer": "cus_123",
|
||||
|
||||
@@ -35,7 +35,6 @@ class TestUsdToMicrodollars:
|
||||
assert usd_to_microdollars(1.0) == 1_000_000
|
||||
|
||||
|
||||
|
||||
class TestMaskEmail:
|
||||
def test_typical_email(self):
|
||||
assert _mask_email("user@example.com") == "us***@example.com"
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
"use client";
|
||||
import { useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { useSubscriptionTierSection } from "./useSubscriptionTierSection";
|
||||
|
||||
type TierInfo = {
|
||||
@@ -15,31 +23,43 @@ const TIERS: TierInfo[] = [
|
||||
key: "FREE",
|
||||
label: "Free",
|
||||
multiplier: "1x",
|
||||
description: "Base rate limits",
|
||||
description: "Base AutoPilot capacity with standard rate limits",
|
||||
},
|
||||
{
|
||||
key: "PRO",
|
||||
label: "Pro",
|
||||
multiplier: "5x",
|
||||
description: "5x more AutoPilot capacity",
|
||||
description: "5x AutoPilot capacity — run 5× more tasks per day/week",
|
||||
},
|
||||
{
|
||||
key: "BUSINESS",
|
||||
label: "Business",
|
||||
multiplier: "20x",
|
||||
description: "20x more AutoPilot capacity",
|
||||
description: "20x AutoPilot capacity — ideal for teams and heavy workloads",
|
||||
},
|
||||
];
|
||||
|
||||
function formatCost(cents: number): string {
|
||||
if (cents === 0) return "Free";
|
||||
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
|
||||
|
||||
function formatCost(cents: number, tierKey: string): string {
|
||||
if (tierKey === "FREE") return "Free";
|
||||
if (cents === 0) return "Pricing available soon";
|
||||
return `$${(cents / 100).toFixed(2)}/mo`;
|
||||
}
|
||||
|
||||
export function SubscriptionTierSection() {
|
||||
const { subscription, isLoading, error, isPending, changeTier } =
|
||||
useSubscriptionTierSection();
|
||||
const [tierError, setTierError] = useState<string | null>(null);
|
||||
const {
|
||||
subscription,
|
||||
isLoading,
|
||||
error,
|
||||
tierError,
|
||||
isPending,
|
||||
pendingTier,
|
||||
changeTier,
|
||||
} = useSubscriptionTierSection();
|
||||
const [confirmDowngradeTo, setConfirmDowngradeTo] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
if (isLoading) return null;
|
||||
|
||||
@@ -47,7 +67,10 @@ export function SubscriptionTierSection() {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<h3 className="text-lg font-medium">Subscription Plan</h3>
|
||||
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
|
||||
<p
|
||||
role="alert"
|
||||
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
|
||||
>
|
||||
{error}
|
||||
</p>
|
||||
</div>
|
||||
@@ -56,10 +79,40 @@ export function SubscriptionTierSection() {
|
||||
|
||||
if (!subscription) return null;
|
||||
|
||||
async function handleTierChange(tierKey: string) {
|
||||
setTierError(null);
|
||||
const err = await changeTier(tierKey);
|
||||
if (err) setTierError(err);
|
||||
const currentTier = subscription.tier;
|
||||
|
||||
if (currentTier === "ENTERPRISE") {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<h3 className="text-lg font-medium">Subscription Plan</h3>
|
||||
<div className="rounded-lg border border-violet-500 bg-violet-50 p-4 dark:bg-violet-900/20">
|
||||
<p className="font-semibold text-violet-700 dark:text-violet-200">
|
||||
Enterprise Plan
|
||||
</p>
|
||||
<p className="mt-1 text-sm text-neutral-600 dark:text-neutral-400">
|
||||
Your Enterprise plan is managed by your administrator. Contact your
|
||||
account team for changes.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function handleTierChange(tierKey: string) {
|
||||
const currentIdx = TIER_ORDER.indexOf(currentTier);
|
||||
const targetIdx = TIER_ORDER.indexOf(tierKey);
|
||||
if (targetIdx < currentIdx) {
|
||||
setConfirmDowngradeTo(tierKey);
|
||||
return;
|
||||
}
|
||||
changeTier(tierKey);
|
||||
}
|
||||
|
||||
async function confirmDowngrade() {
|
||||
if (!confirmDowngradeTo) return;
|
||||
const tier = confirmDowngradeTo;
|
||||
setConfirmDowngradeTo(null);
|
||||
await changeTier(tier);
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -67,24 +120,28 @@ export function SubscriptionTierSection() {
|
||||
<h3 className="text-lg font-medium">Subscription Plan</h3>
|
||||
|
||||
{tierError && (
|
||||
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
|
||||
<p
|
||||
role="alert"
|
||||
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
|
||||
>
|
||||
{tierError}
|
||||
</p>
|
||||
)}
|
||||
|
||||
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
|
||||
{TIERS.map((tier) => {
|
||||
const isCurrent = subscription.tier === tier.key;
|
||||
const isCurrent = currentTier === tier.key;
|
||||
const cost = subscription.tier_costs[tier.key] ?? 0;
|
||||
const currentTierOrder = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
|
||||
const currentIdx = currentTierOrder.indexOf(subscription.tier);
|
||||
const targetIdx = currentTierOrder.indexOf(tier.key);
|
||||
const currentIdx = TIER_ORDER.indexOf(currentTier);
|
||||
const targetIdx = TIER_ORDER.indexOf(tier.key);
|
||||
const isUpgrade = targetIdx > currentIdx;
|
||||
const isDowngrade = targetIdx < currentIdx;
|
||||
const isThisPending = pendingTier === tier.key;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={tier.key}
|
||||
aria-current={isCurrent ? "true" : undefined}
|
||||
className={`rounded-lg border p-4 ${
|
||||
isCurrent
|
||||
? "border-violet-500 bg-violet-50 dark:bg-violet-900/20"
|
||||
@@ -100,7 +157,9 @@ export function SubscriptionTierSection() {
|
||||
)}
|
||||
</div>
|
||||
|
||||
<p className="mb-1 text-2xl font-bold">{formatCost(cost)}</p>
|
||||
<p className="mb-1 text-2xl font-bold">
|
||||
{formatCost(cost, tier.key)}
|
||||
</p>
|
||||
<p className="mb-1 text-sm font-medium text-neutral-600 dark:text-neutral-400">
|
||||
{tier.multiplier} rate limits
|
||||
</p>
|
||||
@@ -115,7 +174,7 @@ export function SubscriptionTierSection() {
|
||||
disabled={isPending}
|
||||
onClick={() => handleTierChange(tier.key)}
|
||||
>
|
||||
{isPending
|
||||
{isThisPending
|
||||
? "Updating..."
|
||||
: isUpgrade
|
||||
? `Upgrade to ${tier.label}`
|
||||
@@ -129,12 +188,40 @@ export function SubscriptionTierSection() {
|
||||
})}
|
||||
</div>
|
||||
|
||||
{subscription.tier !== "FREE" && (
|
||||
{currentTier !== "FREE" && (
|
||||
<p className="text-sm text-neutral-500">
|
||||
Your subscription is managed through Stripe. Changes take effect
|
||||
immediately.
|
||||
</p>
|
||||
)}
|
||||
|
||||
<Dialog
|
||||
open={!!confirmDowngradeTo}
|
||||
onOpenChange={(open) => !open && setConfirmDowngradeTo(null)}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Confirm Downgrade</DialogTitle>
|
||||
<DialogDescription>
|
||||
{confirmDowngradeTo === "FREE"
|
||||
? "Downgrading to Free will cancel your current Stripe subscription immediately and remove your paid-tier rate limit increases."
|
||||
: `Switching to ${confirmDowngradeTo} will take effect immediately.`}{" "}
|
||||
Are you sure?
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setConfirmDowngradeTo(null)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button variant="destructive" onClick={confirmDowngrade}>
|
||||
Confirm Downgrade
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import {
|
||||
useGetSubscriptionStatus,
|
||||
useUpdateSubscriptionTier,
|
||||
} from "@/app/api/__generated__/endpoints/credits/credits";
|
||||
import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse";
|
||||
import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
|
||||
export type SubscriptionStatus = SubscriptionStatusResponse;
|
||||
|
||||
export function useSubscriptionTierSection() {
|
||||
const searchParams = useSearchParams();
|
||||
const subscriptionStatus = searchParams.get("subscription");
|
||||
const { toast } = useToast();
|
||||
const toastShownRef = useRef(false);
|
||||
const [tierError, setTierError] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
data: subscription,
|
||||
isLoading,
|
||||
@@ -17,11 +26,28 @@ export function useSubscriptionTierSection() {
|
||||
query: { select: (data) => (data.status === 200 ? data.data : null) },
|
||||
});
|
||||
|
||||
const error = queryError ? "Failed to load subscription info" : null;
|
||||
const fetchError = queryError ? "Failed to load subscription info" : null;
|
||||
|
||||
const { mutateAsync: doUpdateTier, isPending } = useUpdateSubscriptionTier();
|
||||
const {
|
||||
mutateAsync: doUpdateTier,
|
||||
isPending,
|
||||
variables,
|
||||
} = useUpdateSubscriptionTier();
|
||||
|
||||
async function changeTier(tier: string): Promise<string | null> {
|
||||
useEffect(() => {
|
||||
if (subscriptionStatus === "success" && !toastShownRef.current) {
|
||||
toastShownRef.current = true;
|
||||
refetch();
|
||||
toast({
|
||||
title: "Subscription upgraded",
|
||||
description:
|
||||
"Your plan has been updated. It may take a moment to reflect.",
|
||||
});
|
||||
}
|
||||
}, [subscriptionStatus, refetch, toast]);
|
||||
|
||||
async function changeTier(tier: string) {
|
||||
setTierError(null);
|
||||
try {
|
||||
const successUrl = `${window.location.origin}${window.location.pathname}?subscription=success`;
|
||||
const cancelUrl = `${window.location.origin}${window.location.pathname}?subscription=cancelled`;
|
||||
@@ -34,22 +60,26 @@ export function useSubscriptionTierSection() {
|
||||
});
|
||||
if (result.status === 200 && result.data.url) {
|
||||
window.location.href = result.data.url;
|
||||
return null;
|
||||
return;
|
||||
}
|
||||
await refetch();
|
||||
return null;
|
||||
} catch (e: unknown) {
|
||||
const msg =
|
||||
e instanceof Error ? e.message : "Failed to change subscription tier";
|
||||
return msg;
|
||||
setTierError(msg);
|
||||
}
|
||||
}
|
||||
|
||||
const pendingTier =
|
||||
isPending && variables?.data?.tier ? variables.data.tier : null;
|
||||
|
||||
return {
|
||||
subscription: subscription ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
error: fetchError,
|
||||
tierError,
|
||||
isPending,
|
||||
pendingTier,
|
||||
changeTier,
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user