mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(platform): propagate Stripe errors in cancel_stripe_subscription
- stripe.Subscription.list() is now wrapped in try-except; StripeError is logged and re-raised so callers know the listing failed. - stripe.Subscription.cancel() StripeError is now re-raised (was swallowed), preventing set_subscription_tier from marking the user FREE when Stripe cancellation failed. - update_subscription_tier catches StripeError from cancel and returns HTTP 502 so DB tier is only updated if Stripe succeeds. - Fix test patch path: use backend.data.credit.stripe.checkout.Session.create instead of bare stripe.checkout.Session.create for import-refactor safety. - Add tests for raise-on-list-failure, raise-on-cancel-failure, and 502 route response on cancel failure. Addresses sentry[bot] comments 3061585490, 3061654688 on PR #12727.
This commit is contained in:
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, Mock
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest_mock
|
||||
import stripe
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from prisma.enums import SubscriptionTier
|
||||
|
||||
@@ -292,3 +293,36 @@ def test_update_subscription_tier_free_with_payment_cancels_stripe(
|
||||
mock_cancel.assert_awaited_once()
|
||||
finally:
|
||||
teardown_auth(app)
|
||||
|
||||
|
||||
def test_update_subscription_tier_free_cancel_failure_returns_502(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to FREE returns 502 if Stripe cancellation fails."""
|
||||
setup_auth(app)
|
||||
try:
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.cancel_stripe_subscription",
|
||||
side_effect=stripe.StripeError("network error"),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "FREE"})
|
||||
|
||||
assert response.status_code == 502
|
||||
finally:
|
||||
teardown_auth(app)
|
||||
|
||||
@@ -769,7 +769,13 @@ async def update_subscription_tier(
|
||||
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
|
||||
if tier == SubscriptionTier.FREE:
|
||||
if payment_enabled:
|
||||
await cancel_stripe_subscription(user_id)
|
||||
try:
|
||||
await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to cancel Stripe subscription: {e}",
|
||||
)
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
|
||||
@@ -432,7 +432,7 @@ class UserCreditBase(ABC):
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
if current_balance >= ceiling_balance:
|
||||
raise ValueError(
|
||||
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
|
||||
)
|
||||
|
||||
# Single unified atomic operation for all transaction types using UserBalance
|
||||
@@ -571,7 +571,7 @@ class UserCreditBase(ABC):
|
||||
if amount < 0 and fail_insufficient_credits:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
|
||||
user_id=user_id,
|
||||
balance=current_balance,
|
||||
amount=amount,
|
||||
@@ -582,7 +582,6 @@ class UserCreditBase(ABC):
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
|
||||
async def _send_refund_notification(
|
||||
self,
|
||||
notification_request: RefundRequestData,
|
||||
@@ -734,7 +733,7 @@ class UserCredit(UserCreditBase):
|
||||
)
|
||||
if request.amount <= 0 or request.amount > transaction.amount:
|
||||
raise AssertionError(
|
||||
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
|
||||
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
|
||||
)
|
||||
|
||||
balance, _ = await self._add_transaction(
|
||||
@@ -788,12 +787,12 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
# If the user has enough balance, just let them win the dispute.
|
||||
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
|
||||
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
|
||||
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
|
||||
dispute.close()
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"Adding extra info for dispute from {user_id} for ${amount/100}"
|
||||
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
|
||||
)
|
||||
# Retrieve recent transaction history to support our evidence.
|
||||
# This provides a concise timeline that shows service usage and proper credit application.
|
||||
@@ -1266,11 +1265,22 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
|
||||
|
||||
async def cancel_stripe_subscription(user_id: str) -> None:
|
||||
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
|
||||
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE).
|
||||
|
||||
Raises stripe.StripeError if any cancellation fails, so the caller can avoid
|
||||
updating the DB tier when Stripe is inconsistent.
|
||||
"""
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
subscriptions = stripe.Subscription.list(
|
||||
customer=customer_id, status="active", limit=10
|
||||
)
|
||||
try:
|
||||
subscriptions = stripe.Subscription.list(
|
||||
customer=customer_id, status="active", limit=10
|
||||
)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"cancel_stripe_subscription: failed to list subscriptions for user %s",
|
||||
user_id,
|
||||
)
|
||||
raise
|
||||
for sub in subscriptions.auto_paging_iter():
|
||||
try:
|
||||
stripe.Subscription.cancel(sub["id"])
|
||||
@@ -1280,6 +1290,7 @@ async def cancel_stripe_subscription(user_id: str) -> None:
|
||||
sub["id"],
|
||||
user_id,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
|
||||
@@ -5,6 +5,7 @@ Tests for Stripe-based subscription tier billing.
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from prisma.enums import SubscriptionTier
|
||||
from prisma.models import User
|
||||
|
||||
@@ -159,6 +160,24 @@ async def test_cancel_stripe_subscription_no_active():
|
||||
mock_cancel.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_stripe_subscription_raises_on_list_failure():
|
||||
"""stripe.Subscription.list() failure propagates so DB tier is not updated."""
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list",
|
||||
side_effect=stripe.StripeError("network error"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(stripe.StripeError):
|
||||
await cancel_stripe_subscription("user-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_returns_url():
|
||||
mock_session = MagicMock()
|
||||
@@ -174,7 +193,10 @@ async def test_create_subscription_checkout_returns_url():
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
),
|
||||
patch("stripe.checkout.Session.create", return_value=mock_session),
|
||||
patch(
|
||||
"backend.data.credit.stripe.checkout.Session.create",
|
||||
return_value=mock_session,
|
||||
),
|
||||
):
|
||||
url = await create_subscription_checkout(
|
||||
user_id="user-1",
|
||||
@@ -333,8 +355,8 @@ async def test_get_subscription_price_id_empty_flag_returns_none():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_stripe_subscription_handles_stripe_error():
|
||||
"""Stripe errors during cancellation should be logged, not raised."""
|
||||
async def test_cancel_stripe_subscription_raises_on_cancel_error():
|
||||
"""Stripe errors during cancellation are re-raised so the DB tier is not updated."""
|
||||
import stripe as stripe_mod
|
||||
|
||||
mock_sub = {"id": "sub_abc123"}
|
||||
@@ -356,5 +378,5 @@ async def test_cancel_stripe_subscription_handles_stripe_error():
|
||||
side_effect=stripe_mod.StripeError("network error"),
|
||||
),
|
||||
):
|
||||
# Should not raise — errors are logged as warnings
|
||||
await cancel_stripe_subscription("user-1")
|
||||
with pytest.raises(stripe_mod.StripeError):
|
||||
await cancel_stripe_subscription("user-1")
|
||||
|
||||
Reference in New Issue
Block a user