fix(backend): guard get_proration_credit_cents against creating orphaned Stripe customers

get_proration_credit_cents now checks user.stripe_customer_id before calling into
Stripe, matching the same pattern applied to cancel_stripe_subscription. Admin-granted
paid-tier users without a Stripe record previously triggered customer creation on every
billing page load; they now get 0 immediately.

Also adds tests: no_customer_id fast-path for cancel, and three proration scenarios
(zero cost, no customer id, active subscription with proration calculation).
This commit is contained in:
majdyz
2026-04-15 13:42:40 +07:00
parent c421a66fa5
commit 2cdd164223
2 changed files with 98 additions and 15 deletions

View File

@@ -1385,8 +1385,14 @@ async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> i
"""
if monthly_cost_cents <= 0:
return 0
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
# orphaned customer on every billing-page load for those users.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return 0
try:
customer_id = await get_stripe_customer_id(user_id)
customer_id = user.stripe_customer_id
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status="active", limit=1
)

View File

@@ -12,6 +12,7 @@ from prisma.models import User
from backend.data.credit import (
cancel_stripe_subscription,
create_subscription_checkout,
get_proration_credit_cents,
handle_subscription_payment_failure,
modify_stripe_subscription_for_tier,
set_subscription_tier,
@@ -299,6 +300,13 @@ async def test_sync_subscription_from_stripe_unknown_customer():
await sync_subscription_from_stripe(stripe_sub)
def _make_user_with_stripe(stripe_customer_id: str | None = "cus_123") -> MagicMock:
"""Return a mock model.User with the given stripe_customer_id."""
mock_user = MagicMock()
mock_user.stripe_customer_id = stripe_customer_id
return mock_user
@pytest.mark.asyncio
async def test_cancel_stripe_subscription_cancels_active():
mock_subscriptions = MagicMock()
@@ -307,9 +315,9 @@ async def test_cancel_stripe_subscription_cancels_active():
with (
patch(
"backend.data.credit.get_stripe_customer_id",
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value="cus_123",
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",
@@ -321,6 +329,19 @@ async def test_cancel_stripe_subscription_cancels_active():
mock_modify.assert_called_once_with("sub_abc123", cancel_at_period_end=True)
@pytest.mark.asyncio
async def test_cancel_stripe_subscription_no_customer_id_returns_false():
"""Users with no stripe_customer_id return False without creating a Stripe customer."""
result = False
with patch(
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value=_make_user_with_stripe(stripe_customer_id=None),
):
result = await cancel_stripe_subscription("user-1")
assert result is False
@pytest.mark.asyncio
async def test_cancel_stripe_subscription_multi_partial_failure():
"""First modify raises → error propagates and subsequent subs are not scheduled."""
@@ -330,9 +351,9 @@ async def test_cancel_stripe_subscription_multi_partial_failure():
with (
patch(
"backend.data.credit.get_stripe_customer_id",
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value="cus_123",
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",
@@ -368,9 +389,9 @@ async def test_cancel_stripe_subscription_no_active():
with (
patch(
"backend.data.credit.get_stripe_customer_id",
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value="cus_123",
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",
@@ -387,9 +408,9 @@ async def test_cancel_stripe_subscription_raises_on_list_failure():
"""stripe.Subscription.list() failure propagates so DB tier is not updated."""
with (
patch(
"backend.data.credit.get_stripe_customer_id",
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value="cus_123",
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",
@@ -415,9 +436,9 @@ async def test_cancel_stripe_subscription_cancels_trialing():
with (
patch(
"backend.data.credit.get_stripe_customer_id",
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value="cus_123",
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",
@@ -444,9 +465,9 @@ async def test_cancel_stripe_subscription_cancels_active_and_trialing():
with (
patch(
"backend.data.credit.get_stripe_customer_id",
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value="cus_123",
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",
@@ -459,6 +480,62 @@ async def test_cancel_stripe_subscription_cancels_active_and_trialing():
assert modified_ids == {"sub_active_1", "sub_trial_2"}
@pytest.mark.asyncio
async def test_get_proration_credit_cents_no_stripe_customer_returns_zero():
"""Admin-granted tier users without stripe_customer_id get 0 without creating a customer."""
with patch(
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value=_make_user_with_stripe(stripe_customer_id=None),
) as mock_user:
result = await get_proration_credit_cents("user-1", monthly_cost_cents=2000)
assert result == 0
mock_user.assert_awaited_once_with("user-1")
@pytest.mark.asyncio
async def test_get_proration_credit_cents_zero_cost_returns_zero():
"""FREE tier users (cost=0) return 0 without calling get_user_by_id."""
with patch(
"backend.data.credit.get_user_by_id", new_callable=AsyncMock
) as mock_get_user:
result = await get_proration_credit_cents("user-1", monthly_cost_cents=0)
assert result == 0
mock_get_user.assert_not_awaited()
@pytest.mark.asyncio
async def test_get_proration_credit_cents_with_active_sub():
"""User with active sub returns prorated credit based on remaining billing period."""
import time
now = int(time.time())
period_start = now - 15 * 24 * 3600 # 15 days ago
period_end = now + 15 * 24 * 3600 # 15 days ahead
mock_sub = {
"id": "sub_abc",
"current_period_start": period_start,
"current_period_end": period_end,
}
mock_subs = MagicMock()
mock_subs.data = [mock_sub]
with (
patch(
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",
return_value=mock_subs,
),
):
result = await get_proration_credit_cents("user-1", monthly_cost_cents=2000)
assert result > 0
assert result < 2000
@pytest.mark.asyncio
async def test_create_subscription_checkout_returns_url():
mock_session = MagicMock()
@@ -806,9 +883,9 @@ async def test_cancel_stripe_subscription_raises_on_cancel_error():
with (
patch(
"backend.data.credit.get_stripe_customer_id",
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value="cus_123",
return_value=_make_user_with_stripe("cus_123"),
),
patch(
"backend.data.credit.stripe.Subscription.list",