Compare commits

..

90 Commits

Author SHA1 Message Date
Zamil Majdy
50324f710a fix(backend): remove dead code and fix misleading docstring
Remove the redundant `if invoice_id:` guard in
`handle_subscription_payment_failure` — `invoice_id` is guaranteed
non-empty by the early-return guard added above it.

Fix `_cleanup_stale_subscriptions` docstring that claimed a metric
`stripe_stale_subscription_cleanup_failed` is emitted; only a
`logger.exception` call is made, so remove the inaccurate metric claim
to avoid misleading on-call engineers.
2026-04-16 18:42:09 +07:00
Zamil Majdy
cc78a92a3f fix(frontend): fix multi-match in paid→paid upgrade dialog test
Use getAllByText with length check instead of getByText for
"take effect immediately" since it also appears in the static
subscription info paragraph below the tier cards.
2026-04-16 18:38:01 +07:00
Zamil Majdy
f92dc0fb02 test(backend): add URL validation edge-case tests + simplify transaction_key
- Add path-prefix attack, URL-encoded @, and valid query-string @
  test cases to test_validate_checkout_redirect_url parametrize block
  to document the security guarantees of _validate_checkout_redirect_url.
- Remove now-redundant `or None` from transaction_key assignment in
  handle_subscription_payment_failure — invoice_id is guaranteed
  non-empty by the early-return guard added in the previous commit.
2026-04-16 18:29:44 +07:00
Zamil Majdy
b8b5291022 test(frontend): add paid-to-paid upgrade dialog button label test
Assert that PRO→BUSINESS upgrade dialog shows "Confirm Upgrade" (not
"Continue to Checkout") and "take effect immediately" copy — covers the
dialog copy fix that differentiates FREE→paid (redirect) from
paid→paid (in-place Stripe modification).
2026-04-16 18:27:08 +07:00
Zamil Majdy
9b56f2f927 fix(backend): query trialing subs in proration + guard empty invoice_id
- get_proration_credit_cents now queries both "active" and "trialing"
  subscriptions (same pattern as modify_stripe_subscription_for_tier)
  so trial users see accurate proration credit before upgrading.
- handle_subscription_payment_failure returns early when invoice.id is
  missing — without an idempotency key, webhook retries would double-
  charge the user's credit balance on every retry cycle.
- Add tests: test_get_proration_credit_cents_with_trialing_sub and
  test_handle_subscription_payment_failure_missing_invoice_id_skips.
2026-04-16 18:24:41 +07:00
Zamil Majdy
fd75467ab0 fix(frontend): show contextual toast for FREE vs paid-to-paid subscription changes
The downgrade toast previously always said "will be downgraded at the end of
your billing cycle" — misleading for paid-to-paid changes (PRO↔BUSINESS) that
take effect immediately. Now shows a generic "subscription updated" message
for immediate changes and the period-end copy only for FREE downgrades.
2026-04-16 17:31:23 +07:00
Zamil Majdy
3e9f161856 fix(frontend): correct proration copy — credits applied to next invoice not account balance
The upgrade dialog said prorated credits 'will be added to your account
balance' which is incorrect — Stripe applies proration as a credit on the
next invoice, not to the account balance.

Addresses sentry[bot] comment #3092325806.
2026-04-16 17:18:12 +07:00
Zamil Majdy
f645e947d1 fix(backend): update DB tier directly for admin-granted paid→paid changes
When modify_stripe_subscription_for_tier returns False (no active Stripe
subscription — admin-granted paid tier), the endpoint was falling through
to Checkout Session creation and incorrectly redirecting users to Stripe.

Instead, update the DB tier directly (same pattern as admin-granted
FREE downgrade), so the user's tier is changed immediately without
creating an orphaned Stripe Checkout Session.

Adds test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
to cover this path.

Addresses sentry[bot] comment #3092325813.
2026-04-16 17:18:00 +07:00
Zamil Majdy
3a66eb2548 fix(frontend): show success toast on subscription downgrade 2026-04-16 16:51:37 +07:00
Zamil Majdy
1e1caa8810 Merge branch 'dev' into feat/subscription-tier-billing 2026-04-16 16:20:07 +07:00
Zamil Majdy
3e654228ac fix(backend/copilot): revert upload_cli_session re-add, update test to use strip_for_upload instead 2026-04-16 16:18:39 +07:00
Zamil Majdy
697ffa81f0 fix(backend/copilot): update transcript_test to use strip_for_upload after upload_cli_session removal 2026-04-16 16:17:02 +07:00
Zamil Majdy
3e1f886503 fix(backend/copilot): add upload_cli_session — fix TestUploadCliSession import errors
The dev merge (2b4727e8) brought in transcript_test.py tests for
upload_cli_session that didn't exist in transcript.py (only upload_transcript
existed). This adds the function plus a _projects_base patchable alias so
the tests can redirect disk I/O to tmp_path.

upload_cli_session: reads session from disk, strips progress/thinking,
writes back if smaller, uploads JSONL-only to GCS (single store call,
consistent with test assertions).
2026-04-16 16:02:44 +07:00
Zamil Majdy
e2e7c85a48 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/subscription-tier-billing 2026-04-16 15:43:22 +07:00
Zamil Majdy
2b4727e8b2 chore: merge master into dev, resolve baseline/transcript conflicts
Conflicts in baseline/service.py, baseline/transcript_integration_test.py,
and transcript.py arose because dev-only commit 0cd0a76305
(baseline upload fix) overlapped with the same fix in PR #12804 which
landed in master. Took master's version for all three files — it is the
complete, reviewed implementation.
2026-04-16 15:38:46 +07:00
Zamil Majdy
0cd0a76305 fix(backend/copilot): baseline always uploads when GCS has no transcript
_load_prior_transcript was returning False for missing/invalid transcripts,
which caused should_upload_transcript to suppress the upload. The original
intent was to protect against overwriting a *newer* GCS version — but a
missing or corrupt file is not 'newer'. Only stale (watermark ahead) and
download errors (unknown GCS state) should suppress upload.

Also renames transcript_covers_prefix → transcript_upload_safe throughout
to accurately describe what the flag means.
2026-04-16 14:58:42 +07:00
chernistry
bd2efed080 fix(frontend): allow zooming out more in the builder (#12690)
Reduced minZoom on the builder canvas from 0.1 to 0.05 to allow zooming
out further when working with large agent graphs.

Fixes #9325

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-15 21:25:07 +00:00
Zamil Majdy
5fccd8a762 Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-04-16 01:23:07 +07:00
Zamil Majdy
d27d22159d Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-04-16 00:05:32 +07:00
Zamil Majdy
df205b5444 fix(backend/copilot): strip CLI session file to prevent auto-compaction context loss
The Claude Code CLI auto-compacts its native session JSONL when the context
approaches the model's token limit (~200K for Sonnet).  After compaction the
detailed conversation history is replaced by a ~27K-token summary, causing
the silent context loss users see as memory failures in long sessions.

Root cause identified from production logs for session 93ecf7c9:
- T6 CLI session: 233KB / ~207K tokens (near Sonnet limit)
- T7 CLI compacted session -> ~167KB / ~47K tokens (PreCompact hook missed)
- T12 second compaction -> ~176KB / ~27K tokens (just system prompt + summary)
- T14-T21: cache_read=26714 constantly -- only system prompt visible to Claude

The same stripping we already apply to our transcript (stale thinking blocks,
progress/metadata entries) now also runs on the CLI native session file.  At
~2x the size of the stripped transcript, unstripped sessions routinely hit the
compaction threshold within 6-10 turns of a heavy Opus/thinking session.
After stripping:
- same-pod turns reuse the stripped local file (no compaction trigger)
- cross-pod turns restore the stripped GCS file (same benefit)
2026-04-15 23:19:12 +07:00
majdyz
4efa1c4310 fix(copilot): set session_id on mode-switch T1 to enable --resume on subsequent turns
When a user switches from baseline (fast) mode to SDK (extended_thinking)
mode mid-session, the first SDK turn has has_history=True (prior baseline
messages in DB) but no CLI session file in storage.

The old code gated session_id on `not has_history`, so mode-switch T1
never received a session_id — the CLI generated a random ID that wasn't
uploaded under the expected key.  Every subsequent SDK turn would fail to
restore the CLI session and run without --resume, injecting the full
compressed history on each turn, causing model confusion.

Fix: set session_id whenever not using --resume (the `else` branch),
covering T1 fresh, mode-switch T1, and T2+ fallback turns.  The retry
path is updated to use `"session_id" in sdk_options_kwargs` as the
discriminator (instead of `not has_history`) so mode-switch T1 retries
also keep the session_id while T2+ retries (where T1 restored a session
file via restore_cli_session) still remove it to avoid "Session ID
already in use".
2026-04-15 23:19:11 +07:00
Zamil Majdy
3324e7199b fix(backend): return 503 when checkout redirect URLs are unconfigured
When neither frontend_base_url nor platform_base_url is set, subscription
upgrade attempts were failing with a misleading 422 'success_url and
cancel_url must match the platform frontend origin' error. The real problem
is a server misconfiguration, not a bad URL from the client.

Add an explicit pre-flight check in update_subscription_tier: if the allowed
origin is not configured, log an error and raise 503 with a clear message so
operators can diagnose the missing config instead of chasing a false URL
mismatch error.
2026-04-15 23:08:15 +07:00
majdyz
51532c4fd1 chore(platform): merge origin/dev into feat/subscription-tier-billing 2026-04-15 20:50:32 +07:00
majdyz
a73ceb2838 refactor(backend/copilot): convert absolute copilot imports to relative in sdk/service.py
Replace all `from backend.copilot.X import Y` (top-level and inline)
with `from ..X import Y` to eliminate Pyright type collisions from
mixed absolute/relative imports. Add `# isort: skip_file` to prevent
isort from reverting the change.
2026-04-15 20:15:04 +07:00
majdyz
7672722996 fix(backend/copilot): add _SystemPromptPreset TypedDict for Pyright compat
claude-agent-sdk 0.1.58's SystemPromptPreset TypedDict does not declare
exclude_dynamic_sections, causing a reportCallIssue Pyright error at the
_build_system_prompt_value call site (service.py:820). The field was
added in 0.1.59.

Define a local _SystemPromptPreset that extends the SDK TypedDict with
exclude_dynamic_sections: NotRequired[bool] so Pyright accepts the kwarg
without a # type: ignore comment, until the SDK pin is bumped to 0.1.59.
2026-04-15 20:05:05 +07:00
majdyz
2f75eff082 fix(backend): guard modify_stripe_subscription_for_tier against orphaned customers
Add early return when user has no stripe_customer_id to prevent creating
an orphaned Stripe customer if a subsequent Subscription.list call fails.
Follows the same pattern as cancel_stripe_subscription and
get_proration_credit_cents. Update tests to mock get_user_by_id and add
a test for the no-customer-id path.
2026-04-15 19:53:22 +07:00
majdyz
10b92fbaa2 Merge remote-tracking branch 'origin/dev' into feat/subscription-tier-billing 2026-04-15 19:35:29 +07:00
majdyz
2cdd164223 fix(backend): guard get_proration_credit_cents against creating orphaned Stripe customers
get_proration_credit_cents now checks user.stripe_customer_id before calling into
Stripe, matching the same pattern applied to cancel_stripe_subscription. Admin-granted
paid-tier users without a Stripe record previously triggered customer creation on every
billing page load; they now get 0 immediately.

Also adds tests: no_customer_id fast-path for cancel, and three proration scenarios
(zero cost, no customer id, active subscription with proration calculation).
2026-04-15 13:42:40 +07:00
majdyz
c421a66fa5 fix(platform): fix Stripe customer creation on FREE downgrade and dialog display issues
- Guard cancel_stripe_subscription with stripe_customer_id check to prevent
  creating orphaned Stripe customers for users who never had a paid subscription
- Fix upgrade dialog showing raw tier key (PRO) instead of human-readable label (Pro)
- Fix misleading footer text that contradicted downgrade-to-FREE behaviour
  (downgrades to Free are scheduled at period end, not immediate)
2026-04-15 13:25:42 +07:00
majdyz
b435814826 test(backend): mock get_proration_credit_cents in route tests and assert response field
Adds the missing get_proration_credit_cents mock to the three
GET /credits/subscription route tests so they don't attempt a live
Stripe/DB call in unit-test context, and extends assertions to cover
the proration_credit_cents field in the response.
2026-04-15 13:09:47 +07:00
majdyz
10bf830b59 Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into feat/subscription-tier-billing 2026-04-15 12:53:31 +07:00
majdyz
11a5ce99f4 fix(frontend): show human-readable tier label in downgrade confirmation dialog
The dialog was showing the raw tier key (e.g. "PRO") instead of the label
(e.g. "Pro"). Look up the matching TIERS entry to get the label.
2026-04-15 08:06:09 +07:00
majdyz
354de5dc0f fix(backend): format credit_subscription_test.py with black 2026-04-15 01:47:04 +07:00
majdyz
7648aacb89 fix(backend): use stripe.Subscription.modify for paid→paid tier changes to preserve proration
Downgrading between paid tiers (e.g. BUSINESS→PRO) previously fell through to
the Checkout Session path, which would cancel the existing subscription immediately
without crediting unused time on the old plan.

For paid→paid tier changes, modify the existing Stripe subscription in-place
via stripe.Subscription.modify with proration_behavior="create_prorations".
Stripe handles the proration automatically (crediting unused time + charging
the pro-rated difference), and the existing customer.subscription.updated
webhook fires to update the DB tier — no new Checkout flow needed.

Adds unit tests for modify_stripe_subscription_for_tier and API-level tests for
the paid→paid code path in update_subscription_tier.
2026-04-15 01:42:27 +07:00
majdyz
5ba14e1152 fix(backend): add invoice_id idempotency key for subscription payment failure handler
Pass the Stripe invoice ID as transaction_key to _add_transaction in
handle_subscription_payment_failure. This prevents double-charging user
credits when Stripe retries the invoice.payment_failed webhook after a
transient failure (e.g. if stripe.Invoice.pay raises a network error).

_add_transaction silently skips insertion when the key already exists,
so subsequent retries deduct nothing while still attempting to mark the
invoice as paid on the Stripe side.

Adds test to verify the idempotency key is set correctly.
2026-04-15 00:58:25 +07:00
majdyz
fdfda78bc8 fix(backend): pay Stripe invoice after balance covers subscription payment failure
When handle_subscription_payment_failure successfully deducts the invoice
amount from a user's credit balance, Stripe's dunning system would still
retry the failed invoice automatically on its schedule, causing repeated
deductions per billing period. Fix by calling stripe.Invoice.pay() after
the successful balance deduction so Stripe marks the invoice as settled and
stops retrying. Invoice.pay failures are logged as warnings but do not
propagate, matching the existing best-effort pattern for cleanup operations.

Add two tests covering the pay call and the swallowed-error path.
2026-04-15 00:42:19 +07:00
majdyz
b681363969 fix(backend): fix cancel_stripe_subscription tests to mock modify instead of cancel
The implementation schedules cancellation at period end via
stripe.Subscription.modify(cancel_at_period_end=True) rather than
stripe.Subscription.cancel. Update the four affected test cases to
patch the correct API call and assert the right arguments.
2026-04-15 00:36:11 +07:00
majdyz
2d22de7aa8 fix(backend): update DB tier immediately when no Stripe subscription exists on downgrade
When a user's paid tier was admin-granted (no associated Stripe subscription),
cancel_stripe_subscription found nothing to cancel and returned. The API then
returned 200 without updating the DB tier, because the design relied on a
customer.subscription.deleted webhook — which would never fire.

Fix: have _cancel_customer_subscriptions return the count of subs cancelled and
cancel_stripe_subscription return a bool. When False (no subs found), the endpoint
now calls set_subscription_tier directly, mirroring the payment-disabled code path.

Added test_update_subscription_tier_free_no_stripe_subscription to cover this case.
2026-04-15 00:15:15 +07:00
majdyz
f174d75e8e fix(frontend): sync openapi.json with proration_credit_cents and fix upgrade tests
- Regenerated openapi.json to include `proration_credit_cents` field added
  to `SubscriptionStatusResponse` (was missing, causing check-API-types CI failure)
- Updated SubscriptionTierSection tests to go through the Confirm Upgrade
  dialog before asserting mutate/redirect/error behaviour — previously tests
  clicked "Upgrade to Pro" and expected a direct mutate call, but the current
  code shows a confirmation dialog first (3 tests were failing because of this)
- Added `proration_credit_cents: 0` to `makeSubscription()` helper to match
  the now-required field in the generated TypeScript type
2026-04-15 00:11:16 +07:00
majdyz
607854375b Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/subscription-tier-billing 2026-04-15 00:00:23 +07:00
majdyz
9151755f00 feat(platform): proration credit notice on upgrade + cancel-at-period-end copy
- Add get_proration_credit_cents() — fetches active Stripe sub, calculates
  unused portion of current billing period in cents
- Expose proration_credit_cents in GET /credits/subscription response
- Show upgrade confirmation dialog: "Your unused $X.XX will be added to
  your balance" before redirecting to Stripe Checkout
- Update downgrade dialog copy to reflect cancel-at-period-end behaviour
2026-04-14 23:55:37 +07:00
majdyz
b9da535cfd fix(backend): don't cancel Stripe sub when balance covers failed payment
Cancelling the subscription triggered customer.subscription.deleted
which then downgraded the user to FREE despite the balance having
already covered the invoice. Now the sub is left intact when balance
covers the cost — only cancelled (and tier downgraded) when balance
is insufficient.
2026-04-14 23:47:45 +07:00
majdyz
befd9df446 fix(backend): cancel Stripe subscription at period end on downgrade
User keeps their paid tier for the remainder of the billing period
they already paid for. DB tier is no longer updated immediately on
downgrade — customer.subscription.deleted webhook fires at period end
and downgrades to FREE then.
2026-04-14 23:46:48 +07:00
majdyz
bfd1e6e793 fix(backend): fix Stripe price ID LD flag lookup and subscription payment handling
- Use user_id="system" for global LD flag lookups (price IDs don't need user context)
- Skip Supabase lookup silently for non-UUID keys in _fetch_user_context_data
- Block paid tier changes when ENABLE_PLATFORM_PAYMENT is disabled
- Add invoice.payment_failed handler: deduct from balance or downgrade to FREE
- Hide upgrade/downgrade buttons in UI when payment flag is disabled
2026-04-14 22:41:21 +07:00
Zamil Majdy
c477e7b92e Merge branch 'dev' into feat/subscription-tier-billing 2026-04-14 22:18:50 +07:00
majdyz
bcbe7f4525 fix(platform): block self-service paid upgrades when payment flag is disabled
When ENABLE_PLATFORM_PAYMENT is off for paid tier requests, return 422
instead of setting the tier directly. Admin tier changes must go through
the /api/admin/ routes, not the self-service endpoint.

Updates the corresponding subscription route test to assert the 422
response and removes the now-invalid set_subscription_tier mock.
2026-04-14 21:51:02 +07:00
majdyz
a118ea564e Merge remote-tracking branch 'origin/feat/subscription-tier-billing' into feat/subscription-tier-billing 2026-04-14 21:23:54 +07:00
majdyz
cf89b58960 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/subscription-tier-billing 2026-04-14 21:20:24 +07:00
Zamil Majdy
6b03b8d4d8 Merge branch 'dev' into feat/subscription-tier-billing 2026-04-14 21:11:50 +07:00
majdyz
ec65fd5c84 fix(backend): add cache_none=False to get_subscription_price_id
A transient LaunchDarkly failure returned None from get_subscription_price_id,
which was cached for the full 60-second TTL, blocking subscription upgrades
until expiry. Adding cache_none=False ensures None is never stored in the cache
so the next call retries LD immediately.

Adds a regression test verifying that two consecutive calls where the first
returns None (LD transient error) and the second returns the real price ID
both hit LD, confirming the None sentinel is not cached.

Flagged by sentry[bot] (credit.py:1352, Severity: MEDIUM).
2026-04-14 14:38:39 +07:00
majdyz
14e1b47b5a test(backend): clear price_id cache between direct get_subscription_price_id tests
Since get_subscription_price_id is now @cached, in-memory cache state can
persist between tests in the same process and cause false cache hits. Call
cache_clear() before and after tests that call the function directly to
ensure each test exercises a fresh LD flag lookup.
2026-04-14 14:03:02 +07:00
majdyz
be8d54b331 fix(platform): address reviewer should-fix items for subscription billing
- Cache `get_subscription_price_id` with 60s TTL via @cached decorator;
  LD flag values change only at deploy time — caching avoids hitting the
  LD SDK on every webhook delivery and GET /credits/subscription page load
- Add webhook identity cross-check in `sync_subscription_from_stripe`:
  verify metadata.user_id (set during Checkout Session creation) matches
  the user found via stripeCustomerId; log + bail on mismatch to prevent
  silently updating the wrong user's subscription tier
- Move `handleTierChange` business logic from SubscriptionTierSection
  component into `useSubscriptionTierSection` hook per project convention;
  dialog state (confirmDowngradeTo) stays in component as it's UI state
- Add three new backend tests for metadata identity cross-check:
  matching user_id accepted, mismatching user_id blocked, absent
  metadata skips check (backward compat with non-Checkout subs)
2026-04-14 14:00:24 +07:00
majdyz
bb52c5b10d fix(platform): cleanup cancelled URL param, parallelize stripe calls, add test
- Strip ?subscription=cancelled from address bar in useSubscriptionTierSection
  alongside the existing ?subscription=success cleanup so Stripe cancel
  redirects don't leave stale params in the URL
- Parallelize the two sequential stripe.Subscription.list calls on the
  cancel webhook path using asyncio.gather to reduce handler latency
- Add a test for ?subscription=cancelled being a no-op (no toast, URL cleaned)
2026-04-13 23:50:30 +07:00
majdyz
4bd79d8f6e fix(frontend): remove unused variable in skeleton loading test 2026-04-13 23:33:06 +07:00
majdyz
bfe67b6e3d test: add missing-customer test for sync_subscription_from_stripe and update isLoading assertion
- Add test_sync_subscription_from_stripe_missing_customer_key_returns_early to
  verify the .get("customer") fix: a payload with no 'customer' key returns
  early without querying the DB or writing a tier (no KeyError→500)
- Update SubscriptionTierSection loading test to match skeleton-card output
  (no longer expects empty container; now asserts tier card text is absent)
2026-04-13 23:06:33 +07:00
majdyz
46434e7402 fix(frontend): add skeleton loader on isLoading and document useEffect deps
- Replace null return on isLoading with three Skeleton card placeholders
  matching the expected height of the tier grid to prevent layout shift
- Add eslint-disable-next-line comment on useEffect dependency array
  explaining why refetch/toast are included despite being new refs each
  render (stable in practice; effect is guarded by subscriptionStatus check)
2026-04-13 23:01:30 +07:00
majdyz
eaa833528c fix(backend): harden sync_subscription_from_stripe and add partial-cancel test assertion
- Use .get("customer") with an early return + warning log instead of direct
  key access; prevents KeyError→500 on malformed webhook payloads that pass
  HMAC verification but omit the customer field
- Document the paid-to-paid upgrade race window (PRO→BUSINESS) in a comment
  so the known limitation is visible without changing semantics
- Add mock_set_tier.assert_not_called() to the multi-partial-failure test to
  explicitly assert the DB tier is never updated when a Stripe cancel raises
2026-04-13 23:01:25 +07:00
majdyz
62a6175d2a fix(frontend): clear ?subscription=success URL param after showing toast
Replace toastShownRef guard with router.replace(pathname) so the success
toast is not re-shown on page refresh and correctly re-fires on a second
checkout in the same SPA session. Adds test coverage for the behaviour.
2026-04-13 21:47:15 +07:00
majdyz
929c8a316c fix(platform): move stale-sub cleanup after idempotency check in sync_subscription_from_stripe
_cleanup_stale_subscriptions was called before the idempotency guard
(current_tier == tier -> return), so webhook replays for an already-
applied event would fire another cleanup round and could inadvertently
cancel a new subscription the user signed up for between the original
event and its replay.

Move the cleanup call to after the idempotency check so it only runs
when we are actually going to apply a tier change. Add status in
("active", "trialing") and new_sub_id guard to ensure cleanup is
only triggered for paid-sub activation events, not cancellations.
2026-04-13 04:59:32 +00:00
Zamil Majdy
557ff84196 style(backend): apply Black formatting to credit.py set-difference expressions 2026-04-13 04:45:35 +00:00
majdyz
8a2dd8f62a fix(frontend): apply Prettier formatting to openapi.json after enum addition 2026-04-13 04:41:35 +00:00
majdyz
52d8e67135 fix(subscription): add enum to SubscriptionStatusResponse.tier in openapi.json, fix MagicMock.has_more in tests, type _MISSING sentinel
- openapi.json: SubscriptionStatusResponse.tier was missing enum constraint — generated TS type was string instead of literal union. Added enum:[FREE,PRO,BUSINESS,ENTERPRISE] to match the Literal on the Python model.
- credit_subscription_test.py: set has_more=False on all MagicMock subscription list objects so _cancel_customer_subscriptions does not log spurious 'more than 10 subs' errors in tests. Also added clarifying comment on multi_partial_failure assertion.
- cache.py: replaced _MISSING: Any = object() with a dedicated _MissingType singleton class so mypy correctly narrows type after 'result is _MISSING' comparisons.
2026-04-13 04:35:25 +00:00
majdyz
48f022b506 fix(subscription): type SubscriptionStatusResponse.tier as Literal, add same-tier noop test, reset toastShownRef on SPA nav
- SubscriptionStatusResponse.tier: str -> Literal["FREE","PRO","BUSINESS","ENTERPRISE"] so OpenAPI schema emits an enum and the generated TS client is narrowly typed
- Add test_update_subscription_tier_same_tier_is_noop: asserts the double-billing guard at line 868 returns 200/empty URL and never calls create_subscription_checkout
- Reset toastShownRef.current to false when subscriptionStatus != "success" so the success toast fires again after a second checkout on the same SPA mount
2026-04-13 03:55:19 +00:00
majdyz
bf7f674b2f fix(frontend): void floating promise in handleTierChange
Add void operator to changeTier(tierKey) call to explicitly
discard the promise.
2026-04-12 23:17:49 +00:00
majdyz
69e0a66f5e fix(frontend): wrap async confirmDowngrade in void to avoid floating promise
React onClick handlers don't await async functions, so passing an
async function directly creates a floating promise. Wrap in void to
make the intent explicit and prevent unhandled rejections.
2026-04-12 10:19:08 +00:00
majdyz
a4006fa5a1 fix(backend): scope URL @ check to netloc only in checkout redirect validation
The pre-parse rejection of @ was overly broad — it rejected valid URLs
with @ in query strings or fragments (e.g. ?ref=user@company.com).
The user:pass@host authority attack only applies to the netloc component.
Move the @ check to run against parsed.netloc after urlparse.
2026-04-12 10:18:55 +00:00
majdyz
0251bfd664 fix(backend): fix inverted still_has_active_sub predicate and add has_more check
- Use set difference instead of any() to correctly detect other active
  subs (any(sub["id"] != new_sub_id ...) returns True if ANY sub has a
  different ID, which is always true when >1 sub exists regardless of
  whether the cancelled sub is in the list).
- Add has_more check with logger.error in _cancel_customer_subscriptions
  so we surface when a customer has >10 subs and some were silently
  skipped.
2026-04-12 10:18:42 +00:00
majdyz
2f24091c17 fix(platform): simplify stripe customer race protection
Revert the tentative update_many conditional guard (prisma where-clause
null semantics are fiddly and the test suite mocks get_stripe_customer_id
end-to-end, so a real prisma error wouldn't be caught locally). The
idempotency_key on Customer.create is sufficient: Stripe collapses
concurrent + retried calls to the same Customer object for 24h, which
comfortably covers every realistic in-flight retry window.

Also invalidate the get_user_by_id cache after the DB write so the
freshly-persisted stripeCustomerId is visible on the next read.
2026-04-11 12:00:58 +00:00
majdyz
8b93cea4d4 fix(platform): harden Stripe billing flow against race + replay edges
Address review findings on the subscription tier billing PR:

1. get_stripe_customer_id race: two concurrent calls (double-click,
   retried request) could each create a Stripe Customer for the same
   user, leaving an orphaned billable customer. Pass an idempotency_key
   so Stripe collapses concurrent + retried calls server-side, and use
   a conditional update_many so the loser of a longer-window race
   re-reads the persisted ID instead of overwriting.

2. update_subscription_tier no-op short-circuit: if the user is already
   on the requested paid tier, return without creating a Checkout
   Session. Without this guard, a duplicate request creates a second
   subscription for the same price; the user would be charged for both
   until _cleanup_stale_subscriptions runs from the resulting webhook —
   which only fires after the second charge has cleared.

3. stripe_webhook payload defensive extraction: a malformed payload
   (missing/non-dict data.object, missing id) would raise KeyError /
   TypeError after signature verification, which Stripe interprets as
   a delivery failure and retries forever. Validate shape, log a
   warning, and ack with 200 so Stripe stops retrying.

4. _cleanup_stale_subscriptions: bump the swallowed-error log from
   warning to exception so Sentry surfaces it as an error, include
   the customer/sub IDs needed for manual reconciliation, and add a
   TODO referencing the missing periodic reconcile job that the
   docstring already promises as the backstop.
2026-04-11 11:48:33 +00:00
majdyz
693c616bf5 fix(util/cache): properly distinguish missing entries from cached None
The @cached decorator could not differentiate "no entry" from "entry is
None" — both `_get_from_memory` and `_get_from_redis` returned `None`
for misses, and the wrappers checked `result is not None` to decide
whether to recompute. Functions that returned `None` as a valid value
were therefore re-executed on every call, defeating the cache and (for
shared_cache=False) potentially causing per-pod thundering herd against
upstream APIs.

Fix:
- Use a module-level `_MISSING = object()` sentinel for "no entry".
- Wrappers now check `result is not _MISSING` so cached `None` is
  returned correctly.
- Add a `cache_none: bool = True` parameter so callers that *want* the
  retry-on-None behavior (e.g. external API calls returning `None` to
  signal a transient error) can explicitly opt out via `cache_none=False`.
- `_get_stripe_price_amount` opts out: returning None on a Stripe error
  must not poison the 5-minute cache window. Updated its docstring to
  describe the actual behavior.

New tests cover both default (None is cached) and `cache_none=False`
(None is not stored, next call retries) for sync, async, and shared
cache paths.

Sentry bug prediction: PRRT_kwDOJKSTjM56RTEu (severity HIGH).
2026-04-11 05:03:54 +00:00
majdyz
6f7bf90769 fix(backend): harden URL validator and add adversarial redirect tests
Reject URLs containing '@', backslashes, or control characters before
urlparse to prevent auth-trick and backslash-normalisation attacks.
Add parametrized tests covering 11 adversarial inputs + valid cases.
2026-04-11 09:27:29 +07:00
majdyz
ce57601305 fix(frontend): fix TypeScript errors in SubscriptionTierSection and its test
- Dialog controlled set callback: use explicit if-block to avoid
  returning 'false | void' (TS2322)
- Test redirect test: use vi.stubGlobal to replace window.location with
  a plain object (Proxy on jsdom Location breaks private-field access)
2026-04-11 09:24:35 +07:00
majdyz
d81bbdb870 fix(backend): avoid caching Stripe error fallback in _get_stripe_price_amount
Return None on StripeError instead of 0 so the @cached decorator
(which skips caching None) does not persist the error state for 5 min.
Added test to verify the None→0 fallback path in get_subscription_status.
2026-04-11 09:14:24 +07:00
majdyz
7f6163b180 fix(platform): address final PR review comments on subscription billing
- Replace __legacy__ Dialog import with molecules/Dialog in SubscriptionTierSection
- Update test mock to match new Dialog API (controlled pattern)
- Guard still_has_active_sub against empty new_sub_id in sync_subscription_from_stripe
- Move urlparse import from inside _validate_checkout_redirect_url to module level
2026-04-11 09:07:31 +07:00
majdyz
2057b4597e test(frontend): add Vitest+RTL integration tests for SubscriptionTierSection
Covers: tier card rendering, Current badge, cost display, upgrade/downgrade
flow (with Stripe redirect), confirmation dialog, error handling, ENTERPRISE
user messaging, and success param handling.
2026-04-11 09:00:45 +07:00
majdyz
5bb7027f89 fix(platform): address remaining PR review comments on subscription billing
Backend:
- Cache stripe.Price.retrieve with 5-min TTL via _get_stripe_price_amount
  to avoid 200-600ms Stripe round-trip on every GET /credits/subscription
- Use SubscriptionTier enum .value for FREE/ENTERPRISE in tier_costs dict
  for consistency (instead of hardcoded strings)
- Rename misleading test names: "defaults_to_FREE" → "preserves_current_tier"
  to reflect actual behaviour (unknown price IDs preserve tier, not reset)
- Update subscription_routes_test to mock _get_stripe_price_amount instead
  of stripe.Price.retrieve directly, avoiding cached-result interference

Frontend:
- Handle ?subscription=success return from Stripe Checkout: refetch + toast
- Add downgrade confirmation Dialog before cancelling paid subscription
- Handle ENTERPRISE tier: render dedicated admin-managed plan card, not the
  FREE/PRO/BUSINESS tier cards (which would show no "Current" badge)
- Track pendingTier (via variables) so only the clicked button shows "Updating..."
- Show "Pricing available soon" for paid tiers with cost=0 (unconfigured LD flags)
  instead of misleading "Free"
- Move tierError state into the hook, set via changeTier internally
- Move TIER_ORDER constant to module scope (was magic array inside render body)
- Add aria-current="true" to active tier card for screen reader accessibility
- Add role="alert" to all error paragraph elements
- Improve tier descriptions with concrete capacity values
2026-04-11 08:57:34 +07:00
majdyz
329a034ebe merge(platform): merge latest dev into feat/subscription-tier-billing 2026-04-11 08:50:35 +07:00
majdyz
62f3ed79be style(backend): fix Black formatting in platform_cost_test.py
Black detected double blank lines between class definitions in
platform_cost_test.py (pulled from dev base). Normalise to a single
blank line so the CI merge-commit lint check passes.
2026-04-11 00:12:16 +07:00
majdyz
54450def6b fix(platform): guard Stripe webhook against empty-secret HMAC bypass
An empty STRIPE_WEBHOOK_SECRET (the default) allows an attacker to
compute a valid HMAC-SHA256 signature over the same key and forge any
webhook event (customer.subscription.created, etc.), escalating any
user to an arbitrary subscription tier without paying.

Fix: return 503 immediately when stripe_webhook_secret is unset rather
than proceeding to signature verification. Also add run_in_threadpool
to get_stripe_customer_id and remove the duplicate trialing-sub test.

Merges origin/feat/subscription-tier-billing which had the open-redirect
guard, blocking-IO fix, and idempotency/ENTERPRISE guard.

Test added: test_stripe_webhook_unconfigured_secret_returns_503
2026-04-11 00:00:50 +07:00
majdyz
8ad5bf03a7 fix(platform): critical security fixes for Stripe webhook + async IO
- Guard stripe_webhook: return 503 when STRIPE_WEBHOOK_SECRET is empty.
  An empty secret allows HMAC forgery (attacker computes a valid sig over
  the same key), so we reject all webhook calls when unconfigured.
- Suppress raw Stripe error from 502 cancel response; log server-side instead.
- Wrap all blocking Stripe SDK calls in run_in_threadpool: Customer.create,
  Subscription.list, Subscription.cancel, checkout.Session.create.
- cancel_stripe_subscription now also cancels 'trialing' subscriptions
  (previously only 'active'), preventing billing after a FREE downgrade.
- session.url None now raises ValueError instead of returning empty string.
- Add tests: webhook 503 on missing secret, trialing-sub cancellation.
2026-04-10 23:55:18 +07:00
majdyz
16c38c4dfb style(credit): apply Black formatting
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 09:59:42 +00:00
majdyz
945297b965 fix(backend): cancel trialing Stripe subs alongside active ones
_cancel_customer_subscriptions previously only queried status="active",
leaving trialing subscriptions in place. A user on a trial who downgrades
to FREE, or upgrades to a different paid tier, would continue to be billed
once the trial ended. Query both "active" and "trialing" statuses and
dedupe by sub id to ensure every billable sub is cleaned up.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 09:17:38 +00:00
majdyz
6b57dc0c7f fix(backend): prevent race-condition downgrade in Stripe webhook handler
When Stripe processes a subscription upgrade, the old subscription's
customer.subscription.deleted event may arrive after the new subscription's
customer.subscription.created has already been handled. Unconditionally
setting the user to FREE in the cancel branch would immediately undo the
upgrade.

sync_subscription_from_stripe now checks Stripe for other active/trialing
subscriptions on the same customer before downgrading. If at least one
different active sub exists, the handler preserves the current tier and
returns without writing. Added a regression test that mocks Stripe
returning sub_new as active and asserts set_subscription_tier is never
awaited.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 08:49:23 +00:00
majdyz
c1aec96c0f fix(platform): address round-2 review comments on subscription billing
Security and quality fixes for PR #12727 subscription tier billing review:

- Open-redirect protection: validate success_url/cancel_url against
  settings.config.frontend_base_url before passing to Stripe Checkout.
- Blocking I/O: wrap every synchronous Stripe SDK call (Subscription.list,
  Subscription.cancel, checkout.Session.create) with run_in_threadpool via
  a shared _cancel_customer_subscriptions helper.
- Info leakage: log raw Stripe errors server-side but return a generic
  502 detail to the client ("Please try again or contact support.").
- Webhook idempotency: skip DB writes in sync_subscription_from_stripe
  when the tier is already current, avoiding redundant writes on retry.
- ENTERPRISE guard in webhook: refuse to overwrite ENTERPRISE tier from
  Stripe events (admin-managed, not self-service).
- create_subscription_checkout raises ValueError on empty session.url
  instead of silently returning "".
- Tests: fixture-based client (no leaky try/finally), open-redirect test,
  ENTERPRISE 403 test, webhook dispatch test, trialing status test,
  multi-sub partial-cancel-failure test, idempotency test, renamed
  misleading "defaults to FREE" tests to "preserves_current_tier".

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 08:44:01 +00:00
majdyz
52b0e2a9a6 fix(backend): cancel stale Stripe subs on paid-to-paid tier upgrade
When a PRO user upgrades to BUSINESS via a fresh Checkout Session, Stripe
creates a new subscription without touching the existing one, leaving the
customer double-billed. Cleaning up in sync_subscription_from_stripe
rather than the API handler ensures an abandoned Checkout does not leave
the user without a subscription: we only cancel the old sub once the new
sub has actually become active.

Errors listing or cancelling stale subs are logged but not propagated —
the new subscription tier still gets persisted, and Stripe will retry
the webhook later if listing fails.

Addresses sentry[bot] comment 3061713750 on PR #12727.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 06:54:58 +00:00
majdyz
3ef14e9657 fix(backend): invalidate get_user_tier cache in set_subscription_tier
After a tier change, the rate-limit cache (get_user_tier, 5-minute TTL)
was not cleared, so CoPilot rate limits would continue enforcing the old tier
until the TTL expired. Call get_user_tier.cache_delete(user_id) via a local
import to avoid circular import issues.

Addresses sentry[bot] comment 3061725912 on PR #12727.
2026-04-10 09:43:51 +07:00
majdyz
3c49d3373d fix(backend): remove invalid customer_update parameter from Stripe checkout
customer_update only accepts {address, name, shipping} per Stripe's TypedDict.
The payment_method key does not exist in CreateParamsCustomerUpdate, so pyright
was failing the type-check CI. Remove the invalid parameter — for Stripe
subscriptions the payment method used for the first invoice is automatically
saved to the customer by Stripe.
2026-04-10 09:30:37 +07:00
majdyz
e7e6c8f4b4 refactor(frontend): remove unused legacy subscription methods from BackendAPI
getSubscription() and setSubscriptionTier() in client.ts were replaced by
generated hooks (useGetSubscriptionStatus, useUpdateSubscriptionTier) and
are no longer called anywhere in the codebase. Remove them to avoid adding
further surface area to the deprecated BackendAPI.
2026-04-10 09:25:42 +07:00
majdyz
4b3e47fe88 fix(platform): propagate Stripe errors in cancel_stripe_subscription
- stripe.Subscription.list() is now wrapped in try-except; StripeError
  is logged and re-raised so callers know the listing failed.
- stripe.Subscription.cancel() StripeError is now re-raised (was swallowed),
  preventing set_subscription_tier from marking the user FREE when Stripe
  cancellation failed.
- update_subscription_tier catches StripeError from cancel and returns HTTP 502
  so DB tier is only updated if Stripe succeeds.
- Fix test patch path: use backend.data.credit.stripe.checkout.Session.create
  instead of bare stripe.checkout.Session.create for import-refactor safety.
- Add tests for raise-on-list-failure, raise-on-cancel-failure, and
  502 route response on cancel failure.

Addresses sentry[bot] comments 3061585490, 3061654688 on PR #12727.
2026-04-10 09:22:44 +07:00
majdyz
cc1cef7da5 fix(platform): set customer default payment method on subscription checkout
Adds customer_update={payment_method: auto} so the payment method used
for subscription is set as the Stripe customer's default. Makes it show
pre-selected in future Checkout sessions (manual top-ups).
2026-04-10 09:02:16 +07:00
18 changed files with 3439 additions and 446 deletions

View File

@@ -5,7 +5,8 @@ import time
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Any, Literal, Sequence, get_args
from typing import Annotated, Any, Literal, Sequence, cast, get_args
from urllib.parse import urlparse
import pydantic
import stripe
@@ -54,8 +55,11 @@ from backend.data.credit import (
cancel_stripe_subscription,
create_subscription_checkout,
get_auto_top_up,
get_proration_credit_cents,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
modify_stripe_subscription_for_tier,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
@@ -699,9 +703,72 @@ class SubscriptionCheckoutResponse(BaseModel):
class SubscriptionStatusResponse(BaseModel):
tier: str
monthly_cost: int
tier_costs: dict[str, int]
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
monthly_cost: int # amount in cents (Stripe convention)
tier_costs: dict[str, int] # tier name -> amount in cents
proration_credit_cents: int # unused portion of current sub to convert on upgrade
def _validate_checkout_redirect_url(url: str) -> bool:
"""Return True if `url` matches the configured frontend origin.
Prevents open-redirect: attackers must not be able to supply arbitrary
success_url/cancel_url that Stripe will redirect users to after checkout.
Pre-parse rejection rules (applied before urlparse):
- Backslashes (``\\``) are normalised differently across parsers/browsers.
- Control characters (U+0000U+001F) are not valid in URLs and may confuse
some URL-parsing implementations.
"""
# Reject characters that can confuse URL parsers before any parsing.
if "\\" in url:
return False
if any(ord(c) < 0x20 for c in url):
return False
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
if not allowed:
# No configured origin — refuse to validate rather than allow arbitrary URLs.
return False
try:
parsed = urlparse(url)
allowed_parsed = urlparse(allowed)
except ValueError:
return False
if parsed.scheme not in ("http", "https"):
return False
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
# can trick browsers into connecting to a different host than displayed.
# ``@`` in query/fragment is harmless and must be allowed.
if "@" in parsed.netloc:
return False
return (
parsed.scheme == allowed_parsed.scheme
and parsed.netloc == allowed_parsed.netloc
)
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
async def _get_stripe_price_amount(price_id: str) -> int | None:
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
of caching the ``None`` sentinel so the next request retries Stripe instead
of being served a stale "no price" for the rest of the TTL window. Callers
should treat ``None`` as an unknown price and fall back to 0.
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
every GET /credits/subscription page load and reduces quota consumption.
"""
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
return price.unit_amount or 0
except stripe.StripeError:
logger.warning(
"Failed to retrieve Stripe price %s — returning None (not cached)",
price_id,
)
return None
@v1_router.get(
@@ -722,21 +789,26 @@ async def get_subscription_status(
*[get_subscription_price_id(t) for t in paid_tiers]
)
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
for t, price_id in zip(paid_tiers, price_ids):
cost = 0
if price_id:
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
cost = price.unit_amount or 0
except stripe.StripeError:
pass
tier_costs: dict[str, int] = {
SubscriptionTier.FREE.value: 0,
SubscriptionTier.ENTERPRISE.value: 0,
}
async def _cost(pid: str | None) -> int:
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
for t, cost in zip(paid_tiers, costs):
tier_costs[t.value] = cost
current_monthly_cost = tier_costs.get(tier.value, 0)
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
return SubscriptionStatusResponse(
tier=tier.value,
monthly_cost=tier_costs.get(tier.value, 0),
monthly_cost=current_monthly_cost,
tier_costs=tier_costs,
proration_credit_cents=proration_credit,
)
@@ -766,24 +838,125 @@ async def update_subscription_tier(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
# keeps their tier for the time they already paid for. The DB tier is NOT
# updated here when a subscription exists — the customer.subscription.deleted
# webhook fires at period end and downgrades to FREE then.
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
# tier), cancel_stripe_subscription returns False and we update the DB tier
# immediately since no webhook will ever fire.
# When payment is disabled entirely, update the DB tier directly.
if tier == SubscriptionTier.FREE:
if payment_enabled:
await cancel_stripe_subscription(user_id)
try:
had_subscription = await cancel_stripe_subscription(user_id)
except stripe.StripeError as e:
# Log full Stripe error server-side but return a generic message
# to the client — raw Stripe errors can leak customer/sub IDs and
# infrastructure config details.
logger.exception(
"Stripe error cancelling subscription for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=502,
detail=(
"Unable to cancel your subscription right now. "
"Please try again or contact support."
),
)
if not had_subscription:
# No active Stripe subscription found — the user was on an
# admin-granted tier. Update DB immediately since the
# subscription.deleted webhook will never fire.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
# Beta users (payment not enabled) → update tier directly without Stripe.
# Paid tier changes require payment to be enabled — block self-service upgrades
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
if not payment_enabled:
await set_subscription_tier(user_id, tier)
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier}",
)
# No-op short-circuit: if the user is already on the requested paid tier,
# do NOT create a new Checkout Session. Without this guard, a duplicate
# request (double-click, retried POST, stale page) creates a second
# subscription for the same price; the user would be charged for both
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
# which only fires after the second charge has cleared.
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
return SubscriptionCheckoutResponse(url="")
# Paid upgrade → create Stripe Checkout Session.
# Paid→paid tier change: if the user already has a Stripe subscription,
# modify it in-place with proration instead of creating a new Checkout
# Session. This preserves remaining paid time and avoids double-charging.
# The customer.subscription.updated webhook fires and updates the DB tier.
current_tier = user.subscription_tier or SubscriptionTier.FREE
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return SubscriptionCheckoutResponse(url="")
# modify_stripe_subscription_for_tier returns False when no active
# Stripe subscription exists — i.e. the user has an admin-granted
# paid tier with no Stripe record. In that case, update the DB
# tier directly (same as the FREE-downgrade path for admin-granted
# users) rather than sending them through a new Checkout Session.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
# Paid upgrade from FREE → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
raise HTTPException(
status_code=422,
detail="success_url and cancel_url are required for paid tier upgrades",
)
# Open-redirect protection: both URLs must point to the configured frontend
# origin, otherwise an attacker could use our Stripe integration as a
# redirector to arbitrary phishing sites.
#
# Fail early with a clear 503 if the server is misconfigured (neither
# frontend_base_url nor platform_base_url set), so operators get an
# actionable error instead of the misleading "must match the platform
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
# produce when `allowed` is empty.
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
logger.error(
"update_subscription_tier: neither frontend_base_url nor "
"platform_base_url is configured; cannot validate checkout redirect URLs"
)
raise HTTPException(
status_code=503,
detail=(
"Payment redirect URLs cannot be validated: "
"frontend_base_url or platform_base_url must be set on the server."
),
)
if not _validate_checkout_redirect_url(
request.success_url
) or not _validate_checkout_redirect_url(request.cancel_url):
raise HTTPException(
status_code=422,
detail="success_url and cancel_url must match the platform frontend origin",
)
try:
url = await create_subscription_checkout(
user_id=user_id,
@@ -791,8 +964,19 @@ async def update_subscription_tier(
success_url=request.success_url,
cancel_url=request.cancel_url,
)
except (ValueError, stripe.StripeError) as e:
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error creating checkout session for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to start checkout right now. "
"Please try again or contact support."
),
)
return SubscriptionCheckoutResponse(url=url)
@@ -801,44 +985,78 @@ async def update_subscription_tier(
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
)
async def stripe_webhook(request: Request):
webhook_secret = settings.secrets.stripe_webhook_secret
if not webhook_secret:
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
# signature over the same empty key). Reject all webhook calls when unconfigured.
logger.error(
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
"rejecting request to prevent signature bypass"
)
raise HTTPException(status_code=503, detail="Webhook not configured")
# Get the raw request body
payload = await request.body()
# Get the signature header
sig_header = request.headers.get("stripe-signature")
try:
event = stripe.Webhook.construct_event(
payload, sig_header, settings.secrets.stripe_webhook_secret
)
except ValueError as e:
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
except ValueError:
# Invalid payload
raise HTTPException(
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
)
except stripe.SignatureVerificationError as e:
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.SignatureVerificationError:
# Invalid signature
raise HTTPException(
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
raise HTTPException(status_code=400, detail="Invalid signature")
# Defensive payload extraction. A malformed payload (missing/non-dict
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
# AFTER signature verification — which Stripe interprets as a delivery
# failure and retries forever, while spamming Sentry with no useful info.
# Acknowledge with 200 and a warning so Stripe stops retrying.
event_type = event.get("type", "")
event_data = event.get("data") or {}
data_object = event_data.get("object") if isinstance(event_data, dict) else None
if not isinstance(data_object, dict):
logger.warning(
"stripe_webhook: %s missing or non-dict data.object; ignoring",
event_type,
)
return Response(status_code=200)
if (
event["type"] == "checkout.session.completed"
or event["type"] == "checkout.session.async_payment_succeeded"
if event_type in (
"checkout.session.completed",
"checkout.session.async_payment_succeeded",
):
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
session_id = data_object.get("id")
if not session_id:
logger.warning(
"stripe_webhook: %s missing data.object.id; ignoring", event_type
)
return Response(status_code=200)
await UserCredit().fulfill_checkout(session_id=session_id)
if event["type"] in (
if event_type in (
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
):
await sync_subscription_from_stripe(event["data"]["object"])
await sync_subscription_from_stripe(data_object)
if event["type"] == "charge.dispute.created":
await UserCredit().handle_dispute(event["data"]["object"])
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
await UserCredit().deduct_credits(event["data"]["object"])
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
# to satisfy the type checker without changing runtime behaviour.
if event_type == "charge.dispute.created":
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
if event_type == "refund.created" or event_type == "charge.dispute.closed":
await UserCredit().deduct_credits(
cast("stripe.Refund | stripe.Dispute", data_object)
)
return Response(status_code=200)

View File

@@ -124,15 +124,6 @@ def create_copilot_queue_config() -> RabbitMQConfig:
vhost="/",
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
queues=[run_queue, cancel_queue],
# The consumer threads sit in pika's blocking ``start_consuming()`` for
# the full lifetime of the process. If the TCP connection is dropped
# (server restart, NAT timeout, laptop sleep) while pika's IO thread is
# starved, the socket rots in CLOSE_WAIT and no message is ever
# consumed — see zombie-consumer incident notes. A short heartbeat plus
# kernel-level TCP keepalive makes both the app and the OS notice a
# dead peer within a couple of minutes instead of hours.
heartbeat=60,
tcp_keepalive=True,
)

View File

@@ -174,18 +174,14 @@ sandbox so `bash_exec` can access it for further processing.
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
### GitHub CLI (`gh`) and git
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before running `connect_integration(provider="github")` which will ask the user to connect their GitHub regardless if it's already connected.
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.
- If the token changes mid-session (e.g. user reconnects with a new token),
run `gh auth setup-git` to re-register the credential helper.
- **MANDATORY:** You MUST run `gh auth status` before EVER calling
`connect_integration(provider="github")`. If it shows `Logged in`,
proceed directly — no integration connection needed. Never skip this check.
- If `gh auth status` shows NOT logged in, or `gh`/`git` fails with an
authentication error (e.g. "authentication required", "could not read
Username", or exit code 128), THEN call
- If `gh` or `git` fails with an authentication error (e.g. "authentication
required", "could not read Username", or exit code 128), call
`connect_integration(provider="github")` to surface the GitHub credentials
setup card so the user can connect their account. Once connected, retry
the operation.

View File

@@ -125,7 +125,12 @@ config = ChatConfig()
class _SystemPromptPreset(SystemPromptPreset, total=False):
"""Extends SystemPromptPreset with fields added in claude-agent-sdk 0.1.59."""
"""Extends :class:`SystemPromptPreset` with ``exclude_dynamic_sections``.
The field was added to the upstream TypedDict in claude-agent-sdk 0.1.59.
Until the package is pinned to that version we declare it locally so Pyright
accepts the kwarg without a ``# type: ignore`` comment.
"""
exclude_dynamic_sections: NotRequired[bool]

View File

@@ -880,6 +880,116 @@ class TestUploadCliSession:
assert meta_content["mode"] == "baseline"
assert meta_content["message_count"] == 4
def test_strips_session_before_upload_and_writes_back(self):
"""strip_for_upload removes progress entries and returns smaller content."""
import json
from .transcript import strip_for_upload
progress_entry = {
"type": "progress",
"uuid": "p1",
"parentUuid": "u1",
"data": {"type": "bash_progress", "stdout": "running..."},
}
user_entry = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hello"},
}
asst_entry = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {"role": "assistant", "content": "world"},
}
raw_content = (
json.dumps(progress_entry)
+ "\n"
+ json.dumps(user_entry)
+ "\n"
+ json.dumps(asst_entry)
+ "\n"
)
stripped = strip_for_upload(raw_content)
stored_lines = stripped.strip().split("\n")
stored_types = [json.loads(line).get("type") for line in stored_lines]
assert "progress" not in stored_types
assert "user" in stored_types
assert "assistant" in stored_types
assert len(stripped.encode()) < len(raw_content.encode())
def test_strips_stale_thinking_blocks_before_upload(self):
"""strip_for_upload removes thinking blocks from non-last assistant turns."""
import json
from .transcript import strip_for_upload
u1 = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "q1"},
}
a1_with_thinking = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"id": "msg_a1",
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "A" * 5000},
{"type": "text", "text": "answer1"},
],
},
}
u2 = {
"type": "user",
"uuid": "u2",
"parentUuid": "a1",
"message": {"role": "user", "content": "q2"},
}
a2_no_thinking = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "u2",
"message": {
"id": "msg_a2",
"role": "assistant",
"content": [{"type": "text", "text": "answer2"}],
},
}
raw_content = (
json.dumps(u1)
+ "\n"
+ json.dumps(a1_with_thinking)
+ "\n"
+ json.dumps(u2)
+ "\n"
+ json.dumps(a2_no_thinking)
+ "\n"
)
stripped = strip_for_upload(raw_content)
stored_lines = stripped.strip().split("\n")
# a1 should have its thinking block stripped (it's not the last assistant turn).
a1_stored = json.loads(stored_lines[1])
a1_content = a1_stored["message"]["content"]
assert all(
b["type"] != "thinking" for b in a1_content
), "stale thinking block should be stripped from a1"
assert any(
b["type"] == "text" for b in a1_content
), "text block should be kept in a1"
# a2 (last turn) should be unchanged.
a2_stored = json.loads(stored_lines[3])
assert a2_stored["message"]["content"] == [{"type": "text", "text": "answer2"}]
class TestRestoreCliSession:
def test_returns_none_when_file_not_found_in_storage(self):

View File

@@ -1,10 +1,13 @@
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast
import stripe
from fastapi.concurrency import run_in_threadpool
from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
@@ -31,6 +34,7 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.util.cache import cached
from backend.util.exceptions import InsufficientBalanceError
from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled
from backend.util.json import SafeJson, dumps
@@ -432,7 +436,7 @@ class UserCreditBase(ABC):
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
)
# Single unified atomic operation for all transaction types using UserBalance
@@ -571,7 +575,7 @@ class UserCreditBase(ABC):
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
user_id=user_id,
balance=current_balance,
amount=amount,
@@ -582,7 +586,6 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
@@ -734,7 +737,7 @@ class UserCredit(UserCreditBase):
)
if request.amount <= 0 or request.amount > transaction.amount:
raise AssertionError(
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
)
balance, _ = await self._add_transaction(
@@ -788,12 +791,12 @@ class UserCredit(UserCreditBase):
# If the user has enough balance, just let them win the dispute.
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
dispute.close()
return
logger.warning(
f"Adding extra info for dispute from {user_id} for ${amount/100}"
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
)
# Retrieve recent transaction history to support our evidence.
# This provides a concise timeline that shows service usage and proper credit application.
@@ -1237,14 +1240,23 @@ async def get_stripe_customer_id(user_id: str) -> str:
if user.stripe_customer_id:
return user.stripe_customer_id
customer = stripe.Customer.create(
# Race protection: two concurrent calls (e.g. user double-clicks "Upgrade",
# or any retried request) would each pass the check above and create their
# own Stripe Customer, leaving an orphaned billable customer in Stripe.
# Pass an idempotency_key so Stripe collapses concurrent + retried calls
# into the same Customer object server-side. The 24h Stripe idempotency
# window comfortably covers any realistic in-flight retry scenario.
customer = await run_in_threadpool(
stripe.Customer.create,
name=user.name or "",
email=user.email,
metadata={"user_id": user_id},
idempotency_key=f"customer-create-{user_id}",
)
await User.prisma().update(
where={"id": user_id}, data={"stripeCustomerId": customer.id}
)
get_user_by_id.cache_delete(user_id)
return customer.id
@@ -1263,23 +1275,211 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
data={"subscriptionTier": tier},
)
get_user_by_id.cache_delete(user_id)
# Also invalidate the rate-limit tier cache so CoPilot picks up the new
# tier immediately rather than waiting up to 5 minutes for the TTL to expire.
from backend.copilot.rate_limit import get_user_tier # local import avoids circular
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
async def cancel_stripe_subscription(user_id: str) -> None:
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
customer_id = await get_stripe_customer_id(user_id)
subscriptions = stripe.Subscription.list(
customer=customer_id, status="active", limit=10
)
for sub in subscriptions.auto_paging_iter():
try:
stripe.Subscription.cancel(sub["id"])
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: failed to cancel sub %s for user %s",
sub["id"],
user_id,
async def _cancel_customer_subscriptions(
customer_id: str,
exclude_sub_id: str | None = None,
at_period_end: bool = False,
) -> int:
"""Cancel all billable Stripe subscriptions for a customer, optionally excluding one.
Cancels both ``active`` and ``trialing`` subscriptions, since trialing subs will
start billing once the trial ends and must be cleaned up on downgrade/upgrade to
avoid double-charging or charging users who intended to cancel.
When ``at_period_end=True``, schedules cancellation at the end of the current
billing period instead of cancelling immediately — the user keeps their tier
until the period ends, then ``customer.subscription.deleted`` fires and the
webhook downgrades them to FREE.
Wraps every synchronous Stripe SDK call with run_in_threadpool so the async event
loop is never blocked. Raises stripe.StripeError on list/cancel failure so callers
that need strict consistency can react; cleanup callers can catch and log instead.
Returns the number of subscriptions cancelled/scheduled for cancellation.
"""
# Query active and trialing separately; Stripe's list API accepts a single status
# filter at a time (no OR), and we explicitly want to skip canceled/incomplete/
# past_due subs rather than filter them out client-side via status="all".
seen_ids: set[str] = set()
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status=status, limit=10
)
# Iterate only the first page (up to 10); avoid auto_paging_iter which would
# trigger additional sync HTTP calls inside the event loop.
if subscriptions.has_more:
logger.error(
"_cancel_customer_subscriptions: customer %s has more than 10 %s"
" subscriptions — only the first page was processed; remaining"
" subscriptions were NOT cancelled",
customer_id,
status,
)
for sub in subscriptions.data:
sub_id = sub["id"]
if exclude_sub_id and sub_id == exclude_sub_id:
continue
if sub_id in seen_ids:
continue
seen_ids.add(sub_id)
if at_period_end:
await run_in_threadpool(
stripe.Subscription.modify, sub_id, cancel_at_period_end=True
)
else:
await run_in_threadpool(stripe.Subscription.cancel, sub_id)
return len(seen_ids)
async def cancel_stripe_subscription(user_id: str) -> bool:
"""Schedule cancellation of all active/trialing Stripe subscriptions at period end.
The subscription stays active until the end of the billing period so the user
keeps their tier for the time they already paid for. The ``customer.subscription.deleted``
webhook fires at period end and downgrades the DB tier to FREE.
Returns True if at least one subscription was found and scheduled for cancellation,
False if the customer had no active/trialing subscriptions (e.g., admin-granted tier
with no associated Stripe subscription). When False, the caller should update the
DB tier directly since no webhook will fire to do it.
Raises stripe.StripeError if any modification fails, so the caller can avoid
updating the DB tier when Stripe is inconsistent.
"""
# Guard: only proceed if the user already has a Stripe customer ID. Calling
# get_stripe_customer_id for a user who has never had a paid subscription would
# create an orphaned, potentially-billable Stripe Customer object — we avoid that
# by returning False early so the caller can downgrade the DB tier directly.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
customer_id = user.stripe_customer_id
try:
cancelled_count = await _cancel_customer_subscriptions(
customer_id, at_period_end=True
)
return cancelled_count > 0
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: Stripe error while cancelling subs for user %s",
user_id,
)
raise
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
Fetches the user's active or trialing Stripe subscription to determine how many
seconds remain in the current billing period, then calculates the unused portion of
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
is found.
Both ``active`` and ``trialing`` subscriptions are checked: a trialing user still
accumulates a billing period and Stripe prorates the remaining trial value on upgrade.
"""
if monthly_cost_cents <= 0:
return 0
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
# orphaned customer on every billing-page load for those users.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return 0
try:
customer_id = user.stripe_customer_id
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status=status,
limit=1,
)
if not subscriptions.data:
continue
sub = subscriptions.data[0]
period_start: int = sub["current_period_start"]
period_end: int = sub["current_period_end"]
now = int(time.time())
total_seconds = period_end - period_start
remaining_seconds = max(period_end - now, 0)
if total_seconds <= 0:
return 0
return int(monthly_cost_cents * remaining_seconds / total_seconds)
return 0
except Exception:
logger.warning(
"get_proration_credit_cents: failed to compute proration for user %s",
user_id,
)
return 0
async def modify_stripe_subscription_for_tier(
user_id: str, tier: SubscriptionTier
) -> bool:
"""Modify an existing Stripe subscription to a new paid tier using proration.
For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing
subscription is preferable to cancelling + creating a new one via Checkout:
Stripe handles proration automatically, crediting unused time on the old plan
and charging the pro-rated amount for the new plan in the same billing cycle.
Returns:
True — a subscription was found and modified successfully.
False — no active/trialing subscription exists (e.g. admin-granted tier or
first-time paid signup); caller should fall back to Checkout.
Raises stripe.StripeError on API failures so callers can propagate a 502.
Raises ValueError when no Stripe price ID is configured for the tier.
"""
price_id = await get_subscription_price_id(tier)
if not price_id:
raise ValueError(f"No Stripe price ID configured for tier {tier}")
# Guard: only proceed if the user already has a Stripe customer ID. Calling
# get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier)
# would create an orphaned customer object if the subsequent Subscription.list call
# fails. Return False early so the API layer falls back to Checkout instead.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
customer_id = user.stripe_customer_id
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status=status, limit=1
)
if not subscriptions.data:
continue
sub = subscriptions.data[0]
sub_id = sub["id"]
items = sub.get("items", {}).get("data", [])
if not items:
continue
item_id = items[0]["id"]
await run_in_threadpool(
stripe.Subscription.modify,
sub_id,
items=[{"id": item_id, "price": price_id}],
proration_behavior="create_prorations",
)
logger.info(
"modify_stripe_subscription_for_tier: modified sub %s for user %s%s",
sub_id,
user_id,
tier,
)
return True
return False
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
@@ -1291,8 +1491,19 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
return AutoTopUpConfig.model_validate(user.top_up_config)
@cached(ttl_seconds=60, maxsize=8, cache_none=False)
async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
"""Return Stripe Price ID for a tier from LaunchDarkly. None = not configured."""
"""Return Stripe Price ID for a tier from LaunchDarkly, cached for 60 seconds.
Price IDs are LaunchDarkly flag values that change only at deploy time.
Caching for 60 seconds avoids hitting the LD SDK on every webhook delivery
and every GET /credits/subscription page load (called 2x per request).
``cache_none=False`` prevents a transient LD failure from caching ``None``
and blocking subscription upgrades for the full 60-second TTL window.
A tier with no configured flag (FREE, ENTERPRISE) returns ``None`` from an
O(1) dict lookup before hitting LD, so the extra LD call is never made.
"""
flag_map = {
SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO,
SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS,
@@ -1300,7 +1511,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
flag = flag_map.get(tier)
if flag is None:
return None
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
return price_id if isinstance(price_id, str) and price_id else None
@@ -1315,7 +1526,8 @@ async def create_subscription_checkout(
if not price_id:
raise ValueError(f"Subscription not available for tier {tier.value}")
customer_id = await get_stripe_customer_id(user_id)
session = stripe.checkout.Session.create(
session = await run_in_threadpool(
stripe.checkout.Session.create,
customer=customer_id,
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
@@ -1323,26 +1535,110 @@ async def create_subscription_checkout(
cancel_url=cancel_url,
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
)
return session.url or ""
if not session.url:
# An empty checkout URL for a paid upgrade is always an error; surfacing it
# as ValueError means the API handler returns 422 instead of silently
# redirecting the client to an empty URL.
raise ValueError("Stripe did not return a checkout session URL")
return session.url
async def _cleanup_stale_subscriptions(customer_id: str, new_sub_id: str) -> None:
"""Best-effort cancel of any active subs for the customer other than new_sub_id.
Called from the webhook handler after a new subscription becomes active. Failures
are logged but not raised so a transient Stripe error doesn't crash the webhook —
a periodic reconciliation job is the intended backstop for persistent drift.
NOTE: until that reconcile job lands, a failure here means the user is silently
billed for two simultaneous subscriptions. The error log below is intentionally
`logger.exception` (not `logger.warning`) so it surfaces as an error in Sentry
with the customer/sub IDs needed for manual reconciliation.
TODO(#stripe-reconcile-job): replace this best-effort cleanup with a periodic
reconciliation job that queries Stripe for customers with >1 active sub.
"""
try:
await _cancel_customer_subscriptions(customer_id, exclude_sub_id=new_sub_id)
except stripe.StripeError:
# Use exception() (not warning) so this surfaces as an error in Sentry —
# any failure here means a paid-to-paid upgrade may have left the user
# with two simultaneous active subscriptions.
logger.exception(
"stripe_stale_subscription_cleanup_failed: customer=%s new_sub=%s"
" user may be billed for two simultaneous subscriptions; manual"
" reconciliation required",
customer_id,
new_sub_id,
)
async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
"""Update User.subscriptionTier from a Stripe subscription object."""
customer_id = stripe_subscription["customer"]
"""Update User.subscriptionTier from a Stripe subscription object.
Expected shape of stripe_subscription (subset of Stripe's Subscription object):
customer: str — Stripe customer ID
status: str — "active" | "trialing" | "canceled" | ...
id: str — Stripe subscription ID
items.data[].price.id: str — Stripe price ID identifying the tier
"""
customer_id = stripe_subscription.get("customer")
if not customer_id:
logger.warning(
"sync_subscription_from_stripe: missing 'customer' field in event, "
"skipping (keys: %s)",
list(stripe_subscription.keys()),
)
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"sync_subscription_from_stripe: no user for customer %s", customer_id
)
return
# Cross-check: if the subscription carries a metadata.user_id (set during
# Checkout Session creation), verify it matches the user we found via
# stripeCustomerId. A mismatch indicates a customer↔user mapping
# inconsistency — updating the wrong user's tier would be a data-corruption
# bug, so we log loudly and bail out. Absence of metadata.user_id (e.g.
# subscriptions created outside the Checkout flow) is not an error — we
# simply skip the check and proceed with the customer-ID-based lookup.
metadata = stripe_subscription.get("metadata") or {}
metadata_user_id = metadata.get("user_id") if isinstance(metadata, dict) else None
if metadata_user_id and metadata_user_id != user.id:
logger.error(
"sync_subscription_from_stripe: metadata.user_id=%s does not match"
" user.id=%s found via stripeCustomerId=%s — refusing to update tier"
" to avoid corrupting the wrong user's subscription state",
metadata_user_id,
user.id,
customer_id,
)
return
# ENTERPRISE tiers are admin-managed. Never let a Stripe webhook flip an
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
# a self-service Stripe sub, it's a data-consistency issue for an operator,
# not something the webhook should automatically "fix".
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
" for user %s (customer %s); event status=%s",
user.id,
customer_id,
stripe_subscription.get("status", ""),
)
return
status = stripe_subscription.get("status", "")
new_sub_id = stripe_subscription.get("id", "")
if status in ("active", "trialing"):
price_id = ""
items = stripe_subscription.get("items", {}).get("data", [])
if items:
price_id = items[0].get("price", {}).get("id", "")
pro_price = await get_subscription_price_id(SubscriptionTier.PRO)
biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS)
pro_price, biz_price = await asyncio.gather(
get_subscription_price_id(SubscriptionTier.PRO),
get_subscription_price_id(SubscriptionTier.BUSINESS),
)
if price_id and pro_price and price_id == pro_price:
tier = SubscriptionTier.PRO
elif price_id and biz_price and price_id == biz_price:
@@ -1359,10 +1655,219 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
)
return
else:
# A subscription was cancelled or ended. DO NOT unconditionally downgrade
# to FREE — Stripe does not guarantee webhook delivery order, so a
# `customer.subscription.deleted` for the OLD sub can arrive after we've
# already processed `customer.subscription.created` for a new paid sub.
# Ask Stripe whether any OTHER active/trialing subs exist for this
# customer; if they do, keep the user's current tier (the other sub's
# own event will/has already set the correct tier).
try:
other_subs_active, other_subs_trialing = await asyncio.gather(
run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status="active",
limit=10,
),
run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status="trialing",
limit=10,
),
)
except stripe.StripeError:
logger.warning(
"sync_subscription_from_stripe: could not verify other active"
" subs for customer %s on cancel event %s; preserving current"
" tier to avoid an unsafe downgrade",
customer_id,
new_sub_id,
)
return
# Filter out the cancelled subscription to check if other active subs
# exist. When new_sub_id is empty (malformed event with no 'id' field),
# we cannot safely exclude any sub — preserve current tier to avoid
# an unsafe downgrade on a malformed webhook payload.
if not new_sub_id:
logger.warning(
"sync_subscription_from_stripe: cancel event missing 'id' field"
" for customer %s; preserving current tier",
customer_id,
)
return
other_active_ids = {sub["id"] for sub in other_subs_active.data} - {new_sub_id}
other_trialing_ids = {sub["id"] for sub in other_subs_trialing.data} - {
new_sub_id
}
still_has_active_sub = bool(other_active_ids or other_trialing_ids)
if still_has_active_sub:
logger.info(
"sync_subscription_from_stripe: sub %s cancelled but customer %s"
" still has another active sub; keeping tier %s",
new_sub_id,
customer_id,
current_tier.value,
)
return
tier = SubscriptionTier.FREE
# Idempotency: Stripe retries webhooks on delivery failure, and several event
# types map to the same final tier. Skip the DB write + cache invalidation
# when the tier is already correct to avoid redundant writes on replay.
if current_tier == tier:
return
# When a new subscription becomes active (e.g. paid-to-paid tier upgrade
# via a fresh Checkout Session), cancel any OTHER active subscriptions for
# the same customer so the user isn't billed twice. We do this in the
# webhook rather than the API handler so that abandoning the checkout
# doesn't leave the user without a subscription.
# IMPORTANT: this runs AFTER the idempotency check above so that webhook
# replays for an already-applied event do NOT trigger another cleanup round
# (which could otherwise cancel a legitimately new subscription the user
# signed up for between the original event and its replay).
if status in ("active", "trialing") and new_sub_id:
# NOTE: paid-to-paid upgrade race (e.g. PRO → BUSINESS):
# _cleanup_stale_subscriptions cancels the old PRO sub before
# set_subscription_tier writes BUSINESS to the DB. If Stripe delivers
# the PRO `customer.subscription.deleted` event concurrently and it
# processes after the PRO cancel but before set_subscription_tier
# commits, the user could momentarily appear as FREE in the DB.
# This window is very short in practice (two sequential awaits),
# but is a known limitation of the current webhook-driven approach.
# A future improvement would be to write the new tier first, then
# cancel the old sub.
await _cleanup_stale_subscriptions(customer_id, new_sub_id)
await set_subscription_tier(user.id, tier)
async def handle_subscription_payment_failure(invoice: dict) -> None:
"""Handle a failed Stripe subscription payment.
Tries to cover the invoice amount from the user's credit balance.
- Balance sufficient → deduct from balance, then pay the Stripe invoice so
Stripe stops retrying it. The sub stays intact and the user keeps their tier.
- Balance insufficient → cancel Stripe sub immediately, downgrade to FREE.
Cancelling here avoids further Stripe retries on an invoice we cannot cover.
"""
customer_id = invoice.get("customer")
if not customer_id:
logger.warning(
"handle_subscription_payment_failure: missing customer in invoice; skipping"
)
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"handle_subscription_payment_failure: no user found for customer %s",
customer_id,
)
return
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
" (customer %s) — tier is admin-managed",
user.id,
customer_id,
)
return
amount_due: int = invoice.get("amount_due", 0)
sub_id: str = invoice.get("subscription", "")
invoice_id: str = invoice.get("id", "")
if not invoice_id:
# Without an invoice ID we cannot set an idempotency key on the credit
# deduction. Stripe webhook retries would then double-charge the user's
# balance on every retry cycle. Bail out early — a real Stripe invoice
# always carries an ID, so a missing one indicates a malformed payload.
logger.warning(
"handle_subscription_payment_failure: invoice missing 'id' for"
" customer %s; skipping to avoid non-idempotent balance deduction",
customer_id,
)
return
if amount_due <= 0:
logger.info(
"handle_subscription_payment_failure: amount_due=%d for user %s;"
" nothing to deduct",
amount_due,
user.id,
)
return
credit_model = UserCredit()
try:
await credit_model._add_transaction(
user_id=user.id,
amount=-amount_due,
transaction_type=CreditTransactionType.SUBSCRIPTION,
fail_insufficient_credits=True,
# Use invoice_id as the idempotency key so that Stripe webhook retries
# (e.g. on a transient stripe.Invoice.pay failure) do not double-charge.
# invoice_id is guaranteed non-empty here (early-return guard above).
transaction_key=invoice_id,
metadata=SafeJson(
{
"stripe_customer_id": customer_id,
"stripe_subscription_id": sub_id,
"reason": "subscription_payment_failure_covered_by_balance",
}
),
)
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
# system stops retrying it — without this call Stripe would retry automatically
# and re-trigger this webhook, causing double-deductions each retry cycle.
# invoice_id is guaranteed non-empty here (early-return guard above).
try:
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: balance deducted for user"
" %s but failed to mark invoice %s as paid; Stripe may retry",
user.id,
invoice_id,
)
logger.info(
"handle_subscription_payment_failure: deducted %d cents from balance"
" for user %s; Stripe invoice %s paid, sub %s intact, tier preserved",
amount_due,
user.id,
invoice_id,
sub_id,
)
except InsufficientBalanceError:
# Balance insufficient — cancel Stripe subscription first, then downgrade DB.
# Order matters: if we downgrade the DB first and the Stripe cancel fails, the
# user is permanently stuck on FREE while Stripe continues billing them.
# Cancelling Stripe first is safe: if the DB write then fails, the webhook
# customer.subscription.deleted will fire and correct the tier eventually.
logger.info(
"handle_subscription_payment_failure: insufficient balance for user %s;"
" cancelling Stripe sub %s then downgrading to FREE",
user.id,
sub_id,
)
try:
await _cancel_customer_subscriptions(customer_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
" for user %s (customer %s); skipping tier downgrade to avoid"
" inconsistency — Stripe may continue retrying the invoice",
sub_id,
user.id,
customer_id,
)
return
await set_subscription_tier(user.id, SubscriptionTier.FREE)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,

View File

@@ -1,6 +1,5 @@
import asyncio
import logging
import socket
from abc import ABC, abstractmethod
from enum import Enum
from typing import Awaitable, Optional
@@ -43,39 +42,6 @@ CONNECTION_ATTEMPTS = 5
# Use case: Faster reconnection for long-running executions that need to resume quickly
RETRY_DELAY = 1
# DEFAULT_HEARTBEAT (300s = 5 min)
# AMQP application-level heartbeat. Server drops the connection if no heartbeat
# is seen within ~2x this interval. Consumers that sit in CLOSE_WAIT because
# pika's IO loop was starved (e.g. laptop sleep, blocking main thread) recover
# faster with a lower value. See `create_copilot_queue_config` for a case that
# overrides this.
DEFAULT_HEARTBEAT = 300
def _tcp_keepalive_options() -> dict[str, int]:
"""Platform-aware TCP keepalive socket options for pika.
pika enables ``SO_KEEPALIVE`` on every socket by default; this dict tunes
how quickly the kernel declares a silent peer dead. Without these knobs,
the OS default on Linux is ~2 hours of idle before the first probe — long
enough for a half-closed socket to sit in CLOSE_WAIT forever while the
consumer thread is blocked inside ``start_consuming()``.
pika passes each key through ``getattr(socket, key)`` at ``IPPROTO_TCP``
level, so names must exist on the current platform. Linux has
``TCP_KEEPIDLE``; macOS uses ``TCP_KEEPALIVE`` for the equivalent knob.
"""
opts: dict[str, int] = {}
if hasattr(socket, "TCP_KEEPIDLE"):
opts["TCP_KEEPIDLE"] = 60
elif hasattr(socket, "TCP_KEEPALIVE"):
opts["TCP_KEEPALIVE"] = 60
if hasattr(socket, "TCP_KEEPINTVL"):
opts["TCP_KEEPINTVL"] = 20
if hasattr(socket, "TCP_KEEPCNT"):
opts["TCP_KEEPCNT"] = 3
return opts
class ExchangeType(str, Enum):
DIRECT = "direct"
@@ -107,8 +73,6 @@ class RabbitMQConfig(BaseModel):
vhost: str = "/"
exchanges: list[Exchange]
queues: list[Queue]
heartbeat: int = DEFAULT_HEARTBEAT
tcp_keepalive: bool = False
class RabbitMQBase(ABC):
@@ -177,8 +141,7 @@ class SyncRabbitMQ(RabbitMQBase):
socket_timeout=SOCKET_TIMEOUT,
connection_attempts=CONNECTION_ATTEMPTS,
retry_delay=RETRY_DELAY,
heartbeat=self.config.heartbeat,
tcp_options=_tcp_keepalive_options() if self.config.tcp_keepalive else None,
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
)
self._connection = pika.BlockingConnection(parameters)
@@ -297,7 +260,7 @@ class AsyncRabbitMQ(RabbitMQBase):
password=self.password,
virtualhost=self.config.vhost.lstrip("/"),
blocked_connection_timeout=BLOCKED_CONNECTION_TIMEOUT,
heartbeat=self.config.heartbeat,
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
)
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)

View File

@@ -73,6 +73,31 @@ def _get_redis() -> Redis:
return r
class _MissingType:
"""Singleton sentinel type — distinct from ``None`` (a valid cached value).
Using a dedicated class (instead of ``Any = object()``) lets mypy prove
that comparisons ``result is _MISSING`` narrow the type correctly and
prevents accidental use of the sentinel where a real value is expected.
"""
_instance: "_MissingType | None" = None
def __new__(cls) -> "_MissingType":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "<MISSING>"
# Sentinel returned by ``_get_from_memory`` / ``_get_from_redis`` to mean
# "no entry exists" — distinct from a cached ``None`` value, which is a
# valid result for callers that opt into caching it.
_MISSING = _MissingType()
@dataclass
class CachedValue:
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
@@ -160,6 +185,7 @@ def cached(
ttl_seconds: int,
shared_cache: bool = False,
refresh_ttl_on_get: bool = False,
cache_none: bool = True,
) -> Callable[[Callable[P, R]], CachedFunction[P, R]]:
"""
Thundering herd safe cache decorator for both sync and async functions.
@@ -172,6 +198,10 @@ def cached(
ttl_seconds: Time to live in seconds. Required - entries must expire.
shared_cache: If True, use Redis for cross-process caching
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
cache_none: If True (default) ``None`` is cached like any other value.
Set to ``False`` for functions that return ``None`` to signal a
transient error and should be re-tried on the next call without
poisoning the cache (e.g. external API calls that may fail).
Returns:
Decorated function with caching capabilities
@@ -184,6 +214,12 @@ def cached(
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
@cached(ttl_seconds=300, cache_none=False)
async def fetch_external(id: str) -> dict | None:
# Returns None on transient error — won't be stored,
# next call retries instead of returning the stale None.
...
"""
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
@@ -191,9 +227,14 @@ def cached(
cache_storage: dict[tuple, CachedValue] = {}
_event_loop_locks: dict[Any, asyncio.Lock] = {}
def _get_from_redis(redis_key: str) -> Any | None:
def _get_from_redis(redis_key: str) -> Any:
"""Get value from Redis, optionally refreshing TTL.
Returns the cached value (which may be ``None``) on a hit, or the
module-level ``_MISSING`` sentinel on a miss / corrupt entry.
Callers must compare with ``is _MISSING`` so cached ``None`` values
are not mistaken for misses.
Values are expected to carry an HMAC-SHA256 prefix for integrity
verification. Unsigned (legacy) or tampered entries are silently
discarded and treated as cache misses, so the caller recomputes and
@@ -213,11 +254,11 @@ def cached(
f"for {func_name}, discarding entry: "
"possible tampering or legacy unsigned value"
)
return None
return _MISSING
return pickle.loads(payload)
except Exception as e:
logger.error(f"Redis error during cache check for {func_name}: {e}")
return None
return _MISSING
def _set_to_redis(redis_key: str, value: Any) -> None:
"""Set HMAC-signed pickled value in Redis with TTL."""
@@ -227,8 +268,13 @@ def cached(
except Exception as e:
logger.error(f"Redis error storing cache for {func_name}: {e}")
def _get_from_memory(key: tuple) -> Any | None:
"""Get value from in-memory cache, checking TTL."""
def _get_from_memory(key: tuple) -> Any:
"""Get value from in-memory cache, checking TTL.
Returns the cached value (which may be ``None``) on a hit, or the
``_MISSING`` sentinel on a miss / TTL expiry. See
``_get_from_redis`` for the rationale.
"""
if key in cache_storage:
cached_data = cache_storage[key]
if time.time() - cached_data.timestamp < ttl_seconds:
@@ -236,7 +282,7 @@ def cached(
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
)
return cached_data.result
return None
return _MISSING
def _set_to_memory(key: tuple, value: Any) -> None:
"""Set value in in-memory cache with timestamp."""
@@ -270,11 +316,11 @@ def cached(
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Slow path: acquire lock for cache miss/expiry
@@ -282,22 +328,24 @@ def cached(
# Double-check: another coroutine might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {func_name}")
result = await target_func(*args, **kwargs)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
# Store result (skip ``None`` if the caller opted out of
# caching it — used for transient-error sentinels).
if cache_none or result is not None:
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result
@@ -315,11 +363,11 @@ def cached(
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Slow path: acquire lock for cache miss/expiry
@@ -327,22 +375,24 @@ def cached(
# Double-check: another thread might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {func_name}")
result = target_func(*args, **kwargs)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
# Store result (skip ``None`` if the caller opted out of
# caching it — used for transient-error sentinels).
if cache_none or result is not None:
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result

View File

@@ -1223,3 +1223,123 @@ class TestCacheHMAC:
assert call_count == 2
legacy_test_fn.cache_clear()
class TestCacheNoneHandling:
"""Tests for the ``cache_none`` parameter on the @cached decorator.
Sentry bug PRRT_kwDOJKSTjM56RTEu (HIGH): the cache previously could not
distinguish "no entry" from "entry is None", so any function returning
``None`` was effectively re-executed on every call. The fix is a
sentinel-based check inside the wrappers, plus an opt-out
``cache_none=False`` flag for callers that *want* errors to retry.
"""
@pytest.mark.asyncio
async def test_async_none_is_cached_by_default(self):
"""With ``cache_none=True`` (default), cached ``None`` is returned
from the cache instead of triggering re-execution."""
call_count = 0
@cached(ttl_seconds=300)
async def maybe_none(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
assert await maybe_none(1) is None
assert call_count == 1
# Second call should hit the cache, not re-execute.
assert await maybe_none(1) is None
assert call_count == 1
# Different argument is a different cache key — re-executes.
assert await maybe_none(2) is None
assert call_count == 2
def test_sync_none_is_cached_by_default(self):
call_count = 0
@cached(ttl_seconds=300)
def maybe_none(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
assert maybe_none(1) is None
assert maybe_none(1) is None
assert call_count == 1
@pytest.mark.asyncio
async def test_async_cache_none_false_skips_storing_none(self):
"""``cache_none=False`` skips storing ``None`` so transient errors
are retried on the next call instead of poisoning the cache."""
call_count = 0
results: list[int | None] = [None, None, 42]
@cached(ttl_seconds=300, cache_none=False)
async def maybe_none(x: int) -> int | None:
nonlocal call_count
result = results[call_count]
call_count += 1
return result
# First call: returns None, NOT stored.
assert await maybe_none(1) is None
assert call_count == 1
# Second call with same key: re-executes (None wasn't cached).
assert await maybe_none(1) is None
assert call_count == 2
# Third call: returns 42, this time it IS stored.
assert await maybe_none(1) == 42
assert call_count == 3
# Fourth call: cache hit on the stored 42.
assert await maybe_none(1) == 42
assert call_count == 3
def test_sync_cache_none_false_skips_storing_none(self):
call_count = 0
results: list[int | None] = [None, 99]
@cached(ttl_seconds=300, cache_none=False)
def maybe_none(x: int) -> int | None:
nonlocal call_count
result = results[call_count]
call_count += 1
return result
assert maybe_none(1) is None
assert call_count == 1
# None was not stored — re-executes.
assert maybe_none(1) == 99
assert call_count == 2
# 99 IS stored — no re-execution.
assert maybe_none(1) == 99
assert call_count == 2
@pytest.mark.asyncio
async def test_async_shared_cache_none_is_cached_by_default(self):
"""Shared (Redis) cache also properly returns cached ``None`` values."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
async def maybe_none_redis(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
maybe_none_redis.cache_clear()
assert await maybe_none_redis(1) is None
assert call_count == 1
assert await maybe_none_redis(1) is None
assert call_count == 1
maybe_none_redis.cache_clear()

View File

@@ -1,6 +1,7 @@
import contextlib
import logging
import os
import uuid
from enum import Enum
from functools import wraps
from typing import Any, Awaitable, Callable, TypeVar
@@ -101,6 +102,12 @@ async def _fetch_user_context_data(user_id: str) -> Context:
"""
builder = Context.builder(user_id).kind("user").anonymous(True)
try:
uuid.UUID(user_id)
except ValueError:
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
return builder.build()
try:
from backend.util.clients import get_supabase

View File

@@ -110,7 +110,7 @@ export const Flow = () => {
event.preventDefault();
}}
maxZoom={2}
minZoom={0.1}
minZoom={0.05}
onDragOver={onDragOver}
onDrop={onDrop}
nodesDraggable={!isLocked}

View File

@@ -1,6 +1,8 @@
"use client";
import { useState } from "react";
import { Button } from "@/components/ui/button";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { useSubscriptionTierSection } from "./useSubscriptionTierSection";
type TierInfo = {
@@ -15,39 +17,70 @@ const TIERS: TierInfo[] = [
key: "FREE",
label: "Free",
multiplier: "1x",
description: "Base rate limits",
description: "Base AutoPilot capacity with standard rate limits",
},
{
key: "PRO",
label: "Pro",
multiplier: "5x",
description: "5x more AutoPilot capacity",
description: "5x AutoPilot capacity — run 5× more tasks per day/week",
},
{
key: "BUSINESS",
label: "Business",
multiplier: "20x",
description: "20x more AutoPilot capacity",
description: "20x AutoPilot capacity — ideal for teams and heavy workloads",
},
];
function formatCost(cents: number): string {
if (cents === 0) return "Free";
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
function formatCost(cents: number, tierKey: string): string {
if (tierKey === "FREE") return "Free";
if (cents === 0) return "Pricing available soon";
return `$${(cents / 100).toFixed(2)}/mo`;
}
export function SubscriptionTierSection() {
const { subscription, isLoading, error, isPending, changeTier } =
useSubscriptionTierSection();
const [tierError, setTierError] = useState<string | null>(null);
const {
subscription,
isLoading,
error,
tierError,
isPending,
pendingTier,
pendingUpgradeTier,
setPendingUpgradeTier,
confirmUpgrade,
isPaymentEnabled,
changeTier,
handleTierChange,
} = useSubscriptionTierSection();
const [confirmDowngradeTo, setConfirmDowngradeTo] = useState<string | null>(
null,
);
if (isLoading) return null;
if (isLoading) {
return (
<div className="space-y-4">
<Skeleton className="h-6 w-48" />
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
<Skeleton className="h-40 rounded-lg" />
<Skeleton className="h-40 rounded-lg" />
<Skeleton className="h-40 rounded-lg" />
</div>
</div>
);
}
if (error) {
return (
<div className="space-y-4">
<h3 className="text-lg font-medium">Subscription Plan</h3>
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
<p
role="alert"
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
>
{error}
</p>
</div>
@@ -56,10 +89,30 @@ export function SubscriptionTierSection() {
if (!subscription) return null;
async function handleTierChange(tierKey: string) {
setTierError(null);
const err = await changeTier(tierKey);
if (err) setTierError(err);
const currentTier = subscription.tier;
if (currentTier === "ENTERPRISE") {
return (
<div className="space-y-4">
<h3 className="text-lg font-medium">Subscription Plan</h3>
<div className="rounded-lg border border-violet-500 bg-violet-50 p-4 dark:bg-violet-900/20">
<p className="font-semibold text-violet-700 dark:text-violet-200">
Enterprise Plan
</p>
<p className="mt-1 text-sm text-neutral-600 dark:text-neutral-400">
Your Enterprise plan is managed by your administrator. Contact your
account team for changes.
</p>
</div>
</div>
);
}
async function confirmDowngrade() {
if (!confirmDowngradeTo) return;
const tier = confirmDowngradeTo;
setConfirmDowngradeTo(null);
await changeTier(tier);
}
return (
@@ -67,24 +120,28 @@ export function SubscriptionTierSection() {
<h3 className="text-lg font-medium">Subscription Plan</h3>
{tierError && (
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
<p
role="alert"
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
>
{tierError}
</p>
)}
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
{TIERS.map((tier) => {
const isCurrent = subscription.tier === tier.key;
const isCurrent = currentTier === tier.key;
const cost = subscription.tier_costs[tier.key] ?? 0;
const currentTierOrder = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
const currentIdx = currentTierOrder.indexOf(subscription.tier);
const targetIdx = currentTierOrder.indexOf(tier.key);
const currentIdx = TIER_ORDER.indexOf(currentTier);
const targetIdx = TIER_ORDER.indexOf(tier.key);
const isUpgrade = targetIdx > currentIdx;
const isDowngrade = targetIdx < currentIdx;
const isThisPending = pendingTier === tier.key;
return (
<div
key={tier.key}
aria-current={isCurrent ? "true" : undefined}
className={`rounded-lg border p-4 ${
isCurrent
? "border-violet-500 bg-violet-50 dark:bg-violet-900/20"
@@ -100,7 +157,9 @@ export function SubscriptionTierSection() {
)}
</div>
<p className="mb-1 text-2xl font-bold">{formatCost(cost)}</p>
<p className="mb-1 text-2xl font-bold">
{formatCost(cost, tier.key)}
</p>
<p className="mb-1 text-sm font-medium text-neutral-600 dark:text-neutral-400">
{tier.multiplier} rate limits
</p>
@@ -108,14 +167,20 @@ export function SubscriptionTierSection() {
{tier.description}
</p>
{!isCurrent && (
{!isCurrent && isPaymentEnabled && (
<Button
className="w-full"
variant={isUpgrade ? "default" : "outline"}
disabled={isPending}
onClick={() => handleTierChange(tier.key)}
onClick={() =>
handleTierChange(
tier.key,
currentTier,
setConfirmDowngradeTo,
)
}
>
{isPending
{isThisPending
? "Updating..."
: isUpgrade
? `Upgrade to ${tier.label}`
@@ -129,12 +194,80 @@ export function SubscriptionTierSection() {
})}
</div>
{subscription.tier !== "FREE" && (
{currentTier !== "FREE" && isPaymentEnabled && (
<p className="text-sm text-neutral-500">
Your subscription is managed through Stripe. Changes take effect
immediately.
Your subscription is managed through Stripe. Upgrades and paid-tier
changes take effect immediately; downgrades to Free are scheduled for
the end of the current billing period.
</p>
)}
<Dialog
title="Confirm Downgrade"
controlled={{
isOpen: !!confirmDowngradeTo,
set: (open) => {
if (!open) setConfirmDowngradeTo(null);
},
}}
>
<Dialog.Content>
<p className="text-sm text-neutral-600 dark:text-neutral-400">
{confirmDowngradeTo === "FREE"
? "Downgrading to Free will schedule your subscription to cancel at the end of your current billing period. You keep your current plan until then."
: `Switching to ${TIERS.find((t) => t.key === confirmDowngradeTo)?.label ?? confirmDowngradeTo} will take effect immediately.`}{" "}
Are you sure?
</p>
<Dialog.Footer>
<Button
variant="outline"
onClick={() => setConfirmDowngradeTo(null)}
>
Cancel
</Button>
<Button
variant="destructive"
onClick={() => void confirmDowngrade()}
>
Confirm Downgrade
</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog>
<Dialog
title="Confirm Upgrade"
controlled={{
isOpen: !!pendingUpgradeTier,
set: (open) => {
if (!open) setPendingUpgradeTier(null);
},
}}
>
<Dialog.Content>
<p className="text-sm text-neutral-600 dark:text-neutral-400">
{subscription &&
subscription.proration_credit_cents > 0 &&
`Your unused ${currentTier.charAt(0) + currentTier.slice(1).toLowerCase()} subscription ($${(subscription.proration_credit_cents / 100).toFixed(2)}) will be applied as a credit to your next Stripe invoice. `}
{currentTier === "FREE"
? `You will be redirected to Stripe to complete your upgrade to ${TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? pendingUpgradeTier}.`
: `Upgrading to ${TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? pendingUpgradeTier} will take effect immediately — Stripe will prorate your remaining balance.`}
</p>
<Dialog.Footer>
<Button
variant="outline"
onClick={() => setPendingUpgradeTier(null)}
>
Cancel
</Button>
<Button onClick={() => void confirmUpgrade()}>
{currentTier === "FREE"
? "Continue to Checkout"
: "Confirm Upgrade"}
</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog>
</div>
);
}

View File

@@ -0,0 +1,379 @@
import {
render,
screen,
fireEvent,
waitFor,
cleanup,
} from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { SubscriptionTierSection } from "../SubscriptionTierSection";
// Mock next/navigation
const mockSearchParams = new URLSearchParams();
const mockRouterReplace = vi.fn();
vi.mock("next/navigation", async (importOriginal) => {
const actual = await importOriginal<typeof import("next/navigation")>();
return {
...actual,
useSearchParams: () => mockSearchParams,
useRouter: () => ({ push: vi.fn(), replace: mockRouterReplace }),
usePathname: () => "/profile/credits",
};
});
// Mock toast
const mockToast = vi.fn();
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: mockToast }),
}));
// Mock feature flags — default to payment enabled so button tests work
let mockPaymentEnabled = true;
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: { ENABLE_PLATFORM_PAYMENT: "enable-platform-payment" },
useGetFlag: () => mockPaymentEnabled,
}));
// Mock generated API hooks
const mockUseGetSubscriptionStatus = vi.fn();
const mockUseUpdateSubscriptionTier = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/credits/credits", () => ({
useGetSubscriptionStatus: (opts: unknown) =>
mockUseGetSubscriptionStatus(opts),
useUpdateSubscriptionTier: () => mockUseUpdateSubscriptionTier(),
}));
// Mock Dialog (Radix portals don't work in happy-dom)
const MockDialogContent = ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
);
const MockDialogFooter = ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
);
function MockDialog({
controlled,
children,
}: {
controlled?: { isOpen: boolean; set: (open: boolean) => void };
children: React.ReactNode;
[key: string]: unknown;
}) {
return controlled?.isOpen ? <div role="dialog">{children}</div> : null;
}
MockDialog.Content = MockDialogContent;
MockDialog.Footer = MockDialogFooter;
vi.mock("@/components/molecules/Dialog/Dialog", () => ({
Dialog: MockDialog,
}));
function makeSubscription({
tier = "FREE",
monthlyCost = 0,
tierCosts = { FREE: 0, PRO: 1999, BUSINESS: 4999, ENTERPRISE: 0 },
prorationCreditCents = 0,
}: {
tier?: string;
monthlyCost?: number;
tierCosts?: Record<string, number>;
prorationCreditCents?: number;
} = {}) {
return {
tier,
monthly_cost: monthlyCost,
tier_costs: tierCosts,
proration_credit_cents: prorationCreditCents,
};
}
function setupMocks({
subscription = makeSubscription(),
isLoading = false,
queryError = null as Error | null,
mutateFn = vi.fn().mockResolvedValue({ status: 200, data: { url: "" } }),
isPending = false,
variables = undefined as { data?: { tier?: string } } | undefined,
} = {}) {
// The hook uses select: (data) => (data.status === 200 ? data.data : null)
// so the data value returned by the hook is already the transformed subscription object.
// We simulate that by returning the subscription directly as data.
mockUseGetSubscriptionStatus.mockReturnValue({
data: subscription,
isLoading,
error: queryError,
refetch: vi.fn(),
});
mockUseUpdateSubscriptionTier.mockReturnValue({
mutateAsync: mutateFn,
isPending,
variables,
});
}
afterEach(() => {
cleanup();
mockUseGetSubscriptionStatus.mockReset();
mockUseUpdateSubscriptionTier.mockReset();
mockToast.mockReset();
mockRouterReplace.mockReset();
mockSearchParams.delete("subscription");
mockPaymentEnabled = true;
});
describe("SubscriptionTierSection", () => {
it("renders skeleton cards while loading", () => {
setupMocks({ isLoading: true });
render(<SubscriptionTierSection />);
// Just verify we're rendering something (not null) and no tier cards
expect(screen.queryByText("Pro")).toBeNull();
expect(screen.queryByText("Business")).toBeNull();
});
it("renders error message when subscription fetch fails", () => {
setupMocks({
queryError: new Error("Network error"),
subscription: makeSubscription(),
});
// Override the data to simulate failed state
mockUseGetSubscriptionStatus.mockReturnValue({
data: null,
isLoading: false,
error: new Error("Network error"),
refetch: vi.fn(),
});
render(<SubscriptionTierSection />);
expect(screen.getByRole("alert")).toBeDefined();
expect(screen.getByText(/failed to load subscription info/i)).toBeDefined();
});
it("renders all three tier cards for FREE user", () => {
setupMocks();
render(<SubscriptionTierSection />);
// Use getAllByText to account for the tier label AND cost display both containing "Free"
expect(screen.getAllByText("Free").length).toBeGreaterThan(0);
expect(screen.getByText("Pro")).toBeDefined();
expect(screen.getByText("Business")).toBeDefined();
});
it("shows Current badge on the active tier", () => {
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
render(<SubscriptionTierSection />);
expect(screen.getByText("Current")).toBeDefined();
// Upgrade to PRO button should NOT exist; Upgrade to BUSINESS and Downgrade to Free should
expect(
screen.queryByRole("button", { name: /upgrade to pro/i }),
).toBeNull();
expect(
screen.getByRole("button", { name: /upgrade to business/i }),
).toBeDefined();
expect(
screen.getByRole("button", { name: /downgrade to free/i }),
).toBeDefined();
});
it("displays tier costs from the API", () => {
setupMocks({
subscription: makeSubscription({
tier: "FREE",
tierCosts: { FREE: 0, PRO: 1999, BUSINESS: 4999, ENTERPRISE: 0 },
}),
});
render(<SubscriptionTierSection />);
expect(screen.getByText("$19.99/mo")).toBeDefined();
expect(screen.getByText("$49.99/mo")).toBeDefined();
// FREE tier label should still be visible (there may be multiple "Free" elements)
expect(screen.getAllByText("Free").length).toBeGreaterThan(0);
});
it("shows 'Pricing available soon' when tier cost is 0 for a paid tier", () => {
setupMocks({
subscription: makeSubscription({
tier: "FREE",
tierCosts: { FREE: 0, PRO: 0, BUSINESS: 0, ENTERPRISE: 0 },
}),
});
render(<SubscriptionTierSection />);
// PRO and BUSINESS with cost=0 should show "Pricing available soon"
expect(screen.getAllByText("Pricing available soon")).toHaveLength(2);
});
it("calls changeTier on upgrade click after confirmation dialog", async () => {
const mutateFn = vi
.fn()
.mockResolvedValue({ status: 200, data: { url: "" } });
setupMocks({ mutateFn });
render(<SubscriptionTierSection />);
// Clicking upgrade opens the confirmation dialog first
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
// Confirm via the dialog's "Continue to Checkout" button
fireEvent.click(
screen.getByRole("button", { name: /continue to checkout/i }),
);
await waitFor(() => {
expect(mutateFn).toHaveBeenCalledWith(
expect.objectContaining({
data: expect.objectContaining({ tier: "PRO" }),
}),
);
});
});
it("shows confirmation dialog on downgrade click", () => {
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
render(<SubscriptionTierSection />);
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
expect(screen.getByRole("dialog")).toBeDefined();
// The dialog title text appears in both a div and a button — just check the dialog is open
expect(screen.getAllByText(/confirm downgrade/i).length).toBeGreaterThan(0);
});
it("calls changeTier after downgrade confirmation", async () => {
const mutateFn = vi
.fn()
.mockResolvedValue({ status: 200, data: { url: "" } });
setupMocks({
subscription: makeSubscription({ tier: "PRO" }),
mutateFn,
});
render(<SubscriptionTierSection />);
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
fireEvent.click(screen.getByRole("button", { name: /confirm downgrade/i }));
await waitFor(() => {
expect(mutateFn).toHaveBeenCalledWith(
expect.objectContaining({
data: expect.objectContaining({ tier: "FREE" }),
}),
);
});
});
it("dismisses dialog when Cancel is clicked", () => {
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
render(<SubscriptionTierSection />);
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
expect(screen.getByRole("dialog")).toBeDefined();
fireEvent.click(screen.getByRole("button", { name: /^cancel$/i }));
expect(screen.queryByRole("dialog")).toBeNull();
});
it("redirects to Stripe when checkout URL is returned", async () => {
// Replace window.location with a plain object so assigning .href doesn't
// trigger jsdom navigation (which would throw or reload the test page).
const mockLocation = { href: "" };
vi.stubGlobal("location", mockLocation);
const mutateFn = vi.fn().mockResolvedValue({
status: 200,
data: { url: "https://checkout.stripe.com/pay/cs_test" },
});
setupMocks({ mutateFn });
render(<SubscriptionTierSection />);
// Upgrade opens confirmation dialog first — confirm via "Continue to Checkout"
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
fireEvent.click(
screen.getByRole("button", { name: /continue to checkout/i }),
);
await waitFor(() => {
expect(mockLocation.href).toBe("https://checkout.stripe.com/pay/cs_test");
});
vi.unstubAllGlobals();
});
it("shows an error alert when tier change fails", async () => {
const mutateFn = vi.fn().mockRejectedValue(new Error("Stripe unavailable"));
setupMocks({ mutateFn });
render(<SubscriptionTierSection />);
// Upgrade opens confirmation dialog first — confirm to trigger the mutation
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
fireEvent.click(
screen.getByRole("button", { name: /continue to checkout/i }),
);
await waitFor(() => {
expect(screen.getByRole("alert")).toBeDefined();
expect(screen.getByText(/stripe unavailable/i)).toBeDefined();
});
});
it("hides action buttons when payment flag is disabled", () => {
mockPaymentEnabled = false;
setupMocks({ subscription: makeSubscription({ tier: "FREE" }) });
render(<SubscriptionTierSection />);
// Tier cards still visible
expect(screen.getByText("Pro")).toBeDefined();
expect(screen.getByText("Business")).toBeDefined();
// No upgrade/downgrade buttons
expect(screen.queryByRole("button", { name: /upgrade/i })).toBeNull();
expect(screen.queryByRole("button", { name: /downgrade/i })).toBeNull();
});
it("shows ENTERPRISE message for ENTERPRISE tier users", () => {
setupMocks({ subscription: makeSubscription({ tier: "ENTERPRISE" }) });
render(<SubscriptionTierSection />);
// Enterprise heading text appears in a <p> (may match multiple), just verify it exists
expect(screen.getAllByText(/enterprise plan/i).length).toBeGreaterThan(0);
expect(screen.getByText(/managed by your administrator/i)).toBeDefined();
// No standard tier cards should be rendered
expect(screen.queryByText("Pro")).toBeNull();
expect(screen.queryByText("Business")).toBeNull();
});
it("shows success toast and clears URL param when ?subscription=success is present", async () => {
mockSearchParams.set("subscription", "success");
setupMocks();
render(<SubscriptionTierSection />);
await waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({ title: "Subscription upgraded" }),
);
});
// URL param must be stripped so a page refresh doesn't re-trigger the toast
expect(mockRouterReplace).toHaveBeenCalledWith("/profile/credits");
});
it("clears URL param but shows no toast when ?subscription=cancelled is present", async () => {
mockSearchParams.set("subscription", "cancelled");
setupMocks();
render(<SubscriptionTierSection />);
// The cancelled param must be stripped from the URL (same hygiene as success)
await waitFor(() => {
expect(mockRouterReplace).toHaveBeenCalledWith("/profile/credits");
});
// No toast should fire — the user simply abandoned checkout
expect(mockToast).not.toHaveBeenCalled();
});
it("shows 'Confirm Upgrade' button (not 'Continue to Checkout') for paid→paid tier change", () => {
// PRO user upgrading to BUSINESS — modify in-place, no Stripe redirect
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
render(<SubscriptionTierSection />);
fireEvent.click(
screen.getByRole("button", { name: /upgrade to business/i }),
);
// For paid→paid, the button should say "Confirm Upgrade" not "Continue to Checkout"
expect(
screen.getByRole("button", { name: /confirm upgrade/i }),
).toBeDefined();
expect(
screen.queryByRole("button", { name: /continue to checkout/i }),
).toBeNull();
// Dialog body should mention "take effect immediately" not "redirected to Stripe"
// Two elements match: the static info paragraph and the dialog body — verify at least 2
expect(screen.getAllByText(/take effect immediately/i).length).toBeGreaterThanOrEqual(2);
});
});

View File

@@ -1,13 +1,30 @@
import { useEffect, useState } from "react";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import {
useGetSubscriptionStatus,
useUpdateSubscriptionTier,
} from "@/app/api/__generated__/endpoints/credits/credits";
import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse";
import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
export type SubscriptionStatus = SubscriptionStatusResponse;
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
export function useSubscriptionTierSection() {
const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
const searchParams = useSearchParams();
const subscriptionStatus = searchParams.get("subscription");
const router = useRouter();
const pathname = usePathname();
const { toast } = useToast();
const [tierError, setTierError] = useState<string | null>(null);
const [pendingUpgradeTier, setPendingUpgradeTier] = useState<string | null>(
null,
);
const {
data: subscription,
isLoading,
@@ -17,11 +34,39 @@ export function useSubscriptionTierSection() {
query: { select: (data) => (data.status === 200 ? data.data : null) },
});
const error = queryError ? "Failed to load subscription info" : null;
const fetchError = queryError ? "Failed to load subscription info" : null;
const { mutateAsync: doUpdateTier, isPending } = useUpdateSubscriptionTier();
const {
mutateAsync: doUpdateTier,
isPending,
variables,
} = useUpdateSubscriptionTier();
async function changeTier(tier: string): Promise<string | null> {
useEffect(() => {
if (subscriptionStatus === "success") {
refetch();
toast({
title: "Subscription upgraded",
description:
"Your plan has been updated. It may take a moment to reflect.",
});
}
// Strip ?subscription=success|cancelled from the URL so a page refresh
// does not re-trigger side-effects, and so a second checkout in the same
// session correctly fires the toast again.
if (
subscriptionStatus === "success" ||
subscriptionStatus === "cancelled"
) {
router.replace(pathname);
}
// eslint-disable-next-line react-hooks/exhaustive-deps -- refetch and toast
// are new references each render but are stable in practice; the effect must
// only re-run when subscriptionStatus/pathname changes.
}, [subscriptionStatus, refetch, toast, router, pathname]);
async function changeTier(tier: string) {
setTierError(null);
try {
const successUrl = `${window.location.origin}${window.location.pathname}?subscription=success`;
const cancelUrl = `${window.location.origin}${window.location.pathname}?subscription=cancelled`;
@@ -34,22 +79,59 @@ export function useSubscriptionTierSection() {
});
if (result.status === 200 && result.data.url) {
window.location.href = result.data.url;
return null;
return;
}
await refetch();
return null;
toast({
title: "Subscription updated",
description:
tier === "FREE"
? "Your plan will be downgraded to Free at the end of your current billing period."
: "Your subscription has been updated.",
});
} catch (e: unknown) {
const msg =
e instanceof Error ? e.message : "Failed to change subscription tier";
return msg;
setTierError(msg);
}
}
function handleTierChange(
targetTierKey: string,
currentTier: string,
onConfirmDowngrade: (tier: string) => void,
) {
const currentIdx = TIER_ORDER.indexOf(currentTier);
const targetIdx = TIER_ORDER.indexOf(targetTierKey);
if (targetIdx < currentIdx) {
onConfirmDowngrade(targetTierKey);
return;
}
setPendingUpgradeTier(targetTierKey);
}
async function confirmUpgrade() {
if (!pendingUpgradeTier) return;
const tier = pendingUpgradeTier;
setPendingUpgradeTier(null);
await changeTier(tier);
}
const pendingTier =
isPending && variables?.data?.tier ? variables.data.tier : null;
return {
subscription: subscription ?? null,
isLoading,
error,
error: fetchError,
tierError,
isPending,
pendingTier,
pendingUpgradeTier,
setPendingUpgradeTier,
confirmUpgrade,
isPaymentEnabled,
changeTier,
handleTierChange,
};
}

View File

@@ -14122,16 +14122,29 @@
},
"SubscriptionStatusResponse": {
"properties": {
"tier": { "type": "string", "title": "Tier" },
"tier": {
"type": "string",
"enum": ["FREE", "PRO", "BUSINESS", "ENTERPRISE"],
"title": "Tier"
},
"monthly_cost": { "type": "integer", "title": "Monthly Cost" },
"tier_costs": {
"additionalProperties": { "type": "integer" },
"type": "object",
"title": "Tier Costs"
},
"proration_credit_cents": {
"type": "integer",
"title": "Proration Credit Cents"
}
},
"type": "object",
"required": ["tier", "monthly_cost", "tier_costs"],
"required": [
"tier",
"monthly_cost",
"tier_costs",
"proration_credit_cents"
],
"title": "SubscriptionStatusResponse"
},
"SubscriptionTier": {

View File

@@ -194,26 +194,6 @@ export default class BackendAPI {
return this._request("PATCH", "/credits");
}
getSubscription(): Promise<{
tier: string;
monthly_cost: number;
tier_costs: Record<string, number>;
}> {
return this._get("/credits/subscription");
}
setSubscriptionTier(
tier: string,
successUrl?: string,
cancelUrl?: string,
): Promise<{ url: string }> {
return this._request("POST", "/credits/subscription", {
tier,
success_url: successUrl ?? "",
cancel_url: cancelUrl ?? "",
});
}
////////////////////////////////////////
//////////////// GRAPHS ////////////////
////////////////////////////////////////