fix(platform): address remaining PR review comments on subscription billing

Backend:
- Cache stripe.Price.retrieve with 5-min TTL via _get_stripe_price_amount
  to avoid 200-600ms Stripe round-trip on every GET /credits/subscription
- Use SubscriptionTier enum .value for FREE/ENTERPRISE in tier_costs dict
  for consistency (instead of hardcoded strings)
- Rename misleading test names: "defaults_to_FREE" → "preserves_current_tier"
  to reflect actual behaviour (unknown price IDs preserve tier, not reset)
- Update subscription_routes_test to mock _get_stripe_price_amount instead
  of stripe.Price.retrieve directly, avoiding cached-result interference

Frontend:
- Handle ?subscription=success return from Stripe Checkout: refetch + toast
- Add downgrade confirmation Dialog before cancelling paid subscription
- Handle ENTERPRISE tier: render dedicated admin-managed plan card, not the
  FREE/PRO/BUSINESS tier cards (which would show no "Current" badge)
- Track pendingTier (via variables) so only the clicked button shows "Updating..."
- Show "Pricing available soon" for paid tiers with cost=0 (unconfigured LD flags)
  instead of misleading "Free"
- Move tierError state into the hook, set via changeTier internally
- Move TIER_ORDER constant to module scope (was magic array inside render body)
- Add aria-current="true" to active tier card for screen reader accessibility
- Add role="alert" to all error paragraph elements
- Improve tier descriptions with concrete capacity values
This commit is contained in:
majdyz
2026-04-11 08:57:34 +07:00
parent 329a034ebe
commit 5bb7027f89
6 changed files with 179 additions and 47 deletions

View File

@@ -55,12 +55,12 @@ def test_get_subscription_status_pro(
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
mock_price = Mock()
mock_price.unit_amount = 1999 # $19.99
async def mock_price_id(tier: SubscriptionTier) -> str | None:
return "price_pro" if tier == SubscriptionTier.PRO else None
async def mock_stripe_price_amount(price_id: str) -> int:
return 1999 if price_id == "price_pro" else 0
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
@@ -71,8 +71,8 @@ def test_get_subscription_status_pro(
side_effect=mock_price_id,
)
mocker.patch(
"backend.api.features.v1.stripe.Price.retrieve",
return_value=mock_price,
"backend.api.features.v1._get_stripe_price_amount",
side_effect=mock_stripe_price_amount,
)
response = client.get("/credits/subscription")

View File

@@ -729,6 +729,21 @@ def _validate_checkout_redirect_url(url: str) -> bool:
)
@cached(ttl_seconds=300, maxsize=32)
async def _get_stripe_price_amount(price_id: str) -> int:
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
every GET /credits/subscription page load and reduces quota consumption.
"""
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
return price.unit_amount or 0
except stripe.StripeError:
logger.warning("Failed to retrieve Stripe price %s — returning 0", price_id)
return 0
@v1_router.get(
path="/credits/subscription",
summary="Get subscription tier, current cost, and all tier costs",
@@ -747,15 +762,16 @@ async def get_subscription_status(
*[get_subscription_price_id(t) for t in paid_tiers]
)
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
for t, price_id in zip(paid_tiers, price_ids):
cost = 0
if price_id:
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
cost = price.unit_amount or 0
except stripe.StripeError:
pass
tier_costs: dict[str, int] = {
SubscriptionTier.FREE.value: 0,
SubscriptionTier.ENTERPRISE.value: 0,
}
async def _cost(pid: str | None) -> int:
return await _get_stripe_price_amount(pid) if pid else 0
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
for t, cost in zip(paid_tiers, costs):
tier_costs[t.value] = cost
return SubscriptionStatusResponse(

View File

@@ -480,8 +480,8 @@ async def test_create_subscription_checkout_no_price_raises():
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_unknown_price_preserves_current_tier():
"""Unknown price_id should preserve the current tier (no DB write)."""
async def test_sync_subscription_from_stripe_unknown_price_id_preserves_current_tier():
"""Unknown price_id should preserve the current tier, not default to FREE (no DB write)."""
mock_user = _make_user(tier=SubscriptionTier.PRO)
stripe_sub = {
"customer": "cus_123",
@@ -511,8 +511,8 @@ async def test_sync_subscription_from_stripe_unknown_price_preserves_current_tie
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_none_ld_price_preserves_current_tier():
"""When LD returns None for price IDs, the current tier should be preserved."""
async def test_sync_subscription_from_stripe_unconfigured_ld_price_preserves_current_tier():
"""When LD flags are unconfigured (None price IDs), the current tier should be preserved, not defaulted to FREE."""
mock_user = _make_user(tier=SubscriptionTier.PRO)
stripe_sub = {
"customer": "cus_123",

View File

@@ -35,7 +35,6 @@ class TestUsdToMicrodollars:
assert usd_to_microdollars(1.0) == 1_000_000
class TestMaskEmail:
def test_typical_email(self):
assert _mask_email("user@example.com") == "us***@example.com"

View File

@@ -1,6 +1,14 @@
"use client";
import { useState } from "react";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/__legacy__/ui/dialog";
import { useSubscriptionTierSection } from "./useSubscriptionTierSection";
type TierInfo = {
@@ -15,31 +23,43 @@ const TIERS: TierInfo[] = [
key: "FREE",
label: "Free",
multiplier: "1x",
description: "Base rate limits",
description: "Base AutoPilot capacity with standard rate limits",
},
{
key: "PRO",
label: "Pro",
multiplier: "5x",
description: "5x more AutoPilot capacity",
description: "5x AutoPilot capacity — run 5× more tasks per day/week",
},
{
key: "BUSINESS",
label: "Business",
multiplier: "20x",
description: "20x more AutoPilot capacity",
description: "20x AutoPilot capacity — ideal for teams and heavy workloads",
},
];
function formatCost(cents: number): string {
if (cents === 0) return "Free";
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
function formatCost(cents: number, tierKey: string): string {
if (tierKey === "FREE") return "Free";
if (cents === 0) return "Pricing available soon";
return `$${(cents / 100).toFixed(2)}/mo`;
}
export function SubscriptionTierSection() {
const { subscription, isLoading, error, isPending, changeTier } =
useSubscriptionTierSection();
const [tierError, setTierError] = useState<string | null>(null);
const {
subscription,
isLoading,
error,
tierError,
isPending,
pendingTier,
changeTier,
} = useSubscriptionTierSection();
const [confirmDowngradeTo, setConfirmDowngradeTo] = useState<string | null>(
null,
);
if (isLoading) return null;
@@ -47,7 +67,10 @@ export function SubscriptionTierSection() {
return (
<div className="space-y-4">
<h3 className="text-lg font-medium">Subscription Plan</h3>
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
<p
role="alert"
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
>
{error}
</p>
</div>
@@ -56,10 +79,40 @@ export function SubscriptionTierSection() {
if (!subscription) return null;
async function handleTierChange(tierKey: string) {
setTierError(null);
const err = await changeTier(tierKey);
if (err) setTierError(err);
const currentTier = subscription.tier;
if (currentTier === "ENTERPRISE") {
return (
<div className="space-y-4">
<h3 className="text-lg font-medium">Subscription Plan</h3>
<div className="rounded-lg border border-violet-500 bg-violet-50 p-4 dark:bg-violet-900/20">
<p className="font-semibold text-violet-700 dark:text-violet-200">
Enterprise Plan
</p>
<p className="mt-1 text-sm text-neutral-600 dark:text-neutral-400">
Your Enterprise plan is managed by your administrator. Contact your
account team for changes.
</p>
</div>
</div>
);
}
function handleTierChange(tierKey: string) {
const currentIdx = TIER_ORDER.indexOf(currentTier);
const targetIdx = TIER_ORDER.indexOf(tierKey);
if (targetIdx < currentIdx) {
setConfirmDowngradeTo(tierKey);
return;
}
changeTier(tierKey);
}
async function confirmDowngrade() {
if (!confirmDowngradeTo) return;
const tier = confirmDowngradeTo;
setConfirmDowngradeTo(null);
await changeTier(tier);
}
return (
@@ -67,24 +120,28 @@ export function SubscriptionTierSection() {
<h3 className="text-lg font-medium">Subscription Plan</h3>
{tierError && (
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
<p
role="alert"
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
>
{tierError}
</p>
)}
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
{TIERS.map((tier) => {
const isCurrent = subscription.tier === tier.key;
const isCurrent = currentTier === tier.key;
const cost = subscription.tier_costs[tier.key] ?? 0;
const currentTierOrder = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
const currentIdx = currentTierOrder.indexOf(subscription.tier);
const targetIdx = currentTierOrder.indexOf(tier.key);
const currentIdx = TIER_ORDER.indexOf(currentTier);
const targetIdx = TIER_ORDER.indexOf(tier.key);
const isUpgrade = targetIdx > currentIdx;
const isDowngrade = targetIdx < currentIdx;
const isThisPending = pendingTier === tier.key;
return (
<div
key={tier.key}
aria-current={isCurrent ? "true" : undefined}
className={`rounded-lg border p-4 ${
isCurrent
? "border-violet-500 bg-violet-50 dark:bg-violet-900/20"
@@ -100,7 +157,9 @@ export function SubscriptionTierSection() {
)}
</div>
<p className="mb-1 text-2xl font-bold">{formatCost(cost)}</p>
<p className="mb-1 text-2xl font-bold">
{formatCost(cost, tier.key)}
</p>
<p className="mb-1 text-sm font-medium text-neutral-600 dark:text-neutral-400">
{tier.multiplier} rate limits
</p>
@@ -115,7 +174,7 @@ export function SubscriptionTierSection() {
disabled={isPending}
onClick={() => handleTierChange(tier.key)}
>
{isPending
{isThisPending
? "Updating..."
: isUpgrade
? `Upgrade to ${tier.label}`
@@ -129,12 +188,40 @@ export function SubscriptionTierSection() {
})}
</div>
{subscription.tier !== "FREE" && (
{currentTier !== "FREE" && (
<p className="text-sm text-neutral-500">
Your subscription is managed through Stripe. Changes take effect
immediately.
</p>
)}
<Dialog
open={!!confirmDowngradeTo}
onOpenChange={(open) => !open && setConfirmDowngradeTo(null)}
>
<DialogContent>
<DialogHeader>
<DialogTitle>Confirm Downgrade</DialogTitle>
<DialogDescription>
{confirmDowngradeTo === "FREE"
? "Downgrading to Free will cancel your current Stripe subscription immediately and remove your paid-tier rate limit increases."
: `Switching to ${confirmDowngradeTo} will take effect immediately.`}{" "}
Are you sure?
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button
variant="outline"
onClick={() => setConfirmDowngradeTo(null)}
>
Cancel
</Button>
<Button variant="destructive" onClick={confirmDowngrade}>
Confirm Downgrade
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
);
}

View File

@@ -1,13 +1,22 @@
import { useEffect, useRef, useState } from "react";
import { useSearchParams } from "next/navigation";
import {
useGetSubscriptionStatus,
useUpdateSubscriptionTier,
} from "@/app/api/__generated__/endpoints/credits/credits";
import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse";
import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier";
import { useToast } from "@/components/molecules/Toast/use-toast";
export type SubscriptionStatus = SubscriptionStatusResponse;
export function useSubscriptionTierSection() {
const searchParams = useSearchParams();
const subscriptionStatus = searchParams.get("subscription");
const { toast } = useToast();
const toastShownRef = useRef(false);
const [tierError, setTierError] = useState<string | null>(null);
const {
data: subscription,
isLoading,
@@ -17,11 +26,28 @@ export function useSubscriptionTierSection() {
query: { select: (data) => (data.status === 200 ? data.data : null) },
});
const error = queryError ? "Failed to load subscription info" : null;
const fetchError = queryError ? "Failed to load subscription info" : null;
const { mutateAsync: doUpdateTier, isPending } = useUpdateSubscriptionTier();
const {
mutateAsync: doUpdateTier,
isPending,
variables,
} = useUpdateSubscriptionTier();
async function changeTier(tier: string): Promise<string | null> {
useEffect(() => {
if (subscriptionStatus === "success" && !toastShownRef.current) {
toastShownRef.current = true;
refetch();
toast({
title: "Subscription upgraded",
description:
"Your plan has been updated. It may take a moment to reflect.",
});
}
}, [subscriptionStatus, refetch, toast]);
async function changeTier(tier: string) {
setTierError(null);
try {
const successUrl = `${window.location.origin}${window.location.pathname}?subscription=success`;
const cancelUrl = `${window.location.origin}${window.location.pathname}?subscription=cancelled`;
@@ -34,22 +60,26 @@ export function useSubscriptionTierSection() {
});
if (result.status === 200 && result.data.url) {
window.location.href = result.data.url;
return null;
return;
}
await refetch();
return null;
} catch (e: unknown) {
const msg =
e instanceof Error ? e.message : "Failed to change subscription tier";
return msg;
setTierError(msg);
}
}
const pendingTier =
isPending && variables?.data?.tier ? variables.data.tier : null;
return {
subscription: subscription ?? null,
isLoading,
error,
error: fetchError,
tierError,
isPending,
pendingTier,
changeTier,
};
}