mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(backend): guard get_proration_credit_cents against creating orphaned Stripe customers
get_proration_credit_cents now checks user.stripe_customer_id before calling into Stripe, matching the same pattern applied to cancel_stripe_subscription. Admin-granted paid-tier users without a Stripe record previously triggered customer creation on every billing page load; they now get 0 immediately. Also adds tests: no_customer_id fast-path for cancel, and three proration scenarios (zero cost, no customer id, active subscription with proration calculation).
This commit is contained in:
@@ -1385,8 +1385,14 @@ async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> i
|
||||
"""
|
||||
if monthly_cost_cents <= 0:
|
||||
return 0
|
||||
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
|
||||
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
|
||||
# orphaned customer on every billing-page load for those users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return 0
|
||||
try:
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
customer_id = user.stripe_customer_id
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status="active", limit=1
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from prisma.models import User
|
||||
from backend.data.credit import (
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_proration_credit_cents,
|
||||
handle_subscription_payment_failure,
|
||||
modify_stripe_subscription_for_tier,
|
||||
set_subscription_tier,
|
||||
@@ -299,6 +300,13 @@ async def test_sync_subscription_from_stripe_unknown_customer():
|
||||
await sync_subscription_from_stripe(stripe_sub)
|
||||
|
||||
|
||||
def _make_user_with_stripe(stripe_customer_id: str | None = "cus_123") -> MagicMock:
|
||||
"""Return a mock model.User with the given stripe_customer_id."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.stripe_customer_id = stripe_customer_id
|
||||
return mock_user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_stripe_subscription_cancels_active():
|
||||
mock_subscriptions = MagicMock()
|
||||
@@ -307,9 +315,9 @@ async def test_cancel_stripe_subscription_cancels_active():
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
@@ -321,6 +329,19 @@ async def test_cancel_stripe_subscription_cancels_active():
|
||||
mock_modify.assert_called_once_with("sub_abc123", cancel_at_period_end=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_stripe_subscription_no_customer_id_returns_false():
|
||||
"""Users with no stripe_customer_id return False without creating a Stripe customer."""
|
||||
result = False
|
||||
with patch(
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_make_user_with_stripe(stripe_customer_id=None),
|
||||
):
|
||||
result = await cancel_stripe_subscription("user-1")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_stripe_subscription_multi_partial_failure():
|
||||
"""First modify raises → error propagates and subsequent subs are not scheduled."""
|
||||
@@ -330,9 +351,9 @@ async def test_cancel_stripe_subscription_multi_partial_failure():
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
@@ -368,9 +389,9 @@ async def test_cancel_stripe_subscription_no_active():
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
@@ -387,9 +408,9 @@ async def test_cancel_stripe_subscription_raises_on_list_failure():
|
||||
"""stripe.Subscription.list() failure propagates so DB tier is not updated."""
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
@@ -415,9 +436,9 @@ async def test_cancel_stripe_subscription_cancels_trialing():
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
@@ -444,9 +465,9 @@ async def test_cancel_stripe_subscription_cancels_active_and_trialing():
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
@@ -459,6 +480,62 @@ async def test_cancel_stripe_subscription_cancels_active_and_trialing():
|
||||
assert modified_ids == {"sub_active_1", "sub_trial_2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_proration_credit_cents_no_stripe_customer_returns_zero():
|
||||
"""Admin-granted tier users without stripe_customer_id get 0 without creating a customer."""
|
||||
with patch(
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_make_user_with_stripe(stripe_customer_id=None),
|
||||
) as mock_user:
|
||||
result = await get_proration_credit_cents("user-1", monthly_cost_cents=2000)
|
||||
assert result == 0
|
||||
mock_user.assert_awaited_once_with("user-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_proration_credit_cents_zero_cost_returns_zero():
|
||||
"""FREE tier users (cost=0) return 0 without calling get_user_by_id."""
|
||||
with patch(
|
||||
"backend.data.credit.get_user_by_id", new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
result = await get_proration_credit_cents("user-1", monthly_cost_cents=0)
|
||||
assert result == 0
|
||||
mock_get_user.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_proration_credit_cents_with_active_sub():
|
||||
"""User with active sub returns prorated credit based on remaining billing period."""
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
period_start = now - 15 * 24 * 3600 # 15 days ago
|
||||
period_end = now + 15 * 24 * 3600 # 15 days ahead
|
||||
mock_sub = {
|
||||
"id": "sub_abc",
|
||||
"current_period_start": period_start,
|
||||
"current_period_end": period_end,
|
||||
}
|
||||
mock_subs = MagicMock()
|
||||
mock_subs.data = [mock_sub]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
return_value=mock_subs,
|
||||
),
|
||||
):
|
||||
result = await get_proration_credit_cents("user-1", monthly_cost_cents=2000)
|
||||
assert result > 0
|
||||
assert result < 2000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_returns_url():
|
||||
mock_session = MagicMock()
|
||||
@@ -806,9 +883,9 @@ async def test_cancel_stripe_subscription_raises_on_cancel_error():
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
return_value=_make_user_with_stripe("cus_123"),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
|
||||
Reference in New Issue
Block a user