Compare commits

...

2 Commits

Author SHA1 Message Date
Vikhyath Mondreti
5ba322f8f2 more usage of block/unblock helpers 2026-02-01 12:00:23 -08:00
Vikhyath Mondreti
f11e1cf5de improvement(billing): improve against direct subscription creation bypasses 2026-02-01 11:54:53 -08:00
15 changed files with 272 additions and 93 deletions

View File

@@ -20,6 +20,7 @@ import { z } from 'zod'
import { getEmailSubject, renderInvitationEmail } from '@/components/emails'
import { getSession } from '@/lib/auth'
import { hasAccessControlAccess } from '@/lib/billing'
import { syncUsageLimitsFromSubscription } from '@/lib/billing/core/usage'
import { requireStripeClient } from '@/lib/billing/stripe-client'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { sendEmail } from '@/lib/messaging/email/mailer'
@@ -501,6 +502,18 @@ export async function PUT(
}
}
if (status === 'accepted') {
try {
await syncUsageLimitsFromSubscription(session.user.id)
} catch (syncError) {
logger.error('Failed to sync usage limits after joining org', {
userId: session.user.id,
organizationId,
error: syncError,
})
}
}
logger.info(`Organization invitation ${status}`, {
organizationId,
invitationId,

View File

@@ -5,6 +5,7 @@ import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { hasActiveSubscription } from '@/lib/billing'
const logger = createLogger('SubscriptionTransferAPI')
@@ -88,6 +89,14 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
)
}
// Check if org already has an active subscription (prevent duplicates)
if (await hasActiveSubscription(organizationId)) {
return NextResponse.json(
{ error: 'Organization already has an active subscription' },
{ status: 409 }
)
}
await db
.update(subscription)
.set({ referenceId: organizationId })

View File

@@ -203,6 +203,10 @@ export const PATCH = withAdminAuthParams<RouteParams>(async (request, context) =
}
updateData.billingBlocked = body.billingBlocked
// Clear the reason when unblocking
if (body.billingBlocked === false) {
updateData.billingBlockedReason = null
}
updated.push('billingBlocked')
}

View File

@@ -1,6 +1,4 @@
import { db, workflow as workflowTable } from '@sim/db'
import { createLogger } from '@sim/logger'
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid'
import { z } from 'zod'
@@ -8,6 +6,7 @@ import { checkHybridAuth } from '@/lib/auth/hybrid'
import { generateRequestId } from '@/lib/core/utils/request'
import { SSE_HEADERS } from '@/lib/core/utils/sse'
import { markExecutionCancelled } from '@/lib/execution/cancellation'
import { preprocessExecution } from '@/lib/execution/preprocessing'
import { LoggingSession } from '@/lib/logs/execution/logging-session'
import { executeWorkflowCore } from '@/lib/workflows/executor/execution-core'
import { createSSECallbacks } from '@/lib/workflows/executor/execution-events'
@@ -75,12 +74,31 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
const { startBlockId, sourceSnapshot, input } = validation.data
const executionId = uuidv4()
const [workflowRecord] = await db
.select({ workspaceId: workflowTable.workspaceId, userId: workflowTable.userId })
.from(workflowTable)
.where(eq(workflowTable.id, workflowId))
.limit(1)
// Run preprocessing checks (billing, rate limits, usage limits)
const preprocessResult = await preprocessExecution({
workflowId,
userId,
triggerType: 'manual',
executionId,
requestId,
checkRateLimit: false, // Manual executions don't rate limit
checkDeployment: false, // Run-from-block doesn't require deployment
})
if (!preprocessResult.success) {
const { error } = preprocessResult
logger.warn(`[${requestId}] Preprocessing failed for run-from-block`, {
workflowId,
error: error?.message,
statusCode: error?.statusCode,
})
return NextResponse.json(
{ error: error?.message || 'Execution blocked' },
{ status: error?.statusCode || 500 }
)
}
const workflowRecord = preprocessResult.workflowRecord
if (!workflowRecord?.workspaceId) {
return NextResponse.json({ error: 'Workflow not found or has no workspace' }, { status: 404 })
}
@@ -92,6 +110,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
workflowId,
startBlockId,
executedBlocksCount: sourceSnapshot.executedBlocks.length,
billingActorUserId: preprocessResult.actorUserId,
})
const loggingSession = new LoggingSession(workflowId, executionId, 'manual', requestId)

View File

@@ -1,15 +1,32 @@
import { db } from '@sim/db'
import * as schema from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq } from 'drizzle-orm'
import { hasActiveSubscription } from '@/lib/billing'
const logger = createLogger('BillingAuthorization')
/**
* Check if a user is authorized to manage billing for a given reference ID
* Reference ID can be either a user ID (individual subscription) or organization ID (team subscription)
*
* This function also performs duplicate subscription validation:
* - Rejects if the referenceId already has an active subscription (prevents duplicates)
*/
export async function authorizeSubscriptionReference(
userId: string,
referenceId: string
): Promise<boolean> {
// Check for existing active subscriptions on this referenceId
// This prevents creating duplicate subscriptions for the same entity
if (await hasActiveSubscription(referenceId)) {
logger.warn('Blocking checkout - active subscription already exists for referenceId', {
userId,
referenceId,
})
return false
}
// User can always manage their own subscriptions
if (referenceId === userId) {
return true

View File

@@ -25,9 +25,11 @@ export function useSubscriptionUpgrade() {
}
let currentSubscriptionId: string | undefined
let allSubscriptions: any[] = []
try {
const listResult = await client.subscription.list()
const activePersonalSub = listResult.data?.find(
allSubscriptions = listResult.data || []
const activePersonalSub = allSubscriptions.find(
(sub: any) => sub.status === 'active' && sub.referenceId === userId
)
currentSubscriptionId = activePersonalSub?.id
@@ -50,6 +52,25 @@ export function useSubscriptionUpgrade() {
)
if (existingOrg) {
// Check if this org already has an active team subscription
const existingTeamSub = allSubscriptions.find(
(sub: any) =>
sub.status === 'active' &&
sub.referenceId === existingOrg.id &&
(sub.plan === 'team' || sub.plan === 'enterprise')
)
if (existingTeamSub) {
logger.warn('Organization already has an active team subscription', {
userId,
organizationId: existingOrg.id,
existingSubscriptionId: existingTeamSub.id,
})
throw new Error(
'This organization already has an active team subscription. Please manage it from the billing settings.'
)
}
logger.info('Using existing organization for team plan upgrade', {
userId,
organizationId: existingOrg.id,

View File

@@ -1,5 +1,5 @@
import { db } from '@sim/db'
import { member, subscription } from '@sim/db/schema'
import { member, organization, subscription } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, inArray } from 'drizzle-orm'
import { checkEnterprisePlan, checkProPlan, checkTeamPlan } from '@/lib/billing/subscriptions/utils'
@@ -26,10 +26,22 @@ export async function getHighestPrioritySubscription(userId: string) {
let orgSubs: typeof personalSubs = []
if (orgIds.length > 0) {
orgSubs = await db
.select()
.from(subscription)
.where(and(inArray(subscription.referenceId, orgIds), eq(subscription.status, 'active')))
// Verify orgs exist to filter out orphaned subscriptions
const existingOrgs = await db
.select({ id: organization.id })
.from(organization)
.where(inArray(organization.id, orgIds))
const validOrgIds = existingOrgs.map((o) => o.id)
if (validOrgIds.length > 0) {
orgSubs = await db
.select()
.from(subscription)
.where(
and(inArray(subscription.referenceId, validOrgIds), eq(subscription.status, 'active'))
)
}
}
const allSubs = [...personalSubs, ...orgSubs]

View File

@@ -25,6 +25,25 @@ const logger = createLogger('SubscriptionCore')
export { getHighestPrioritySubscription }
/**
* Check if a referenceId (user ID or org ID) has an active subscription
* Used for duplicate subscription prevention
*/
export async function hasActiveSubscription(referenceId: string): Promise<boolean> {
try {
const [activeSub] = await db
.select({ id: subscription.id })
.from(subscription)
.where(and(eq(subscription.referenceId, referenceId), eq(subscription.status, 'active')))
.limit(1)
return !!activeSub
} catch (error) {
logger.error('Error checking active subscription', { error, referenceId })
return false
}
}
/**
* Check if user is on Pro plan (direct or via organization)
*/

View File

@@ -11,6 +11,7 @@ export {
getHighestPrioritySubscription as getActiveSubscription,
getUserSubscriptionState as getSubscriptionState,
hasAccessControlAccess,
hasActiveSubscription,
hasCredentialSetsAccess,
hasSSOAccess,
isEnterpriseOrgAdminOrOwner,
@@ -32,6 +33,11 @@ export {
} from '@/lib/billing/core/usage'
export * from '@/lib/billing/credits/balance'
export * from '@/lib/billing/credits/purchase'
export {
blockOrgMembers,
getOrgMemberIds,
unblockOrgMembers,
} from '@/lib/billing/organizations/membership'
export * from '@/lib/billing/subscriptions/utils'
export { canEditUsageLimit as canEditLimit } from '@/lib/billing/subscriptions/utils'
export * from '@/lib/billing/types'

View File

@@ -8,6 +8,7 @@ import {
} from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq } from 'drizzle-orm'
import { hasActiveSubscription } from '@/lib/billing'
import { getPlanPricing } from '@/lib/billing/core/billing'
import { syncUsageLimitsFromSubscription } from '@/lib/billing/core/usage'
@@ -159,6 +160,16 @@ export async function ensureOrganizationForTeamSubscription(
if (existingMembership.length > 0) {
const membership = existingMembership[0]
if (membership.role === 'owner' || membership.role === 'admin') {
// Check if org already has an active subscription (prevent duplicates)
if (await hasActiveSubscription(membership.organizationId)) {
logger.error('Organization already has an active subscription', {
userId,
organizationId: membership.organizationId,
newSubscriptionId: subscription.id,
})
throw new Error('Organization already has an active subscription')
}
logger.info('User already owns/admins an org, using it', {
userId,
organizationId: membership.organizationId,

View File

@@ -15,13 +15,70 @@ import {
userStats,
} from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, sql } from 'drizzle-orm'
import { and, eq, inArray, sql } from 'drizzle-orm'
import { syncUsageLimitsFromSubscription } from '@/lib/billing/core/usage'
import { requireStripeClient } from '@/lib/billing/stripe-client'
import { validateSeatAvailability } from '@/lib/billing/validation/seat-management'
const logger = createLogger('OrganizationMembership')
export type BillingBlockReason = 'payment_failed' | 'dispute'
/**
* Get all member user IDs for an organization
*/
export async function getOrgMemberIds(organizationId: string): Promise<string[]> {
const members = await db
.select({ userId: member.userId })
.from(member)
.where(eq(member.organizationId, organizationId))
return members.map((m) => m.userId)
}
/**
* Block all members of an organization for billing reasons
*/
export async function blockOrgMembers(
organizationId: string,
reason: BillingBlockReason
): Promise<number> {
const memberIds = await getOrgMemberIds(organizationId)
if (memberIds.length === 0) {
return 0
}
await db
.update(userStats)
.set({ billingBlocked: true, billingBlockedReason: reason })
.where(inArray(userStats.userId, memberIds))
return memberIds.length
}
/**
* Unblock all members of an organization blocked for a specific reason
* Only unblocks members blocked for the specified reason (not other reasons)
*/
export async function unblockOrgMembers(
organizationId: string,
reason: BillingBlockReason
): Promise<number> {
const memberIds = await getOrgMemberIds(organizationId)
if (memberIds.length === 0) {
return 0
}
await db
.update(userStats)
.set({ billingBlocked: false, billingBlockedReason: null })
.where(and(inArray(userStats.userId, memberIds), eq(userStats.billingBlockedReason, reason)))
return memberIds.length
}
export interface RestoreProResult {
restored: boolean
usageRestored: boolean

View File

@@ -1,8 +1,9 @@
import { db } from '@sim/db'
import { member, subscription, user, userStats } from '@sim/db/schema'
import { subscription, user, userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq } from 'drizzle-orm'
import { eq } from 'drizzle-orm'
import type Stripe from 'stripe'
import { blockOrgMembers, unblockOrgMembers } from '@/lib/billing'
import { requireStripeClient } from '@/lib/billing/stripe-client'
const logger = createLogger('DisputeWebhooks')
@@ -57,36 +58,34 @@ export async function handleChargeDispute(event: Stripe.Event): Promise<void> {
if (subs.length > 0) {
const orgId = subs[0].referenceId
const memberCount = await blockOrgMembers(orgId, 'dispute')
const owners = await db
.select({ userId: member.userId })
.from(member)
.where(and(eq(member.organizationId, orgId), eq(member.role, 'owner')))
.limit(1)
if (owners.length > 0) {
await db
.update(userStats)
.set({ billingBlocked: true, billingBlockedReason: 'dispute' })
.where(eq(userStats.userId, owners[0].userId))
logger.warn('Blocked org owner due to dispute', {
if (memberCount > 0) {
logger.warn('Blocked all org members due to dispute', {
disputeId: dispute.id,
ownerId: owners[0].userId,
organizationId: orgId,
memberCount,
})
}
}
}
/**
* Handles charge.dispute.closed - unblocks user if dispute was won
* Handles charge.dispute.closed - unblocks user if dispute was won or warning closed
*
* Status meanings:
* - 'won': Merchant won, customer's chargeback denied → unblock
* - 'lost': Customer won, money refunded → stay blocked (they owe us)
* - 'warning_closed': Pre-dispute inquiry closed without chargeback → unblock (false alarm)
*/
export async function handleDisputeClosed(event: Stripe.Event): Promise<void> {
const dispute = event.data.object as Stripe.Dispute
if (dispute.status !== 'won') {
logger.info('Dispute not won, user remains blocked', {
// Only unblock if we won or the warning was closed without a full dispute
const shouldUnblock = dispute.status === 'won' || dispute.status === 'warning_closed'
if (!shouldUnblock) {
logger.info('Dispute resolved against us, user remains blocked', {
disputeId: dispute.id,
status: dispute.status,
})
@@ -111,14 +110,15 @@ export async function handleDisputeClosed(event: Stripe.Event): Promise<void> {
.set({ billingBlocked: false, billingBlockedReason: null })
.where(eq(userStats.userId, users[0].id))
logger.info('Unblocked user after winning dispute', {
logger.info('Unblocked user after dispute resolved in our favor', {
disputeId: dispute.id,
userId: users[0].id,
status: dispute.status,
})
return
}
// Find and unblock org owner (Team/Enterprise)
// Find and unblock all org members (Team/Enterprise) - consistent with payment success
const subs = await db
.select({ referenceId: subscription.referenceId })
.from(subscription)
@@ -127,24 +127,13 @@ export async function handleDisputeClosed(event: Stripe.Event): Promise<void> {
if (subs.length > 0) {
const orgId = subs[0].referenceId
const memberCount = await unblockOrgMembers(orgId, 'dispute')
const owners = await db
.select({ userId: member.userId })
.from(member)
.where(and(eq(member.organizationId, orgId), eq(member.role, 'owner')))
.limit(1)
if (owners.length > 0) {
await db
.update(userStats)
.set({ billingBlocked: false, billingBlockedReason: null })
.where(eq(userStats.userId, owners[0].userId))
logger.info('Unblocked org owner after winning dispute', {
disputeId: dispute.id,
ownerId: owners[0].userId,
organizationId: orgId,
})
}
logger.info('Unblocked all org members after dispute resolved in our favor', {
disputeId: dispute.id,
organizationId: orgId,
memberCount,
status: dispute.status,
})
}
}

View File

@@ -14,6 +14,7 @@ import { getEmailSubject, PaymentFailedEmail, renderCreditPurchaseEmail } from '
import { calculateSubscriptionOverage } from '@/lib/billing/core/billing'
import { addCredits, getCreditBalance, removeCredits } from '@/lib/billing/credits/balance'
import { setUsageLimitForCredits } from '@/lib/billing/credits/purchase'
import { blockOrgMembers, unblockOrgMembers } from '@/lib/billing/organizations/membership'
import { requireStripeClient } from '@/lib/billing/stripe-client'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { sendEmail } from '@/lib/messaging/email/mailer'
@@ -502,24 +503,7 @@ export async function handleInvoicePaymentSucceeded(event: Stripe.Event) {
}
if (sub.plan === 'team' || sub.plan === 'enterprise') {
const members = await db
.select({ userId: member.userId })
.from(member)
.where(eq(member.organizationId, sub.referenceId))
const memberIds = members.map((m) => m.userId)
if (memberIds.length > 0) {
// Only unblock users blocked for payment_failed, not disputes
await db
.update(userStats)
.set({ billingBlocked: false, billingBlockedReason: null })
.where(
and(
inArray(userStats.userId, memberIds),
eq(userStats.billingBlockedReason, 'payment_failed')
)
)
}
await unblockOrgMembers(sub.referenceId, 'payment_failed')
} else {
// Only unblock users blocked for payment_failed, not disputes
await db
@@ -616,21 +600,10 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) {
if (records.length > 0) {
const sub = records[0]
if (sub.plan === 'team' || sub.plan === 'enterprise') {
const members = await db
.select({ userId: member.userId })
.from(member)
.where(eq(member.organizationId, sub.referenceId))
const memberIds = members.map((m) => m.userId)
if (memberIds.length > 0) {
await db
.update(userStats)
.set({ billingBlocked: true, billingBlockedReason: 'payment_failed' })
.where(inArray(userStats.userId, memberIds))
}
const memberCount = await blockOrgMembers(sub.referenceId, 'payment_failed')
logger.info('Blocked team/enterprise members due to payment failure', {
organizationId: sub.referenceId,
memberCount: members.length,
memberCount,
isOverageInvoice,
})
} else {

View File

@@ -3,6 +3,7 @@ import { member, organization, subscription } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, ne } from 'drizzle-orm'
import { calculateSubscriptionOverage } from '@/lib/billing/core/billing'
import { hasActiveSubscription } from '@/lib/billing/core/subscription'
import { syncUsageLimitsFromSubscription } from '@/lib/billing/core/usage'
import { restoreUserProSubscription } from '@/lib/billing/organizations/membership'
import { requireStripeClient } from '@/lib/billing/stripe-client'
@@ -52,14 +53,37 @@ async function restoreMemberProSubscriptions(organizationId: string): Promise<nu
/**
* Cleanup organization when team/enterprise subscription is deleted.
* - Checks if other active subscriptions point to this org (skip deletion if so)
* - Restores member Pro subscriptions
* - Deletes the organization
* - Deletes the organization (only if no other active subs)
* - Syncs usage limits for former members (resets to free or Pro tier)
*/
async function cleanupOrganizationSubscription(organizationId: string): Promise<{
restoredProCount: number
membersSynced: number
organizationDeleted: boolean
}> {
// Check if other active subscriptions still point to this org
// Note: The subscription being deleted is already marked as 'canceled' by better-auth
// before this handler runs, so we only find truly active ones
if (await hasActiveSubscription(organizationId)) {
logger.info('Skipping organization deletion - other active subscriptions exist', {
organizationId,
})
// Still sync limits for members since this subscription was deleted
const memberUserIds = await db
.select({ userId: member.userId })
.from(member)
.where(eq(member.organizationId, organizationId))
for (const m of memberUserIds) {
await syncUsageLimitsFromSubscription(m.userId)
}
return { restoredProCount: 0, membersSynced: memberUserIds.length, organizationDeleted: false }
}
// Get member userIds before deletion (needed for limit syncing after org deletion)
const memberUserIds = await db
.select({ userId: member.userId })
@@ -75,7 +99,7 @@ async function cleanupOrganizationSubscription(organizationId: string): Promise<
await syncUsageLimitsFromSubscription(m.userId)
}
return { restoredProCount, membersSynced: memberUserIds.length }
return { restoredProCount, membersSynced: memberUserIds.length, organizationDeleted: true }
}
/**
@@ -172,15 +196,14 @@ export async function handleSubscriptionDeleted(subscription: {
referenceId: subscription.referenceId,
})
const { restoredProCount, membersSynced } = await cleanupOrganizationSubscription(
subscription.referenceId
)
const { restoredProCount, membersSynced, organizationDeleted } =
await cleanupOrganizationSubscription(subscription.referenceId)
logger.info('Successfully processed enterprise subscription cancellation', {
subscriptionId: subscription.id,
stripeSubscriptionId,
restoredProCount,
organizationDeleted: true,
organizationDeleted,
membersSynced,
})
return
@@ -297,7 +320,7 @@ export async function handleSubscriptionDeleted(subscription: {
const cleanup = await cleanupOrganizationSubscription(subscription.referenceId)
restoredProCount = cleanup.restoredProCount
membersSynced = cleanup.membersSynced
organizationDeleted = true
organizationDeleted = cleanup.organizationDeleted
} else if (subscription.plan === 'pro') {
await syncUsageLimitsFromSubscription(subscription.referenceId)
membersSynced = 1

View File

@@ -33,6 +33,7 @@ import type {
WorkflowExecutionSnapshot,
WorkflowState,
} from '@/lib/logs/types'
import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
export interface ToolCall {
name: string
@@ -503,7 +504,7 @@ export class ExecutionLogger implements IExecutionLoggerService {
}
try {
// Get the workflow record to get the userId
// Get the workflow record to get workspace and fallback userId
const [workflowRecord] = await db
.select()
.from(workflow)
@@ -515,7 +516,12 @@ export class ExecutionLogger implements IExecutionLoggerService {
return
}
const userId = workflowRecord.userId
let billingUserId: string | null = null
if (workflowRecord.workspaceId) {
billingUserId = await getWorkspaceBilledAccountUserId(workflowRecord.workspaceId)
}
const userId = billingUserId || workflowRecord.userId
const costToStore = costSummary.totalCost
const existing = await db.select().from(userStats).where(eq(userStats.userId, userId))