fix(platform): propagate Stripe errors in cancel_stripe_subscription

- stripe.Subscription.list() is now wrapped in try-except; StripeError
  is logged and re-raised so callers know the listing failed.
- stripe.Subscription.cancel() StripeError is now re-raised (was swallowed),
  preventing set_subscription_tier from marking the user FREE when Stripe
  cancellation failed.
- update_subscription_tier catches StripeError from cancel and returns HTTP 502
  so DB tier is only updated if Stripe succeeds.
- Fix test patch path: use backend.data.credit.stripe.checkout.Session.create
  instead of bare stripe.checkout.Session.create for import-refactor safety.
- Add tests for raise-on-list-failure, raise-on-cancel-failure, and
  502 route response on cancel failure.

Addresses sentry[bot] comments 3061585490, 3061654688 on PR #12727.
This commit is contained in:
majdyz
2026-04-10 09:22:44 +07:00
parent cc1cef7da5
commit 4b3e47fe88
4 changed files with 89 additions and 16 deletions

View File

@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, Mock
import fastapi
import fastapi.testclient
import pytest_mock
import stripe
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from prisma.enums import SubscriptionTier
@@ -292,3 +293,36 @@ def test_update_subscription_tier_free_with_payment_cancels_stripe(
mock_cancel.assert_awaited_once()
finally:
teardown_auth(app)
def test_update_subscription_tier_free_cancel_failure_returns_502(
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to FREE returns 502 if Stripe cancellation fails."""
setup_auth(app)
try:
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
async def mock_feature_enabled(*args, **kwargs):
return True
mocker.patch(
"backend.api.features.v1.cancel_stripe_subscription",
side_effect=stripe.StripeError("network error"),
)
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
response = client.post("/credits/subscription", json={"tier": "FREE"})
assert response.status_code == 502
finally:
teardown_auth(app)

View File

@@ -769,7 +769,13 @@ async def update_subscription_tier(
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
if tier == SubscriptionTier.FREE:
if payment_enabled:
await cancel_stripe_subscription(user_id)
try:
await cancel_stripe_subscription(user_id)
except stripe.StripeError as e:
raise HTTPException(
status_code=502,
detail=f"Failed to cancel Stripe subscription: {e}",
)
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")

View File

@@ -432,7 +432,7 @@ class UserCreditBase(ABC):
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
)
# Single unified atomic operation for all transaction types using UserBalance
@@ -571,7 +571,7 @@ class UserCreditBase(ABC):
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
user_id=user_id,
balance=current_balance,
amount=amount,
@@ -582,7 +582,6 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
@@ -734,7 +733,7 @@ class UserCredit(UserCreditBase):
)
if request.amount <= 0 or request.amount > transaction.amount:
raise AssertionError(
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
)
balance, _ = await self._add_transaction(
@@ -788,12 +787,12 @@ class UserCredit(UserCreditBase):
# If the user has enough balance, just let them win the dispute.
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
dispute.close()
return
logger.warning(
f"Adding extra info for dispute from {user_id} for ${amount/100}"
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
)
# Retrieve recent transaction history to support our evidence.
# This provides a concise timeline that shows service usage and proper credit application.
@@ -1266,11 +1265,22 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
async def cancel_stripe_subscription(user_id: str) -> None:
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE).
Raises stripe.StripeError if any cancellation fails, so the caller can avoid
updating the DB tier when Stripe is inconsistent.
"""
customer_id = await get_stripe_customer_id(user_id)
subscriptions = stripe.Subscription.list(
customer=customer_id, status="active", limit=10
)
try:
subscriptions = stripe.Subscription.list(
customer=customer_id, status="active", limit=10
)
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: failed to list subscriptions for user %s",
user_id,
)
raise
for sub in subscriptions.auto_paging_iter():
try:
stripe.Subscription.cancel(sub["id"])
@@ -1280,6 +1290,7 @@ async def cancel_stripe_subscription(user_id: str) -> None:
sub["id"],
user_id,
)
raise
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:

View File

@@ -5,6 +5,7 @@ Tests for Stripe-based subscription tier billing.
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import stripe
from prisma.enums import SubscriptionTier
from prisma.models import User
@@ -159,6 +160,24 @@ async def test_cancel_stripe_subscription_no_active():
mock_cancel.assert_not_called()
@pytest.mark.asyncio
async def test_cancel_stripe_subscription_raises_on_list_failure():
"""stripe.Subscription.list() failure propagates so DB tier is not updated."""
with (
patch(
"backend.data.credit.get_stripe_customer_id",
new_callable=AsyncMock,
return_value="cus_123",
),
patch(
"backend.data.credit.stripe.Subscription.list",
side_effect=stripe.StripeError("network error"),
),
):
with pytest.raises(stripe.StripeError):
await cancel_stripe_subscription("user-1")
@pytest.mark.asyncio
async def test_create_subscription_checkout_returns_url():
mock_session = MagicMock()
@@ -174,7 +193,10 @@ async def test_create_subscription_checkout_returns_url():
new_callable=AsyncMock,
return_value="cus_123",
),
patch("stripe.checkout.Session.create", return_value=mock_session),
patch(
"backend.data.credit.stripe.checkout.Session.create",
return_value=mock_session,
),
):
url = await create_subscription_checkout(
user_id="user-1",
@@ -333,8 +355,8 @@ async def test_get_subscription_price_id_empty_flag_returns_none():
@pytest.mark.asyncio
async def test_cancel_stripe_subscription_handles_stripe_error():
"""Stripe errors during cancellation should be logged, not raised."""
async def test_cancel_stripe_subscription_raises_on_cancel_error():
"""Stripe errors during cancellation are re-raised so the DB tier is not updated."""
import stripe as stripe_mod
mock_sub = {"id": "sub_abc123"}
@@ -356,5 +378,5 @@ async def test_cancel_stripe_subscription_handles_stripe_error():
side_effect=stripe_mod.StripeError("network error"),
),
):
# Should not raise — errors are logged as warnings
await cancel_stripe_subscription("user-1")
with pytest.raises(stripe_mod.StripeError):
await cancel_stripe_subscription("user-1")