mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-10 07:27:57 -05:00
feat(redis): added redis option for rate limiter, 10x speed improvement in rate limit checks & reduction of DB load (#2263)
* feat(redis): added redis option for rate limiter, 10x speed improvement in rate limit checks & reduction of DB load * ack PR comments * improvements
This commit is contained in:
@@ -4,9 +4,9 @@ import { checkServerSideUsageLimits } from '@/lib/billing'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { getEffectiveCurrentPeriodCost } from '@/lib/billing/core/usage'
|
||||
import { getUserStorageLimit, getUserStorageUsage } from '@/lib/billing/storage'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { createErrorResponse } from '@/app/api/workflows/utils'
|
||||
import { RateLimiter } from '@/services/queue'
|
||||
|
||||
const logger = createLogger('UsageLimitsAPI')
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { checkServerSideUsageLimits } from '@/lib/billing'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { getEffectiveCurrentPeriodCost } from '@/lib/billing/core/usage'
|
||||
import { RateLimiter } from '@/services/queue'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter'
|
||||
|
||||
export interface UserLimits {
|
||||
workflowExecutionRateLimit: {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { authenticateV1Request } from '@/app/api/v1/auth'
|
||||
import { RateLimiter } from '@/services/queue/RateLimiter'
|
||||
|
||||
const logger = createLogger('V1Middleware')
|
||||
const rateLimiter = new RateLimiter()
|
||||
|
||||
@@ -140,7 +140,7 @@ vi.mock('@/lib/workspaces/utils', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/services/queue', () => ({
|
||||
vi.mock('@/lib/core/rate-limiter', () => ({
|
||||
RateLimiter: vi.fn().mockImplementation(() => ({
|
||||
checkRateLimit: vi.fn().mockResolvedValue({
|
||||
allowed: true,
|
||||
|
||||
@@ -395,8 +395,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
triggerType: loggingTriggerType,
|
||||
executionId,
|
||||
requestId,
|
||||
checkRateLimit: false, // Manual executions bypass rate limits
|
||||
checkDeployment: !shouldUseDraftState, // Check deployment unless using draft
|
||||
checkDeployment: !shouldUseDraftState,
|
||||
loggingSession,
|
||||
})
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ import { and, eq, isNull, lte, or, sql } from 'drizzle-orm'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { checkUsageStatus } from '@/lib/billing/calculations/usage-monitor'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter'
|
||||
import { decryptSecret } from '@/lib/core/security/encryption'
|
||||
import { getBaseUrl } from '@/lib/core/utils/urls'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { TraceSpan, WorkflowExecutionLog } from '@/lib/logs/types'
|
||||
import { sendEmail } from '@/lib/messaging/email/mailer'
|
||||
import type { AlertConfig } from '@/lib/notifications/alert-rules'
|
||||
import { RateLimiter } from '@/services/queue'
|
||||
|
||||
const logger = createLogger('WorkspaceNotificationDelivery')
|
||||
|
||||
|
||||
7
apps/sim/lib/core/rate-limiter/index.ts
Normal file
7
apps/sim/lib/core/rate-limiter/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter'
|
||||
export type {
|
||||
RateLimitConfig,
|
||||
SubscriptionPlan,
|
||||
TriggerType,
|
||||
} from '@/lib/core/rate-limiter/types'
|
||||
export { RATE_LIMITS, RateLimitError } from '@/lib/core/rate-limiter/types'
|
||||
309
apps/sim/lib/core/rate-limiter/rate-limiter.test.ts
Normal file
309
apps/sim/lib/core/rate-limiter/rate-limiter.test.ts
Normal file
@@ -0,0 +1,309 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter'
|
||||
import { MANUAL_EXECUTION_LIMIT, RATE_LIMITS } from '@/lib/core/rate-limiter/types'
|
||||
|
||||
vi.mock('@sim/db', () => ({
|
||||
db: {
|
||||
select: vi.fn(),
|
||||
insert: vi.fn(),
|
||||
update: vi.fn(),
|
||||
delete: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('drizzle-orm', () => ({
|
||||
eq: vi.fn((field, value) => ({ field, value })),
|
||||
sql: vi.fn((strings, ...values) => ({ sql: strings.join('?'), values })),
|
||||
and: vi.fn((...conditions) => ({ and: conditions })),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/core/config/redis', () => ({
|
||||
getRedisClient: vi.fn().mockReturnValue(null),
|
||||
}))
|
||||
|
||||
import { db } from '@sim/db'
|
||||
import { getRedisClient } from '@/lib/core/config/redis'
|
||||
|
||||
describe('RateLimiter', () => {
|
||||
const rateLimiter = new RateLimiter()
|
||||
const testUserId = 'test-user-123'
|
||||
const freeSubscription = { plan: 'free', referenceId: testUserId }
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(getRedisClient).mockReturnValue(null)
|
||||
})
|
||||
|
||||
describe('checkRateLimitWithSubscription', () => {
|
||||
it('should allow unlimited requests for manual trigger type', async () => {
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'manual',
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
expect(db.select).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should allow first API request for sync execution (DB fallback)', async () => {
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 1,
|
||||
asyncApiRequests: 0,
|
||||
apiEndpointRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 1)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
it('should allow first API request for async execution (DB fallback)', async () => {
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 0,
|
||||
asyncApiRequests: 1,
|
||||
apiEndpointRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
true
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.asyncApiExecutionsPerMinute - 1)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
it('should work for all trigger types except manual (DB fallback)', async () => {
|
||||
const triggerTypes = ['api', 'webhook', 'schedule', 'chat'] as const
|
||||
|
||||
for (const triggerType of triggerTypes) {
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 1,
|
||||
asyncApiRequests: 0,
|
||||
apiEndpointRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
triggerType,
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 1)
|
||||
}
|
||||
})
|
||||
|
||||
it('should use Redis when available', async () => {
|
||||
const mockRedis = {
|
||||
eval: vi.fn().mockResolvedValue(1), // Lua script returns count after INCR
|
||||
}
|
||||
vi.mocked(getRedisClient).mockReturnValue(mockRedis as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 1)
|
||||
expect(mockRedis.eval).toHaveBeenCalled()
|
||||
expect(db.select).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should deny requests when Redis rate limit exceeded', async () => {
|
||||
const mockRedis = {
|
||||
eval: vi.fn().mockResolvedValue(RATE_LIMITS.free.syncApiExecutionsPerMinute + 1),
|
||||
}
|
||||
vi.mocked(getRedisClient).mockReturnValue(mockRedis as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(false)
|
||||
expect(result.remaining).toBe(0)
|
||||
})
|
||||
|
||||
it('should fall back to DB when Redis fails', async () => {
|
||||
const mockRedis = {
|
||||
eval: vi.fn().mockRejectedValue(new Error('Redis connection failed')),
|
||||
}
|
||||
vi.mocked(getRedisClient).mockReturnValue(mockRedis as any)
|
||||
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 1,
|
||||
asyncApiRequests: 0,
|
||||
apiEndpointRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(db.select).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getRateLimitStatusWithSubscription', () => {
|
||||
it('should return unlimited for manual trigger type', async () => {
|
||||
const status = await rateLimiter.getRateLimitStatusWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'manual',
|
||||
false
|
||||
)
|
||||
|
||||
expect(status.used).toBe(0)
|
||||
expect(status.limit).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(status.remaining).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(status.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
it('should return sync API limits for API trigger type (DB fallback)', async () => {
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const status = await rateLimiter.getRateLimitStatusWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(status.used).toBe(0)
|
||||
expect(status.limit).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute)
|
||||
expect(status.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute)
|
||||
expect(status.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
it('should use Redis for status when available', async () => {
|
||||
const mockRedis = {
|
||||
get: vi.fn().mockResolvedValue('5'),
|
||||
}
|
||||
vi.mocked(getRedisClient).mockReturnValue(mockRedis as any)
|
||||
|
||||
const status = await rateLimiter.getRateLimitStatusWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(status.used).toBe(5)
|
||||
expect(status.limit).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute)
|
||||
expect(status.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 5)
|
||||
expect(mockRedis.get).toHaveBeenCalled()
|
||||
expect(db.select).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('resetRateLimit', () => {
|
||||
it('should delete rate limit record for user', async () => {
|
||||
vi.mocked(db.delete).mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue({}),
|
||||
} as any)
|
||||
|
||||
await rateLimiter.resetRateLimit(testUserId)
|
||||
|
||||
expect(db.delete).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,8 +1,8 @@
|
||||
import { db } from '@sim/db'
|
||||
import { userRateLimits } from '@sim/db/schema'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type Redis from 'ioredis'
|
||||
import { getRedisClient } from '@/lib/core/config/redis'
|
||||
import {
|
||||
MANUAL_EXECUTION_LIMIT,
|
||||
RATE_LIMIT_WINDOW_MS,
|
||||
@@ -10,7 +10,8 @@ import {
|
||||
type RateLimitCounterType,
|
||||
type SubscriptionPlan,
|
||||
type TriggerType,
|
||||
} from '@/services/queue/types'
|
||||
} from '@/lib/core/rate-limiter/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('RateLimiter')
|
||||
|
||||
@@ -88,6 +89,69 @@ export class RateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check rate limit using Redis (faster, single atomic operation)
|
||||
* Uses fixed window algorithm with INCR + EXPIRE
|
||||
*/
|
||||
private async checkRateLimitRedis(
|
||||
redis: Redis,
|
||||
rateLimitKey: string,
|
||||
counterType: RateLimitCounterType,
|
||||
limit: number
|
||||
): Promise<{ allowed: boolean; remaining: number; resetAt: Date }> {
|
||||
const windowMs = RATE_LIMIT_WINDOW_MS
|
||||
const windowKey = Math.floor(Date.now() / windowMs)
|
||||
const key = `ratelimit:${rateLimitKey}:${counterType}:${windowKey}`
|
||||
const ttlSeconds = Math.ceil(windowMs / 1000)
|
||||
|
||||
// Atomic increment + expire
|
||||
const count = (await redis.eval(
|
||||
'local c = redis.call("INCR", KEYS[1]) if c == 1 then redis.call("EXPIRE", KEYS[1], ARGV[1]) end return c',
|
||||
1,
|
||||
key,
|
||||
ttlSeconds
|
||||
)) as number
|
||||
|
||||
const resetAt = new Date((windowKey + 1) * windowMs)
|
||||
|
||||
if (count > limit) {
|
||||
logger.info(`Rate limit exceeded (Redis) - request ${count} > limit ${limit}`, {
|
||||
rateLimitKey,
|
||||
counterType,
|
||||
limit,
|
||||
count,
|
||||
})
|
||||
return { allowed: false, remaining: 0, resetAt }
|
||||
}
|
||||
|
||||
return { allowed: true, remaining: limit - count, resetAt }
|
||||
}
|
||||
|
||||
/**
|
||||
* Get rate limit status using Redis (read-only, doesn't increment)
|
||||
*/
|
||||
private async getRateLimitStatusRedis(
|
||||
redis: Redis,
|
||||
rateLimitKey: string,
|
||||
counterType: RateLimitCounterType,
|
||||
limit: number
|
||||
): Promise<{ used: number; limit: number; remaining: number; resetAt: Date }> {
|
||||
const windowMs = RATE_LIMIT_WINDOW_MS
|
||||
const windowKey = Math.floor(Date.now() / windowMs)
|
||||
const key = `ratelimit:${rateLimitKey}:${counterType}:${windowKey}`
|
||||
|
||||
const countStr = await redis.get(key)
|
||||
const used = countStr ? Number.parseInt(countStr, 10) : 0
|
||||
const resetAt = new Date((windowKey + 1) * windowMs)
|
||||
|
||||
return {
|
||||
used,
|
||||
limit,
|
||||
remaining: Math.max(0, limit - used),
|
||||
resetAt,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if user can execute a workflow with organization-aware rate limiting
|
||||
* Manual executions bypass rate limiting entirely
|
||||
@@ -114,6 +178,18 @@ export class RateLimiter {
|
||||
const counterType = this.getCounterType(triggerType, isAsync)
|
||||
const execLimit = this.getRateLimitForCounter(limit, counterType)
|
||||
|
||||
// Try Redis first for faster rate limiting
|
||||
const redis = getRedisClient()
|
||||
if (redis) {
|
||||
try {
|
||||
return await this.checkRateLimitRedis(redis, rateLimitKey, counterType, execLimit)
|
||||
} catch (error) {
|
||||
logger.warn('Redis rate limit check failed, falling back to DB:', { error })
|
||||
// Fall through to DB implementation
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to DB implementation
|
||||
const now = new Date()
|
||||
const windowStart = new Date(now.getTime() - RATE_LIMIT_WINDOW_MS)
|
||||
|
||||
@@ -273,21 +349,6 @@ export class RateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Legacy method - for backward compatibility
|
||||
* @deprecated Use checkRateLimitWithSubscription instead
|
||||
*/
|
||||
async checkRateLimit(
|
||||
userId: string,
|
||||
subscriptionPlan: SubscriptionPlan = 'free',
|
||||
triggerType: TriggerType = 'manual',
|
||||
isAsync = false
|
||||
): Promise<{ allowed: boolean; remaining: number; resetAt: Date }> {
|
||||
// For backward compatibility, fetch the subscription
|
||||
const subscription = await getHighestPrioritySubscription(userId)
|
||||
return this.checkRateLimitWithSubscription(userId, subscription, triggerType, isAsync)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current rate limit status with organization awareness
|
||||
* Only applies to API executions
|
||||
@@ -315,6 +376,18 @@ export class RateLimiter {
|
||||
const counterType = this.getCounterType(triggerType, isAsync)
|
||||
const execLimit = this.getRateLimitForCounter(limit, counterType)
|
||||
|
||||
// Try Redis first for faster status check
|
||||
const redis = getRedisClient()
|
||||
if (redis) {
|
||||
try {
|
||||
return await this.getRateLimitStatusRedis(redis, rateLimitKey, counterType, execLimit)
|
||||
} catch (error) {
|
||||
logger.warn('Redis rate limit status check failed, falling back to DB:', { error })
|
||||
// Fall through to DB implementation
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to DB implementation
|
||||
const now = new Date()
|
||||
const windowStart = new Date(now.getTime() - RATE_LIMIT_WINDOW_MS)
|
||||
|
||||
@@ -355,21 +428,6 @@ export class RateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Legacy method - for backward compatibility
|
||||
* @deprecated Use getRateLimitStatusWithSubscription instead
|
||||
*/
|
||||
async getRateLimitStatus(
|
||||
userId: string,
|
||||
subscriptionPlan: SubscriptionPlan = 'free',
|
||||
triggerType: TriggerType = 'manual',
|
||||
isAsync = false
|
||||
): Promise<{ used: number; limit: number; remaining: number; resetAt: Date }> {
|
||||
// For backward compatibility, fetch the subscription
|
||||
const subscription = await getHighestPrioritySubscription(userId)
|
||||
return this.getRateLimitStatusWithSubscription(userId, subscription, triggerType, isAsync)
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset rate limit for a user or organization
|
||||
*/
|
||||
@@ -3,10 +3,10 @@ import { workflow } from '@sim/db/schema'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { checkServerSideUsageLimits } from '@/lib/billing/calculations/usage-monitor'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { LoggingSession } from '@/lib/logs/execution/logging-session'
|
||||
import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
|
||||
import { RateLimiter } from '@/services/queue/RateLimiter'
|
||||
|
||||
const logger = createLogger('ExecutionPreprocessing')
|
||||
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { RateLimiter } from '@/services/queue/RateLimiter'
|
||||
import { MANUAL_EXECUTION_LIMIT, RATE_LIMITS } from '@/services/queue/types'
|
||||
|
||||
// Mock the database module
|
||||
vi.mock('@sim/db', () => ({
|
||||
db: {
|
||||
select: vi.fn(),
|
||||
insert: vi.fn(),
|
||||
update: vi.fn(),
|
||||
delete: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock drizzle-orm
|
||||
vi.mock('drizzle-orm', () => ({
|
||||
eq: vi.fn((field, value) => ({ field, value })),
|
||||
sql: vi.fn((strings, ...values) => ({ sql: strings.join('?'), values })),
|
||||
and: vi.fn((...conditions) => ({ and: conditions })),
|
||||
}))
|
||||
|
||||
// Mock getHighestPrioritySubscription
|
||||
vi.mock('@/lib/billing/core/subscription', () => ({
|
||||
getHighestPrioritySubscription: vi.fn().mockResolvedValue(null),
|
||||
}))
|
||||
|
||||
import { db } from '@sim/db'
|
||||
|
||||
describe('RateLimiter', () => {
|
||||
const rateLimiter = new RateLimiter()
|
||||
const testUserId = 'test-user-123'
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('checkRateLimit', () => {
|
||||
it('should allow unlimited requests for manual trigger type', async () => {
|
||||
const result = await rateLimiter.checkRateLimit(testUserId, 'free', 'manual', false)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
expect(db.select).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should allow first API request for sync execution', async () => {
|
||||
// Mock select to return empty array (no existing record)
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]), // No existing record
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
// Mock insert to return the expected structure
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 1,
|
||||
asyncApiRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimit(testUserId, 'free', 'api', false)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 1)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
it('should allow first API request for async execution', async () => {
|
||||
// Mock select to return empty array (no existing record)
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]), // No existing record
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
// Mock insert to return the expected structure
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 0,
|
||||
asyncApiRequests: 1,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimit(testUserId, 'free', 'api', true)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.asyncApiExecutionsPerMinute - 1)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
it('should work for all trigger types except manual', async () => {
|
||||
const triggerTypes = ['api', 'webhook', 'schedule', 'chat'] as const
|
||||
|
||||
for (const triggerType of triggerTypes) {
|
||||
// Mock select to return empty array (no existing record)
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]), // No existing record
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
// Mock insert to return the expected structure
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 1,
|
||||
asyncApiRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimit(testUserId, 'free', triggerType, false)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 1)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('getRateLimitStatus', () => {
|
||||
it('should return unlimited for manual trigger type', async () => {
|
||||
const status = await rateLimiter.getRateLimitStatus(testUserId, 'free', 'manual', false)
|
||||
|
||||
expect(status.used).toBe(0)
|
||||
expect(status.limit).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(status.remaining).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(status.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
it('should return sync API limits for API trigger type', async () => {
|
||||
const mockSelect = vi.fn().mockReturnThis()
|
||||
const mockFrom = vi.fn().mockReturnThis()
|
||||
const mockWhere = vi.fn().mockReturnThis()
|
||||
const mockLimit = vi.fn().mockResolvedValue([])
|
||||
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: mockFrom,
|
||||
where: mockWhere,
|
||||
limit: mockLimit,
|
||||
} as any)
|
||||
|
||||
const status = await rateLimiter.getRateLimitStatus(testUserId, 'free', 'api', false)
|
||||
|
||||
expect(status.used).toBe(0)
|
||||
expect(status.limit).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute)
|
||||
expect(status.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute)
|
||||
expect(status.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
})
|
||||
|
||||
describe('resetRateLimit', () => {
|
||||
it('should delete rate limit record for user', async () => {
|
||||
vi.mocked(db.delete).mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue({}),
|
||||
} as any)
|
||||
|
||||
await rateLimiter.resetRateLimit(testUserId)
|
||||
|
||||
expect(db.delete).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,7 +0,0 @@
|
||||
export { RateLimiter } from '@/services/queue/RateLimiter'
|
||||
export type {
|
||||
RateLimitConfig,
|
||||
SubscriptionPlan,
|
||||
TriggerType,
|
||||
} from '@/services/queue/types'
|
||||
export { RATE_LIMITS, RateLimitError } from '@/services/queue/types'
|
||||
Reference in New Issue
Block a user