diff --git a/apps/sim/lib/billing/organizations/membership.ts b/apps/sim/lib/billing/organizations/membership.ts index fa9169cb2..5fee8bb5c 100644 --- a/apps/sim/lib/billing/organizations/membership.ts +++ b/apps/sim/lib/billing/organizations/membership.ts @@ -15,7 +15,7 @@ import { userStats, } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { and, eq, inArray, sql } from 'drizzle-orm' +import { and, eq, inArray, isNull, ne, or, sql } from 'drizzle-orm' import { syncUsageLimitsFromSubscription } from '@/lib/billing/core/usage' import { requireStripeClient } from '@/lib/billing/stripe-client' import { validateSeatAvailability } from '@/lib/billing/validation/seat-management' @@ -38,6 +38,10 @@ export async function getOrgMemberIds(organizationId: string): Promise /** * Block all members of an organization for billing reasons + * Returns the number of members actually blocked + * + * Reason priority: dispute > payment_failed + * A payment_failed block won't overwrite an existing dispute block */ export async function blockOrgMembers( organizationId: string, @@ -49,17 +53,28 @@ export async function blockOrgMembers( return 0 } - await db + // Don't overwrite dispute blocks with payment_failed (dispute is higher priority) + const whereClause = + reason === 'payment_failed' + ? and( + inArray(userStats.userId, memberIds), + or(ne(userStats.billingBlockedReason, 'dispute'), isNull(userStats.billingBlockedReason)) + ) + : inArray(userStats.userId, memberIds) + + const result = await db .update(userStats) .set({ billingBlocked: true, billingBlockedReason: reason }) - .where(inArray(userStats.userId, memberIds)) + .where(whereClause) + .returning({ userId: userStats.userId }) - return memberIds.length + return result.length } /** * Unblock all members of an organization blocked for a specific reason * Only unblocks members blocked for the specified reason (not other reasons) + * Returns the number of members actually unblocked */ export async function unblockOrgMembers( organizationId: string, @@ -71,12 +86,13 @@ export async function unblockOrgMembers( return 0 } - await db + const result = await db .update(userStats) .set({ billingBlocked: false, billingBlockedReason: null }) .where(and(inArray(userStats.userId, memberIds), eq(userStats.billingBlockedReason, reason))) + .returning({ userId: userStats.userId }) - return memberIds.length + return result.length } export interface RestoreProResult { diff --git a/apps/sim/lib/billing/webhooks/disputes.ts b/apps/sim/lib/billing/webhooks/disputes.ts index 586c36a3f..647ad8a9c 100644 --- a/apps/sim/lib/billing/webhooks/disputes.ts +++ b/apps/sim/lib/billing/webhooks/disputes.ts @@ -1,7 +1,7 @@ import { db } from '@sim/db' import { subscription, user, userStats } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { eq } from 'drizzle-orm' +import { and, eq } from 'drizzle-orm' import type Stripe from 'stripe' import { blockOrgMembers, unblockOrgMembers } from '@/lib/billing' import { requireStripeClient } from '@/lib/billing/stripe-client' @@ -97,7 +97,7 @@ export async function handleDisputeClosed(event: Stripe.Event): Promise { return } - // Find and unblock user (Pro plans) + // Find and unblock user (Pro plans) - only if blocked for dispute, not other reasons const users = await db .select({ id: user.id }) .from(user) @@ -108,7 +108,7 @@ export async function handleDisputeClosed(event: Stripe.Event): Promise { await db .update(userStats) .set({ billingBlocked: false, billingBlockedReason: null }) - .where(eq(userStats.userId, users[0].id)) + .where(and(eq(userStats.userId, users[0].id), eq(userStats.billingBlockedReason, 'dispute'))) logger.info('Unblocked user after dispute resolved in our favor', { disputeId: dispute.id, diff --git a/apps/sim/lib/billing/webhooks/invoices.ts b/apps/sim/lib/billing/webhooks/invoices.ts index d551ba645..2c1419631 100644 --- a/apps/sim/lib/billing/webhooks/invoices.ts +++ b/apps/sim/lib/billing/webhooks/invoices.ts @@ -8,7 +8,7 @@ import { userStats, } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { and, eq, inArray } from 'drizzle-orm' +import { and, eq, inArray, isNull, ne, or } from 'drizzle-orm' import type Stripe from 'stripe' import { getEmailSubject, PaymentFailedEmail, renderCreditPurchaseEmail } from '@/components/emails' import { calculateSubscriptionOverage } from '@/lib/billing/core/billing' @@ -607,10 +607,19 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) { isOverageInvoice, }) } else { + // Don't overwrite dispute blocks (dispute > payment_failed priority) await db .update(userStats) .set({ billingBlocked: true, billingBlockedReason: 'payment_failed' }) - .where(eq(userStats.userId, sub.referenceId)) + .where( + and( + eq(userStats.userId, sub.referenceId), + or( + ne(userStats.billingBlockedReason, 'dispute'), + isNull(userStats.billingBlockedReason) + ) + ) + ) logger.info('Blocked user due to payment failure', { userId: sub.referenceId, isOverageInvoice,