mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
90 Commits
swiftyos/r
...
feat/subsc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50324f710a | ||
|
|
cc78a92a3f | ||
|
|
f92dc0fb02 | ||
|
|
b8b5291022 | ||
|
|
9b56f2f927 | ||
|
|
fd75467ab0 | ||
|
|
3e9f161856 | ||
|
|
f645e947d1 | ||
|
|
3a66eb2548 | ||
|
|
1e1caa8810 | ||
|
|
3e654228ac | ||
|
|
697ffa81f0 | ||
|
|
3e1f886503 | ||
|
|
e2e7c85a48 | ||
|
|
2b4727e8b2 | ||
|
|
0cd0a76305 | ||
|
|
bd2efed080 | ||
|
|
5fccd8a762 | ||
|
|
d27d22159d | ||
|
|
df205b5444 | ||
|
|
4efa1c4310 | ||
|
|
3324e7199b | ||
|
|
51532c4fd1 | ||
|
|
a73ceb2838 | ||
|
|
7672722996 | ||
|
|
2f75eff082 | ||
|
|
10b92fbaa2 | ||
|
|
2cdd164223 | ||
|
|
c421a66fa5 | ||
|
|
b435814826 | ||
|
|
10bf830b59 | ||
|
|
11a5ce99f4 | ||
|
|
354de5dc0f | ||
|
|
7648aacb89 | ||
|
|
5ba14e1152 | ||
|
|
fdfda78bc8 | ||
|
|
b681363969 | ||
|
|
2d22de7aa8 | ||
|
|
f174d75e8e | ||
|
|
607854375b | ||
|
|
9151755f00 | ||
|
|
b9da535cfd | ||
|
|
befd9df446 | ||
|
|
bfd1e6e793 | ||
|
|
c477e7b92e | ||
|
|
bcbe7f4525 | ||
|
|
a118ea564e | ||
|
|
cf89b58960 | ||
|
|
6b03b8d4d8 | ||
|
|
ec65fd5c84 | ||
|
|
14e1b47b5a | ||
|
|
be8d54b331 | ||
|
|
bb52c5b10d | ||
|
|
4bd79d8f6e | ||
|
|
bfe67b6e3d | ||
|
|
46434e7402 | ||
|
|
eaa833528c | ||
|
|
62a6175d2a | ||
|
|
929c8a316c | ||
|
|
557ff84196 | ||
|
|
8a2dd8f62a | ||
|
|
52d8e67135 | ||
|
|
48f022b506 | ||
|
|
bf7f674b2f | ||
|
|
69e0a66f5e | ||
|
|
a4006fa5a1 | ||
|
|
0251bfd664 | ||
|
|
2f24091c17 | ||
|
|
8b93cea4d4 | ||
|
|
693c616bf5 | ||
|
|
6f7bf90769 | ||
|
|
ce57601305 | ||
|
|
d81bbdb870 | ||
|
|
7f6163b180 | ||
|
|
2057b4597e | ||
|
|
5bb7027f89 | ||
|
|
329a034ebe | ||
|
|
62f3ed79be | ||
|
|
54450def6b | ||
|
|
8ad5bf03a7 | ||
|
|
16c38c4dfb | ||
|
|
945297b965 | ||
|
|
6b57dc0c7f | ||
|
|
c1aec96c0f | ||
|
|
52b0e2a9a6 | ||
|
|
3ef14e9657 | ||
|
|
3c49d3373d | ||
|
|
e7e6c8f4b4 | ||
|
|
4b3e47fe88 | ||
|
|
cc1cef7da5 |
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,8 @@ import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Literal, Sequence, get_args
|
||||
from typing import Annotated, Any, Literal, Sequence, cast, get_args
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -54,8 +55,11 @@ from backend.data.credit import (
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_auto_top_up,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
modify_stripe_subscription_for_tier,
|
||||
set_auto_top_up,
|
||||
set_subscription_tier,
|
||||
sync_subscription_from_stripe,
|
||||
@@ -699,9 +703,72 @@ class SubscriptionCheckoutResponse(BaseModel):
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: str
|
||||
monthly_cost: int
|
||||
tier_costs: dict[str, int]
|
||||
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
|
||||
|
||||
def _validate_checkout_redirect_url(url: str) -> bool:
|
||||
"""Return True if `url` matches the configured frontend origin.
|
||||
|
||||
Prevents open-redirect: attackers must not be able to supply arbitrary
|
||||
success_url/cancel_url that Stripe will redirect users to after checkout.
|
||||
|
||||
Pre-parse rejection rules (applied before urlparse):
|
||||
- Backslashes (``\\``) are normalised differently across parsers/browsers.
|
||||
- Control characters (U+0000–U+001F) are not valid in URLs and may confuse
|
||||
some URL-parsing implementations.
|
||||
"""
|
||||
# Reject characters that can confuse URL parsers before any parsing.
|
||||
if "\\" in url:
|
||||
return False
|
||||
if any(ord(c) < 0x20 for c in url):
|
||||
return False
|
||||
|
||||
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
if not allowed:
|
||||
# No configured origin — refuse to validate rather than allow arbitrary URLs.
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
allowed_parsed = urlparse(allowed)
|
||||
except ValueError:
|
||||
return False
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
|
||||
# can trick browsers into connecting to a different host than displayed.
|
||||
# ``@`` in query/fragment is harmless and must be allowed.
|
||||
if "@" in parsed.netloc:
|
||||
return False
|
||||
return (
|
||||
parsed.scheme == allowed_parsed.scheme
|
||||
and parsed.netloc == allowed_parsed.netloc
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
|
||||
async def _get_stripe_price_amount(price_id: str) -> int | None:
|
||||
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
|
||||
|
||||
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
|
||||
of caching the ``None`` sentinel so the next request retries Stripe instead
|
||||
of being served a stale "no price" for the rest of the TTL window. Callers
|
||||
should treat ``None`` as an unknown price and fall back to 0.
|
||||
|
||||
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
|
||||
every GET /credits/subscription page load and reduces quota consumption.
|
||||
"""
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
return price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"Failed to retrieve Stripe price %s — returning None (not cached)",
|
||||
price_id,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -722,21 +789,26 @@ async def get_subscription_status(
|
||||
*[get_subscription_price_id(t) for t in paid_tiers]
|
||||
)
|
||||
|
||||
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
|
||||
for t, price_id in zip(paid_tiers, price_ids):
|
||||
cost = 0
|
||||
if price_id:
|
||||
try:
|
||||
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
|
||||
cost = price.unit_amount or 0
|
||||
except stripe.StripeError:
|
||||
pass
|
||||
tier_costs: dict[str, int] = {
|
||||
SubscriptionTier.FREE.value: 0,
|
||||
SubscriptionTier.ENTERPRISE.value: 0,
|
||||
}
|
||||
|
||||
async def _cost(pid: str | None) -> int:
|
||||
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
|
||||
|
||||
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
|
||||
for t, cost in zip(paid_tiers, costs):
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
|
||||
return SubscriptionStatusResponse(
|
||||
tier=tier.value,
|
||||
monthly_cost=tier_costs.get(tier.value, 0),
|
||||
monthly_cost=current_monthly_cost,
|
||||
tier_costs=tier_costs,
|
||||
proration_credit_cents=proration_credit,
|
||||
)
|
||||
|
||||
|
||||
@@ -766,24 +838,125 @@ async def update_subscription_tier(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
|
||||
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
|
||||
# keeps their tier for the time they already paid for. The DB tier is NOT
|
||||
# updated here when a subscription exists — the customer.subscription.deleted
|
||||
# webhook fires at period end and downgrades to FREE then.
|
||||
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
|
||||
# tier), cancel_stripe_subscription returns False and we update the DB tier
|
||||
# immediately since no webhook will ever fire.
|
||||
# When payment is disabled entirely, update the DB tier directly.
|
||||
if tier == SubscriptionTier.FREE:
|
||||
if payment_enabled:
|
||||
await cancel_stripe_subscription(user_id)
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
except stripe.StripeError as e:
|
||||
# Log full Stripe error server-side but return a generic message
|
||||
# to the client — raw Stripe errors can leak customer/sub IDs and
|
||||
# infrastructure config details.
|
||||
logger.exception(
|
||||
"Stripe error cancelling subscription for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to cancel your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
if not had_subscription:
|
||||
# No active Stripe subscription found — the user was on an
|
||||
# admin-granted tier. Update DB immediately since the
|
||||
# subscription.deleted webhook will never fire.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Beta users (payment not enabled) → update tier directly without Stripe.
|
||||
# Paid tier changes require payment to be enabled — block self-service upgrades
|
||||
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
|
||||
if not payment_enabled:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Subscription not available for tier {tier}",
|
||||
)
|
||||
|
||||
# No-op short-circuit: if the user is already on the requested paid tier,
|
||||
# do NOT create a new Checkout Session. Without this guard, a duplicate
|
||||
# request (double-click, retried POST, stale page) creates a second
|
||||
# subscription for the same price; the user would be charged for both
|
||||
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
|
||||
# which only fires after the second charge has cleared.
|
||||
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
|
||||
# Paid upgrade → create Stripe Checkout Session.
|
||||
# Paid→paid tier change: if the user already has a Stripe subscription,
|
||||
# modify it in-place with proration instead of creating a new Checkout
|
||||
# Session. This preserves remaining paid time and avoids double-charging.
|
||||
# The customer.subscription.updated webhook fires and updates the DB tier.
|
||||
current_tier = user.subscription_tier or SubscriptionTier.FREE
|
||||
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
# modify_stripe_subscription_for_tier returns False when no active
|
||||
# Stripe subscription exists — i.e. the user has an admin-granted
|
||||
# paid tier with no Stripe record. In that case, update the DB
|
||||
# tier directly (same as the FREE-downgrade path for admin-granted
|
||||
# users) rather than sending them through a new Checkout Session.
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return SubscriptionCheckoutResponse(url="")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# Paid upgrade from FREE → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url are required for paid tier upgrades",
|
||||
)
|
||||
# Open-redirect protection: both URLs must point to the configured frontend
|
||||
# origin, otherwise an attacker could use our Stripe integration as a
|
||||
# redirector to arbitrary phishing sites.
|
||||
#
|
||||
# Fail early with a clear 503 if the server is misconfigured (neither
|
||||
# frontend_base_url nor platform_base_url set), so operators get an
|
||||
# actionable error instead of the misleading "must match the platform
|
||||
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
|
||||
# produce when `allowed` is empty.
|
||||
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
|
||||
logger.error(
|
||||
"update_subscription_tier: neither frontend_base_url nor "
|
||||
"platform_base_url is configured; cannot validate checkout redirect URLs"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Payment redirect URLs cannot be validated: "
|
||||
"frontend_base_url or platform_base_url must be set on the server."
|
||||
),
|
||||
)
|
||||
if not _validate_checkout_redirect_url(
|
||||
request.success_url
|
||||
) or not _validate_checkout_redirect_url(request.cancel_url):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="success_url and cancel_url must match the platform frontend origin",
|
||||
)
|
||||
try:
|
||||
url = await create_subscription_checkout(
|
||||
user_id=user_id,
|
||||
@@ -791,8 +964,19 @@ async def update_subscription_tier(
|
||||
success_url=request.success_url,
|
||||
cancel_url=request.cancel_url,
|
||||
)
|
||||
except (ValueError, stripe.StripeError) as e:
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error creating checkout session for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to start checkout right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
return SubscriptionCheckoutResponse(url=url)
|
||||
|
||||
@@ -801,44 +985,78 @@ async def update_subscription_tier(
|
||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||
)
|
||||
async def stripe_webhook(request: Request):
|
||||
webhook_secret = settings.secrets.stripe_webhook_secret
|
||||
if not webhook_secret:
|
||||
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
|
||||
# signature over the same empty key). Reject all webhook calls when unconfigured.
|
||||
logger.error(
|
||||
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
|
||||
"rejecting request to prevent signature bypass"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Webhook not configured")
|
||||
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
# Get the signature header
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.secrets.stripe_webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
|
||||
except ValueError:
|
||||
# Invalid payload
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
|
||||
)
|
||||
except stripe.SignatureVerificationError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.SignatureVerificationError:
|
||||
# Invalid signature
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
# Defensive payload extraction. A malformed payload (missing/non-dict
|
||||
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
|
||||
# AFTER signature verification — which Stripe interprets as a delivery
|
||||
# failure and retries forever, while spamming Sentry with no useful info.
|
||||
# Acknowledge with 200 and a warning so Stripe stops retrying.
|
||||
event_type = event.get("type", "")
|
||||
event_data = event.get("data") or {}
|
||||
data_object = event_data.get("object") if isinstance(event_data, dict) else None
|
||||
if not isinstance(data_object, dict):
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing or non-dict data.object; ignoring",
|
||||
event_type,
|
||||
)
|
||||
return Response(status_code=200)
|
||||
|
||||
if (
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
if event_type in (
|
||||
"checkout.session.completed",
|
||||
"checkout.session.async_payment_succeeded",
|
||||
):
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
session_id = data_object.get("id")
|
||||
if not session_id:
|
||||
logger.warning(
|
||||
"stripe_webhook: %s missing data.object.id; ignoring", event_type
|
||||
)
|
||||
return Response(status_code=200)
|
||||
await UserCredit().fulfill_checkout(session_id=session_id)
|
||||
|
||||
if event["type"] in (
|
||||
if event_type in (
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
):
|
||||
await sync_subscription_from_stripe(event["data"]["object"])
|
||||
await sync_subscription_from_stripe(data_object)
|
||||
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
|
||||
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
|
||||
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
|
||||
# to satisfy the type checker without changing runtime behaviour.
|
||||
if event_type == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
|
||||
|
||||
if event_type == "refund.created" or event_type == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(
|
||||
cast("stripe.Refund | stripe.Dispute", data_object)
|
||||
)
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -124,15 +124,6 @@ def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
vhost="/",
|
||||
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
# The consumer threads sit in pika's blocking ``start_consuming()`` for
|
||||
# the full lifetime of the process. If the TCP connection is dropped
|
||||
# (server restart, NAT timeout, laptop sleep) while pika's IO thread is
|
||||
# starved, the socket rots in CLOSE_WAIT and no message is ever
|
||||
# consumed — see zombie-consumer incident notes. A short heartbeat plus
|
||||
# kernel-level TCP keepalive makes both the app and the OS notice a
|
||||
# dead peer within a couple of minutes instead of hours.
|
||||
heartbeat=60,
|
||||
tcp_keepalive=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -174,18 +174,14 @@ sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before running `connect_integration(provider="github")` which will ask the user to connect their GitHub regardless if it's already connected.
|
||||
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- **MANDATORY:** You MUST run `gh auth status` before EVER calling
|
||||
`connect_integration(provider="github")`. If it shows `Logged in`,
|
||||
proceed directly — no integration connection needed. Never skip this check.
|
||||
- If `gh auth status` shows NOT logged in, or `gh`/`git` fails with an
|
||||
authentication error (e.g. "authentication required", "could not read
|
||||
Username", or exit code 128), THEN call
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
|
||||
@@ -125,7 +125,12 @@ config = ChatConfig()
|
||||
|
||||
|
||||
class _SystemPromptPreset(SystemPromptPreset, total=False):
|
||||
"""Extends SystemPromptPreset with fields added in claude-agent-sdk 0.1.59."""
|
||||
"""Extends :class:`SystemPromptPreset` with ``exclude_dynamic_sections``.
|
||||
|
||||
The field was added to the upstream TypedDict in claude-agent-sdk 0.1.59.
|
||||
Until the package is pinned to that version we declare it locally so Pyright
|
||||
accepts the kwarg without a ``# type: ignore`` comment.
|
||||
"""
|
||||
|
||||
exclude_dynamic_sections: NotRequired[bool]
|
||||
|
||||
|
||||
@@ -880,6 +880,116 @@ class TestUploadCliSession:
|
||||
assert meta_content["mode"] == "baseline"
|
||||
assert meta_content["message_count"] == 4
|
||||
|
||||
def test_strips_session_before_upload_and_writes_back(self):
|
||||
"""strip_for_upload removes progress entries and returns smaller content."""
|
||||
import json
|
||||
|
||||
from .transcript import strip_for_upload
|
||||
|
||||
progress_entry = {
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"parentUuid": "u1",
|
||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
||||
}
|
||||
user_entry = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "hello"},
|
||||
}
|
||||
asst_entry = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {"role": "assistant", "content": "world"},
|
||||
}
|
||||
raw_content = (
|
||||
json.dumps(progress_entry)
|
||||
+ "\n"
|
||||
+ json.dumps(user_entry)
|
||||
+ "\n"
|
||||
+ json.dumps(asst_entry)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
stripped = strip_for_upload(raw_content)
|
||||
|
||||
stored_lines = stripped.strip().split("\n")
|
||||
stored_types = [json.loads(line).get("type") for line in stored_lines]
|
||||
assert "progress" not in stored_types
|
||||
assert "user" in stored_types
|
||||
assert "assistant" in stored_types
|
||||
assert len(stripped.encode()) < len(raw_content.encode())
|
||||
|
||||
def test_strips_stale_thinking_blocks_before_upload(self):
|
||||
"""strip_for_upload removes thinking blocks from non-last assistant turns."""
|
||||
import json
|
||||
|
||||
from .transcript import strip_for_upload
|
||||
|
||||
u1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "q1"},
|
||||
}
|
||||
a1_with_thinking = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"id": "msg_a1",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "A" * 5000},
|
||||
{"type": "text", "text": "answer1"},
|
||||
],
|
||||
},
|
||||
}
|
||||
u2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1",
|
||||
"message": {"role": "user", "content": "q2"},
|
||||
}
|
||||
a2_no_thinking = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"id": "msg_a2",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "answer2"}],
|
||||
},
|
||||
}
|
||||
raw_content = (
|
||||
json.dumps(u1)
|
||||
+ "\n"
|
||||
+ json.dumps(a1_with_thinking)
|
||||
+ "\n"
|
||||
+ json.dumps(u2)
|
||||
+ "\n"
|
||||
+ json.dumps(a2_no_thinking)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
stripped = strip_for_upload(raw_content)
|
||||
|
||||
stored_lines = stripped.strip().split("\n")
|
||||
|
||||
# a1 should have its thinking block stripped (it's not the last assistant turn).
|
||||
a1_stored = json.loads(stored_lines[1])
|
||||
a1_content = a1_stored["message"]["content"]
|
||||
assert all(
|
||||
b["type"] != "thinking" for b in a1_content
|
||||
), "stale thinking block should be stripped from a1"
|
||||
assert any(
|
||||
b["type"] == "text" for b in a1_content
|
||||
), "text block should be kept in a1"
|
||||
|
||||
# a2 (last turn) should be unchanged.
|
||||
a2_stored = json.loads(stored_lines[3])
|
||||
assert a2_stored["message"]["content"] == [{"type": "text", "text": "answer2"}]
|
||||
|
||||
|
||||
class TestRestoreCliSession:
|
||||
def test_returns_none_when_file_not_found_in_storage(self):
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
@@ -31,6 +34,7 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventModel, RefundRequestData
|
||||
from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.util.cache import cached
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
@@ -432,7 +436,7 @@ class UserCreditBase(ABC):
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
if current_balance >= ceiling_balance:
|
||||
raise ValueError(
|
||||
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
|
||||
)
|
||||
|
||||
# Single unified atomic operation for all transaction types using UserBalance
|
||||
@@ -571,7 +575,7 @@ class UserCreditBase(ABC):
|
||||
if amount < 0 and fail_insufficient_credits:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
|
||||
user_id=user_id,
|
||||
balance=current_balance,
|
||||
amount=amount,
|
||||
@@ -582,7 +586,6 @@ class UserCreditBase(ABC):
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
|
||||
async def _send_refund_notification(
|
||||
self,
|
||||
notification_request: RefundRequestData,
|
||||
@@ -734,7 +737,7 @@ class UserCredit(UserCreditBase):
|
||||
)
|
||||
if request.amount <= 0 or request.amount > transaction.amount:
|
||||
raise AssertionError(
|
||||
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
|
||||
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
|
||||
)
|
||||
|
||||
balance, _ = await self._add_transaction(
|
||||
@@ -788,12 +791,12 @@ class UserCredit(UserCreditBase):
|
||||
|
||||
# If the user has enough balance, just let them win the dispute.
|
||||
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
|
||||
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
|
||||
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
|
||||
dispute.close()
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"Adding extra info for dispute from {user_id} for ${amount/100}"
|
||||
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
|
||||
)
|
||||
# Retrieve recent transaction history to support our evidence.
|
||||
# This provides a concise timeline that shows service usage and proper credit application.
|
||||
@@ -1237,14 +1240,23 @@ async def get_stripe_customer_id(user_id: str) -> str:
|
||||
if user.stripe_customer_id:
|
||||
return user.stripe_customer_id
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
# Race protection: two concurrent calls (e.g. user double-clicks "Upgrade",
|
||||
# or any retried request) would each pass the check above and create their
|
||||
# own Stripe Customer, leaving an orphaned billable customer in Stripe.
|
||||
# Pass an idempotency_key so Stripe collapses concurrent + retried calls
|
||||
# into the same Customer object server-side. The 24h Stripe idempotency
|
||||
# window comfortably covers any realistic in-flight retry scenario.
|
||||
customer = await run_in_threadpool(
|
||||
stripe.Customer.create,
|
||||
name=user.name or "",
|
||||
email=user.email,
|
||||
metadata={"user_id": user_id},
|
||||
idempotency_key=f"customer-create-{user_id}",
|
||||
)
|
||||
await User.prisma().update(
|
||||
where={"id": user_id}, data={"stripeCustomerId": customer.id}
|
||||
)
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
return customer.id
|
||||
|
||||
|
||||
@@ -1263,23 +1275,211 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
data={"subscriptionTier": tier},
|
||||
)
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
# Also invalidate the rate-limit tier cache so CoPilot picks up the new
|
||||
# tier immediately rather than waiting up to 5 minutes for the TTL to expire.
|
||||
from backend.copilot.rate_limit import get_user_tier # local import avoids circular
|
||||
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def cancel_stripe_subscription(user_id: str) -> None:
|
||||
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
subscriptions = stripe.Subscription.list(
|
||||
customer=customer_id, status="active", limit=10
|
||||
)
|
||||
for sub in subscriptions.auto_paging_iter():
|
||||
try:
|
||||
stripe.Subscription.cancel(sub["id"])
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"cancel_stripe_subscription: failed to cancel sub %s for user %s",
|
||||
sub["id"],
|
||||
user_id,
|
||||
async def _cancel_customer_subscriptions(
|
||||
customer_id: str,
|
||||
exclude_sub_id: str | None = None,
|
||||
at_period_end: bool = False,
|
||||
) -> int:
|
||||
"""Cancel all billable Stripe subscriptions for a customer, optionally excluding one.
|
||||
|
||||
Cancels both ``active`` and ``trialing`` subscriptions, since trialing subs will
|
||||
start billing once the trial ends and must be cleaned up on downgrade/upgrade to
|
||||
avoid double-charging or charging users who intended to cancel.
|
||||
|
||||
When ``at_period_end=True``, schedules cancellation at the end of the current
|
||||
billing period instead of cancelling immediately — the user keeps their tier
|
||||
until the period ends, then ``customer.subscription.deleted`` fires and the
|
||||
webhook downgrades them to FREE.
|
||||
|
||||
Wraps every synchronous Stripe SDK call with run_in_threadpool so the async event
|
||||
loop is never blocked. Raises stripe.StripeError on list/cancel failure so callers
|
||||
that need strict consistency can react; cleanup callers can catch and log instead.
|
||||
|
||||
Returns the number of subscriptions cancelled/scheduled for cancellation.
|
||||
"""
|
||||
# Query active and trialing separately; Stripe's list API accepts a single status
|
||||
# filter at a time (no OR), and we explicitly want to skip canceled/incomplete/
|
||||
# past_due subs rather than filter them out client-side via status="all".
|
||||
seen_ids: set[str] = set()
|
||||
for status in ("active", "trialing"):
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status=status, limit=10
|
||||
)
|
||||
# Iterate only the first page (up to 10); avoid auto_paging_iter which would
|
||||
# trigger additional sync HTTP calls inside the event loop.
|
||||
if subscriptions.has_more:
|
||||
logger.error(
|
||||
"_cancel_customer_subscriptions: customer %s has more than 10 %s"
|
||||
" subscriptions — only the first page was processed; remaining"
|
||||
" subscriptions were NOT cancelled",
|
||||
customer_id,
|
||||
status,
|
||||
)
|
||||
for sub in subscriptions.data:
|
||||
sub_id = sub["id"]
|
||||
if exclude_sub_id and sub_id == exclude_sub_id:
|
||||
continue
|
||||
if sub_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(sub_id)
|
||||
if at_period_end:
|
||||
await run_in_threadpool(
|
||||
stripe.Subscription.modify, sub_id, cancel_at_period_end=True
|
||||
)
|
||||
else:
|
||||
await run_in_threadpool(stripe.Subscription.cancel, sub_id)
|
||||
return len(seen_ids)
|
||||
|
||||
|
||||
async def cancel_stripe_subscription(user_id: str) -> bool:
|
||||
"""Schedule cancellation of all active/trialing Stripe subscriptions at period end.
|
||||
|
||||
The subscription stays active until the end of the billing period so the user
|
||||
keeps their tier for the time they already paid for. The ``customer.subscription.deleted``
|
||||
webhook fires at period end and downgrades the DB tier to FREE.
|
||||
|
||||
Returns True if at least one subscription was found and scheduled for cancellation,
|
||||
False if the customer had no active/trialing subscriptions (e.g., admin-granted tier
|
||||
with no associated Stripe subscription). When False, the caller should update the
|
||||
DB tier directly since no webhook will fire to do it.
|
||||
|
||||
Raises stripe.StripeError if any modification fails, so the caller can avoid
|
||||
updating the DB tier when Stripe is inconsistent.
|
||||
"""
|
||||
# Guard: only proceed if the user already has a Stripe customer ID. Calling
|
||||
# get_stripe_customer_id for a user who has never had a paid subscription would
|
||||
# create an orphaned, potentially-billable Stripe Customer object — we avoid that
|
||||
# by returning False early so the caller can downgrade the DB tier directly.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return False
|
||||
|
||||
customer_id = user.stripe_customer_id
|
||||
try:
|
||||
cancelled_count = await _cancel_customer_subscriptions(
|
||||
customer_id, at_period_end=True
|
||||
)
|
||||
return cancelled_count > 0
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"cancel_stripe_subscription: Stripe error while cancelling subs for user %s",
|
||||
user_id,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
|
||||
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
|
||||
|
||||
Fetches the user's active or trialing Stripe subscription to determine how many
|
||||
seconds remain in the current billing period, then calculates the unused portion of
|
||||
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
|
||||
is found.
|
||||
|
||||
Both ``active`` and ``trialing`` subscriptions are checked: a trialing user still
|
||||
accumulates a billing period and Stripe prorates the remaining trial value on upgrade.
|
||||
"""
|
||||
if monthly_cost_cents <= 0:
|
||||
return 0
|
||||
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
|
||||
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
|
||||
# orphaned customer on every billing-page load for those users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return 0
|
||||
try:
|
||||
customer_id = user.stripe_customer_id
|
||||
for status in ("active", "trialing"):
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list,
|
||||
customer=customer_id,
|
||||
status=status,
|
||||
limit=1,
|
||||
)
|
||||
if not subscriptions.data:
|
||||
continue
|
||||
sub = subscriptions.data[0]
|
||||
period_start: int = sub["current_period_start"]
|
||||
period_end: int = sub["current_period_end"]
|
||||
now = int(time.time())
|
||||
total_seconds = period_end - period_start
|
||||
remaining_seconds = max(period_end - now, 0)
|
||||
if total_seconds <= 0:
|
||||
return 0
|
||||
return int(monthly_cost_cents * remaining_seconds / total_seconds)
|
||||
return 0
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"get_proration_credit_cents: failed to compute proration for user %s",
|
||||
user_id,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
async def modify_stripe_subscription_for_tier(
|
||||
user_id: str, tier: SubscriptionTier
|
||||
) -> bool:
|
||||
"""Modify an existing Stripe subscription to a new paid tier using proration.
|
||||
|
||||
For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing
|
||||
subscription is preferable to cancelling + creating a new one via Checkout:
|
||||
Stripe handles proration automatically, crediting unused time on the old plan
|
||||
and charging the pro-rated amount for the new plan in the same billing cycle.
|
||||
|
||||
Returns:
|
||||
True — a subscription was found and modified successfully.
|
||||
False — no active/trialing subscription exists (e.g. admin-granted tier or
|
||||
first-time paid signup); caller should fall back to Checkout.
|
||||
|
||||
Raises stripe.StripeError on API failures so callers can propagate a 502.
|
||||
Raises ValueError when no Stripe price ID is configured for the tier.
|
||||
"""
|
||||
price_id = await get_subscription_price_id(tier)
|
||||
if not price_id:
|
||||
raise ValueError(f"No Stripe price ID configured for tier {tier}")
|
||||
|
||||
# Guard: only proceed if the user already has a Stripe customer ID. Calling
|
||||
# get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier)
|
||||
# would create an orphaned customer object if the subsequent Subscription.list call
|
||||
# fails. Return False early so the API layer falls back to Checkout instead.
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return False
|
||||
|
||||
customer_id = user.stripe_customer_id
|
||||
for status in ("active", "trialing"):
|
||||
subscriptions = await run_in_threadpool(
|
||||
stripe.Subscription.list, customer=customer_id, status=status, limit=1
|
||||
)
|
||||
if not subscriptions.data:
|
||||
continue
|
||||
sub = subscriptions.data[0]
|
||||
sub_id = sub["id"]
|
||||
items = sub.get("items", {}).get("data", [])
|
||||
if not items:
|
||||
continue
|
||||
item_id = items[0]["id"]
|
||||
await run_in_threadpool(
|
||||
stripe.Subscription.modify,
|
||||
sub_id,
|
||||
items=[{"id": item_id, "price": price_id}],
|
||||
proration_behavior="create_prorations",
|
||||
)
|
||||
logger.info(
|
||||
"modify_stripe_subscription_for_tier: modified sub %s for user %s → %s",
|
||||
sub_id,
|
||||
user_id,
|
||||
tier,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
@@ -1291,8 +1491,19 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
return AutoTopUpConfig.model_validate(user.top_up_config)
|
||||
|
||||
|
||||
@cached(ttl_seconds=60, maxsize=8, cache_none=False)
|
||||
async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
|
||||
"""Return Stripe Price ID for a tier from LaunchDarkly. None = not configured."""
|
||||
"""Return Stripe Price ID for a tier from LaunchDarkly, cached for 60 seconds.
|
||||
|
||||
Price IDs are LaunchDarkly flag values that change only at deploy time.
|
||||
Caching for 60 seconds avoids hitting the LD SDK on every webhook delivery
|
||||
and every GET /credits/subscription page load (called 2x per request).
|
||||
|
||||
``cache_none=False`` prevents a transient LD failure from caching ``None``
|
||||
and blocking subscription upgrades for the full 60-second TTL window.
|
||||
A tier with no configured flag (FREE, ENTERPRISE) returns ``None`` from an
|
||||
O(1) dict lookup before hitting LD, so the extra LD call is never made.
|
||||
"""
|
||||
flag_map = {
|
||||
SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO,
|
||||
SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS,
|
||||
@@ -1300,7 +1511,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
|
||||
flag = flag_map.get(tier)
|
||||
if flag is None:
|
||||
return None
|
||||
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
|
||||
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
|
||||
return price_id if isinstance(price_id, str) and price_id else None
|
||||
|
||||
|
||||
@@ -1315,7 +1526,8 @@ async def create_subscription_checkout(
|
||||
if not price_id:
|
||||
raise ValueError(f"Subscription not available for tier {tier.value}")
|
||||
customer_id = await get_stripe_customer_id(user_id)
|
||||
session = stripe.checkout.Session.create(
|
||||
session = await run_in_threadpool(
|
||||
stripe.checkout.Session.create,
|
||||
customer=customer_id,
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
@@ -1323,26 +1535,110 @@ async def create_subscription_checkout(
|
||||
cancel_url=cancel_url,
|
||||
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
|
||||
)
|
||||
return session.url or ""
|
||||
if not session.url:
|
||||
# An empty checkout URL for a paid upgrade is always an error; surfacing it
|
||||
# as ValueError means the API handler returns 422 instead of silently
|
||||
# redirecting the client to an empty URL.
|
||||
raise ValueError("Stripe did not return a checkout session URL")
|
||||
return session.url
|
||||
|
||||
|
||||
async def _cleanup_stale_subscriptions(customer_id: str, new_sub_id: str) -> None:
|
||||
"""Best-effort cancel of any active subs for the customer other than new_sub_id.
|
||||
|
||||
Called from the webhook handler after a new subscription becomes active. Failures
|
||||
are logged but not raised so a transient Stripe error doesn't crash the webhook —
|
||||
a periodic reconciliation job is the intended backstop for persistent drift.
|
||||
|
||||
NOTE: until that reconcile job lands, a failure here means the user is silently
|
||||
billed for two simultaneous subscriptions. The error log below is intentionally
|
||||
`logger.exception` (not `logger.warning`) so it surfaces as an error in Sentry
|
||||
with the customer/sub IDs needed for manual reconciliation.
|
||||
TODO(#stripe-reconcile-job): replace this best-effort cleanup with a periodic
|
||||
reconciliation job that queries Stripe for customers with >1 active sub.
|
||||
"""
|
||||
try:
|
||||
await _cancel_customer_subscriptions(customer_id, exclude_sub_id=new_sub_id)
|
||||
except stripe.StripeError:
|
||||
# Use exception() (not warning) so this surfaces as an error in Sentry —
|
||||
# any failure here means a paid-to-paid upgrade may have left the user
|
||||
# with two simultaneous active subscriptions.
|
||||
logger.exception(
|
||||
"stripe_stale_subscription_cleanup_failed: customer=%s new_sub=%s —"
|
||||
" user may be billed for two simultaneous subscriptions; manual"
|
||||
" reconciliation required",
|
||||
customer_id,
|
||||
new_sub_id,
|
||||
)
|
||||
|
||||
|
||||
async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
"""Update User.subscriptionTier from a Stripe subscription object."""
|
||||
customer_id = stripe_subscription["customer"]
|
||||
"""Update User.subscriptionTier from a Stripe subscription object.
|
||||
|
||||
Expected shape of stripe_subscription (subset of Stripe's Subscription object):
|
||||
customer: str — Stripe customer ID
|
||||
status: str — "active" | "trialing" | "canceled" | ...
|
||||
id: str — Stripe subscription ID
|
||||
items.data[].price.id: str — Stripe price ID identifying the tier
|
||||
"""
|
||||
customer_id = stripe_subscription.get("customer")
|
||||
if not customer_id:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: missing 'customer' field in event, "
|
||||
"skipping (keys: %s)",
|
||||
list(stripe_subscription.keys()),
|
||||
)
|
||||
return
|
||||
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
|
||||
if not user:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: no user for customer %s", customer_id
|
||||
)
|
||||
return
|
||||
# Cross-check: if the subscription carries a metadata.user_id (set during
|
||||
# Checkout Session creation), verify it matches the user we found via
|
||||
# stripeCustomerId. A mismatch indicates a customer↔user mapping
|
||||
# inconsistency — updating the wrong user's tier would be a data-corruption
|
||||
# bug, so we log loudly and bail out. Absence of metadata.user_id (e.g.
|
||||
# subscriptions created outside the Checkout flow) is not an error — we
|
||||
# simply skip the check and proceed with the customer-ID-based lookup.
|
||||
metadata = stripe_subscription.get("metadata") or {}
|
||||
metadata_user_id = metadata.get("user_id") if isinstance(metadata, dict) else None
|
||||
if metadata_user_id and metadata_user_id != user.id:
|
||||
logger.error(
|
||||
"sync_subscription_from_stripe: metadata.user_id=%s does not match"
|
||||
" user.id=%s found via stripeCustomerId=%s — refusing to update tier"
|
||||
" to avoid corrupting the wrong user's subscription state",
|
||||
metadata_user_id,
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
# ENTERPRISE tiers are admin-managed. Never let a Stripe webhook flip an
|
||||
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
|
||||
# a self-service Stripe sub, it's a data-consistency issue for an operator,
|
||||
# not something the webhook should automatically "fix".
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.FREE
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
|
||||
" for user %s (customer %s); event status=%s",
|
||||
user.id,
|
||||
customer_id,
|
||||
stripe_subscription.get("status", ""),
|
||||
)
|
||||
return
|
||||
status = stripe_subscription.get("status", "")
|
||||
new_sub_id = stripe_subscription.get("id", "")
|
||||
if status in ("active", "trialing"):
|
||||
price_id = ""
|
||||
items = stripe_subscription.get("items", {}).get("data", [])
|
||||
if items:
|
||||
price_id = items[0].get("price", {}).get("id", "")
|
||||
pro_price = await get_subscription_price_id(SubscriptionTier.PRO)
|
||||
biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS)
|
||||
pro_price, biz_price = await asyncio.gather(
|
||||
get_subscription_price_id(SubscriptionTier.PRO),
|
||||
get_subscription_price_id(SubscriptionTier.BUSINESS),
|
||||
)
|
||||
if price_id and pro_price and price_id == pro_price:
|
||||
tier = SubscriptionTier.PRO
|
||||
elif price_id and biz_price and price_id == biz_price:
|
||||
@@ -1359,10 +1655,219 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
)
|
||||
return
|
||||
else:
|
||||
# A subscription was cancelled or ended. DO NOT unconditionally downgrade
|
||||
# to FREE — Stripe does not guarantee webhook delivery order, so a
|
||||
# `customer.subscription.deleted` for the OLD sub can arrive after we've
|
||||
# already processed `customer.subscription.created` for a new paid sub.
|
||||
# Ask Stripe whether any OTHER active/trialing subs exist for this
|
||||
# customer; if they do, keep the user's current tier (the other sub's
|
||||
# own event will/has already set the correct tier).
|
||||
try:
|
||||
other_subs_active, other_subs_trialing = await asyncio.gather(
|
||||
run_in_threadpool(
|
||||
stripe.Subscription.list,
|
||||
customer=customer_id,
|
||||
status="active",
|
||||
limit=10,
|
||||
),
|
||||
run_in_threadpool(
|
||||
stripe.Subscription.list,
|
||||
customer=customer_id,
|
||||
status="trialing",
|
||||
limit=10,
|
||||
),
|
||||
)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: could not verify other active"
|
||||
" subs for customer %s on cancel event %s; preserving current"
|
||||
" tier to avoid an unsafe downgrade",
|
||||
customer_id,
|
||||
new_sub_id,
|
||||
)
|
||||
return
|
||||
# Filter out the cancelled subscription to check if other active subs
|
||||
# exist. When new_sub_id is empty (malformed event with no 'id' field),
|
||||
# we cannot safely exclude any sub — preserve current tier to avoid
|
||||
# an unsafe downgrade on a malformed webhook payload.
|
||||
if not new_sub_id:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: cancel event missing 'id' field"
|
||||
" for customer %s; preserving current tier",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
other_active_ids = {sub["id"] for sub in other_subs_active.data} - {new_sub_id}
|
||||
other_trialing_ids = {sub["id"] for sub in other_subs_trialing.data} - {
|
||||
new_sub_id
|
||||
}
|
||||
still_has_active_sub = bool(other_active_ids or other_trialing_ids)
|
||||
if still_has_active_sub:
|
||||
logger.info(
|
||||
"sync_subscription_from_stripe: sub %s cancelled but customer %s"
|
||||
" still has another active sub; keeping tier %s",
|
||||
new_sub_id,
|
||||
customer_id,
|
||||
current_tier.value,
|
||||
)
|
||||
return
|
||||
tier = SubscriptionTier.FREE
|
||||
# Idempotency: Stripe retries webhooks on delivery failure, and several event
|
||||
# types map to the same final tier. Skip the DB write + cache invalidation
|
||||
# when the tier is already correct to avoid redundant writes on replay.
|
||||
if current_tier == tier:
|
||||
return
|
||||
# When a new subscription becomes active (e.g. paid-to-paid tier upgrade
|
||||
# via a fresh Checkout Session), cancel any OTHER active subscriptions for
|
||||
# the same customer so the user isn't billed twice. We do this in the
|
||||
# webhook rather than the API handler so that abandoning the checkout
|
||||
# doesn't leave the user without a subscription.
|
||||
# IMPORTANT: this runs AFTER the idempotency check above so that webhook
|
||||
# replays for an already-applied event do NOT trigger another cleanup round
|
||||
# (which could otherwise cancel a legitimately new subscription the user
|
||||
# signed up for between the original event and its replay).
|
||||
if status in ("active", "trialing") and new_sub_id:
|
||||
# NOTE: paid-to-paid upgrade race (e.g. PRO → BUSINESS):
|
||||
# _cleanup_stale_subscriptions cancels the old PRO sub before
|
||||
# set_subscription_tier writes BUSINESS to the DB. If Stripe delivers
|
||||
# the PRO `customer.subscription.deleted` event concurrently and it
|
||||
# processes after the PRO cancel but before set_subscription_tier
|
||||
# commits, the user could momentarily appear as FREE in the DB.
|
||||
# This window is very short in practice (two sequential awaits),
|
||||
# but is a known limitation of the current webhook-driven approach.
|
||||
# A future improvement would be to write the new tier first, then
|
||||
# cancel the old sub.
|
||||
await _cleanup_stale_subscriptions(customer_id, new_sub_id)
|
||||
await set_subscription_tier(user.id, tier)
|
||||
|
||||
|
||||
async def handle_subscription_payment_failure(invoice: dict) -> None:
|
||||
"""Handle a failed Stripe subscription payment.
|
||||
|
||||
Tries to cover the invoice amount from the user's credit balance.
|
||||
|
||||
- Balance sufficient → deduct from balance, then pay the Stripe invoice so
|
||||
Stripe stops retrying it. The sub stays intact and the user keeps their tier.
|
||||
- Balance insufficient → cancel Stripe sub immediately, downgrade to FREE.
|
||||
Cancelling here avoids further Stripe retries on an invoice we cannot cover.
|
||||
"""
|
||||
customer_id = invoice.get("customer")
|
||||
if not customer_id:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: missing customer in invoice; skipping"
|
||||
)
|
||||
return
|
||||
|
||||
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
|
||||
if not user:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: no user found for customer %s",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.FREE
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
|
||||
" (customer %s) — tier is admin-managed",
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
amount_due: int = invoice.get("amount_due", 0)
|
||||
sub_id: str = invoice.get("subscription", "")
|
||||
invoice_id: str = invoice.get("id", "")
|
||||
|
||||
if not invoice_id:
|
||||
# Without an invoice ID we cannot set an idempotency key on the credit
|
||||
# deduction. Stripe webhook retries would then double-charge the user's
|
||||
# balance on every retry cycle. Bail out early — a real Stripe invoice
|
||||
# always carries an ID, so a missing one indicates a malformed payload.
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: invoice missing 'id' for"
|
||||
" customer %s; skipping to avoid non-idempotent balance deduction",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
if amount_due <= 0:
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: amount_due=%d for user %s;"
|
||||
" nothing to deduct",
|
||||
amount_due,
|
||||
user.id,
|
||||
)
|
||||
return
|
||||
|
||||
credit_model = UserCredit()
|
||||
try:
|
||||
await credit_model._add_transaction(
|
||||
user_id=user.id,
|
||||
amount=-amount_due,
|
||||
transaction_type=CreditTransactionType.SUBSCRIPTION,
|
||||
fail_insufficient_credits=True,
|
||||
# Use invoice_id as the idempotency key so that Stripe webhook retries
|
||||
# (e.g. on a transient stripe.Invoice.pay failure) do not double-charge.
|
||||
# invoice_id is guaranteed non-empty here (early-return guard above).
|
||||
transaction_key=invoice_id,
|
||||
metadata=SafeJson(
|
||||
{
|
||||
"stripe_customer_id": customer_id,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"reason": "subscription_payment_failure_covered_by_balance",
|
||||
}
|
||||
),
|
||||
)
|
||||
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
|
||||
# system stops retrying it — without this call Stripe would retry automatically
|
||||
# and re-trigger this webhook, causing double-deductions each retry cycle.
|
||||
# invoice_id is guaranteed non-empty here (early-return guard above).
|
||||
try:
|
||||
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: balance deducted for user"
|
||||
" %s but failed to mark invoice %s as paid; Stripe may retry",
|
||||
user.id,
|
||||
invoice_id,
|
||||
)
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: deducted %d cents from balance"
|
||||
" for user %s; Stripe invoice %s paid, sub %s intact, tier preserved",
|
||||
amount_due,
|
||||
user.id,
|
||||
invoice_id,
|
||||
sub_id,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# Balance insufficient — cancel Stripe subscription first, then downgrade DB.
|
||||
# Order matters: if we downgrade the DB first and the Stripe cancel fails, the
|
||||
# user is permanently stuck on FREE while Stripe continues billing them.
|
||||
# Cancelling Stripe first is safe: if the DB write then fails, the webhook
|
||||
# customer.subscription.deleted will fire and correct the tier eventually.
|
||||
logger.info(
|
||||
"handle_subscription_payment_failure: insufficient balance for user %s;"
|
||||
" cancelling Stripe sub %s then downgrading to FREE",
|
||||
user.id,
|
||||
sub_id,
|
||||
)
|
||||
try:
|
||||
await _cancel_customer_subscriptions(customer_id)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
|
||||
" for user %s (customer %s); skipping tier downgrade to avoid"
|
||||
" inconsistency — Stripe may continue retrying the invoice",
|
||||
sub_id,
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
await set_subscription_tier(user.id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Optional
|
||||
@@ -43,39 +42,6 @@ CONNECTION_ATTEMPTS = 5
|
||||
# Use case: Faster reconnection for long-running executions that need to resume quickly
|
||||
RETRY_DELAY = 1
|
||||
|
||||
# DEFAULT_HEARTBEAT (300s = 5 min)
|
||||
# AMQP application-level heartbeat. Server drops the connection if no heartbeat
|
||||
# is seen within ~2x this interval. Consumers that sit in CLOSE_WAIT because
|
||||
# pika's IO loop was starved (e.g. laptop sleep, blocking main thread) recover
|
||||
# faster with a lower value. See `create_copilot_queue_config` for a case that
|
||||
# overrides this.
|
||||
DEFAULT_HEARTBEAT = 300
|
||||
|
||||
|
||||
def _tcp_keepalive_options() -> dict[str, int]:
|
||||
"""Platform-aware TCP keepalive socket options for pika.
|
||||
|
||||
pika enables ``SO_KEEPALIVE`` on every socket by default; this dict tunes
|
||||
how quickly the kernel declares a silent peer dead. Without these knobs,
|
||||
the OS default on Linux is ~2 hours of idle before the first probe — long
|
||||
enough for a half-closed socket to sit in CLOSE_WAIT forever while the
|
||||
consumer thread is blocked inside ``start_consuming()``.
|
||||
|
||||
pika passes each key through ``getattr(socket, key)`` at ``IPPROTO_TCP``
|
||||
level, so names must exist on the current platform. Linux has
|
||||
``TCP_KEEPIDLE``; macOS uses ``TCP_KEEPALIVE`` for the equivalent knob.
|
||||
"""
|
||||
opts: dict[str, int] = {}
|
||||
if hasattr(socket, "TCP_KEEPIDLE"):
|
||||
opts["TCP_KEEPIDLE"] = 60
|
||||
elif hasattr(socket, "TCP_KEEPALIVE"):
|
||||
opts["TCP_KEEPALIVE"] = 60
|
||||
if hasattr(socket, "TCP_KEEPINTVL"):
|
||||
opts["TCP_KEEPINTVL"] = 20
|
||||
if hasattr(socket, "TCP_KEEPCNT"):
|
||||
opts["TCP_KEEPCNT"] = 3
|
||||
return opts
|
||||
|
||||
|
||||
class ExchangeType(str, Enum):
|
||||
DIRECT = "direct"
|
||||
@@ -107,8 +73,6 @@ class RabbitMQConfig(BaseModel):
|
||||
vhost: str = "/"
|
||||
exchanges: list[Exchange]
|
||||
queues: list[Queue]
|
||||
heartbeat: int = DEFAULT_HEARTBEAT
|
||||
tcp_keepalive: bool = False
|
||||
|
||||
|
||||
class RabbitMQBase(ABC):
|
||||
@@ -177,8 +141,7 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
connection_attempts=CONNECTION_ATTEMPTS,
|
||||
retry_delay=RETRY_DELAY,
|
||||
heartbeat=self.config.heartbeat,
|
||||
tcp_options=_tcp_keepalive_options() if self.config.tcp_keepalive else None,
|
||||
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
|
||||
)
|
||||
|
||||
self._connection = pika.BlockingConnection(parameters)
|
||||
@@ -297,7 +260,7 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
password=self.password,
|
||||
virtualhost=self.config.vhost.lstrip("/"),
|
||||
blocked_connection_timeout=BLOCKED_CONNECTION_TIMEOUT,
|
||||
heartbeat=self.config.heartbeat,
|
||||
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
|
||||
)
|
||||
self._channel = await self._connection.channel()
|
||||
await self._channel.set_qos(prefetch_count=1)
|
||||
|
||||
@@ -73,6 +73,31 @@ def _get_redis() -> Redis:
|
||||
return r
|
||||
|
||||
|
||||
class _MissingType:
|
||||
"""Singleton sentinel type — distinct from ``None`` (a valid cached value).
|
||||
|
||||
Using a dedicated class (instead of ``Any = object()``) lets mypy prove
|
||||
that comparisons ``result is _MISSING`` narrow the type correctly and
|
||||
prevents accidental use of the sentinel where a real value is expected.
|
||||
"""
|
||||
|
||||
_instance: "_MissingType | None" = None
|
||||
|
||||
def __new__(cls) -> "_MissingType":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<MISSING>"
|
||||
|
||||
|
||||
# Sentinel returned by ``_get_from_memory`` / ``_get_from_redis`` to mean
|
||||
# "no entry exists" — distinct from a cached ``None`` value, which is a
|
||||
# valid result for callers that opt into caching it.
|
||||
_MISSING = _MissingType()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedValue:
|
||||
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
|
||||
@@ -160,6 +185,7 @@ def cached(
|
||||
ttl_seconds: int,
|
||||
shared_cache: bool = False,
|
||||
refresh_ttl_on_get: bool = False,
|
||||
cache_none: bool = True,
|
||||
) -> Callable[[Callable[P, R]], CachedFunction[P, R]]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
@@ -172,6 +198,10 @@ def cached(
|
||||
ttl_seconds: Time to live in seconds. Required - entries must expire.
|
||||
shared_cache: If True, use Redis for cross-process caching
|
||||
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
|
||||
cache_none: If True (default) ``None`` is cached like any other value.
|
||||
Set to ``False`` for functions that return ``None`` to signal a
|
||||
transient error and should be re-tried on the next call without
|
||||
poisoning the cache (e.g. external API calls that may fail).
|
||||
|
||||
Returns:
|
||||
Decorated function with caching capabilities
|
||||
@@ -184,6 +214,12 @@ def cached(
|
||||
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
async def fetch_external(id: str) -> dict | None:
|
||||
# Returns None on transient error — won't be stored,
|
||||
# next call retries instead of returning the stale None.
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
|
||||
@@ -191,9 +227,14 @@ def cached(
|
||||
cache_storage: dict[tuple, CachedValue] = {}
|
||||
_event_loop_locks: dict[Any, asyncio.Lock] = {}
|
||||
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
def _get_from_redis(redis_key: str) -> Any:
|
||||
"""Get value from Redis, optionally refreshing TTL.
|
||||
|
||||
Returns the cached value (which may be ``None``) on a hit, or the
|
||||
module-level ``_MISSING`` sentinel on a miss / corrupt entry.
|
||||
Callers must compare with ``is _MISSING`` so cached ``None`` values
|
||||
are not mistaken for misses.
|
||||
|
||||
Values are expected to carry an HMAC-SHA256 prefix for integrity
|
||||
verification. Unsigned (legacy) or tampered entries are silently
|
||||
discarded and treated as cache misses, so the caller recomputes and
|
||||
@@ -213,11 +254,11 @@ def cached(
|
||||
f"for {func_name}, discarding entry: "
|
||||
"possible tampering or legacy unsigned value"
|
||||
)
|
||||
return None
|
||||
return _MISSING
|
||||
return pickle.loads(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis error during cache check for {func_name}: {e}")
|
||||
return None
|
||||
return _MISSING
|
||||
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set HMAC-signed pickled value in Redis with TTL."""
|
||||
@@ -227,8 +268,13 @@ def cached(
|
||||
except Exception as e:
|
||||
logger.error(f"Redis error storing cache for {func_name}: {e}")
|
||||
|
||||
def _get_from_memory(key: tuple) -> Any | None:
|
||||
"""Get value from in-memory cache, checking TTL."""
|
||||
def _get_from_memory(key: tuple) -> Any:
|
||||
"""Get value from in-memory cache, checking TTL.
|
||||
|
||||
Returns the cached value (which may be ``None``) on a hit, or the
|
||||
``_MISSING`` sentinel on a miss / TTL expiry. See
|
||||
``_get_from_redis`` for the rationale.
|
||||
"""
|
||||
if key in cache_storage:
|
||||
cached_data = cache_storage[key]
|
||||
if time.time() - cached_data.timestamp < ttl_seconds:
|
||||
@@ -236,7 +282,7 @@ def cached(
|
||||
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
|
||||
)
|
||||
return cached_data.result
|
||||
return None
|
||||
return _MISSING
|
||||
|
||||
def _set_to_memory(key: tuple, value: Any) -> None:
|
||||
"""Set value in in-memory cache with timestamp."""
|
||||
@@ -270,11 +316,11 @@ def cached(
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
@@ -282,22 +328,24 @@ def cached(
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
# Store result (skip ``None`` if the caller opted out of
|
||||
# caching it — used for transient-error sentinels).
|
||||
if cache_none or result is not None:
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
@@ -315,11 +363,11 @@ def cached(
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
@@ -327,22 +375,24 @@ def cached(
|
||||
# Double-check: another thread might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
if result is not _MISSING:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
# Store result (skip ``None`` if the caller opted out of
|
||||
# caching it — used for transient-error sentinels).
|
||||
if cache_none or result is not None:
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1223,3 +1223,123 @@ class TestCacheHMAC:
|
||||
assert call_count == 2
|
||||
|
||||
legacy_test_fn.cache_clear()
|
||||
|
||||
|
||||
class TestCacheNoneHandling:
|
||||
"""Tests for the ``cache_none`` parameter on the @cached decorator.
|
||||
|
||||
Sentry bug PRRT_kwDOJKSTjM56RTEu (HIGH): the cache previously could not
|
||||
distinguish "no entry" from "entry is None", so any function returning
|
||||
``None`` was effectively re-executed on every call. The fix is a
|
||||
sentinel-based check inside the wrappers, plus an opt-out
|
||||
``cache_none=False`` flag for callers that *want* errors to retry.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_none_is_cached_by_default(self):
|
||||
"""With ``cache_none=True`` (default), cached ``None`` is returned
|
||||
from the cache instead of triggering re-execution."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
async def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Second call should hit the cache, not re-execute.
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Different argument is a different cache key — re-executes.
|
||||
assert await maybe_none(2) is None
|
||||
assert call_count == 2
|
||||
|
||||
def test_sync_none_is_cached_by_default(self):
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=300)
|
||||
def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
assert maybe_none(1) is None
|
||||
assert maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_none_false_skips_storing_none(self):
|
||||
"""``cache_none=False`` skips storing ``None`` so transient errors
|
||||
are retried on the next call instead of poisoning the cache."""
|
||||
call_count = 0
|
||||
results: list[int | None] = [None, None, 42]
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
async def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
result = results[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
# First call: returns None, NOT stored.
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same key: re-executes (None wasn't cached).
|
||||
assert await maybe_none(1) is None
|
||||
assert call_count == 2
|
||||
|
||||
# Third call: returns 42, this time it IS stored.
|
||||
assert await maybe_none(1) == 42
|
||||
assert call_count == 3
|
||||
|
||||
# Fourth call: cache hit on the stored 42.
|
||||
assert await maybe_none(1) == 42
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_cache_none_false_skips_storing_none(self):
|
||||
call_count = 0
|
||||
results: list[int | None] = [None, 99]
|
||||
|
||||
@cached(ttl_seconds=300, cache_none=False)
|
||||
def maybe_none(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
result = results[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
assert maybe_none(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
# None was not stored — re-executes.
|
||||
assert maybe_none(1) == 99
|
||||
assert call_count == 2
|
||||
|
||||
# 99 IS stored — no re-execution.
|
||||
assert maybe_none(1) == 99
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_none_is_cached_by_default(self):
|
||||
"""Shared (Redis) cache also properly returns cached ``None`` values."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def maybe_none_redis(x: int) -> int | None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return None
|
||||
|
||||
maybe_none_redis.cache_clear()
|
||||
|
||||
assert await maybe_none_redis(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
assert await maybe_none_redis(1) is None
|
||||
assert call_count == 1
|
||||
|
||||
maybe_none_redis.cache_clear()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
@@ -101,6 +102,12 @@ async def _fetch_user_context_data(user_id: str) -> Context:
|
||||
"""
|
||||
builder = Context.builder(user_id).kind("user").anonymous(True)
|
||||
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
|
||||
return builder.build()
|
||||
|
||||
try:
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ export const Flow = () => {
|
||||
event.preventDefault();
|
||||
}}
|
||||
maxZoom={2}
|
||||
minZoom={0.1}
|
||||
minZoom={0.05}
|
||||
onDragOver={onDragOver}
|
||||
onDrop={onDrop}
|
||||
nodesDraggable={!isLocked}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"use client";
|
||||
import { useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { useSubscriptionTierSection } from "./useSubscriptionTierSection";
|
||||
|
||||
type TierInfo = {
|
||||
@@ -15,39 +17,70 @@ const TIERS: TierInfo[] = [
|
||||
key: "FREE",
|
||||
label: "Free",
|
||||
multiplier: "1x",
|
||||
description: "Base rate limits",
|
||||
description: "Base AutoPilot capacity with standard rate limits",
|
||||
},
|
||||
{
|
||||
key: "PRO",
|
||||
label: "Pro",
|
||||
multiplier: "5x",
|
||||
description: "5x more AutoPilot capacity",
|
||||
description: "5x AutoPilot capacity — run 5× more tasks per day/week",
|
||||
},
|
||||
{
|
||||
key: "BUSINESS",
|
||||
label: "Business",
|
||||
multiplier: "20x",
|
||||
description: "20x more AutoPilot capacity",
|
||||
description: "20x AutoPilot capacity — ideal for teams and heavy workloads",
|
||||
},
|
||||
];
|
||||
|
||||
function formatCost(cents: number): string {
|
||||
if (cents === 0) return "Free";
|
||||
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
|
||||
|
||||
function formatCost(cents: number, tierKey: string): string {
|
||||
if (tierKey === "FREE") return "Free";
|
||||
if (cents === 0) return "Pricing available soon";
|
||||
return `$${(cents / 100).toFixed(2)}/mo`;
|
||||
}
|
||||
|
||||
export function SubscriptionTierSection() {
|
||||
const { subscription, isLoading, error, isPending, changeTier } =
|
||||
useSubscriptionTierSection();
|
||||
const [tierError, setTierError] = useState<string | null>(null);
|
||||
const {
|
||||
subscription,
|
||||
isLoading,
|
||||
error,
|
||||
tierError,
|
||||
isPending,
|
||||
pendingTier,
|
||||
pendingUpgradeTier,
|
||||
setPendingUpgradeTier,
|
||||
confirmUpgrade,
|
||||
isPaymentEnabled,
|
||||
changeTier,
|
||||
handleTierChange,
|
||||
} = useSubscriptionTierSection();
|
||||
const [confirmDowngradeTo, setConfirmDowngradeTo] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
if (isLoading) return null;
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<Skeleton className="h-6 w-48" />
|
||||
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
|
||||
<Skeleton className="h-40 rounded-lg" />
|
||||
<Skeleton className="h-40 rounded-lg" />
|
||||
<Skeleton className="h-40 rounded-lg" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<h3 className="text-lg font-medium">Subscription Plan</h3>
|
||||
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
|
||||
<p
|
||||
role="alert"
|
||||
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
|
||||
>
|
||||
{error}
|
||||
</p>
|
||||
</div>
|
||||
@@ -56,10 +89,30 @@ export function SubscriptionTierSection() {
|
||||
|
||||
if (!subscription) return null;
|
||||
|
||||
async function handleTierChange(tierKey: string) {
|
||||
setTierError(null);
|
||||
const err = await changeTier(tierKey);
|
||||
if (err) setTierError(err);
|
||||
const currentTier = subscription.tier;
|
||||
|
||||
if (currentTier === "ENTERPRISE") {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<h3 className="text-lg font-medium">Subscription Plan</h3>
|
||||
<div className="rounded-lg border border-violet-500 bg-violet-50 p-4 dark:bg-violet-900/20">
|
||||
<p className="font-semibold text-violet-700 dark:text-violet-200">
|
||||
Enterprise Plan
|
||||
</p>
|
||||
<p className="mt-1 text-sm text-neutral-600 dark:text-neutral-400">
|
||||
Your Enterprise plan is managed by your administrator. Contact your
|
||||
account team for changes.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
async function confirmDowngrade() {
|
||||
if (!confirmDowngradeTo) return;
|
||||
const tier = confirmDowngradeTo;
|
||||
setConfirmDowngradeTo(null);
|
||||
await changeTier(tier);
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -67,24 +120,28 @@ export function SubscriptionTierSection() {
|
||||
<h3 className="text-lg font-medium">Subscription Plan</h3>
|
||||
|
||||
{tierError && (
|
||||
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
|
||||
<p
|
||||
role="alert"
|
||||
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
|
||||
>
|
||||
{tierError}
|
||||
</p>
|
||||
)}
|
||||
|
||||
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
|
||||
{TIERS.map((tier) => {
|
||||
const isCurrent = subscription.tier === tier.key;
|
||||
const isCurrent = currentTier === tier.key;
|
||||
const cost = subscription.tier_costs[tier.key] ?? 0;
|
||||
const currentTierOrder = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
|
||||
const currentIdx = currentTierOrder.indexOf(subscription.tier);
|
||||
const targetIdx = currentTierOrder.indexOf(tier.key);
|
||||
const currentIdx = TIER_ORDER.indexOf(currentTier);
|
||||
const targetIdx = TIER_ORDER.indexOf(tier.key);
|
||||
const isUpgrade = targetIdx > currentIdx;
|
||||
const isDowngrade = targetIdx < currentIdx;
|
||||
const isThisPending = pendingTier === tier.key;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={tier.key}
|
||||
aria-current={isCurrent ? "true" : undefined}
|
||||
className={`rounded-lg border p-4 ${
|
||||
isCurrent
|
||||
? "border-violet-500 bg-violet-50 dark:bg-violet-900/20"
|
||||
@@ -100,7 +157,9 @@ export function SubscriptionTierSection() {
|
||||
)}
|
||||
</div>
|
||||
|
||||
<p className="mb-1 text-2xl font-bold">{formatCost(cost)}</p>
|
||||
<p className="mb-1 text-2xl font-bold">
|
||||
{formatCost(cost, tier.key)}
|
||||
</p>
|
||||
<p className="mb-1 text-sm font-medium text-neutral-600 dark:text-neutral-400">
|
||||
{tier.multiplier} rate limits
|
||||
</p>
|
||||
@@ -108,14 +167,20 @@ export function SubscriptionTierSection() {
|
||||
{tier.description}
|
||||
</p>
|
||||
|
||||
{!isCurrent && (
|
||||
{!isCurrent && isPaymentEnabled && (
|
||||
<Button
|
||||
className="w-full"
|
||||
variant={isUpgrade ? "default" : "outline"}
|
||||
disabled={isPending}
|
||||
onClick={() => handleTierChange(tier.key)}
|
||||
onClick={() =>
|
||||
handleTierChange(
|
||||
tier.key,
|
||||
currentTier,
|
||||
setConfirmDowngradeTo,
|
||||
)
|
||||
}
|
||||
>
|
||||
{isPending
|
||||
{isThisPending
|
||||
? "Updating..."
|
||||
: isUpgrade
|
||||
? `Upgrade to ${tier.label}`
|
||||
@@ -129,12 +194,80 @@ export function SubscriptionTierSection() {
|
||||
})}
|
||||
</div>
|
||||
|
||||
{subscription.tier !== "FREE" && (
|
||||
{currentTier !== "FREE" && isPaymentEnabled && (
|
||||
<p className="text-sm text-neutral-500">
|
||||
Your subscription is managed through Stripe. Changes take effect
|
||||
immediately.
|
||||
Your subscription is managed through Stripe. Upgrades and paid-tier
|
||||
changes take effect immediately; downgrades to Free are scheduled for
|
||||
the end of the current billing period.
|
||||
</p>
|
||||
)}
|
||||
|
||||
<Dialog
|
||||
title="Confirm Downgrade"
|
||||
controlled={{
|
||||
isOpen: !!confirmDowngradeTo,
|
||||
set: (open) => {
|
||||
if (!open) setConfirmDowngradeTo(null);
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
{confirmDowngradeTo === "FREE"
|
||||
? "Downgrading to Free will schedule your subscription to cancel at the end of your current billing period. You keep your current plan until then."
|
||||
: `Switching to ${TIERS.find((t) => t.key === confirmDowngradeTo)?.label ?? confirmDowngradeTo} will take effect immediately.`}{" "}
|
||||
Are you sure?
|
||||
</p>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setConfirmDowngradeTo(null)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={() => void confirmDowngrade()}
|
||||
>
|
||||
Confirm Downgrade
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
|
||||
<Dialog
|
||||
title="Confirm Upgrade"
|
||||
controlled={{
|
||||
isOpen: !!pendingUpgradeTier,
|
||||
set: (open) => {
|
||||
if (!open) setPendingUpgradeTier(null);
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
{subscription &&
|
||||
subscription.proration_credit_cents > 0 &&
|
||||
`Your unused ${currentTier.charAt(0) + currentTier.slice(1).toLowerCase()} subscription ($${(subscription.proration_credit_cents / 100).toFixed(2)}) will be applied as a credit to your next Stripe invoice. `}
|
||||
{currentTier === "FREE"
|
||||
? `You will be redirected to Stripe to complete your upgrade to ${TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? pendingUpgradeTier}.`
|
||||
: `Upgrading to ${TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? pendingUpgradeTier} will take effect immediately — Stripe will prorate your remaining balance.`}
|
||||
</p>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setPendingUpgradeTier(null)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={() => void confirmUpgrade()}>
|
||||
{currentTier === "FREE"
|
||||
? "Continue to Checkout"
|
||||
: "Confirm Upgrade"}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,379 @@
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
cleanup,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { SubscriptionTierSection } from "../SubscriptionTierSection";
|
||||
|
||||
// Mock next/navigation
|
||||
const mockSearchParams = new URLSearchParams();
|
||||
const mockRouterReplace = vi.fn();
|
||||
vi.mock("next/navigation", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("next/navigation")>();
|
||||
return {
|
||||
...actual,
|
||||
useSearchParams: () => mockSearchParams,
|
||||
useRouter: () => ({ push: vi.fn(), replace: mockRouterReplace }),
|
||||
usePathname: () => "/profile/credits",
|
||||
};
|
||||
});
|
||||
|
||||
// Mock toast
|
||||
const mockToast = vi.fn();
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
useToast: () => ({ toast: mockToast }),
|
||||
}));
|
||||
|
||||
// Mock feature flags — default to payment enabled so button tests work
|
||||
let mockPaymentEnabled = true;
|
||||
vi.mock("@/services/feature-flags/use-get-flag", () => ({
|
||||
Flag: { ENABLE_PLATFORM_PAYMENT: "enable-platform-payment" },
|
||||
useGetFlag: () => mockPaymentEnabled,
|
||||
}));
|
||||
|
||||
// Mock generated API hooks
|
||||
const mockUseGetSubscriptionStatus = vi.fn();
|
||||
const mockUseUpdateSubscriptionTier = vi.fn();
|
||||
vi.mock("@/app/api/__generated__/endpoints/credits/credits", () => ({
|
||||
useGetSubscriptionStatus: (opts: unknown) =>
|
||||
mockUseGetSubscriptionStatus(opts),
|
||||
useUpdateSubscriptionTier: () => mockUseUpdateSubscriptionTier(),
|
||||
}));
|
||||
|
||||
// Mock Dialog (Radix portals don't work in happy-dom)
|
||||
const MockDialogContent = ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
);
|
||||
const MockDialogFooter = ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
);
|
||||
function MockDialog({
|
||||
controlled,
|
||||
children,
|
||||
}: {
|
||||
controlled?: { isOpen: boolean; set: (open: boolean) => void };
|
||||
children: React.ReactNode;
|
||||
[key: string]: unknown;
|
||||
}) {
|
||||
return controlled?.isOpen ? <div role="dialog">{children}</div> : null;
|
||||
}
|
||||
MockDialog.Content = MockDialogContent;
|
||||
MockDialog.Footer = MockDialogFooter;
|
||||
vi.mock("@/components/molecules/Dialog/Dialog", () => ({
|
||||
Dialog: MockDialog,
|
||||
}));
|
||||
|
||||
function makeSubscription({
|
||||
tier = "FREE",
|
||||
monthlyCost = 0,
|
||||
tierCosts = { FREE: 0, PRO: 1999, BUSINESS: 4999, ENTERPRISE: 0 },
|
||||
prorationCreditCents = 0,
|
||||
}: {
|
||||
tier?: string;
|
||||
monthlyCost?: number;
|
||||
tierCosts?: Record<string, number>;
|
||||
prorationCreditCents?: number;
|
||||
} = {}) {
|
||||
return {
|
||||
tier,
|
||||
monthly_cost: monthlyCost,
|
||||
tier_costs: tierCosts,
|
||||
proration_credit_cents: prorationCreditCents,
|
||||
};
|
||||
}
|
||||
|
||||
function setupMocks({
|
||||
subscription = makeSubscription(),
|
||||
isLoading = false,
|
||||
queryError = null as Error | null,
|
||||
mutateFn = vi.fn().mockResolvedValue({ status: 200, data: { url: "" } }),
|
||||
isPending = false,
|
||||
variables = undefined as { data?: { tier?: string } } | undefined,
|
||||
} = {}) {
|
||||
// The hook uses select: (data) => (data.status === 200 ? data.data : null)
|
||||
// so the data value returned by the hook is already the transformed subscription object.
|
||||
// We simulate that by returning the subscription directly as data.
|
||||
mockUseGetSubscriptionStatus.mockReturnValue({
|
||||
data: subscription,
|
||||
isLoading,
|
||||
error: queryError,
|
||||
refetch: vi.fn(),
|
||||
});
|
||||
mockUseUpdateSubscriptionTier.mockReturnValue({
|
||||
mutateAsync: mutateFn,
|
||||
isPending,
|
||||
variables,
|
||||
});
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseGetSubscriptionStatus.mockReset();
|
||||
mockUseUpdateSubscriptionTier.mockReset();
|
||||
mockToast.mockReset();
|
||||
mockRouterReplace.mockReset();
|
||||
mockSearchParams.delete("subscription");
|
||||
mockPaymentEnabled = true;
|
||||
});
|
||||
|
||||
describe("SubscriptionTierSection", () => {
|
||||
it("renders skeleton cards while loading", () => {
|
||||
setupMocks({ isLoading: true });
|
||||
render(<SubscriptionTierSection />);
|
||||
// Just verify we're rendering something (not null) and no tier cards
|
||||
expect(screen.queryByText("Pro")).toBeNull();
|
||||
expect(screen.queryByText("Business")).toBeNull();
|
||||
});
|
||||
|
||||
it("renders error message when subscription fetch fails", () => {
|
||||
setupMocks({
|
||||
queryError: new Error("Network error"),
|
||||
subscription: makeSubscription(),
|
||||
});
|
||||
// Override the data to simulate failed state
|
||||
mockUseGetSubscriptionStatus.mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
error: new Error("Network error"),
|
||||
refetch: vi.fn(),
|
||||
});
|
||||
render(<SubscriptionTierSection />);
|
||||
expect(screen.getByRole("alert")).toBeDefined();
|
||||
expect(screen.getByText(/failed to load subscription info/i)).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders all three tier cards for FREE user", () => {
|
||||
setupMocks();
|
||||
render(<SubscriptionTierSection />);
|
||||
// Use getAllByText to account for the tier label AND cost display both containing "Free"
|
||||
expect(screen.getAllByText("Free").length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("Pro")).toBeDefined();
|
||||
expect(screen.getByText("Business")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows Current badge on the active tier", () => {
|
||||
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
|
||||
render(<SubscriptionTierSection />);
|
||||
expect(screen.getByText("Current")).toBeDefined();
|
||||
// Upgrade to PRO button should NOT exist; Upgrade to BUSINESS and Downgrade to Free should
|
||||
expect(
|
||||
screen.queryByRole("button", { name: /upgrade to pro/i }),
|
||||
).toBeNull();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /upgrade to business/i }),
|
||||
).toBeDefined();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /downgrade to free/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("displays tier costs from the API", () => {
|
||||
setupMocks({
|
||||
subscription: makeSubscription({
|
||||
tier: "FREE",
|
||||
tierCosts: { FREE: 0, PRO: 1999, BUSINESS: 4999, ENTERPRISE: 0 },
|
||||
}),
|
||||
});
|
||||
render(<SubscriptionTierSection />);
|
||||
expect(screen.getByText("$19.99/mo")).toBeDefined();
|
||||
expect(screen.getByText("$49.99/mo")).toBeDefined();
|
||||
// FREE tier label should still be visible (there may be multiple "Free" elements)
|
||||
expect(screen.getAllByText("Free").length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("shows 'Pricing available soon' when tier cost is 0 for a paid tier", () => {
|
||||
setupMocks({
|
||||
subscription: makeSubscription({
|
||||
tier: "FREE",
|
||||
tierCosts: { FREE: 0, PRO: 0, BUSINESS: 0, ENTERPRISE: 0 },
|
||||
}),
|
||||
});
|
||||
render(<SubscriptionTierSection />);
|
||||
// PRO and BUSINESS with cost=0 should show "Pricing available soon"
|
||||
expect(screen.getAllByText("Pricing available soon")).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("calls changeTier on upgrade click after confirmation dialog", async () => {
|
||||
const mutateFn = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ status: 200, data: { url: "" } });
|
||||
setupMocks({ mutateFn });
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
// Clicking upgrade opens the confirmation dialog first
|
||||
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
|
||||
// Confirm via the dialog's "Continue to Checkout" button
|
||||
fireEvent.click(
|
||||
screen.getByRole("button", { name: /continue to checkout/i }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mutateFn).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
data: expect.objectContaining({ tier: "PRO" }),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("shows confirmation dialog on downgrade click", () => {
|
||||
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
|
||||
|
||||
expect(screen.getByRole("dialog")).toBeDefined();
|
||||
// The dialog title text appears in both a div and a button — just check the dialog is open
|
||||
expect(screen.getAllByText(/confirm downgrade/i).length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("calls changeTier after downgrade confirmation", async () => {
|
||||
const mutateFn = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ status: 200, data: { url: "" } });
|
||||
setupMocks({
|
||||
subscription: makeSubscription({ tier: "PRO" }),
|
||||
mutateFn,
|
||||
});
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
|
||||
fireEvent.click(screen.getByRole("button", { name: /confirm downgrade/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mutateFn).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
data: expect.objectContaining({ tier: "FREE" }),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("dismisses dialog when Cancel is clicked", () => {
|
||||
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
|
||||
expect(screen.getByRole("dialog")).toBeDefined();
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^cancel$/i }));
|
||||
expect(screen.queryByRole("dialog")).toBeNull();
|
||||
});
|
||||
|
||||
it("redirects to Stripe when checkout URL is returned", async () => {
|
||||
// Replace window.location with a plain object so assigning .href doesn't
|
||||
// trigger jsdom navigation (which would throw or reload the test page).
|
||||
const mockLocation = { href: "" };
|
||||
vi.stubGlobal("location", mockLocation);
|
||||
|
||||
const mutateFn = vi.fn().mockResolvedValue({
|
||||
status: 200,
|
||||
data: { url: "https://checkout.stripe.com/pay/cs_test" },
|
||||
});
|
||||
setupMocks({ mutateFn });
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
// Upgrade opens confirmation dialog first — confirm via "Continue to Checkout"
|
||||
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
|
||||
fireEvent.click(
|
||||
screen.getByRole("button", { name: /continue to checkout/i }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockLocation.href).toBe("https://checkout.stripe.com/pay/cs_test");
|
||||
});
|
||||
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("shows an error alert when tier change fails", async () => {
|
||||
const mutateFn = vi.fn().mockRejectedValue(new Error("Stripe unavailable"));
|
||||
setupMocks({ mutateFn });
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
// Upgrade opens confirmation dialog first — confirm to trigger the mutation
|
||||
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
|
||||
fireEvent.click(
|
||||
screen.getByRole("button", { name: /continue to checkout/i }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole("alert")).toBeDefined();
|
||||
expect(screen.getByText(/stripe unavailable/i)).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
it("hides action buttons when payment flag is disabled", () => {
|
||||
mockPaymentEnabled = false;
|
||||
setupMocks({ subscription: makeSubscription({ tier: "FREE" }) });
|
||||
render(<SubscriptionTierSection />);
|
||||
// Tier cards still visible
|
||||
expect(screen.getByText("Pro")).toBeDefined();
|
||||
expect(screen.getByText("Business")).toBeDefined();
|
||||
// No upgrade/downgrade buttons
|
||||
expect(screen.queryByRole("button", { name: /upgrade/i })).toBeNull();
|
||||
expect(screen.queryByRole("button", { name: /downgrade/i })).toBeNull();
|
||||
});
|
||||
|
||||
it("shows ENTERPRISE message for ENTERPRISE tier users", () => {
|
||||
setupMocks({ subscription: makeSubscription({ tier: "ENTERPRISE" }) });
|
||||
render(<SubscriptionTierSection />);
|
||||
// Enterprise heading text appears in a <p> (may match multiple), just verify it exists
|
||||
expect(screen.getAllByText(/enterprise plan/i).length).toBeGreaterThan(0);
|
||||
expect(screen.getByText(/managed by your administrator/i)).toBeDefined();
|
||||
// No standard tier cards should be rendered
|
||||
expect(screen.queryByText("Pro")).toBeNull();
|
||||
expect(screen.queryByText("Business")).toBeNull();
|
||||
});
|
||||
|
||||
it("shows success toast and clears URL param when ?subscription=success is present", async () => {
|
||||
mockSearchParams.set("subscription", "success");
|
||||
setupMocks();
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ title: "Subscription upgraded" }),
|
||||
);
|
||||
});
|
||||
// URL param must be stripped so a page refresh doesn't re-trigger the toast
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith("/profile/credits");
|
||||
});
|
||||
|
||||
it("clears URL param but shows no toast when ?subscription=cancelled is present", async () => {
|
||||
mockSearchParams.set("subscription", "cancelled");
|
||||
setupMocks();
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
// The cancelled param must be stripped from the URL (same hygiene as success)
|
||||
await waitFor(() => {
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith("/profile/credits");
|
||||
});
|
||||
// No toast should fire — the user simply abandoned checkout
|
||||
expect(mockToast).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("shows 'Confirm Upgrade' button (not 'Continue to Checkout') for paid→paid tier change", () => {
|
||||
// PRO user upgrading to BUSINESS — modify in-place, no Stripe redirect
|
||||
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
|
||||
render(<SubscriptionTierSection />);
|
||||
|
||||
fireEvent.click(
|
||||
screen.getByRole("button", { name: /upgrade to business/i }),
|
||||
);
|
||||
|
||||
// For paid→paid, the button should say "Confirm Upgrade" not "Continue to Checkout"
|
||||
expect(
|
||||
screen.getByRole("button", { name: /confirm upgrade/i }),
|
||||
).toBeDefined();
|
||||
expect(
|
||||
screen.queryByRole("button", { name: /continue to checkout/i }),
|
||||
).toBeNull();
|
||||
// Dialog body should mention "take effect immediately" not "redirected to Stripe"
|
||||
// Two elements match: the static info paragraph and the dialog body — verify at least 2
|
||||
expect(screen.getAllByText(/take effect immediately/i).length).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
});
|
||||
@@ -1,13 +1,30 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import {
|
||||
useGetSubscriptionStatus,
|
||||
useUpdateSubscriptionTier,
|
||||
} from "@/app/api/__generated__/endpoints/credits/credits";
|
||||
import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse";
|
||||
import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
export type SubscriptionStatus = SubscriptionStatusResponse;
|
||||
|
||||
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
|
||||
|
||||
export function useSubscriptionTierSection() {
|
||||
const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
|
||||
const searchParams = useSearchParams();
|
||||
const subscriptionStatus = searchParams.get("subscription");
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const { toast } = useToast();
|
||||
const [tierError, setTierError] = useState<string | null>(null);
|
||||
const [pendingUpgradeTier, setPendingUpgradeTier] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const {
|
||||
data: subscription,
|
||||
isLoading,
|
||||
@@ -17,11 +34,39 @@ export function useSubscriptionTierSection() {
|
||||
query: { select: (data) => (data.status === 200 ? data.data : null) },
|
||||
});
|
||||
|
||||
const error = queryError ? "Failed to load subscription info" : null;
|
||||
const fetchError = queryError ? "Failed to load subscription info" : null;
|
||||
|
||||
const { mutateAsync: doUpdateTier, isPending } = useUpdateSubscriptionTier();
|
||||
const {
|
||||
mutateAsync: doUpdateTier,
|
||||
isPending,
|
||||
variables,
|
||||
} = useUpdateSubscriptionTier();
|
||||
|
||||
async function changeTier(tier: string): Promise<string | null> {
|
||||
useEffect(() => {
|
||||
if (subscriptionStatus === "success") {
|
||||
refetch();
|
||||
toast({
|
||||
title: "Subscription upgraded",
|
||||
description:
|
||||
"Your plan has been updated. It may take a moment to reflect.",
|
||||
});
|
||||
}
|
||||
// Strip ?subscription=success|cancelled from the URL so a page refresh
|
||||
// does not re-trigger side-effects, and so a second checkout in the same
|
||||
// session correctly fires the toast again.
|
||||
if (
|
||||
subscriptionStatus === "success" ||
|
||||
subscriptionStatus === "cancelled"
|
||||
) {
|
||||
router.replace(pathname);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps -- refetch and toast
|
||||
// are new references each render but are stable in practice; the effect must
|
||||
// only re-run when subscriptionStatus/pathname changes.
|
||||
}, [subscriptionStatus, refetch, toast, router, pathname]);
|
||||
|
||||
async function changeTier(tier: string) {
|
||||
setTierError(null);
|
||||
try {
|
||||
const successUrl = `${window.location.origin}${window.location.pathname}?subscription=success`;
|
||||
const cancelUrl = `${window.location.origin}${window.location.pathname}?subscription=cancelled`;
|
||||
@@ -34,22 +79,59 @@ export function useSubscriptionTierSection() {
|
||||
});
|
||||
if (result.status === 200 && result.data.url) {
|
||||
window.location.href = result.data.url;
|
||||
return null;
|
||||
return;
|
||||
}
|
||||
await refetch();
|
||||
return null;
|
||||
toast({
|
||||
title: "Subscription updated",
|
||||
description:
|
||||
tier === "FREE"
|
||||
? "Your plan will be downgraded to Free at the end of your current billing period."
|
||||
: "Your subscription has been updated.",
|
||||
});
|
||||
} catch (e: unknown) {
|
||||
const msg =
|
||||
e instanceof Error ? e.message : "Failed to change subscription tier";
|
||||
return msg;
|
||||
setTierError(msg);
|
||||
}
|
||||
}
|
||||
|
||||
function handleTierChange(
|
||||
targetTierKey: string,
|
||||
currentTier: string,
|
||||
onConfirmDowngrade: (tier: string) => void,
|
||||
) {
|
||||
const currentIdx = TIER_ORDER.indexOf(currentTier);
|
||||
const targetIdx = TIER_ORDER.indexOf(targetTierKey);
|
||||
if (targetIdx < currentIdx) {
|
||||
onConfirmDowngrade(targetTierKey);
|
||||
return;
|
||||
}
|
||||
setPendingUpgradeTier(targetTierKey);
|
||||
}
|
||||
|
||||
async function confirmUpgrade() {
|
||||
if (!pendingUpgradeTier) return;
|
||||
const tier = pendingUpgradeTier;
|
||||
setPendingUpgradeTier(null);
|
||||
await changeTier(tier);
|
||||
}
|
||||
|
||||
const pendingTier =
|
||||
isPending && variables?.data?.tier ? variables.data.tier : null;
|
||||
|
||||
return {
|
||||
subscription: subscription ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
error: fetchError,
|
||||
tierError,
|
||||
isPending,
|
||||
pendingTier,
|
||||
pendingUpgradeTier,
|
||||
setPendingUpgradeTier,
|
||||
confirmUpgrade,
|
||||
isPaymentEnabled,
|
||||
changeTier,
|
||||
handleTierChange,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -14122,16 +14122,29 @@
|
||||
},
|
||||
"SubscriptionStatusResponse": {
|
||||
"properties": {
|
||||
"tier": { "type": "string", "title": "Tier" },
|
||||
"tier": {
|
||||
"type": "string",
|
||||
"enum": ["FREE", "PRO", "BUSINESS", "ENTERPRISE"],
|
||||
"title": "Tier"
|
||||
},
|
||||
"monthly_cost": { "type": "integer", "title": "Monthly Cost" },
|
||||
"tier_costs": {
|
||||
"additionalProperties": { "type": "integer" },
|
||||
"type": "object",
|
||||
"title": "Tier Costs"
|
||||
},
|
||||
"proration_credit_cents": {
|
||||
"type": "integer",
|
||||
"title": "Proration Credit Cents"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["tier", "monthly_cost", "tier_costs"],
|
||||
"required": [
|
||||
"tier",
|
||||
"monthly_cost",
|
||||
"tier_costs",
|
||||
"proration_credit_cents"
|
||||
],
|
||||
"title": "SubscriptionStatusResponse"
|
||||
},
|
||||
"SubscriptionTier": {
|
||||
|
||||
@@ -194,26 +194,6 @@ export default class BackendAPI {
|
||||
return this._request("PATCH", "/credits");
|
||||
}
|
||||
|
||||
getSubscription(): Promise<{
|
||||
tier: string;
|
||||
monthly_cost: number;
|
||||
tier_costs: Record<string, number>;
|
||||
}> {
|
||||
return this._get("/credits/subscription");
|
||||
}
|
||||
|
||||
setSubscriptionTier(
|
||||
tier: string,
|
||||
successUrl?: string,
|
||||
cancelUrl?: string,
|
||||
): Promise<{ url: string }> {
|
||||
return this._request("POST", "/credits/subscription", {
|
||||
tier,
|
||||
success_url: successUrl ?? "",
|
||||
cancel_url: cancelUrl ?? "",
|
||||
});
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
//////////////// GRAPHS ////////////////
|
||||
////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user