fix(ratelimits): enterprise and team checks should be pooled limit (#1255)

* fix(ratelimits): enterprise and team checks should be pooled limit"

* fix

* fix dynamic imports

* fix tests"
;
This commit is contained in:
Vikhyath Mondreti
2025-09-04 21:44:56 -07:00
committed by GitHub
parent 8668622d66
commit 864622c1dc
11 changed files with 6192 additions and 81 deletions

View File

@@ -4,6 +4,7 @@ import { NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid'
import { z } from 'zod'
import { checkServerSideUsageLimits } from '@/lib/billing'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { getPersonalAndWorkspaceEnv } from '@/lib/environment/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { LoggingSession } from '@/lib/logs/execution/logging-session'
@@ -18,7 +19,7 @@ import { decryptSecret } from '@/lib/utils'
import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/db-helpers'
import { updateWorkflowRunCounts } from '@/lib/workflows/utils'
import { db } from '@/db'
import { subscription, userStats, workflow, workflowSchedule } from '@/db/schema'
import { userStats, workflow, workflowSchedule } from '@/db/schema'
import { Executor } from '@/executor'
import { Serializer } from '@/serializer'
import { RateLimiter } from '@/services/queue'
@@ -108,19 +109,15 @@ export async function GET() {
continue
}
// Check rate limits for scheduled execution
const [subscriptionRecord] = await db
.select({ plan: subscription.plan })
.from(subscription)
.where(eq(subscription.referenceId, workflowRecord.userId))
.limit(1)
// Check rate limits for scheduled execution (checks both personal and org subscriptions)
const userSubscription = await getHighestPrioritySubscription(workflowRecord.userId)
const subscriptionPlan = (subscriptionRecord?.plan || 'free') as SubscriptionPlan
const subscriptionPlan = (userSubscription?.plan || 'free') as SubscriptionPlan
const rateLimiter = new RateLimiter()
const rateLimitCheck = await rateLimiter.checkRateLimit(
const rateLimitCheck = await rateLimiter.checkRateLimitWithSubscription(
workflowRecord.userId,
subscriptionPlan,
userSubscription,
'schedule',
false // schedules are always sync
)

View File

@@ -1,10 +1,11 @@
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { createLogger } from '@/lib/logs/console/logger'
import { createErrorResponse } from '@/app/api/workflows/utils'
import { db } from '@/db'
import { apiKey as apiKeyTable, subscription } from '@/db/schema'
import { apiKey as apiKeyTable } from '@/db/schema'
import { RateLimiter } from '@/services/queue'
const logger = createLogger('RateLimitAPI')
@@ -33,31 +34,22 @@ export async function GET(request: NextRequest) {
return createErrorResponse('Authentication required', 401)
}
const [subscriptionRecord] = await db
.select({ plan: subscription.plan })
.from(subscription)
.where(eq(subscription.referenceId, authenticatedUserId))
.limit(1)
const subscriptionPlan = (subscriptionRecord?.plan || 'free') as
| 'free'
| 'pro'
| 'team'
| 'enterprise'
// Get user subscription (checks both personal and org subscriptions)
const userSubscription = await getHighestPrioritySubscription(authenticatedUserId)
const rateLimiter = new RateLimiter()
const isApiAuth = !session?.user?.id
const triggerType = isApiAuth ? 'api' : 'manual'
const syncStatus = await rateLimiter.getRateLimitStatus(
const syncStatus = await rateLimiter.getRateLimitStatusWithSubscription(
authenticatedUserId,
subscriptionPlan,
userSubscription,
triggerType,
false
)
const asyncStatus = await rateLimiter.getRateLimitStatus(
const asyncStatus = await rateLimiter.getRateLimitStatusWithSubscription(
authenticatedUserId,
subscriptionPlan,
userSubscription,
triggerType,
true
)

View File

@@ -2,6 +2,7 @@ import { tasks } from '@trigger.dev/sdk'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { checkServerSideUsageLimits } from '@/lib/billing'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { env, isTruthy } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import {
@@ -11,7 +12,7 @@ import {
} from '@/lib/webhooks/utils'
import { executeWebhookJob } from '@/background/webhook-execution'
import { db } from '@/db'
import { subscription, webhook, workflow } from '@/db/schema'
import { webhook, workflow } from '@/db/schema'
import { RateLimiter } from '@/services/queue'
import type { SubscriptionPlan } from '@/services/queue/types'
@@ -249,21 +250,17 @@ export async function POST(
// --- PHASE 3: Rate limiting for webhook execution ---
let isEnterprise = false
try {
// Get user subscription for rate limiting
const [subscriptionRecord] = await db
.select({ plan: subscription.plan })
.from(subscription)
.where(eq(subscription.referenceId, foundWorkflow.userId))
.limit(1)
// Get user subscription for rate limiting (checks both personal and org subscriptions)
const userSubscription = await getHighestPrioritySubscription(foundWorkflow.userId)
const subscriptionPlan = (subscriptionRecord?.plan || 'free') as SubscriptionPlan
const subscriptionPlan = (userSubscription?.plan || 'free') as SubscriptionPlan
isEnterprise = subscriptionPlan === 'enterprise'
// Check async rate limits (webhooks are processed asynchronously)
const rateLimiter = new RateLimiter()
const rateLimitCheck = await rateLimiter.checkRateLimit(
const rateLimitCheck = await rateLimiter.checkRateLimitWithSubscription(
foundWorkflow.userId,
subscriptionPlan,
userSubscription,
'webhook',
true // isAsync = true for webhook execution
)

View File

@@ -46,6 +46,11 @@ describe('Workflow Execution API Route', () => {
remaining: 10,
resetAt: new Date(),
}),
checkRateLimitWithSubscription: vi.fn().mockResolvedValue({
allowed: true,
remaining: 10,
resetAt: new Date(),
}),
})),
RateLimitError: class RateLimitError extends Error {
constructor(
@@ -66,6 +71,13 @@ describe('Workflow Execution API Route', () => {
}),
}))
vi.doMock('@/lib/billing/core/subscription', () => ({
getHighestPrioritySubscription: vi.fn().mockResolvedValue({
plan: 'free',
referenceId: 'user-id',
}),
}))
vi.doMock('@/db/schema', () => ({
subscription: {
plan: 'plan',

View File

@@ -5,6 +5,7 @@ import { v4 as uuidv4 } from 'uuid'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { checkServerSideUsageLimits } from '@/lib/billing'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { getPersonalAndWorkspaceEnv } from '@/lib/environment/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { LoggingSession } from '@/lib/logs/execution/logging-session'
@@ -19,7 +20,7 @@ import {
import { validateWorkflowAccess } from '@/app/api/workflows/middleware'
import { createErrorResponse, createSuccessResponse } from '@/app/api/workflows/utils'
import { db } from '@/db'
import { subscription, userStats } from '@/db/schema'
import { userStats } from '@/db/schema'
import { Executor } from '@/executor'
import { Serializer } from '@/serializer'
import {
@@ -374,19 +375,15 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
try {
// Check rate limits BEFORE entering queue for GET requests
if (triggerType === 'api') {
// Get user subscription
const [subscriptionRecord] = await db
.select({ plan: subscription.plan })
.from(subscription)
.where(eq(subscription.referenceId, validation.workflow.userId))
.limit(1)
// Get user subscription (checks both personal and org subscriptions)
const userSubscription = await getHighestPrioritySubscription(validation.workflow.userId)
const subscriptionPlan = (subscriptionRecord?.plan || 'free') as SubscriptionPlan
const subscriptionPlan = (userSubscription?.plan || 'free') as SubscriptionPlan
const rateLimiter = new RateLimiter()
const rateLimitCheck = await rateLimiter.checkRateLimit(
const rateLimitCheck = await rateLimiter.checkRateLimitWithSubscription(
validation.workflow.userId,
subscriptionPlan,
userSubscription,
triggerType,
false // isAsync = false for sync calls
)
@@ -505,20 +502,17 @@ export async function POST(
return createErrorResponse('Authentication required', 401)
}
const [subscriptionRecord] = await db
.select({ plan: subscription.plan })
.from(subscription)
.where(eq(subscription.referenceId, authenticatedUserId))
.limit(1)
// Get user subscription (checks both personal and org subscriptions)
const userSubscription = await getHighestPrioritySubscription(authenticatedUserId)
const subscriptionPlan = (subscriptionRecord?.plan || 'free') as SubscriptionPlan
const subscriptionPlan = (userSubscription?.plan || 'free') as SubscriptionPlan
if (isAsync) {
try {
const rateLimiter = new RateLimiter()
const rateLimitCheck = await rateLimiter.checkRateLimit(
const rateLimitCheck = await rateLimiter.checkRateLimitWithSubscription(
authenticatedUserId,
subscriptionPlan,
userSubscription,
'api',
true // isAsync = true
)
@@ -580,9 +574,9 @@ export async function POST(
try {
const rateLimiter = new RateLimiter()
const rateLimitCheck = await rateLimiter.checkRateLimit(
const rateLimitCheck = await rateLimiter.checkRateLimitWithSubscription(
authenticatedUserId,
subscriptionPlan,
userSubscription,
triggerType,
false // isAsync = false for sync calls
)

View File

@@ -0,0 +1,2 @@
ALTER TABLE "user_rate_limits" RENAME COLUMN "user_id" TO "reference_id";--> statement-breakpoint
ALTER TABLE "user_rate_limits" DROP CONSTRAINT "user_rate_limits_user_id_user_id_fk";

File diff suppressed because it is too large Load Diff

View File

@@ -582,6 +582,13 @@
"when": 1756768177306,
"tag": "0083_ambiguous_dreadnoughts",
"breakpoints": true
},
{
"idx": 84,
"version": "7",
"when": 1757046301281,
"tag": "0084_even_lockheed",
"breakpoints": true
}
]
}

View File

@@ -531,9 +531,7 @@ export const subscription = pgTable(
)
export const userRateLimits = pgTable('user_rate_limits', {
userId: text('user_id')
.primaryKey()
.references(() => user.id, { onDelete: 'cascade' }),
referenceId: text('reference_id').primaryKey(), // Can be userId or organizationId for pooling
syncApiRequests: integer('sync_api_requests').notNull().default(0), // Sync API requests counter
asyncApiRequests: integer('async_api_requests').notNull().default(0), // Async API requests counter
windowStart: timestamp('window_start').notNull().defaultNow(),

View File

@@ -19,6 +19,11 @@ vi.mock('drizzle-orm', () => ({
and: vi.fn((...conditions) => ({ and: conditions })),
}))
// Mock getHighestPrioritySubscription
vi.mock('@/lib/billing/core/subscription', () => ({
getHighestPrioritySubscription: vi.fn().mockResolvedValue(null),
}))
import { db } from '@/db'
describe('RateLimiter', () => {

View File

@@ -1,4 +1,5 @@
import { eq, sql } from 'drizzle-orm'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { userRateLimits } from '@/db/schema'
@@ -12,14 +13,43 @@ import {
const logger = createLogger('RateLimiter')
interface SubscriptionInfo {
plan: string
referenceId: string
}
export class RateLimiter {
/**
* Check if user can execute a workflow
* Determine the rate limit key based on subscription
* For team/enterprise plans via organization, use the organization ID
* For direct user subscriptions (including direct team), use the user ID
*/
private getRateLimitKey(userId: string, subscription: SubscriptionInfo | null): string {
if (!subscription) {
return userId
}
const plan = subscription.plan as SubscriptionPlan
// Check if this is an organization subscription (referenceId !== userId)
// If referenceId === userId, it's a direct user subscription
if ((plan === 'team' || plan === 'enterprise') && subscription.referenceId !== userId) {
// This is an organization subscription
// All organization members share the same rate limit pool
return subscription.referenceId
}
// For direct user subscriptions (free/pro/team/enterprise where referenceId === userId)
return userId
}
/**
* Check if user can execute a workflow with organization-aware rate limiting
* Manual executions bypass rate limiting entirely
*/
async checkRateLimit(
async checkRateLimitWithSubscription(
userId: string,
subscriptionPlan: SubscriptionPlan = 'free',
subscription: SubscriptionInfo | null,
triggerType: TriggerType = 'manual',
isAsync = false
): Promise<{ allowed: boolean; remaining: number; resetAt: Date }> {
@@ -32,6 +62,9 @@ export class RateLimiter {
}
}
const subscriptionPlan = (subscription?.plan || 'free') as SubscriptionPlan
const rateLimitKey = this.getRateLimitKey(userId, subscription)
const limit = RATE_LIMITS[subscriptionPlan]
const execLimit = isAsync
? limit.asyncApiExecutionsPerMinute
@@ -40,11 +73,11 @@ export class RateLimiter {
const now = new Date()
const windowStart = new Date(now.getTime() - RATE_LIMIT_WINDOW_MS)
// Get or create rate limit record
// Get or create rate limit record using the rate limit key
const [rateLimitRecord] = await db
.select()
.from(userRateLimits)
.where(eq(userRateLimits.userId, userId))
.where(eq(userRateLimits.referenceId, rateLimitKey))
.limit(1)
if (!rateLimitRecord || new Date(rateLimitRecord.windowStart) < windowStart) {
@@ -52,7 +85,7 @@ export class RateLimiter {
const result = await db
.insert(userRateLimits)
.values({
userId,
referenceId: rateLimitKey,
syncApiRequests: isAsync ? 0 : 1,
asyncApiRequests: isAsync ? 1 : 0,
windowStart: now,
@@ -60,7 +93,7 @@ export class RateLimiter {
isRateLimited: false,
})
.onConflictDoUpdate({
target: userRateLimits.userId,
target: userRateLimits.referenceId,
set: {
// Only reset if window is still expired (avoid race condition)
syncApiRequests: sql`CASE WHEN ${userRateLimits.windowStart} < ${windowStart.toISOString()} THEN ${isAsync ? 0 : 1} ELSE ${userRateLimits.syncApiRequests} + ${isAsync ? 0 : 1} END`,
@@ -94,7 +127,20 @@ export class RateLimiter {
isRateLimited: true,
rateLimitResetAt: resetAt,
})
.where(eq(userRateLimits.userId, userId))
.where(eq(userRateLimits.referenceId, rateLimitKey))
logger.info(
`Rate limit exceeded - request ${actualCount} > limit ${execLimit} for ${
rateLimitKey === userId ? `user ${userId}` : `organization ${rateLimitKey}`
}`,
{
execLimit,
isAsync,
actualCount,
rateLimitKey,
plan: subscriptionPlan,
}
)
return {
allowed: false,
@@ -119,7 +165,7 @@ export class RateLimiter {
: { syncApiRequests: sql`${userRateLimits.syncApiRequests} + 1` }),
lastRequestAt: now,
})
.where(eq(userRateLimits.userId, userId))
.where(eq(userRateLimits.referenceId, rateLimitKey))
.returning({
asyncApiRequests: userRateLimits.asyncApiRequests,
syncApiRequests: userRateLimits.syncApiRequests,
@@ -137,11 +183,15 @@ export class RateLimiter {
)
logger.info(
`Rate limit exceeded - request ${actualNewRequests} > limit ${execLimit} for user ${userId}`,
`Rate limit exceeded - request ${actualNewRequests} > limit ${execLimit} for ${
rateLimitKey === userId ? `user ${userId}` : `organization ${rateLimitKey}`
}`,
{
execLimit,
isAsync,
actualNewRequests,
rateLimitKey,
plan: subscriptionPlan,
}
)
@@ -152,7 +202,7 @@ export class RateLimiter {
isRateLimited: true,
rateLimitResetAt: resetAt,
})
.where(eq(userRateLimits.userId, userId))
.where(eq(userRateLimits.referenceId, rateLimitKey))
return {
allowed: false,
@@ -178,14 +228,29 @@ export class RateLimiter {
}
/**
* Get current rate limit status for user
* Only applies to API executions
* Legacy method - for backward compatibility
* @deprecated Use checkRateLimitWithSubscription instead
*/
async getRateLimitStatus(
async checkRateLimit(
userId: string,
subscriptionPlan: SubscriptionPlan = 'free',
triggerType: TriggerType = 'manual',
isAsync = false
): Promise<{ allowed: boolean; remaining: number; resetAt: Date }> {
// For backward compatibility, fetch the subscription
const subscription = await getHighestPrioritySubscription(userId)
return this.checkRateLimitWithSubscription(userId, subscription, triggerType, isAsync)
}
/**
* Get current rate limit status with organization awareness
* Only applies to API executions
*/
async getRateLimitStatusWithSubscription(
userId: string,
subscription: SubscriptionInfo | null,
triggerType: TriggerType = 'manual',
isAsync = false
): Promise<{ used: number; limit: number; remaining: number; resetAt: Date }> {
try {
if (triggerType === 'manual') {
@@ -197,6 +262,9 @@ export class RateLimiter {
}
}
const subscriptionPlan = (subscription?.plan || 'free') as SubscriptionPlan
const rateLimitKey = this.getRateLimitKey(userId, subscription)
const limit = RATE_LIMITS[subscriptionPlan]
const execLimit = isAsync
? limit.asyncApiExecutionsPerMinute
@@ -207,7 +275,7 @@ export class RateLimiter {
const [rateLimitRecord] = await db
.select()
.from(userRateLimits)
.where(eq(userRateLimits.userId, userId))
.where(eq(userRateLimits.referenceId, rateLimitKey))
.limit(1)
if (!rateLimitRecord || new Date(rateLimitRecord.windowStart) < windowStart) {
@@ -229,8 +297,9 @@ export class RateLimiter {
} catch (error) {
logger.error('Error getting rate limit status:', error)
const execLimit = isAsync
? RATE_LIMITS[subscriptionPlan].asyncApiExecutionsPerMinute
: RATE_LIMITS[subscriptionPlan].syncApiExecutionsPerMinute
? RATE_LIMITS[(subscription?.plan || 'free') as SubscriptionPlan]
.asyncApiExecutionsPerMinute
: RATE_LIMITS[(subscription?.plan || 'free') as SubscriptionPlan].syncApiExecutionsPerMinute
return {
used: 0,
limit: execLimit,
@@ -241,13 +310,27 @@ export class RateLimiter {
}
/**
* Reset rate limit for user (admin action)
* Legacy method - for backward compatibility
* @deprecated Use getRateLimitStatusWithSubscription instead
*/
async resetRateLimit(userId: string): Promise<void> {
try {
await db.delete(userRateLimits).where(eq(userRateLimits.userId, userId))
async getRateLimitStatus(
userId: string,
subscriptionPlan: SubscriptionPlan = 'free',
triggerType: TriggerType = 'manual',
isAsync = false
): Promise<{ used: number; limit: number; remaining: number; resetAt: Date }> {
// For backward compatibility, fetch the subscription
const subscription = await getHighestPrioritySubscription(userId)
return this.getRateLimitStatusWithSubscription(userId, subscription, triggerType, isAsync)
}
logger.info(`Reset rate limit for user ${userId}`)
/**
* Reset rate limit for a user or organization
*/
async resetRateLimit(rateLimitKey: string): Promise<void> {
try {
await db.delete(userRateLimits).where(eq(userRateLimits.referenceId, rateLimitKey))
logger.info(`Reset rate limit for ${rateLimitKey}`)
} catch (error) {
logger.error('Error resetting rate limit:', error)
throw error