mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-07 22:24:06 -05:00
feat(billing): bill by threshold to prevent cancellation edge case (#1583)
* feat(billing): bill by threshold to prevent cancellation edge case * fix org billing * fix idempotency key issue * small optimization for team checks * remove console log * remove unused type * fix error handling
This commit is contained in:
committed by
waleed
parent
88d2e7b97b
commit
8ce5a1b7c0
@@ -166,6 +166,38 @@ Different subscription plans have different usage limits:
|
||||
| **Team** | $500 (pooled) | 50 sync, 100 async |
|
||||
| **Enterprise** | Custom | Custom |
|
||||
|
||||
## Billing Model
|
||||
|
||||
Sim uses a **base subscription + overage** billing model:
|
||||
|
||||
### How It Works
|
||||
|
||||
**Pro Plan ($20/month):**
|
||||
- Monthly subscription includes $20 of usage
|
||||
- Usage under $20 → No additional charges
|
||||
- Usage over $20 → Pay the overage at month end
|
||||
- Example: $35 usage = $20 (subscription) + $15 (overage)
|
||||
|
||||
**Team Plan ($40/seat/month):**
|
||||
- Pooled usage across all team members
|
||||
- Overage calculated from total team usage
|
||||
- Organization owner receives one bill
|
||||
|
||||
**Enterprise Plans:**
|
||||
- Fixed monthly price, no overages
|
||||
- Custom usage limits per agreement
|
||||
|
||||
### Threshold Billing
|
||||
|
||||
When unbilled overage reaches $50, Sim automatically bills the full unbilled amount.
|
||||
|
||||
**Example:**
|
||||
- Day 10: $70 overage → Bill $70 immediately
|
||||
- Day 15: Additional $35 usage ($105 total) → Already billed, no action
|
||||
- Day 20: Another $50 usage ($155 total, $85 unbilled) → Bill $85 immediately
|
||||
|
||||
This spreads large overage charges throughout the month instead of one large bill at period end.
|
||||
|
||||
## Cost Management Best Practices
|
||||
|
||||
1. **Monitor Regularly**: Check your usage dashboard frequently to avoid surprises
|
||||
|
||||
@@ -3,6 +3,7 @@ import { userStats } from '@sim/db/schema'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
|
||||
import { checkInternalApiKey } from '@/lib/copilot/utils'
|
||||
import { isBillingEnabled } from '@/lib/environment'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -148,6 +149,9 @@ export async function POST(req: NextRequest) {
|
||||
addedTokens: totalTokens,
|
||||
})
|
||||
|
||||
// Check if user has hit overage threshold and bill incrementally
|
||||
await checkAndBillOverageThreshold(userId)
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
|
||||
logger.info(`[${requestId}] Cost update completed successfully`, {
|
||||
|
||||
@@ -3,6 +3,7 @@ import { userStats, workflow } from '@sim/db/schema'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
|
||||
import { env } from '@/lib/env'
|
||||
import { getCostMultiplier, isBillingEnabled } from '@/lib/environment'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -133,6 +134,9 @@ async function updateUserStatsForWand(
|
||||
tokensUsed: totalTokens,
|
||||
costAdded: costToStore,
|
||||
})
|
||||
|
||||
// Check if user has hit overage threshold and bill incrementally
|
||||
await checkAndBillOverageThreshold(userId)
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to update user stats for wand usage`, error)
|
||||
}
|
||||
|
||||
@@ -19,3 +19,9 @@ export const DEFAULT_ENTERPRISE_TIER_COST_LIMIT = 200
|
||||
* This charge is applied regardless of whether the workflow uses AI models
|
||||
*/
|
||||
export const BASE_EXECUTION_CHARGE = 0.001
|
||||
|
||||
/**
|
||||
* Default threshold (in dollars) for incremental overage billing
|
||||
* When unbilled overage reaches this amount, an invoice item is created
|
||||
*/
|
||||
export const DEFAULT_OVERAGE_THRESHOLD = 50
|
||||
|
||||
419
apps/sim/lib/billing/threshold-billing.ts
Normal file
419
apps/sim/lib/billing/threshold-billing.ts
Normal file
@@ -0,0 +1,419 @@
|
||||
import { db } from '@sim/db'
|
||||
import { member, subscription, userStats } from '@sim/db/schema'
|
||||
import { and, eq, inArray, sql } from 'drizzle-orm'
|
||||
import type Stripe from 'stripe'
|
||||
import { DEFAULT_OVERAGE_THRESHOLD } from '@/lib/billing/constants'
|
||||
import { calculateSubscriptionOverage, getPlanPricing } from '@/lib/billing/core/billing'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { requireStripeClient } from '@/lib/billing/stripe-client'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('ThresholdBilling')
|
||||
|
||||
const OVERAGE_THRESHOLD = env.OVERAGE_THRESHOLD_DOLLARS || DEFAULT_OVERAGE_THRESHOLD
|
||||
|
||||
function parseDecimal(value: string | number | null | undefined): number {
|
||||
if (value === null || value === undefined) return 0
|
||||
return Number.parseFloat(value.toString())
|
||||
}
|
||||
|
||||
async function createAndFinalizeOverageInvoice(
|
||||
stripe: ReturnType<typeof requireStripeClient>,
|
||||
params: {
|
||||
customerId: string
|
||||
stripeSubscriptionId: string
|
||||
amountCents: number
|
||||
description: string
|
||||
itemDescription: string
|
||||
metadata: Record<string, string>
|
||||
idempotencyKey: string
|
||||
}
|
||||
): Promise<string> {
|
||||
const getPaymentMethodId = (
|
||||
pm: string | Stripe.PaymentMethod | null | undefined
|
||||
): string | undefined => (typeof pm === 'string' ? pm : pm?.id)
|
||||
|
||||
let defaultPaymentMethod: string | undefined
|
||||
try {
|
||||
const stripeSub = await stripe.subscriptions.retrieve(params.stripeSubscriptionId)
|
||||
const subDpm = getPaymentMethodId(stripeSub.default_payment_method)
|
||||
if (subDpm) {
|
||||
defaultPaymentMethod = subDpm
|
||||
} else {
|
||||
const custObj = await stripe.customers.retrieve(params.customerId)
|
||||
if (custObj && !('deleted' in custObj)) {
|
||||
const cust = custObj as Stripe.Customer
|
||||
const custDpm = getPaymentMethodId(cust.invoice_settings?.default_payment_method)
|
||||
if (custDpm) defaultPaymentMethod = custDpm
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('Failed to retrieve subscription or customer', { error: e })
|
||||
}
|
||||
|
||||
const invoice = await stripe.invoices.create(
|
||||
{
|
||||
customer: params.customerId,
|
||||
collection_method: 'charge_automatically',
|
||||
auto_advance: false,
|
||||
description: params.description,
|
||||
metadata: params.metadata,
|
||||
...(defaultPaymentMethod ? { default_payment_method: defaultPaymentMethod } : {}),
|
||||
},
|
||||
{ idempotencyKey: `${params.idempotencyKey}-invoice` }
|
||||
)
|
||||
|
||||
await stripe.invoiceItems.create(
|
||||
{
|
||||
customer: params.customerId,
|
||||
invoice: invoice.id,
|
||||
amount: params.amountCents,
|
||||
currency: 'usd',
|
||||
description: params.itemDescription,
|
||||
metadata: params.metadata,
|
||||
},
|
||||
{ idempotencyKey: params.idempotencyKey }
|
||||
)
|
||||
|
||||
if (invoice.id) {
|
||||
const finalized = await stripe.invoices.finalizeInvoice(invoice.id)
|
||||
|
||||
if (finalized.status === 'open' && finalized.id) {
|
||||
try {
|
||||
await stripe.invoices.pay(finalized.id, {
|
||||
payment_method: defaultPaymentMethod,
|
||||
})
|
||||
} catch (payError) {
|
||||
logger.error('Failed to auto-pay threshold overage invoice', {
|
||||
error: payError,
|
||||
invoiceId: finalized.id,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return invoice.id || ''
|
||||
}
|
||||
|
||||
export async function checkAndBillOverageThreshold(userId: string): Promise<void> {
|
||||
try {
|
||||
const threshold = OVERAGE_THRESHOLD
|
||||
|
||||
const userSubscription = await getHighestPrioritySubscription(userId)
|
||||
|
||||
if (!userSubscription || userSubscription.status !== 'active') {
|
||||
logger.debug('No active subscription for threshold billing', { userId })
|
||||
return
|
||||
}
|
||||
|
||||
if (
|
||||
!userSubscription.plan ||
|
||||
userSubscription.plan === 'free' ||
|
||||
userSubscription.plan === 'enterprise'
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
if (userSubscription.plan === 'team') {
|
||||
logger.debug('Team plan detected - triggering org-level threshold billing', {
|
||||
userId,
|
||||
organizationId: userSubscription.referenceId,
|
||||
})
|
||||
await checkAndBillOrganizationOverageThreshold(userSubscription.referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
await db.transaction(async (tx) => {
|
||||
const statsRecords = await tx
|
||||
.select()
|
||||
.from(userStats)
|
||||
.where(eq(userStats.userId, userId))
|
||||
.for('update')
|
||||
.limit(1)
|
||||
|
||||
if (statsRecords.length === 0) {
|
||||
logger.warn('User stats not found for threshold billing', { userId })
|
||||
return
|
||||
}
|
||||
|
||||
const stats = statsRecords[0]
|
||||
|
||||
const currentOverage = await calculateSubscriptionOverage({
|
||||
id: userSubscription.id,
|
||||
plan: userSubscription.plan,
|
||||
referenceId: userSubscription.referenceId,
|
||||
seats: userSubscription.seats,
|
||||
})
|
||||
const billedOverageThisPeriod = parseDecimal(stats.billedOverageThisPeriod)
|
||||
const unbilledOverage = Math.max(0, currentOverage - billedOverageThisPeriod)
|
||||
|
||||
logger.debug('Threshold billing check', {
|
||||
userId,
|
||||
plan: userSubscription.plan,
|
||||
currentOverage,
|
||||
billedOverageThisPeriod,
|
||||
unbilledOverage,
|
||||
threshold,
|
||||
})
|
||||
|
||||
if (unbilledOverage < threshold) {
|
||||
return
|
||||
}
|
||||
|
||||
const amountToBill = unbilledOverage
|
||||
|
||||
const stripeSubscriptionId = userSubscription.stripeSubscriptionId
|
||||
if (!stripeSubscriptionId) {
|
||||
logger.error('No Stripe subscription ID found', { userId })
|
||||
return
|
||||
}
|
||||
|
||||
const stripe = requireStripeClient()
|
||||
const stripeSubscription = await stripe.subscriptions.retrieve(stripeSubscriptionId)
|
||||
const customerId =
|
||||
typeof stripeSubscription.customer === 'string'
|
||||
? stripeSubscription.customer
|
||||
: stripeSubscription.customer.id
|
||||
|
||||
const periodEnd = userSubscription.periodEnd
|
||||
? Math.floor(userSubscription.periodEnd.getTime() / 1000)
|
||||
: Math.floor(Date.now() / 1000)
|
||||
const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7)
|
||||
|
||||
const amountCents = Math.round(amountToBill * 100)
|
||||
const totalOverageCents = Math.round(currentOverage * 100)
|
||||
const idempotencyKey = `threshold-overage:${customerId}:${stripeSubscriptionId}:${billingPeriod}:${totalOverageCents}:${amountCents}`
|
||||
|
||||
logger.info('Creating threshold overage invoice', {
|
||||
userId,
|
||||
plan: userSubscription.plan,
|
||||
amountToBill,
|
||||
billingPeriod,
|
||||
idempotencyKey,
|
||||
})
|
||||
|
||||
const cents = amountCents
|
||||
|
||||
const invoiceId = await createAndFinalizeOverageInvoice(stripe, {
|
||||
customerId,
|
||||
stripeSubscriptionId,
|
||||
amountCents: cents,
|
||||
description: `Threshold overage billing – ${billingPeriod}`,
|
||||
itemDescription: `Usage overage ($${amountToBill.toFixed(2)})`,
|
||||
metadata: {
|
||||
type: 'overage_threshold_billing',
|
||||
userId,
|
||||
subscriptionId: stripeSubscriptionId,
|
||||
billingPeriod,
|
||||
totalOverageAtTimeOfBilling: currentOverage.toFixed(2),
|
||||
},
|
||||
idempotencyKey,
|
||||
})
|
||||
|
||||
await tx
|
||||
.update(userStats)
|
||||
.set({
|
||||
billedOverageThisPeriod: sql`${userStats.billedOverageThisPeriod} + ${amountToBill}`,
|
||||
})
|
||||
.where(eq(userStats.userId, userId))
|
||||
|
||||
logger.info('Successfully created and finalized threshold overage invoice', {
|
||||
userId,
|
||||
amountBilled: amountToBill,
|
||||
invoiceId,
|
||||
newBilledTotal: billedOverageThisPeriod + amountToBill,
|
||||
})
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error in threshold billing check', {
|
||||
userId,
|
||||
error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkAndBillOrganizationOverageThreshold(
|
||||
organizationId: string
|
||||
): Promise<void> {
|
||||
logger.info('=== ENTERED checkAndBillOrganizationOverageThreshold ===', { organizationId })
|
||||
|
||||
try {
|
||||
const threshold = OVERAGE_THRESHOLD
|
||||
|
||||
logger.debug('Starting organization threshold billing check', { organizationId, threshold })
|
||||
|
||||
const orgSubscriptions = await db
|
||||
.select()
|
||||
.from(subscription)
|
||||
.where(and(eq(subscription.referenceId, organizationId), eq(subscription.status, 'active')))
|
||||
.limit(1)
|
||||
|
||||
if (orgSubscriptions.length === 0) {
|
||||
logger.debug('No active subscription for organization', { organizationId })
|
||||
return
|
||||
}
|
||||
|
||||
const orgSubscription = orgSubscriptions[0]
|
||||
logger.debug('Found organization subscription', {
|
||||
organizationId,
|
||||
plan: orgSubscription.plan,
|
||||
seats: orgSubscription.seats,
|
||||
stripeSubscriptionId: orgSubscription.stripeSubscriptionId,
|
||||
})
|
||||
|
||||
if (orgSubscription.plan !== 'team') {
|
||||
logger.debug('Organization plan is not team, skipping', {
|
||||
organizationId,
|
||||
plan: orgSubscription.plan,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const members = await db
|
||||
.select({ userId: member.userId, role: member.role })
|
||||
.from(member)
|
||||
.where(eq(member.organizationId, organizationId))
|
||||
|
||||
logger.debug('Found organization members', {
|
||||
organizationId,
|
||||
memberCount: members.length,
|
||||
members: members.map((m) => ({ userId: m.userId, role: m.role })),
|
||||
})
|
||||
|
||||
if (members.length === 0) {
|
||||
logger.warn('No members found for organization', { organizationId })
|
||||
return
|
||||
}
|
||||
|
||||
const owner = members.find((m) => m.role === 'owner')
|
||||
if (!owner) {
|
||||
logger.error('No owner found for organization', { organizationId })
|
||||
return
|
||||
}
|
||||
|
||||
logger.debug('Found organization owner, starting transaction', {
|
||||
organizationId,
|
||||
ownerId: owner.userId,
|
||||
})
|
||||
|
||||
await db.transaction(async (tx) => {
|
||||
const ownerStatsLock = await tx
|
||||
.select()
|
||||
.from(userStats)
|
||||
.where(eq(userStats.userId, owner.userId))
|
||||
.for('update')
|
||||
.limit(1)
|
||||
|
||||
if (ownerStatsLock.length === 0) {
|
||||
logger.error('Owner stats not found', { organizationId, ownerId: owner.userId })
|
||||
return
|
||||
}
|
||||
|
||||
let totalTeamUsage = parseDecimal(ownerStatsLock[0].currentPeriodCost)
|
||||
const totalBilledOverage = parseDecimal(ownerStatsLock[0].billedOverageThisPeriod)
|
||||
|
||||
const nonOwnerIds = members.filter((m) => m.userId !== owner.userId).map((m) => m.userId)
|
||||
|
||||
if (nonOwnerIds.length > 0) {
|
||||
const memberStatsRows = await tx
|
||||
.select({
|
||||
userId: userStats.userId,
|
||||
currentPeriodCost: userStats.currentPeriodCost,
|
||||
})
|
||||
.from(userStats)
|
||||
.where(inArray(userStats.userId, nonOwnerIds))
|
||||
|
||||
for (const stats of memberStatsRows) {
|
||||
totalTeamUsage += parseDecimal(stats.currentPeriodCost)
|
||||
}
|
||||
}
|
||||
|
||||
const { basePrice: basePricePerSeat } = getPlanPricing(orgSubscription.plan)
|
||||
const basePrice = basePricePerSeat * (orgSubscription.seats || 1)
|
||||
const currentOverage = Math.max(0, totalTeamUsage - basePrice)
|
||||
const unbilledOverage = Math.max(0, currentOverage - totalBilledOverage)
|
||||
|
||||
logger.debug('Organization threshold billing check', {
|
||||
organizationId,
|
||||
totalTeamUsage,
|
||||
basePrice,
|
||||
currentOverage,
|
||||
totalBilledOverage,
|
||||
unbilledOverage,
|
||||
threshold,
|
||||
})
|
||||
|
||||
if (unbilledOverage < threshold) {
|
||||
return
|
||||
}
|
||||
|
||||
const amountToBill = unbilledOverage
|
||||
|
||||
const stripeSubscriptionId = orgSubscription.stripeSubscriptionId
|
||||
if (!stripeSubscriptionId) {
|
||||
logger.error('No Stripe subscription ID for organization', { organizationId })
|
||||
return
|
||||
}
|
||||
|
||||
const stripe = requireStripeClient()
|
||||
const stripeSubscription = await stripe.subscriptions.retrieve(stripeSubscriptionId)
|
||||
const customerId =
|
||||
typeof stripeSubscription.customer === 'string'
|
||||
? stripeSubscription.customer
|
||||
: stripeSubscription.customer.id
|
||||
|
||||
const periodEnd = orgSubscription.periodEnd
|
||||
? Math.floor(orgSubscription.periodEnd.getTime() / 1000)
|
||||
: Math.floor(Date.now() / 1000)
|
||||
const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7)
|
||||
const amountCents = Math.round(amountToBill * 100)
|
||||
const totalOverageCents = Math.round(currentOverage * 100)
|
||||
|
||||
const idempotencyKey = `threshold-overage-org:${customerId}:${stripeSubscriptionId}:${billingPeriod}:${totalOverageCents}:${amountCents}`
|
||||
|
||||
logger.info('Creating organization threshold overage invoice', {
|
||||
organizationId,
|
||||
amountToBill,
|
||||
billingPeriod,
|
||||
})
|
||||
|
||||
const cents = amountCents
|
||||
|
||||
const invoiceId = await createAndFinalizeOverageInvoice(stripe, {
|
||||
customerId,
|
||||
stripeSubscriptionId,
|
||||
amountCents: cents,
|
||||
description: `Team threshold overage billing – ${billingPeriod}`,
|
||||
itemDescription: `Team usage overage ($${amountToBill.toFixed(2)})`,
|
||||
metadata: {
|
||||
type: 'overage_threshold_billing_org',
|
||||
organizationId,
|
||||
subscriptionId: stripeSubscriptionId,
|
||||
billingPeriod,
|
||||
totalOverageAtTimeOfBilling: currentOverage.toFixed(2),
|
||||
},
|
||||
idempotencyKey,
|
||||
})
|
||||
|
||||
await tx
|
||||
.update(userStats)
|
||||
.set({
|
||||
billedOverageThisPeriod: sql`${userStats.billedOverageThisPeriod} + ${amountToBill}`,
|
||||
})
|
||||
.where(eq(userStats.userId, owner.userId))
|
||||
|
||||
logger.info('Successfully created and finalized organization threshold overage invoice', {
|
||||
organizationId,
|
||||
ownerId: owner.userId,
|
||||
amountBilled: amountToBill,
|
||||
invoiceId,
|
||||
})
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error in organization threshold billing', {
|
||||
organizationId,
|
||||
error,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { db } from '@sim/db'
|
||||
import { member, subscription as subscriptionTable, userStats } from '@sim/db/schema'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { eq, inArray } from 'drizzle-orm'
|
||||
import type Stripe from 'stripe'
|
||||
import { calculateSubscriptionOverage } from '@/lib/billing/core/billing'
|
||||
import { requireStripeClient } from '@/lib/billing/stripe-client'
|
||||
@@ -8,6 +8,64 @@ import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('StripeInvoiceWebhooks')
|
||||
|
||||
const OVERAGE_INVOICE_TYPES = new Set<string>([
|
||||
'overage_billing',
|
||||
'overage_threshold_billing',
|
||||
'overage_threshold_billing_org',
|
||||
])
|
||||
|
||||
function parseDecimal(value: string | number | null | undefined): number {
|
||||
if (value === null || value === undefined) return 0
|
||||
return Number.parseFloat(value.toString())
|
||||
}
|
||||
|
||||
/**
|
||||
* Get total billed overage for a subscription, handling team vs individual plans
|
||||
* For team plans: sums billedOverageThisPeriod across all members
|
||||
* For other plans: gets billedOverageThisPeriod for the user
|
||||
*/
|
||||
export async function getBilledOverageForSubscription(sub: {
|
||||
plan: string | null
|
||||
referenceId: string
|
||||
}): Promise<number> {
|
||||
let billedOverage = 0
|
||||
|
||||
if (sub.plan === 'team') {
|
||||
const members = await db
|
||||
.select({ userId: member.userId })
|
||||
.from(member)
|
||||
.where(eq(member.organizationId, sub.referenceId))
|
||||
|
||||
const memberIds = members.map((m) => m.userId)
|
||||
|
||||
if (memberIds.length > 0) {
|
||||
const memberStatsRows = await db
|
||||
.select({
|
||||
userId: userStats.userId,
|
||||
billedOverageThisPeriod: userStats.billedOverageThisPeriod,
|
||||
})
|
||||
.from(userStats)
|
||||
.where(inArray(userStats.userId, memberIds))
|
||||
|
||||
for (const stats of memberStatsRows) {
|
||||
billedOverage += parseDecimal(stats.billedOverageThisPeriod)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const userStatsRecords = await db
|
||||
.select({ billedOverageThisPeriod: userStats.billedOverageThisPeriod })
|
||||
.from(userStats)
|
||||
.where(eq(userStats.userId, sub.referenceId))
|
||||
.limit(1)
|
||||
|
||||
if (userStatsRecords.length > 0) {
|
||||
billedOverage = parseDecimal(userStatsRecords[0].billedOverageThisPeriod)
|
||||
}
|
||||
}
|
||||
|
||||
return billedOverage
|
||||
}
|
||||
|
||||
export async function resetUsageForSubscription(sub: { plan: string | null; referenceId: string }) {
|
||||
if (sub.plan === 'team' || sub.plan === 'enterprise') {
|
||||
const membersRows = await db
|
||||
@@ -25,7 +83,11 @@ export async function resetUsageForSubscription(sub: { plan: string | null; refe
|
||||
const current = currentStats[0].current || '0'
|
||||
await db
|
||||
.update(userStats)
|
||||
.set({ lastPeriodCost: current, currentPeriodCost: '0' })
|
||||
.set({
|
||||
lastPeriodCost: current,
|
||||
currentPeriodCost: '0',
|
||||
billedOverageThisPeriod: '0',
|
||||
})
|
||||
.where(eq(userStats.userId, m.userId))
|
||||
}
|
||||
}
|
||||
@@ -50,6 +112,7 @@ export async function resetUsageForSubscription(sub: { plan: string | null; refe
|
||||
lastPeriodCost: totalLastPeriod,
|
||||
currentPeriodCost: '0',
|
||||
proPeriodCostSnapshot: '0', // Clear snapshot at period end
|
||||
billedOverageThisPeriod: '0', // Clear threshold billing tracker at period end
|
||||
})
|
||||
.where(eq(userStats.userId, sub.referenceId))
|
||||
}
|
||||
@@ -88,16 +151,14 @@ export async function handleInvoicePaymentSucceeded(event: Stripe.Event) {
|
||||
.select({ userId: member.userId })
|
||||
.from(member)
|
||||
.where(eq(member.organizationId, sub.referenceId))
|
||||
for (const m of membersRows) {
|
||||
const row = await db
|
||||
const memberIds = membersRows.map((m) => m.userId)
|
||||
if (memberIds.length > 0) {
|
||||
const blockedRows = await db
|
||||
.select({ blocked: userStats.billingBlocked })
|
||||
.from(userStats)
|
||||
.where(eq(userStats.userId, m.userId))
|
||||
.limit(1)
|
||||
if (row.length > 0 && row[0].blocked) {
|
||||
wasBlocked = true
|
||||
break
|
||||
}
|
||||
.where(inArray(userStats.userId, memberIds))
|
||||
|
||||
wasBlocked = blockedRows.some((row) => !!row.blocked)
|
||||
}
|
||||
} else {
|
||||
const row = await db
|
||||
@@ -113,11 +174,13 @@ export async function handleInvoicePaymentSucceeded(event: Stripe.Event) {
|
||||
.select({ userId: member.userId })
|
||||
.from(member)
|
||||
.where(eq(member.organizationId, sub.referenceId))
|
||||
for (const m of members) {
|
||||
const memberIds = members.map((m) => m.userId)
|
||||
|
||||
if (memberIds.length > 0) {
|
||||
await db
|
||||
.update(userStats)
|
||||
.set({ billingBlocked: false })
|
||||
.where(eq(userStats.userId, m.userId))
|
||||
.where(inArray(userStats.userId, memberIds))
|
||||
}
|
||||
} else {
|
||||
await db
|
||||
@@ -143,7 +206,8 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) {
|
||||
try {
|
||||
const invoice = event.data.object as Stripe.Invoice
|
||||
|
||||
const isOverageInvoice = invoice.metadata?.type === 'overage_billing'
|
||||
const invoiceType = invoice.metadata?.type
|
||||
const isOverageInvoice = !!(invoiceType && OVERAGE_INVOICE_TYPES.has(invoiceType))
|
||||
let stripeSubscriptionId: string | undefined
|
||||
|
||||
if (isOverageInvoice) {
|
||||
@@ -203,11 +267,13 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) {
|
||||
.select({ userId: member.userId })
|
||||
.from(member)
|
||||
.where(eq(member.organizationId, sub.referenceId))
|
||||
for (const m of members) {
|
||||
const memberIds = members.map((m) => m.userId)
|
||||
|
||||
if (memberIds.length > 0) {
|
||||
await db
|
||||
.update(userStats)
|
||||
.set({ billingBlocked: true })
|
||||
.where(eq(userStats.userId, m.userId))
|
||||
.where(inArray(userStats.userId, memberIds))
|
||||
}
|
||||
logger.info('Blocked team/enterprise members due to payment failure', {
|
||||
organizationId: sub.referenceId,
|
||||
@@ -281,9 +347,23 @@ export async function handleInvoiceFinalized(event: Stripe.Event) {
|
||||
// Compute overage (only for team and pro plans), before resetting usage
|
||||
const totalOverage = await calculateSubscriptionOverage(sub)
|
||||
|
||||
if (totalOverage > 0) {
|
||||
// Get already-billed overage from threshold billing
|
||||
const billedOverage = await getBilledOverageForSubscription(sub)
|
||||
|
||||
// Only bill the remaining unbilled overage
|
||||
const remainingOverage = Math.max(0, totalOverage - billedOverage)
|
||||
|
||||
logger.info('Invoice finalized overage calculation', {
|
||||
subscriptionId: sub.id,
|
||||
totalOverage,
|
||||
billedOverage,
|
||||
remainingOverage,
|
||||
billingPeriod,
|
||||
})
|
||||
|
||||
if (remainingOverage > 0) {
|
||||
const customerId = String(invoice.customer)
|
||||
const cents = Math.round(totalOverage * 100)
|
||||
const cents = Math.round(remainingOverage * 100)
|
||||
const itemIdemKey = `overage-item:${customerId}:${stripeSubscriptionId}:${billingPeriod}`
|
||||
const invoiceIdemKey = `overage-invoice:${customerId}:${stripeSubscriptionId}:${billingPeriod}`
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import { and, eq, ne } from 'drizzle-orm'
|
||||
import { calculateSubscriptionOverage } from '@/lib/billing/core/billing'
|
||||
import { requireStripeClient } from '@/lib/billing/stripe-client'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { resetUsageForSubscription } from './invoices'
|
||||
import { getBilledOverageForSubscription, resetUsageForSubscription } from './invoices'
|
||||
|
||||
const logger = createLogger('StripeSubscriptionWebhooks')
|
||||
|
||||
@@ -104,11 +104,24 @@ export async function handleSubscriptionDeleted(subscription: {
|
||||
return
|
||||
}
|
||||
|
||||
// Get already-billed overage from threshold billing
|
||||
const billedOverage = await getBilledOverageForSubscription(subscription)
|
||||
|
||||
// Only bill the remaining unbilled overage
|
||||
const remainingOverage = Math.max(0, totalOverage - billedOverage)
|
||||
|
||||
logger.info('Subscription deleted overage calculation', {
|
||||
subscriptionId: subscription.id,
|
||||
totalOverage,
|
||||
billedOverage,
|
||||
remainingOverage,
|
||||
})
|
||||
|
||||
// Create final overage invoice if needed
|
||||
if (totalOverage > 0 && stripeSubscriptionId) {
|
||||
if (remainingOverage > 0 && stripeSubscriptionId) {
|
||||
const stripeSubscription = await stripe.subscriptions.retrieve(stripeSubscriptionId)
|
||||
const customerId = stripeSubscription.customer as string
|
||||
const cents = Math.round(totalOverage * 100)
|
||||
const cents = Math.round(remainingOverage * 100)
|
||||
|
||||
// Use the subscription end date for the billing period
|
||||
const endedAt = stripeSubscription.ended_at || Math.floor(Date.now() / 1000)
|
||||
@@ -145,7 +158,9 @@ export async function handleSubscriptionDeleted(subscription: {
|
||||
description: `Usage overage for ${subscription.plan} plan (Final billing period)`,
|
||||
metadata: {
|
||||
type: 'final_usage_overage',
|
||||
usage: totalOverage.toFixed(2),
|
||||
usage: remainingOverage.toFixed(2),
|
||||
totalOverage: totalOverage.toFixed(2),
|
||||
billedOverage: billedOverage.toFixed(2),
|
||||
billingPeriod,
|
||||
},
|
||||
},
|
||||
@@ -161,7 +176,9 @@ export async function handleSubscriptionDeleted(subscription: {
|
||||
subscriptionId: subscription.id,
|
||||
stripeSubscriptionId,
|
||||
invoiceId: overageInvoice.id,
|
||||
overageAmount: totalOverage,
|
||||
totalOverage,
|
||||
billedOverage,
|
||||
remainingOverage,
|
||||
cents,
|
||||
billingPeriod,
|
||||
})
|
||||
@@ -169,7 +186,9 @@ export async function handleSubscriptionDeleted(subscription: {
|
||||
logger.error('Failed to create final overage invoice', {
|
||||
subscriptionId: subscription.id,
|
||||
stripeSubscriptionId,
|
||||
overageAmount: totalOverage,
|
||||
totalOverage,
|
||||
billedOverage,
|
||||
remainingOverage,
|
||||
error: invoiceError,
|
||||
})
|
||||
// Don't throw - we don't want to fail the webhook
|
||||
|
||||
@@ -49,6 +49,7 @@ export const env = createEnv({
|
||||
STRIPE_ENTERPRISE_PRICE_ID: z.string().min(1).optional(), // Stripe price ID for enterprise tier
|
||||
ENTERPRISE_TIER_COST_LIMIT: z.number().optional(), // Cost limit for enterprise tier users
|
||||
BILLING_ENABLED: z.boolean().optional(), // Enable billing enforcement and usage tracking
|
||||
OVERAGE_THRESHOLD_DOLLARS: z.number().optional().default(50), // Dollar threshold for incremental overage billing (default: $50)
|
||||
|
||||
// Email & Communication
|
||||
EMAIL_VERIFICATION_ENABLED: z.boolean().optional(), // Enable email verification for user registration and login (defaults to false)
|
||||
|
||||
@@ -11,6 +11,7 @@ import { eq, sql } from 'drizzle-orm'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { checkUsageStatus, maybeSendUsageThresholdEmail } from '@/lib/billing/core/usage'
|
||||
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
|
||||
import { isBillingEnabled } from '@/lib/environment'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { emitWorkflowExecutionCompleted } from '@/lib/logs/events'
|
||||
@@ -441,6 +442,9 @@ export class ExecutionLogger implements IExecutionLoggerService {
|
||||
addedCost: costToStore,
|
||||
addedTokens: costSummary.totalTokens,
|
||||
})
|
||||
|
||||
// Check if user has hit overage threshold and bill incrementally
|
||||
await checkAndBillOverageThreshold(userId)
|
||||
} catch (error) {
|
||||
logger.error('Error updating user stats with cost information', {
|
||||
workflowId,
|
||||
|
||||
@@ -1,32 +1,14 @@
|
||||
import { drizzle, type PostgresJsDatabase } from 'drizzle-orm/postgres-js'
|
||||
import { drizzle } from 'drizzle-orm/postgres-js'
|
||||
import postgres from 'postgres'
|
||||
import * as schema from './schema'
|
||||
|
||||
export * from './schema'
|
||||
export type { PostgresJsDatabase }
|
||||
|
||||
const connectionString = process.env.DATABASE_URL!
|
||||
if (!connectionString) {
|
||||
throw new Error('Missing DATABASE_URL environment variable')
|
||||
}
|
||||
|
||||
console.log(
|
||||
'[DB Pool Init]',
|
||||
JSON.stringify({
|
||||
timestamp: new Date().toISOString(),
|
||||
nodeEnv: process.env.NODE_ENV,
|
||||
action: 'CREATING_CONNECTION_POOL',
|
||||
poolConfig: {
|
||||
max: 30,
|
||||
idle_timeout: 20,
|
||||
connect_timeout: 30,
|
||||
prepare: false,
|
||||
},
|
||||
pid: process.pid,
|
||||
isProduction: process.env.NODE_ENV === 'production',
|
||||
})
|
||||
)
|
||||
|
||||
const postgresClient = postgres(connectionString, {
|
||||
prepare: false,
|
||||
idle_timeout: 20,
|
||||
|
||||
1
packages/db/migrations/0097_dazzling_mephisto.sql
Normal file
1
packages/db/migrations/0097_dazzling_mephisto.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE "user_stats" ADD COLUMN "billed_overage_this_period" numeric DEFAULT '0' NOT NULL;
|
||||
6973
packages/db/migrations/meta/0097_snapshot.json
Normal file
6973
packages/db/migrations/meta/0097_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -673,6 +673,13 @@
|
||||
"when": 1759534968812,
|
||||
"tag": "0096_tranquil_arachne",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 97,
|
||||
"version": "7",
|
||||
"when": 1759963094548,
|
||||
"tag": "0097_dazzling_mephisto",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -564,6 +564,7 @@ export const userStats = pgTable('user_stats', {
|
||||
// Billing period tracking
|
||||
currentPeriodCost: decimal('current_period_cost').notNull().default('0'), // Usage in current billing period
|
||||
lastPeriodCost: decimal('last_period_cost').default('0'), // Usage from previous billing period
|
||||
billedOverageThisPeriod: decimal('billed_overage_this_period').notNull().default('0'), // Amount of overage already billed via threshold billing
|
||||
// Pro usage snapshot when joining a team (to prevent double-billing)
|
||||
proPeriodCostSnapshot: decimal('pro_period_cost_snapshot').default('0'), // Snapshot of Pro usage when joining team
|
||||
// Copilot usage tracking
|
||||
|
||||
Reference in New Issue
Block a user