fix(backend): fix Stripe price ID LD flag lookup and subscription payment handling

- Use user_id="system" for global LD flag lookups (price IDs don't need user context)
- Skip Supabase lookup silently for non-UUID keys in _fetch_user_context_data
- Block paid tier changes when ENABLE_PLATFORM_PAYMENT is disabled
- Add invoice.payment_failed handler: deduct from balance or downgrade to FREE
- Hide upgrade/downgrade buttons in UI when payment flag is disabled
This commit is contained in:
majdyz
2026-04-14 22:41:07 +07:00
parent c477e7b92e
commit bfd1e6e793
7 changed files with 168 additions and 4 deletions

View File

@@ -565,3 +565,41 @@ def test_stripe_webhook_dispatches_subscription_events(
assert response.status_code == 200
sync_mock.assert_awaited_once_with(stripe_sub_obj)
def test_stripe_webhook_dispatches_invoice_payment_failed(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/stripe_webhook routes invoice.payment_failed to the failure handler."""
invoice_obj = {
"customer": "cus_test",
"subscription": "sub_test",
"amount_due": 1999,
}
event = {
"type": "invoice.payment_failed",
"data": {"object": invoice_obj},
}
mocker.patch(
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
new="whsec_test",
)
mocker.patch(
"backend.api.features.v1.stripe.Webhook.construct_event",
return_value=event,
)
failure_mock = mocker.patch(
"backend.api.features.v1.handle_subscription_payment_failure",
new_callable=AsyncMock,
)
response = client.post(
"/credits/stripe_webhook",
content=b"{}",
headers={"stripe-signature": "t=1,v1=abc"},
)
assert response.status_code == 200
failure_mock.assert_awaited_once_with(invoice_obj)

View File

@@ -57,6 +57,7 @@ from backend.data.credit import (
get_auto_top_up,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
@@ -973,6 +974,9 @@ async def stripe_webhook(request: Request):
):
await sync_subscription_from_stripe(data_object)
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
# StripeObject (a dict subclass) carrying that runtime shape, so we cast

View File

@@ -1368,7 +1368,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
flag = flag_map.get(tier)
if flag is None:
return None
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
return price_id if isinstance(price_id, str) and price_id else None
@@ -1599,6 +1599,98 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
await set_subscription_tier(user.id, tier)
async def handle_subscription_payment_failure(invoice: dict) -> None:
"""Handle a failed Stripe subscription payment.
Tries to cover the invoice amount from the user's credit balance.
Either way the Stripe subscription is cancelled so Stripe stops retrying.
- Balance sufficient → deduct, cancel Stripe sub, keep tier.
- Balance insufficient → cancel Stripe sub, downgrade to FREE immediately.
"""
customer_id = invoice.get("customer")
if not customer_id:
logger.warning(
"handle_subscription_payment_failure: missing customer in invoice; skipping"
)
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"handle_subscription_payment_failure: no user found for customer %s",
customer_id,
)
return
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
" (customer %s) — tier is admin-managed",
user.id,
customer_id,
)
return
amount_due: int = invoice.get("amount_due", 0)
sub_id: str = invoice.get("subscription", "")
if amount_due <= 0:
logger.info(
"handle_subscription_payment_failure: amount_due=%d for user %s;"
" nothing to deduct",
amount_due,
user.id,
)
return
credit_model = UserCredit()
try:
await credit_model._add_transaction(
user_id=user.id,
amount=-amount_due,
transaction_type=CreditTransactionType.SUBSCRIPTION,
fail_insufficient_credits=True,
metadata=SafeJson(
{
"stripe_customer_id": customer_id,
"stripe_subscription_id": sub_id,
"reason": "subscription_payment_failure_covered_by_balance",
}
),
)
logger.info(
"handle_subscription_payment_failure: deducted %d cents from balance"
" for user %s; cancelling Stripe sub %s to prevent further retries",
amount_due,
user.id,
sub_id,
)
except InsufficientBalanceError:
logger.info(
"handle_subscription_payment_failure: insufficient balance for user %s;"
" downgrading to FREE and cancelling Stripe sub %s",
user.id,
sub_id,
)
await set_subscription_tier(user.id, SubscriptionTier.FREE)
# Cancel the Stripe subscription regardless — if balance covered it we don't
# want Stripe to retry next month; if balance was insufficient the user is
# already downgraded and the sub must go.
try:
await _cancel_customer_subscriptions(customer_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
" for user %s (customer %s); Stripe may continue retrying",
sub_id,
user.id,
customer_id,
)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,

View File

@@ -1,6 +1,7 @@
import contextlib
import logging
import os
import uuid
from enum import Enum
from functools import wraps
from typing import Any, Awaitable, Callable, TypeVar
@@ -101,6 +102,12 @@ async def _fetch_user_context_data(user_id: str) -> Context:
"""
builder = Context.builder(user_id).kind("user").anonymous(True)
try:
uuid.UUID(user_id)
except ValueError:
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
return builder.build()
try:
from backend.util.clients import get_supabase

View File

@@ -49,6 +49,7 @@ export function SubscriptionTierSection() {
tierError,
isPending,
pendingTier,
isPaymentEnabled,
changeTier,
handleTierChange,
} = useSubscriptionTierSection();
@@ -163,7 +164,7 @@ export function SubscriptionTierSection() {
{tier.description}
</p>
{!isCurrent && (
{!isCurrent && isPaymentEnabled && (
<Button
className="w-full"
variant={isUpgrade ? "default" : "outline"}
@@ -190,7 +191,7 @@ export function SubscriptionTierSection() {
})}
</div>
{currentTier !== "FREE" && (
{currentTier !== "FREE" && isPaymentEnabled && (
<p className="text-sm text-neutral-500">
Your subscription is managed through Stripe. Changes take effect
immediately.

View File

@@ -27,6 +27,13 @@ vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: mockToast }),
}));
// Mock feature flags — default to payment enabled so button tests work
let mockPaymentEnabled = true;
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: { ENABLE_PLATFORM_PAYMENT: "enable-platform-payment" },
useGetFlag: () => mockPaymentEnabled,
}));
// Mock generated API hooks
const mockUseGetSubscriptionStatus = vi.fn();
const mockUseUpdateSubscriptionTier = vi.fn();
@@ -105,8 +112,8 @@ afterEach(() => {
mockUseUpdateSubscriptionTier.mockReset();
mockToast.mockReset();
mockRouterReplace.mockReset();
// Reset search params
mockSearchParams.delete("subscription");
mockPaymentEnabled = true;
});
describe("SubscriptionTierSection", () => {
@@ -283,6 +290,18 @@ describe("SubscriptionTierSection", () => {
});
});
it("hides action buttons when payment flag is disabled", () => {
mockPaymentEnabled = false;
setupMocks({ subscription: makeSubscription({ tier: "FREE" }) });
render(<SubscriptionTierSection />);
// Tier cards still visible
expect(screen.getByText("Pro")).toBeDefined();
expect(screen.getByText("Business")).toBeDefined();
// No upgrade/downgrade buttons
expect(screen.queryByRole("button", { name: /upgrade/i })).toBeNull();
expect(screen.queryByRole("button", { name: /downgrade/i })).toBeNull();
});
it("shows ENTERPRISE message for ENTERPRISE tier users", () => {
setupMocks({ subscription: makeSubscription({ tier: "ENTERPRISE" }) });
render(<SubscriptionTierSection />);

View File

@@ -7,12 +7,14 @@ import {
import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse";
import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
export type SubscriptionStatus = SubscriptionStatusResponse;
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
export function useSubscriptionTierSection() {
const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
const searchParams = useSearchParams();
const subscriptionStatus = searchParams.get("subscription");
const router = useRouter();
@@ -108,6 +110,7 @@ export function useSubscriptionTierSection() {
tierError,
isPending,
pendingTier,
isPaymentEnabled,
changeTier,
handleTierChange,
};