fix(backend): query trialing subs in proration + guard empty invoice_id

- get_proration_credit_cents now queries both "active" and "trialing"
  subscriptions (same pattern as modify_stripe_subscription_for_tier)
  so trial users see accurate proration credit before upgrading.
- handle_subscription_payment_failure returns early when invoice.id is
  missing — without an idempotency key, webhook retries would double-
  charge the user's credit balance on every retry cycle.
- Add tests: test_get_proration_credit_cents_with_trialing_sub and
  test_handle_subscription_payment_failure_missing_invoice_id_skips.
This commit is contained in:
Zamil Majdy
2026-04-16 18:24:41 +07:00
parent fd75467ab0
commit 9b56f2f927
3 changed files with 111 additions and 21 deletions

View File

@@ -1378,10 +1378,13 @@ async def cancel_stripe_subscription(user_id: str) -> bool:
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
Fetches the user's active Stripe subscription to determine how many seconds
remain in the current billing period, then calculates the unused portion of
Fetches the user's active or trialing Stripe subscription to determine how many
seconds remain in the current billing period, then calculates the unused portion of
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
is found.
Both ``active`` and ``trialing`` subscriptions are checked: a trialing user still
accumulates a billing period and Stripe prorates the remaining trial value on upgrade.
"""
if monthly_cost_cents <= 0:
return 0
@@ -1393,20 +1396,25 @@ async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> i
return 0
try:
customer_id = user.stripe_customer_id
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status="active", limit=1
)
if not subscriptions.data:
return 0
sub = subscriptions.data[0]
period_start: int = sub["current_period_start"]
period_end: int = sub["current_period_end"]
now = int(time.time())
total_seconds = period_end - period_start
remaining_seconds = max(period_end - now, 0)
if total_seconds <= 0:
return 0
return int(monthly_cost_cents * remaining_seconds / total_seconds)
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status=status,
limit=1,
)
if not subscriptions.data:
continue
sub = subscriptions.data[0]
period_start: int = sub["current_period_start"]
period_end: int = sub["current_period_end"]
now = int(time.time())
total_seconds = period_end - period_start
remaining_seconds = max(period_end - now, 0)
if total_seconds <= 0:
return 0
return int(monthly_cost_cents * remaining_seconds / total_seconds)
return 0
except Exception:
logger.warning(
"get_proration_credit_cents: failed to compute proration for user %s",
@@ -1773,6 +1781,18 @@ async def handle_subscription_payment_failure(invoice: dict) -> None:
sub_id: str = invoice.get("subscription", "")
invoice_id: str = invoice.get("id", "")
if not invoice_id:
# Without an invoice ID we cannot set an idempotency key on the credit
# deduction. Stripe webhook retries would then double-charge the user's
# balance on every retry cycle. Bail out early — a real Stripe invoice
# always carries an ID, so a missing one indicates a malformed payload.
logger.warning(
"handle_subscription_payment_failure: invoice missing 'id' for"
" customer %s; skipping to avoid non-idempotent balance deduction",
customer_id,
)
return
if amount_due <= 0:
logger.info(
"handle_subscription_payment_failure: amount_due=%d for user %s;"

View File

@@ -536,6 +536,44 @@ async def test_get_proration_credit_cents_with_active_sub():
assert result < 2000
@pytest.mark.asyncio
async def test_get_proration_credit_cents_with_trialing_sub():
"""Trialing subscriptions also have a billing period — proration must be non-zero."""
import time
now = int(time.time())
period_start = now - 5 * 24 * 3600 # 5 days ago
period_end = now + 25 * 24 * 3600 # 25 days ahead
mock_sub = {
"id": "sub_trial_abc",
"current_period_start": period_start,
"current_period_end": period_end,
}
empty_subs = MagicMock()
empty_subs.data = []
trialing_subs = MagicMock()
trialing_subs.data = [mock_sub]
def list_side_effect(*args, **kwargs):
return trialing_subs if kwargs.get("status") == "trialing" else empty_subs
with (
patch(
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",
side_effect=list_side_effect,
),
):
result = await get_proration_credit_cents("user-1", monthly_cost_cents=2000)
# Trialing sub with ~25 days remaining should yield a significant proration credit
assert result > 0
assert result < 2000
@pytest.mark.asyncio
async def test_create_subscription_checkout_returns_url():
mock_session = MagicMock()
@@ -1096,6 +1134,37 @@ async def test_handle_subscription_payment_failure_passes_invoice_id_as_transact
assert kwargs.get("transaction_key") == "in_idempotency_test"
@pytest.mark.asyncio
async def test_handle_subscription_payment_failure_missing_invoice_id_skips():
"""An invoice payload without an 'id' field must be skipped.
Without an invoice ID we cannot set an idempotency key on the credit deduction,
so Stripe webhook retries would double-charge the user's balance. The function
must return early before calling _add_transaction.
"""
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
invoice = {
# No "id" field — malformed payload
"customer": "cus_123",
"subscription": "sub_abc123",
"amount_due": 2000,
}
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.UserCredit._add_transaction",
new_callable=AsyncMock,
) as mock_add_tx,
):
await handle_subscription_payment_failure(invoice)
# Must not deduct credits when there is no invoice ID to use as an idempotency key
mock_add_tx.assert_not_awaited()
@pytest.mark.asyncio
async def test_modify_stripe_subscription_for_tier_modifies_existing_sub():
"""modify_stripe_subscription_for_tier calls Subscription.modify and returns True."""

View File

@@ -249,10 +249,9 @@ export function SubscriptionTierSection() {
{subscription &&
subscription.proration_credit_cents > 0 &&
`Your unused ${currentTier.charAt(0) + currentTier.slice(1).toLowerCase()} subscription ($${(subscription.proration_credit_cents / 100).toFixed(2)}) will be applied as a credit to your next Stripe invoice. `}
You will be redirected to Stripe to complete your upgrade to{" "}
{TIERS.find((t) => t.key === pendingUpgradeTier)?.label ??
pendingUpgradeTier}
.
{currentTier === "FREE"
? `You will be redirected to Stripe to complete your upgrade to ${TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? pendingUpgradeTier}.`
: `Upgrading to ${TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? pendingUpgradeTier} will take effect immediately — Stripe will prorate your remaining balance.`}
</p>
<Dialog.Footer>
<Button
@@ -262,7 +261,9 @@ export function SubscriptionTierSection() {
Cancel
</Button>
<Button onClick={() => void confirmUpgrade()}>
Continue to Checkout
{currentTier === "FREE"
? "Continue to Checkout"
: "Confirm Upgrade"}
</Button>
</Dialog.Footer>
</Dialog.Content>