diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index ae730aa81c..ee21b907db 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -1716,6 +1716,9 @@ async def handle_subscription_payment_failure(invoice: dict) -> None: amount=-amount_due, transaction_type=CreditTransactionType.SUBSCRIPTION, fail_insufficient_credits=True, + # Use invoice_id as the idempotency key so that Stripe webhook retries + # (e.g. on a transient stripe.Invoice.pay failure) do not double-charge. + transaction_key=invoice_id or None, metadata=SafeJson( { "stripe_customer_id": customer_id, diff --git a/autogpt_platform/backend/backend/data/credit_subscription_test.py b/autogpt_platform/backend/backend/data/credit_subscription_test.py index 5192cda04e..8a9666b87f 100644 --- a/autogpt_platform/backend/backend/data/credit_subscription_test.py +++ b/autogpt_platform/backend/backend/data/credit_subscription_test.py @@ -988,3 +988,31 @@ async def test_handle_subscription_payment_failure_invoice_pay_error_does_not_ra ): # Must not raise — the pay failure is only logged as a warning await handle_subscription_payment_failure(invoice) + + +@pytest.mark.asyncio +async def test_handle_subscription_payment_failure_passes_invoice_id_as_transaction_key(): + """invoice_id is used as the idempotency key to prevent double-charging on webhook retries.""" + mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO) + invoice = { + "id": "in_idempotency_test", + "customer": "cus_123", + "subscription": "sub_abc123", + "amount_due": 2000, + } + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.UserCredit._add_transaction", + new_callable=AsyncMock, + ) as mock_add_tx, + patch("backend.data.credit.stripe.Invoice.pay"), + ): + await handle_subscription_payment_failure(invoice) + mock_add_tx.assert_called_once() + _, kwargs = mock_add_tx.call_args + assert kwargs.get("transaction_key") == "in_idempotency_test"