mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(backend): fix Stripe price ID LD flag lookup and subscription payment handling
- Use user_id="system" for global LD flag lookups (price IDs don't need user context) - Skip Supabase lookup silently for non-UUID keys in _fetch_user_context_data - Block paid tier changes when ENABLE_PLATFORM_PAYMENT is disabled - Add invoice.payment_failed handler: deduct from balance or downgrade to FREE - Hide upgrade/downgrade buttons in UI when payment flag is disabled
This commit is contained in:
@@ -565,3 +565,41 @@ def test_stripe_webhook_dispatches_subscription_events(
|
||||
|
||||
assert response.status_code == 200
|
||||
sync_mock.assert_awaited_once_with(stripe_sub_obj)
|
||||
|
||||
|
||||
def test_stripe_webhook_dispatches_invoice_payment_failed(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/stripe_webhook routes invoice.payment_failed to the failure handler."""
|
||||
invoice_obj = {
|
||||
"customer": "cus_test",
|
||||
"subscription": "sub_test",
|
||||
"amount_due": 1999,
|
||||
}
|
||||
event = {
|
||||
"type": "invoice.payment_failed",
|
||||
"data": {"object": invoice_obj},
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
|
||||
new="whsec_test",
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.stripe.Webhook.construct_event",
|
||||
return_value=event,
|
||||
)
|
||||
failure_mock = mocker.patch(
|
||||
"backend.api.features.v1.handle_subscription_payment_failure",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/stripe_webhook",
|
||||
content=b"{}",
|
||||
headers={"stripe-signature": "t=1,v1=abc"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
failure_mock.assert_awaited_once_with(invoice_obj)
|
||||
|
||||
@@ -57,6 +57,7 @@ from backend.data.credit import (
|
||||
get_auto_top_up,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
@@ -973,6 +974,9 @@ async def stripe_webhook(request: Request):
|
||||
):
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
|
||||
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
|
||||
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
|
||||
|
||||
@@ -1368,7 +1368,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
|
||||
flag = flag_map.get(tier)
|
||||
if flag is None:
|
||||
return None
|
||||
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
|
||||
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
|
||||
return price_id if isinstance(price_id, str) and price_id else None
|
||||
|
||||
|
||||
@@ -1599,6 +1599,98 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
await set_subscription_tier(user.id, tier)
|
||||
|
||||
|
||||
async def handle_subscription_payment_failure(invoice: dict) -> None:
|
||||
"""Handle a failed Stripe subscription payment.
|
||||
|
||||
Tries to cover the invoice amount from the user's credit balance.
|
||||
Either way the Stripe subscription is cancelled so Stripe stops retrying.
|
||||
|
||||
- Balance sufficient → deduct, cancel Stripe sub, keep tier.
|
||||
- Balance insufficient → cancel Stripe sub, downgrade to FREE immediately.
|
||||
"""
|
||||
customer_id = invoice.get("customer")
|
||||
if not customer_id:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: missing customer in invoice; skipping"
|
||||
)
|
||||
return
|
||||
|
||||
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
|
||||
if not user:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: no user found for customer %s",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.FREE
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
|
||||
" (customer %s) — tier is admin-managed",
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
amount_due: int = invoice.get("amount_due", 0)
|
||||
sub_id: str = invoice.get("subscription", "")
|
||||
|
||||
if amount_due <= 0:
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: amount_due=%d for user %s;"
|
||||
" nothing to deduct",
|
||||
amount_due,
|
||||
user.id,
|
||||
)
|
||||
return
|
||||
|
||||
credit_model = UserCredit()
|
||||
try:
|
||||
await credit_model._add_transaction(
|
||||
user_id=user.id,
|
||||
amount=-amount_due,
|
||||
transaction_type=CreditTransactionType.SUBSCRIPTION,
|
||||
fail_insufficient_credits=True,
|
||||
metadata=SafeJson(
|
||||
{
|
||||
"stripe_customer_id": customer_id,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"reason": "subscription_payment_failure_covered_by_balance",
|
||||
}
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: deducted %d cents from balance"
|
||||
" for user %s; cancelling Stripe sub %s to prevent further retries",
|
||||
amount_due,
|
||||
user.id,
|
||||
sub_id,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: insufficient balance for user %s;"
|
||||
" downgrading to FREE and cancelling Stripe sub %s",
|
||||
user.id,
|
||||
sub_id,
|
||||
)
|
||||
await set_subscription_tier(user.id, SubscriptionTier.FREE)
|
||||
|
||||
# Cancel the Stripe subscription regardless — if balance covered it we don't
|
||||
# want Stripe to retry next month; if balance was insufficient the user is
|
||||
# already downgraded and the sub must go.
|
||||
try:
|
||||
await _cancel_customer_subscriptions(customer_id)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
|
||||
" for user %s (customer %s); Stripe may continue retrying",
|
||||
sub_id,
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
@@ -101,6 +102,12 @@ async def _fetch_user_context_data(user_id: str) -> Context:
|
||||
"""
|
||||
builder = Context.builder(user_id).kind("user").anonymous(True)
|
||||
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
|
||||
return builder.build()
|
||||
|
||||
try:
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ export function SubscriptionTierSection() {
|
||||
tierError,
|
||||
isPending,
|
||||
pendingTier,
|
||||
isPaymentEnabled,
|
||||
changeTier,
|
||||
handleTierChange,
|
||||
} = useSubscriptionTierSection();
|
||||
@@ -163,7 +164,7 @@ export function SubscriptionTierSection() {
|
||||
{tier.description}
|
||||
</p>
|
||||
|
||||
{!isCurrent && (
|
||||
{!isCurrent && isPaymentEnabled && (
|
||||
<Button
|
||||
className="w-full"
|
||||
variant={isUpgrade ? "default" : "outline"}
|
||||
@@ -190,7 +191,7 @@ export function SubscriptionTierSection() {
|
||||
})}
|
||||
</div>
|
||||
|
||||
{currentTier !== "FREE" && (
|
||||
{currentTier !== "FREE" && isPaymentEnabled && (
|
||||
<p className="text-sm text-neutral-500">
|
||||
Your subscription is managed through Stripe. Changes take effect
|
||||
immediately.
|
||||
|
||||
@@ -27,6 +27,13 @@ vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
useToast: () => ({ toast: mockToast }),
|
||||
}));
|
||||
|
||||
// Mock feature flags — default to payment enabled so button tests work
|
||||
let mockPaymentEnabled = true;
|
||||
vi.mock("@/services/feature-flags/use-get-flag", () => ({
|
||||
Flag: { ENABLE_PLATFORM_PAYMENT: "enable-platform-payment" },
|
||||
useGetFlag: () => mockPaymentEnabled,
|
||||
}));
|
||||
|
||||
// Mock generated API hooks
|
||||
const mockUseGetSubscriptionStatus = vi.fn();
|
||||
const mockUseUpdateSubscriptionTier = vi.fn();
|
||||
@@ -105,8 +112,8 @@ afterEach(() => {
|
||||
mockUseUpdateSubscriptionTier.mockReset();
|
||||
mockToast.mockReset();
|
||||
mockRouterReplace.mockReset();
|
||||
// Reset search params
|
||||
mockSearchParams.delete("subscription");
|
||||
mockPaymentEnabled = true;
|
||||
});
|
||||
|
||||
describe("SubscriptionTierSection", () => {
|
||||
@@ -283,6 +290,18 @@ describe("SubscriptionTierSection", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("hides action buttons when payment flag is disabled", () => {
|
||||
mockPaymentEnabled = false;
|
||||
setupMocks({ subscription: makeSubscription({ tier: "FREE" }) });
|
||||
render(<SubscriptionTierSection />);
|
||||
// Tier cards still visible
|
||||
expect(screen.getByText("Pro")).toBeDefined();
|
||||
expect(screen.getByText("Business")).toBeDefined();
|
||||
// No upgrade/downgrade buttons
|
||||
expect(screen.queryByRole("button", { name: /upgrade/i })).toBeNull();
|
||||
expect(screen.queryByRole("button", { name: /downgrade/i })).toBeNull();
|
||||
});
|
||||
|
||||
it("shows ENTERPRISE message for ENTERPRISE tier users", () => {
|
||||
setupMocks({ subscription: makeSubscription({ tier: "ENTERPRISE" }) });
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
@@ -7,12 +7,14 @@ import {
|
||||
import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse";
|
||||
import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
export type SubscriptionStatus = SubscriptionStatusResponse;
|
||||
|
||||
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
|
||||
|
||||
export function useSubscriptionTierSection() {
|
||||
const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
|
||||
const searchParams = useSearchParams();
|
||||
const subscriptionStatus = searchParams.get("subscription");
|
||||
const router = useRouter();
|
||||
@@ -108,6 +110,7 @@ export function useSubscriptionTierSection() {
|
||||
tierError,
|
||||
isPending,
|
||||
pendingTier,
|
||||
isPaymentEnabled,
|
||||
changeTier,
|
||||
handleTierChange,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user