mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(backend): query trialing subs in proration + guard empty invoice_id
- get_proration_credit_cents now queries both "active" and "trialing" subscriptions (same pattern as modify_stripe_subscription_for_tier) so trial users see accurate proration credit before upgrading. - handle_subscription_payment_failure returns early when invoice.id is missing — without an idempotency key, webhook retries would double- charge the user's credit balance on every retry cycle. - Add tests: test_get_proration_credit_cents_with_trialing_sub and test_handle_subscription_payment_failure_missing_invoice_id_skips.
This commit is contained in:
@@ -1378,10 +1378,13 @@ async def cancel_stripe_subscription(user_id: str) -> bool:
|
||||
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
|
||||
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
|
||||
|
||||
Fetches the user's active Stripe subscription to determine how many seconds
|
||||
remain in the current billing period, then calculates the unused portion of
|
||||
Fetches the user's active or trialing Stripe subscription to determine how many
|
||||
seconds remain in the current billing period, then calculates the unused portion of
|
||||
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
|
||||
is found.
|
||||
|
||||
Both ``active`` and ``trialing`` subscriptions are checked: a trialing user still
|
||||
accumulates a billing period and Stripe prorates the remaining trial value on upgrade.
|
||||
"""
|
||||
if monthly_cost_cents <= 0:
|
||||
return 0
|
||||
@@ -1393,20 +1396,25 @@ async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> i
|
||||
return 0
|
||||
try:
|
||||
customer_id = user.stripe_customer_id
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status="active", limit=1
|
||||
)
|
||||
if not subscriptions.data:
|
||||
return 0
|
||||
sub = subscriptions.data[0]
|
||||
period_start: int = sub["current_period_start"]
|
||||
period_end: int = sub["current_period_end"]
|
||||
now = int(time.time())
|
||||
total_seconds = period_end - period_start
|
||||
remaining_seconds = max(period_end - now, 0)
|
||||
if total_seconds <= 0:
|
||||
return 0
|
||||
return int(monthly_cost_cents * remaining_seconds / total_seconds)
|
||||
for status in ("active", "trialing"):
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list,
|
||||
customer=customer_id,
|
||||
status=status,
|
||||
limit=1,
|
||||
)
|
||||
if not subscriptions.data:
|
||||
continue
|
||||
sub = subscriptions.data[0]
|
||||
period_start: int = sub["current_period_start"]
|
||||
period_end: int = sub["current_period_end"]
|
||||
now = int(time.time())
|
||||
total_seconds = period_end - period_start
|
||||
remaining_seconds = max(period_end - now, 0)
|
||||
if total_seconds <= 0:
|
||||
return 0
|
||||
return int(monthly_cost_cents * remaining_seconds / total_seconds)
|
||||
return 0
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"get_proration_credit_cents: failed to compute proration for user %s",
|
||||
@@ -1773,6 +1781,18 @@ async def handle_subscription_payment_failure(invoice: dict) -> None:
|
||||
sub_id: str = invoice.get("subscription", "")
|
||||
invoice_id: str = invoice.get("id", "")
|
||||
|
||||
if not invoice_id:
|
||||
# Without an invoice ID we cannot set an idempotency key on the credit
|
||||
# deduction. Stripe webhook retries would then double-charge the user's
|
||||
# balance on every retry cycle. Bail out early — a real Stripe invoice
|
||||
# always carries an ID, so a missing one indicates a malformed payload.
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: invoice missing 'id' for"
|
||||
" customer %s; skipping to avoid non-idempotent balance deduction",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
if amount_due <= 0:
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: amount_due=%d for user %s;"
|
||||
|
||||
@@ -536,6 +536,44 @@ async def test_get_proration_credit_cents_with_active_sub():
|
||||
assert result < 2000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_proration_credit_cents_with_trialing_sub():
|
||||
"""Trialing subscriptions also have a billing period — proration must be non-zero."""
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
period_start = now - 5 * 24 * 3600 # 5 days ago
|
||||
period_end = now + 25 * 24 * 3600 # 25 days ahead
|
||||
mock_sub = {
|
||||
"id": "sub_trial_abc",
|
||||
"current_period_start": period_start,
|
||||
"current_period_end": period_end,
|
||||
}
|
||||
empty_subs = MagicMock()
|
||||
empty_subs.data = []
|
||||
trialing_subs = MagicMock()
|
||||
trialing_subs.data = [mock_sub]
|
||||
|
||||
def list_side_effect(*args, **kwargs):
|
||||
return trialing_subs if kwargs.get("status") == "trialing" else empty_subs
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
side_effect=list_side_effect,
|
||||
),
|
||||
):
|
||||
result = await get_proration_credit_cents("user-1", monthly_cost_cents=2000)
|
||||
# Trialing sub with ~25 days remaining should yield a significant proration credit
|
||||
assert result > 0
|
||||
assert result < 2000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_returns_url():
|
||||
mock_session = MagicMock()
|
||||
@@ -1096,6 +1134,37 @@ async def test_handle_subscription_payment_failure_passes_invoice_id_as_transact
|
||||
assert kwargs.get("transaction_key") == "in_idempotency_test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_failure_missing_invoice_id_skips():
|
||||
"""An invoice payload without an 'id' field must be skipped.
|
||||
|
||||
Without an invoice ID we cannot set an idempotency key on the credit deduction,
|
||||
so Stripe webhook retries would double-charge the user's balance. The function
|
||||
must return early before calling _add_transaction.
|
||||
"""
|
||||
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
|
||||
invoice = {
|
||||
# No "id" field — malformed payload
|
||||
"customer": "cus_123",
|
||||
"subscription": "sub_abc123",
|
||||
"amount_due": 2000,
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.UserCredit._add_transaction",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_add_tx,
|
||||
):
|
||||
await handle_subscription_payment_failure(invoice)
|
||||
# Must not deduct credits when there is no invoice ID to use as an idempotency key
|
||||
mock_add_tx.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_stripe_subscription_for_tier_modifies_existing_sub():
|
||||
"""modify_stripe_subscription_for_tier calls Subscription.modify and returns True."""
|
||||
|
||||
@@ -249,10 +249,9 @@ export function SubscriptionTierSection() {
|
||||
{subscription &&
|
||||
subscription.proration_credit_cents > 0 &&
|
||||
`Your unused ${currentTier.charAt(0) + currentTier.slice(1).toLowerCase()} subscription ($${(subscription.proration_credit_cents / 100).toFixed(2)}) will be applied as a credit to your next Stripe invoice. `}
|
||||
You will be redirected to Stripe to complete your upgrade to{" "}
|
||||
{TIERS.find((t) => t.key === pendingUpgradeTier)?.label ??
|
||||
pendingUpgradeTier}
|
||||
.
|
||||
{currentTier === "FREE"
|
||||
? `You will be redirected to Stripe to complete your upgrade to ${TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? pendingUpgradeTier}.`
|
||||
: `Upgrading to ${TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? pendingUpgradeTier} will take effect immediately — Stripe will prorate your remaining balance.`}
|
||||
</p>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
@@ -262,7 +261,9 @@ export function SubscriptionTierSection() {
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={() => void confirmUpgrade()}>
|
||||
Continue to Checkout
|
||||
{currentTier === "FREE"
|
||||
? "Continue to Checkout"
|
||||
: "Confirm Upgrade"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
|
||||
Reference in New Issue
Block a user