mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-07 22:24:06 -05:00
feat(rate-limiter): token bucket algorithm (#2270)
* fix(ratelimit): make deployed chat rate limited * improvement(rate-limiter): use token bucket algo * update docs * fix * fix type * fix db rate limiter * address greptile comments
This commit is contained in:
committed by
GitHub
parent
22abf98835
commit
aea32d423f
@@ -27,14 +27,16 @@ All API responses include information about your workflow execution limits and u
|
||||
"limits": {
|
||||
"workflowExecutionRateLimit": {
|
||||
"sync": {
|
||||
"limit": 60, // Max sync workflow executions per minute
|
||||
"remaining": 58, // Remaining sync workflow executions
|
||||
"resetAt": "..." // When the window resets
|
||||
"requestsPerMinute": 60, // Sustained rate limit per minute
|
||||
"maxBurst": 120, // Maximum burst capacity
|
||||
"remaining": 118, // Current tokens available (up to maxBurst)
|
||||
"resetAt": "..." // When tokens next refill
|
||||
},
|
||||
"async": {
|
||||
"limit": 60, // Max async workflow executions per minute
|
||||
"remaining": 59, // Remaining async workflow executions
|
||||
"resetAt": "..." // When the window resets
|
||||
"requestsPerMinute": 200, // Sustained rate limit per minute
|
||||
"maxBurst": 400, // Maximum burst capacity
|
||||
"remaining": 398, // Current tokens available
|
||||
"resetAt": "..." // When tokens next refill
|
||||
}
|
||||
},
|
||||
"usage": {
|
||||
@@ -46,7 +48,7 @@ All API responses include information about your workflow execution limits and u
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** The rate limits in the response body are for workflow executions. The rate limits for calling this API endpoint are in the response headers (`X-RateLimit-*`).
|
||||
**Note:** Rate limits use a token bucket algorithm. `remaining` can exceed `requestsPerMinute` up to `maxBurst` when you haven't used your full allowance recently, allowing for burst traffic. The rate limits in the response body are for workflow executions. The rate limits for calling this API endpoint are in the response headers (`X-RateLimit-*`).
|
||||
|
||||
### Query Logs
|
||||
|
||||
@@ -108,13 +110,15 @@ Query workflow execution logs with extensive filtering options.
|
||||
"limits": {
|
||||
"workflowExecutionRateLimit": {
|
||||
"sync": {
|
||||
"limit": 60,
|
||||
"remaining": 58,
|
||||
"requestsPerMinute": 60,
|
||||
"maxBurst": 120,
|
||||
"remaining": 118,
|
||||
"resetAt": "2025-01-01T12:35:56.789Z"
|
||||
},
|
||||
"async": {
|
||||
"limit": 60,
|
||||
"remaining": 59,
|
||||
"requestsPerMinute": 200,
|
||||
"maxBurst": 400,
|
||||
"remaining": 398,
|
||||
"resetAt": "2025-01-01T12:35:56.789Z"
|
||||
}
|
||||
},
|
||||
@@ -184,13 +188,15 @@ Retrieve detailed information about a specific log entry.
|
||||
"limits": {
|
||||
"workflowExecutionRateLimit": {
|
||||
"sync": {
|
||||
"limit": 60,
|
||||
"remaining": 58,
|
||||
"requestsPerMinute": 60,
|
||||
"maxBurst": 120,
|
||||
"remaining": 118,
|
||||
"resetAt": "2025-01-01T12:35:56.789Z"
|
||||
},
|
||||
"async": {
|
||||
"limit": 60,
|
||||
"remaining": 59,
|
||||
"requestsPerMinute": 200,
|
||||
"maxBurst": 400,
|
||||
"remaining": 398,
|
||||
"resetAt": "2025-01-01T12:35:56.789Z"
|
||||
}
|
||||
},
|
||||
@@ -467,17 +473,25 @@ Failed webhook deliveries are retried with exponential backoff and jitter:
|
||||
|
||||
## Rate Limiting
|
||||
|
||||
The API implements rate limiting to ensure fair usage:
|
||||
The API uses a **token bucket algorithm** for rate limiting, providing fair usage while allowing burst traffic:
|
||||
|
||||
- **Free plan**: 10 requests per minute
|
||||
- **Pro plan**: 30 requests per minute
|
||||
- **Team plan**: 60 requests per minute
|
||||
- **Enterprise plan**: Custom limits
|
||||
| Plan | Requests/Minute | Burst Capacity |
|
||||
|------|-----------------|----------------|
|
||||
| Free | 10 | 20 |
|
||||
| Pro | 30 | 60 |
|
||||
| Team | 60 | 120 |
|
||||
| Enterprise | 120 | 240 |
|
||||
|
||||
**How it works:**
|
||||
- Tokens refill at `requestsPerMinute` rate
|
||||
- You can accumulate up to `maxBurst` tokens when idle
|
||||
- Each request consumes 1 token
|
||||
- Burst capacity allows handling traffic spikes
|
||||
|
||||
Rate limit information is included in response headers:
|
||||
- `X-RateLimit-Limit`: Maximum requests per window
|
||||
- `X-RateLimit-Remaining`: Requests remaining in current window
|
||||
- `X-RateLimit-Reset`: ISO timestamp when the window resets
|
||||
- `X-RateLimit-Limit`: Requests per minute (refill rate)
|
||||
- `X-RateLimit-Remaining`: Current tokens available
|
||||
- `X-RateLimit-Reset`: ISO timestamp when tokens next refill
|
||||
|
||||
## Example: Polling for New Logs
|
||||
|
||||
|
||||
@@ -143,8 +143,20 @@ curl -X GET -H "X-API-Key: YOUR_API_KEY" -H "Content-Type: application/json" htt
|
||||
{
|
||||
"success": true,
|
||||
"rateLimit": {
|
||||
"sync": { "isLimited": false, "limit": 10, "remaining": 10, "resetAt": "2025-09-08T22:51:55.999Z" },
|
||||
"async": { "isLimited": false, "limit": 50, "remaining": 50, "resetAt": "2025-09-08T22:51:56.155Z" },
|
||||
"sync": {
|
||||
"isLimited": false,
|
||||
"requestsPerMinute": 25,
|
||||
"maxBurst": 50,
|
||||
"remaining": 50,
|
||||
"resetAt": "2025-09-08T22:51:55.999Z"
|
||||
},
|
||||
"async": {
|
||||
"isLimited": false,
|
||||
"requestsPerMinute": 200,
|
||||
"maxBurst": 400,
|
||||
"remaining": 400,
|
||||
"resetAt": "2025-09-08T22:51:56.155Z"
|
||||
},
|
||||
"authType": "api"
|
||||
},
|
||||
"usage": {
|
||||
@@ -155,6 +167,11 @@ curl -X GET -H "X-API-Key: YOUR_API_KEY" -H "Content-Type: application/json" htt
|
||||
}
|
||||
```
|
||||
|
||||
**Rate Limit Fields:**
|
||||
- `requestsPerMinute`: Sustained rate limit (tokens refill at this rate)
|
||||
- `maxBurst`: Maximum tokens you can accumulate (burst capacity)
|
||||
- `remaining`: Current tokens available (can be up to `maxBurst`)
|
||||
|
||||
**Response Fields:**
|
||||
- `currentPeriodCost` reflects usage in the current billing period
|
||||
- `limit` is derived from individual limits (Free/Pro) or pooled organization limits (Team/Enterprise)
|
||||
|
||||
@@ -151,8 +151,8 @@ export async function POST(
|
||||
triggerType: 'chat',
|
||||
executionId,
|
||||
requestId,
|
||||
checkRateLimit: false, // Chat bypasses rate limits
|
||||
checkDeployment: true, // Chat requires deployed workflows
|
||||
checkRateLimit: true,
|
||||
checkDeployment: true,
|
||||
loggingSession,
|
||||
})
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ export async function GET(request: NextRequest) {
|
||||
}
|
||||
const authenticatedUserId = auth.userId
|
||||
|
||||
// Rate limit info (sync + async), mirroring /users/me/rate-limit
|
||||
const userSubscription = await getHighestPrioritySubscription(authenticatedUserId)
|
||||
const rateLimiter = new RateLimiter()
|
||||
const triggerType = auth.authType === 'api_key' ? 'api' : 'manual'
|
||||
@@ -37,7 +36,6 @@ export async function GET(request: NextRequest) {
|
||||
),
|
||||
])
|
||||
|
||||
// Usage summary (current period cost + limit + plan)
|
||||
const [usageCheck, effectiveCost, storageUsage, storageLimit] = await Promise.all([
|
||||
checkServerSideUsageLimits(authenticatedUserId),
|
||||
getEffectiveCurrentPeriodCost(authenticatedUserId),
|
||||
@@ -52,13 +50,15 @@ export async function GET(request: NextRequest) {
|
||||
rateLimit: {
|
||||
sync: {
|
||||
isLimited: syncStatus.remaining === 0,
|
||||
limit: syncStatus.limit,
|
||||
requestsPerMinute: syncStatus.requestsPerMinute,
|
||||
maxBurst: syncStatus.maxBurst,
|
||||
remaining: syncStatus.remaining,
|
||||
resetAt: syncStatus.resetAt,
|
||||
},
|
||||
async: {
|
||||
isLimited: asyncStatus.remaining === 0,
|
||||
limit: asyncStatus.limit,
|
||||
requestsPerMinute: asyncStatus.requestsPerMinute,
|
||||
maxBurst: asyncStatus.maxBurst,
|
||||
remaining: asyncStatus.remaining,
|
||||
resetAt: asyncStatus.resetAt,
|
||||
},
|
||||
|
||||
@@ -6,12 +6,14 @@ import { RateLimiter } from '@/lib/core/rate-limiter'
|
||||
export interface UserLimits {
|
||||
workflowExecutionRateLimit: {
|
||||
sync: {
|
||||
limit: number
|
||||
requestsPerMinute: number
|
||||
maxBurst: number
|
||||
remaining: number
|
||||
resetAt: string
|
||||
}
|
||||
async: {
|
||||
limit: number
|
||||
requestsPerMinute: number
|
||||
maxBurst: number
|
||||
remaining: number
|
||||
resetAt: string
|
||||
}
|
||||
@@ -40,12 +42,14 @@ export async function getUserLimits(userId: string): Promise<UserLimits> {
|
||||
return {
|
||||
workflowExecutionRateLimit: {
|
||||
sync: {
|
||||
limit: syncStatus.limit,
|
||||
requestsPerMinute: syncStatus.requestsPerMinute,
|
||||
maxBurst: syncStatus.maxBurst,
|
||||
remaining: syncStatus.remaining,
|
||||
resetAt: syncStatus.resetAt.toISOString(),
|
||||
},
|
||||
async: {
|
||||
limit: asyncStatus.limit,
|
||||
requestsPerMinute: asyncStatus.requestsPerMinute,
|
||||
maxBurst: asyncStatus.maxBurst,
|
||||
remaining: asyncStatus.remaining,
|
||||
resetAt: asyncStatus.resetAt.toISOString(),
|
||||
},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { authenticateV1Request } from '@/app/api/v1/auth'
|
||||
|
||||
@@ -12,6 +12,7 @@ export interface RateLimitResult {
|
||||
remaining: number
|
||||
resetAt: Date
|
||||
limit: number
|
||||
retryAfterMs?: number
|
||||
userId?: string
|
||||
error?: string
|
||||
}
|
||||
@@ -26,7 +27,7 @@ export async function checkRateLimit(
|
||||
return {
|
||||
allowed: false,
|
||||
remaining: 0,
|
||||
limit: 10, // Default to free tier limit
|
||||
limit: 10,
|
||||
resetAt: new Date(),
|
||||
error: auth.error,
|
||||
}
|
||||
@@ -35,12 +36,11 @@ export async function checkRateLimit(
|
||||
const userId = auth.userId!
|
||||
const subscription = await getHighestPrioritySubscription(userId)
|
||||
|
||||
// Use api-endpoint trigger type for external API rate limiting
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
userId,
|
||||
subscription,
|
||||
'api-endpoint',
|
||||
false // Not relevant for api-endpoint trigger type
|
||||
false
|
||||
)
|
||||
|
||||
if (!result.allowed) {
|
||||
@@ -51,7 +51,6 @@ export async function checkRateLimit(
|
||||
})
|
||||
}
|
||||
|
||||
// Get the actual rate limit for this user's plan
|
||||
const rateLimitStatus = await rateLimiter.getRateLimitStatusWithSubscription(
|
||||
userId,
|
||||
subscription,
|
||||
@@ -60,8 +59,11 @@ export async function checkRateLimit(
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
limit: rateLimitStatus.limit,
|
||||
allowed: result.allowed,
|
||||
remaining: result.remaining,
|
||||
resetAt: result.resetAt,
|
||||
limit: rateLimitStatus.requestsPerMinute,
|
||||
retryAfterMs: result.retryAfterMs,
|
||||
userId,
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -88,6 +90,10 @@ export function createRateLimitResponse(result: RateLimitResult): NextResponse {
|
||||
}
|
||||
|
||||
if (!result.allowed) {
|
||||
const retryAfterSeconds = result.retryAfterMs
|
||||
? Math.ceil(result.retryAfterMs / 1000)
|
||||
: Math.ceil((result.resetAt.getTime() - Date.now()) / 1000)
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'Rate limit exceeded',
|
||||
@@ -98,7 +104,7 @@ export function createRateLimitResponse(result: RateLimitResult): NextResponse {
|
||||
status: 429,
|
||||
headers: {
|
||||
...headers,
|
||||
'Retry-After': Math.ceil((result.resetAt.getTime() - Date.now()) / 1000).toString(),
|
||||
'Retry-After': retryAfterSeconds.toString(),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -116,12 +116,14 @@ async function buildPayload(
|
||||
|
||||
payload.data.rateLimits = {
|
||||
sync: {
|
||||
limit: syncStatus.limit,
|
||||
requestsPerMinute: syncStatus.requestsPerMinute,
|
||||
maxBurst: syncStatus.maxBurst,
|
||||
remaining: syncStatus.remaining,
|
||||
resetAt: syncStatus.resetAt.toISOString(),
|
||||
},
|
||||
async: {
|
||||
limit: asyncStatus.limit,
|
||||
requestsPerMinute: asyncStatus.requestsPerMinute,
|
||||
maxBurst: asyncStatus.maxBurst,
|
||||
remaining: asyncStatus.remaining,
|
||||
resetAt: asyncStatus.resetAt.toISOString(),
|
||||
},
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
export { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter'
|
||||
export type {
|
||||
RateLimitConfig,
|
||||
SubscriptionPlan,
|
||||
TriggerType,
|
||||
} from '@/lib/core/rate-limiter/types'
|
||||
export { RATE_LIMITS, RateLimitError } from '@/lib/core/rate-limiter/types'
|
||||
export type { RateLimitResult, RateLimitStatus } from './rate-limiter'
|
||||
export { RateLimiter } from './rate-limiter'
|
||||
export type { RateLimitStorageAdapter, TokenBucketConfig } from './storage'
|
||||
export type { RateLimitConfig, SubscriptionPlan, TriggerType } from './types'
|
||||
export { RATE_LIMITS, RateLimitError } from './types'
|
||||
|
||||
@@ -1,37 +1,24 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { RateLimiter } from '@/lib/core/rate-limiter/rate-limiter'
|
||||
import { MANUAL_EXECUTION_LIMIT, RATE_LIMITS } from '@/lib/core/rate-limiter/types'
|
||||
import { RateLimiter } from './rate-limiter'
|
||||
import type { ConsumeResult, RateLimitStorageAdapter, TokenStatus } from './storage'
|
||||
import { MANUAL_EXECUTION_LIMIT, RATE_LIMITS } from './types'
|
||||
|
||||
vi.mock('@sim/db', () => ({
|
||||
db: {
|
||||
select: vi.fn(),
|
||||
insert: vi.fn(),
|
||||
update: vi.fn(),
|
||||
delete: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('drizzle-orm', () => ({
|
||||
eq: vi.fn((field, value) => ({ field, value })),
|
||||
sql: vi.fn((strings, ...values) => ({ sql: strings.join('?'), values })),
|
||||
and: vi.fn((...conditions) => ({ and: conditions })),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/core/config/redis', () => ({
|
||||
getRedisClient: vi.fn().mockReturnValue(null),
|
||||
}))
|
||||
|
||||
import { db } from '@sim/db'
|
||||
import { getRedisClient } from '@/lib/core/config/redis'
|
||||
const createMockAdapter = (): RateLimitStorageAdapter => ({
|
||||
consumeTokens: vi.fn(),
|
||||
getTokenStatus: vi.fn(),
|
||||
resetBucket: vi.fn(),
|
||||
})
|
||||
|
||||
describe('RateLimiter', () => {
|
||||
const rateLimiter = new RateLimiter()
|
||||
const testUserId = 'test-user-123'
|
||||
const freeSubscription = { plan: 'free', referenceId: testUserId }
|
||||
let mockAdapter: RateLimitStorageAdapter
|
||||
let rateLimiter: RateLimiter
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(getRedisClient).mockReturnValue(null)
|
||||
mockAdapter = createMockAdapter()
|
||||
rateLimiter = new RateLimiter(mockAdapter)
|
||||
})
|
||||
|
||||
describe('checkRateLimitWithSubscription', () => {
|
||||
@@ -46,32 +33,16 @@ describe('RateLimiter', () => {
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
expect(db.select).not.toHaveBeenCalled()
|
||||
expect(mockAdapter.consumeTokens).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should allow first API request for sync execution (DB fallback)', async () => {
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 1,
|
||||
asyncApiRequests: 0,
|
||||
apiEndpointRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
it('should consume tokens for API requests', async () => {
|
||||
const mockResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: RATE_LIMITS.free.sync.maxTokens - 1,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
vi.mocked(mockAdapter.consumeTokens).mockResolvedValue(mockResult)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
@@ -81,90 +52,61 @@ describe('RateLimiter', () => {
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 1)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
expect(result.remaining).toBe(mockResult.tokensRemaining)
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledWith(
|
||||
`${testUserId}:sync`,
|
||||
1,
|
||||
RATE_LIMITS.free.sync
|
||||
)
|
||||
})
|
||||
|
||||
it('should allow first API request for async execution (DB fallback)', async () => {
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
it('should use async bucket for async requests', async () => {
|
||||
const mockResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: RATE_LIMITS.free.async.maxTokens - 1,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
vi.mocked(mockAdapter.consumeTokens).mockResolvedValue(mockResult)
|
||||
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 0,
|
||||
asyncApiRequests: 1,
|
||||
apiEndpointRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
await rateLimiter.checkRateLimitWithSubscription(testUserId, freeSubscription, 'api', true)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledWith(
|
||||
`${testUserId}:async`,
|
||||
1,
|
||||
RATE_LIMITS.free.async
|
||||
)
|
||||
})
|
||||
|
||||
it('should use api-endpoint bucket for api-endpoint trigger', async () => {
|
||||
const mockResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: RATE_LIMITS.free.apiEndpoint.maxTokens - 1,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
vi.mocked(mockAdapter.consumeTokens).mockResolvedValue(mockResult)
|
||||
|
||||
await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
true
|
||||
'api-endpoint',
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.asyncApiExecutionsPerMinute - 1)
|
||||
expect(result.resetAt).toBeInstanceOf(Date)
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledWith(
|
||||
`${testUserId}:api-endpoint`,
|
||||
1,
|
||||
RATE_LIMITS.free.apiEndpoint
|
||||
)
|
||||
})
|
||||
|
||||
it('should work for all trigger types except manual (DB fallback)', async () => {
|
||||
const triggerTypes = ['api', 'webhook', 'schedule', 'chat'] as const
|
||||
|
||||
for (const triggerType of triggerTypes) {
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 1,
|
||||
asyncApiRequests: 0,
|
||||
apiEndpointRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
triggerType,
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 1)
|
||||
it('should deny requests when rate limit exceeded', async () => {
|
||||
const mockResult: ConsumeResult = {
|
||||
allowed: false,
|
||||
tokensRemaining: 0,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
retryAfterMs: 30000,
|
||||
}
|
||||
})
|
||||
|
||||
it('should use Redis when available', async () => {
|
||||
const mockRedis = {
|
||||
eval: vi.fn().mockResolvedValue(1), // Lua script returns count after INCR
|
||||
}
|
||||
vi.mocked(getRedisClient).mockReturnValue(mockRedis as any)
|
||||
vi.mocked(mockAdapter.consumeTokens).mockResolvedValue(mockResult)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
@@ -173,17 +115,55 @@ describe('RateLimiter', () => {
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(result.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 1)
|
||||
expect(mockRedis.eval).toHaveBeenCalled()
|
||||
expect(db.select).not.toHaveBeenCalled()
|
||||
expect(result.allowed).toBe(false)
|
||||
expect(result.remaining).toBe(0)
|
||||
expect(result.retryAfterMs).toBe(30000)
|
||||
})
|
||||
|
||||
it('should deny requests when Redis rate limit exceeded', async () => {
|
||||
const mockRedis = {
|
||||
eval: vi.fn().mockResolvedValue(RATE_LIMITS.free.syncApiExecutionsPerMinute + 1),
|
||||
it('should use organization key for team subscriptions', async () => {
|
||||
const orgId = 'org-123'
|
||||
const teamSubscription = { plan: 'team', referenceId: orgId }
|
||||
const mockResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: RATE_LIMITS.team.sync.maxTokens - 1,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
vi.mocked(getRedisClient).mockReturnValue(mockRedis as any)
|
||||
vi.mocked(mockAdapter.consumeTokens).mockResolvedValue(mockResult)
|
||||
|
||||
await rateLimiter.checkRateLimitWithSubscription(testUserId, teamSubscription, 'api', false)
|
||||
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledWith(
|
||||
`${orgId}:sync`,
|
||||
1,
|
||||
RATE_LIMITS.team.sync
|
||||
)
|
||||
})
|
||||
|
||||
it('should use user key when team subscription referenceId matches userId', async () => {
|
||||
const directTeamSubscription = { plan: 'team', referenceId: testUserId }
|
||||
const mockResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: RATE_LIMITS.team.sync.maxTokens - 1,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
vi.mocked(mockAdapter.consumeTokens).mockResolvedValue(mockResult)
|
||||
|
||||
await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
directTeamSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledWith(
|
||||
`${testUserId}:sync`,
|
||||
1,
|
||||
RATE_LIMITS.team.sync
|
||||
)
|
||||
})
|
||||
|
||||
it('should deny on storage error (fail closed)', async () => {
|
||||
vi.mocked(mockAdapter.consumeTokens).mockRejectedValue(new Error('Storage error'))
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
@@ -196,49 +176,30 @@ describe('RateLimiter', () => {
|
||||
expect(result.remaining).toBe(0)
|
||||
})
|
||||
|
||||
it('should fall back to DB when Redis fails', async () => {
|
||||
const mockRedis = {
|
||||
eval: vi.fn().mockRejectedValue(new Error('Redis connection failed')),
|
||||
it('should work for all non-manual trigger types', async () => {
|
||||
const triggerTypes = ['api', 'webhook', 'schedule', 'chat'] as const
|
||||
const mockResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 10,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
vi.mocked(getRedisClient).mockReturnValue(mockRedis as any)
|
||||
vi.mocked(mockAdapter.consumeTokens).mockResolvedValue(mockResult)
|
||||
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
vi.mocked(db.insert).mockReturnValue({
|
||||
values: vi.fn().mockReturnValue({
|
||||
onConflictDoUpdate: vi.fn().mockReturnValue({
|
||||
returning: vi.fn().mockResolvedValue([
|
||||
{
|
||||
syncApiRequests: 1,
|
||||
asyncApiRequests: 0,
|
||||
apiEndpointRequests: 0,
|
||||
windowStart: new Date(),
|
||||
},
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(result.allowed).toBe(true)
|
||||
expect(db.select).toHaveBeenCalled()
|
||||
for (const triggerType of triggerTypes) {
|
||||
await rateLimiter.checkRateLimitWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
triggerType,
|
||||
false
|
||||
)
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalled()
|
||||
vi.mocked(mockAdapter.consumeTokens).mockClear()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('getRateLimitStatusWithSubscription', () => {
|
||||
it('should return unlimited for manual trigger type', async () => {
|
||||
it('should return unlimited status for manual trigger type', async () => {
|
||||
const status = await rateLimiter.getRateLimitStatusWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
@@ -246,39 +207,20 @@ describe('RateLimiter', () => {
|
||||
false
|
||||
)
|
||||
|
||||
expect(status.used).toBe(0)
|
||||
expect(status.limit).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(status.requestsPerMinute).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(status.maxBurst).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(status.remaining).toBe(MANUAL_EXECUTION_LIMIT)
|
||||
expect(status.resetAt).toBeInstanceOf(Date)
|
||||
expect(mockAdapter.getTokenStatus).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return sync API limits for API trigger type (DB fallback)', async () => {
|
||||
vi.mocked(db.select).mockReturnValue({
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const status = await rateLimiter.getRateLimitStatusWithSubscription(
|
||||
testUserId,
|
||||
freeSubscription,
|
||||
'api',
|
||||
false
|
||||
)
|
||||
|
||||
expect(status.used).toBe(0)
|
||||
expect(status.limit).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute)
|
||||
expect(status.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute)
|
||||
expect(status.resetAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
it('should use Redis for status when available', async () => {
|
||||
const mockRedis = {
|
||||
get: vi.fn().mockResolvedValue('5'),
|
||||
it('should return status from storage for API requests', async () => {
|
||||
const mockStatus: TokenStatus = {
|
||||
tokensAvailable: 15,
|
||||
maxTokens: RATE_LIMITS.free.sync.maxTokens,
|
||||
lastRefillAt: new Date(),
|
||||
nextRefillAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
vi.mocked(getRedisClient).mockReturnValue(mockRedis as any)
|
||||
vi.mocked(mockAdapter.getTokenStatus).mockResolvedValue(mockStatus)
|
||||
|
||||
const status = await rateLimiter.getRateLimitStatusWithSubscription(
|
||||
testUserId,
|
||||
@@ -287,23 +229,26 @@ describe('RateLimiter', () => {
|
||||
false
|
||||
)
|
||||
|
||||
expect(status.used).toBe(5)
|
||||
expect(status.limit).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute)
|
||||
expect(status.remaining).toBe(RATE_LIMITS.free.syncApiExecutionsPerMinute - 5)
|
||||
expect(mockRedis.get).toHaveBeenCalled()
|
||||
expect(db.select).not.toHaveBeenCalled()
|
||||
expect(status.remaining).toBe(15)
|
||||
expect(status.requestsPerMinute).toBe(RATE_LIMITS.free.sync.refillRate)
|
||||
expect(status.maxBurst).toBe(RATE_LIMITS.free.sync.maxTokens)
|
||||
expect(mockAdapter.getTokenStatus).toHaveBeenCalledWith(
|
||||
`${testUserId}:sync`,
|
||||
RATE_LIMITS.free.sync
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('resetRateLimit', () => {
|
||||
it('should delete rate limit record for user', async () => {
|
||||
vi.mocked(db.delete).mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue({}),
|
||||
} as any)
|
||||
it('should reset all bucket types for a user', async () => {
|
||||
vi.mocked(mockAdapter.resetBucket).mockResolvedValue()
|
||||
|
||||
await rateLimiter.resetRateLimit(testUserId)
|
||||
|
||||
expect(db.delete).toHaveBeenCalled()
|
||||
expect(mockAdapter.resetBucket).toHaveBeenCalledTimes(3)
|
||||
expect(mockAdapter.resetBucket).toHaveBeenCalledWith(`${testUserId}:sync`)
|
||||
expect(mockAdapter.resetBucket).toHaveBeenCalledWith(`${testUserId}:async`)
|
||||
expect(mockAdapter.resetBucket).toHaveBeenCalledWith(`${testUserId}:api-endpoint`)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { db } from '@sim/db'
|
||||
import { userRateLimits } from '@sim/db/schema'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import type Redis from 'ioredis'
|
||||
import { getRedisClient } from '@/lib/core/config/redis'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import {
|
||||
createStorageAdapter,
|
||||
type RateLimitStorageAdapter,
|
||||
type TokenBucketConfig,
|
||||
} from './storage'
|
||||
import {
|
||||
MANUAL_EXECUTION_LIMIT,
|
||||
RATE_LIMIT_WINDOW_MS,
|
||||
@@ -10,8 +11,7 @@ import {
|
||||
type RateLimitCounterType,
|
||||
type SubscriptionPlan,
|
||||
type TriggerType,
|
||||
} from '@/lib/core/rate-limiter/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
} from './types'
|
||||
|
||||
const logger = createLogger('RateLimiter')
|
||||
|
||||
@@ -20,420 +20,181 @@ interface SubscriptionInfo {
|
||||
referenceId: string
|
||||
}
|
||||
|
||||
export interface RateLimitResult {
|
||||
allowed: boolean
|
||||
remaining: number
|
||||
resetAt: Date
|
||||
retryAfterMs?: number
|
||||
}
|
||||
|
||||
export interface RateLimitStatus {
|
||||
requestsPerMinute: number
|
||||
maxBurst: number
|
||||
remaining: number
|
||||
resetAt: Date
|
||||
}
|
||||
|
||||
export class RateLimiter {
|
||||
/**
|
||||
* Determine the rate limit key based on subscription
|
||||
* For team/enterprise plans via organization, use the organization ID
|
||||
* For direct user subscriptions (including direct team), use the user ID
|
||||
*/
|
||||
private storage: RateLimitStorageAdapter
|
||||
|
||||
constructor(storage?: RateLimitStorageAdapter) {
|
||||
this.storage = storage ?? createStorageAdapter()
|
||||
}
|
||||
|
||||
private getRateLimitKey(userId: string, subscription: SubscriptionInfo | null): string {
|
||||
if (!subscription) {
|
||||
return userId
|
||||
}
|
||||
if (!subscription) return userId
|
||||
|
||||
const plan = subscription.plan as SubscriptionPlan
|
||||
|
||||
// Check if this is an organization subscription (referenceId !== userId)
|
||||
// If referenceId === userId, it's a direct user subscription
|
||||
if ((plan === 'team' || plan === 'enterprise') && subscription.referenceId !== userId) {
|
||||
// This is an organization subscription
|
||||
// All organization members share the same rate limit pool
|
||||
return subscription.referenceId
|
||||
}
|
||||
|
||||
// For direct user subscriptions (free/pro/team/enterprise where referenceId === userId)
|
||||
return userId
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine which counter type to use based on trigger type and async flag
|
||||
*/
|
||||
private getCounterType(triggerType: TriggerType, isAsync: boolean): RateLimitCounterType {
|
||||
if (triggerType === 'api-endpoint') {
|
||||
return 'api-endpoint'
|
||||
}
|
||||
if (triggerType === 'api-endpoint') return 'api-endpoint'
|
||||
return isAsync ? 'async' : 'sync'
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the rate limit for a specific counter type
|
||||
*/
|
||||
private getRateLimitForCounter(
|
||||
config: (typeof RATE_LIMITS)[SubscriptionPlan],
|
||||
private getBucketConfig(
|
||||
plan: SubscriptionPlan,
|
||||
counterType: RateLimitCounterType
|
||||
): number {
|
||||
): TokenBucketConfig {
|
||||
const config = RATE_LIMITS[plan]
|
||||
switch (counterType) {
|
||||
case 'api-endpoint':
|
||||
return config.apiEndpointRequestsPerMinute
|
||||
return config.apiEndpoint
|
||||
case 'async':
|
||||
return config.asyncApiExecutionsPerMinute
|
||||
return config.async
|
||||
case 'sync':
|
||||
return config.syncApiExecutionsPerMinute
|
||||
return config.sync
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current count from a rate limit record for a specific counter type
|
||||
*/
|
||||
private getCountFromRecord(
|
||||
record: { syncApiRequests: number; asyncApiRequests: number; apiEndpointRequests: number },
|
||||
counterType: RateLimitCounterType
|
||||
): number {
|
||||
switch (counterType) {
|
||||
case 'api-endpoint':
|
||||
return record.apiEndpointRequests
|
||||
case 'async':
|
||||
return record.asyncApiRequests
|
||||
case 'sync':
|
||||
return record.syncApiRequests
|
||||
}
|
||||
private buildStorageKey(rateLimitKey: string, counterType: RateLimitCounterType): string {
|
||||
return `${rateLimitKey}:${counterType}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Check rate limit using Redis (faster, single atomic operation)
|
||||
* Uses fixed window algorithm with INCR + EXPIRE
|
||||
*/
|
||||
private async checkRateLimitRedis(
|
||||
redis: Redis,
|
||||
rateLimitKey: string,
|
||||
counterType: RateLimitCounterType,
|
||||
limit: number
|
||||
): Promise<{ allowed: boolean; remaining: number; resetAt: Date }> {
|
||||
const windowMs = RATE_LIMIT_WINDOW_MS
|
||||
const windowKey = Math.floor(Date.now() / windowMs)
|
||||
const key = `ratelimit:${rateLimitKey}:${counterType}:${windowKey}`
|
||||
const ttlSeconds = Math.ceil(windowMs / 1000)
|
||||
|
||||
// Atomic increment + expire
|
||||
const count = (await redis.eval(
|
||||
'local c = redis.call("INCR", KEYS[1]) if c == 1 then redis.call("EXPIRE", KEYS[1], ARGV[1]) end return c',
|
||||
1,
|
||||
key,
|
||||
ttlSeconds
|
||||
)) as number
|
||||
|
||||
const resetAt = new Date((windowKey + 1) * windowMs)
|
||||
|
||||
if (count > limit) {
|
||||
logger.info(`Rate limit exceeded (Redis) - request ${count} > limit ${limit}`, {
|
||||
rateLimitKey,
|
||||
counterType,
|
||||
limit,
|
||||
count,
|
||||
})
|
||||
return { allowed: false, remaining: 0, resetAt }
|
||||
}
|
||||
|
||||
return { allowed: true, remaining: limit - count, resetAt }
|
||||
}
|
||||
|
||||
/**
|
||||
* Get rate limit status using Redis (read-only, doesn't increment)
|
||||
*/
|
||||
private async getRateLimitStatusRedis(
|
||||
redis: Redis,
|
||||
rateLimitKey: string,
|
||||
counterType: RateLimitCounterType,
|
||||
limit: number
|
||||
): Promise<{ used: number; limit: number; remaining: number; resetAt: Date }> {
|
||||
const windowMs = RATE_LIMIT_WINDOW_MS
|
||||
const windowKey = Math.floor(Date.now() / windowMs)
|
||||
const key = `ratelimit:${rateLimitKey}:${counterType}:${windowKey}`
|
||||
|
||||
const countStr = await redis.get(key)
|
||||
const used = countStr ? Number.parseInt(countStr, 10) : 0
|
||||
const resetAt = new Date((windowKey + 1) * windowMs)
|
||||
|
||||
private createUnlimitedResult(): RateLimitResult {
|
||||
return {
|
||||
used,
|
||||
limit,
|
||||
remaining: Math.max(0, limit - used),
|
||||
resetAt,
|
||||
allowed: true,
|
||||
remaining: MANUAL_EXECUTION_LIMIT,
|
||||
resetAt: new Date(Date.now() + RATE_LIMIT_WINDOW_MS),
|
||||
}
|
||||
}
|
||||
|
||||
private createUnlimitedStatus(config: TokenBucketConfig): RateLimitStatus {
|
||||
return {
|
||||
requestsPerMinute: MANUAL_EXECUTION_LIMIT,
|
||||
maxBurst: MANUAL_EXECUTION_LIMIT,
|
||||
remaining: MANUAL_EXECUTION_LIMIT,
|
||||
resetAt: new Date(Date.now() + config.refillIntervalMs),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if user can execute a workflow with organization-aware rate limiting
|
||||
* Manual executions bypass rate limiting entirely
|
||||
*/
|
||||
async checkRateLimitWithSubscription(
|
||||
userId: string,
|
||||
subscription: SubscriptionInfo | null,
|
||||
triggerType: TriggerType = 'manual',
|
||||
isAsync = false
|
||||
): Promise<{ allowed: boolean; remaining: number; resetAt: Date }> {
|
||||
): Promise<RateLimitResult> {
|
||||
try {
|
||||
if (triggerType === 'manual') {
|
||||
return {
|
||||
allowed: true,
|
||||
remaining: MANUAL_EXECUTION_LIMIT,
|
||||
resetAt: new Date(Date.now() + RATE_LIMIT_WINDOW_MS),
|
||||
}
|
||||
return this.createUnlimitedResult()
|
||||
}
|
||||
|
||||
const subscriptionPlan = (subscription?.plan || 'free') as SubscriptionPlan
|
||||
const plan = (subscription?.plan || 'free') as SubscriptionPlan
|
||||
const rateLimitKey = this.getRateLimitKey(userId, subscription)
|
||||
const limit = RATE_LIMITS[subscriptionPlan]
|
||||
|
||||
const counterType = this.getCounterType(triggerType, isAsync)
|
||||
const execLimit = this.getRateLimitForCounter(limit, counterType)
|
||||
const config = this.getBucketConfig(plan, counterType)
|
||||
const storageKey = this.buildStorageKey(rateLimitKey, counterType)
|
||||
|
||||
// Try Redis first for faster rate limiting
|
||||
const redis = getRedisClient()
|
||||
if (redis) {
|
||||
try {
|
||||
return await this.checkRateLimitRedis(redis, rateLimitKey, counterType, execLimit)
|
||||
} catch (error) {
|
||||
logger.warn('Redis rate limit check failed, falling back to DB:', { error })
|
||||
// Fall through to DB implementation
|
||||
}
|
||||
}
|
||||
const result = await this.storage.consumeTokens(storageKey, 1, config)
|
||||
|
||||
// Fallback to DB implementation
|
||||
const now = new Date()
|
||||
const windowStart = new Date(now.getTime() - RATE_LIMIT_WINDOW_MS)
|
||||
|
||||
// Get or create rate limit record using the rate limit key
|
||||
const [rateLimitRecord] = await db
|
||||
.select()
|
||||
.from(userRateLimits)
|
||||
.where(eq(userRateLimits.referenceId, rateLimitKey))
|
||||
.limit(1)
|
||||
|
||||
if (!rateLimitRecord || new Date(rateLimitRecord.windowStart) < windowStart) {
|
||||
// Window expired - reset window with this request as the first one
|
||||
const result = await db
|
||||
.insert(userRateLimits)
|
||||
.values({
|
||||
referenceId: rateLimitKey,
|
||||
syncApiRequests: counterType === 'sync' ? 1 : 0,
|
||||
asyncApiRequests: counterType === 'async' ? 1 : 0,
|
||||
apiEndpointRequests: counterType === 'api-endpoint' ? 1 : 0,
|
||||
windowStart: now,
|
||||
lastRequestAt: now,
|
||||
isRateLimited: false,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: userRateLimits.referenceId,
|
||||
set: {
|
||||
// Only reset if window is still expired (avoid race condition)
|
||||
syncApiRequests: sql`CASE WHEN ${userRateLimits.windowStart} < ${windowStart.toISOString()} THEN ${counterType === 'sync' ? 1 : 0} ELSE ${userRateLimits.syncApiRequests} + ${counterType === 'sync' ? 1 : 0} END`,
|
||||
asyncApiRequests: sql`CASE WHEN ${userRateLimits.windowStart} < ${windowStart.toISOString()} THEN ${counterType === 'async' ? 1 : 0} ELSE ${userRateLimits.asyncApiRequests} + ${counterType === 'async' ? 1 : 0} END`,
|
||||
apiEndpointRequests: sql`CASE WHEN ${userRateLimits.windowStart} < ${windowStart.toISOString()} THEN ${counterType === 'api-endpoint' ? 1 : 0} ELSE ${userRateLimits.apiEndpointRequests} + ${counterType === 'api-endpoint' ? 1 : 0} END`,
|
||||
windowStart: sql`CASE WHEN ${userRateLimits.windowStart} < ${windowStart.toISOString()} THEN ${now.toISOString()} ELSE ${userRateLimits.windowStart} END`,
|
||||
lastRequestAt: now,
|
||||
isRateLimited: false,
|
||||
rateLimitResetAt: null,
|
||||
},
|
||||
})
|
||||
.returning({
|
||||
syncApiRequests: userRateLimits.syncApiRequests,
|
||||
asyncApiRequests: userRateLimits.asyncApiRequests,
|
||||
apiEndpointRequests: userRateLimits.apiEndpointRequests,
|
||||
windowStart: userRateLimits.windowStart,
|
||||
})
|
||||
|
||||
const insertedRecord = result[0]
|
||||
const actualCount = this.getCountFromRecord(insertedRecord, counterType)
|
||||
|
||||
// Check if we exceeded the limit
|
||||
if (actualCount > execLimit) {
|
||||
const resetAt = new Date(
|
||||
new Date(insertedRecord.windowStart).getTime() + RATE_LIMIT_WINDOW_MS
|
||||
)
|
||||
|
||||
await db
|
||||
.update(userRateLimits)
|
||||
.set({
|
||||
isRateLimited: true,
|
||||
rateLimitResetAt: resetAt,
|
||||
})
|
||||
.where(eq(userRateLimits.referenceId, rateLimitKey))
|
||||
|
||||
logger.info(
|
||||
`Rate limit exceeded - request ${actualCount} > limit ${execLimit} for ${
|
||||
rateLimitKey === userId ? `user ${userId}` : `organization ${rateLimitKey}`
|
||||
}`,
|
||||
{
|
||||
execLimit,
|
||||
isAsync,
|
||||
actualCount,
|
||||
rateLimitKey,
|
||||
plan: subscriptionPlan,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
allowed: false,
|
||||
remaining: 0,
|
||||
resetAt,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
allowed: true,
|
||||
remaining: execLimit - actualCount,
|
||||
resetAt: new Date(new Date(insertedRecord.windowStart).getTime() + RATE_LIMIT_WINDOW_MS),
|
||||
}
|
||||
}
|
||||
|
||||
// Simple atomic increment - increment first, then check if over limit
|
||||
const updateResult = await db
|
||||
.update(userRateLimits)
|
||||
.set({
|
||||
...(counterType === 'api-endpoint'
|
||||
? { apiEndpointRequests: sql`${userRateLimits.apiEndpointRequests} + 1` }
|
||||
: counterType === 'async'
|
||||
? { asyncApiRequests: sql`${userRateLimits.asyncApiRequests} + 1` }
|
||||
: { syncApiRequests: sql`${userRateLimits.syncApiRequests} + 1` }),
|
||||
lastRequestAt: now,
|
||||
if (!result.allowed) {
|
||||
logger.info('Rate limit exceeded', {
|
||||
rateLimitKey,
|
||||
counterType,
|
||||
plan,
|
||||
tokensRemaining: result.tokensRemaining,
|
||||
})
|
||||
.where(eq(userRateLimits.referenceId, rateLimitKey))
|
||||
.returning({
|
||||
asyncApiRequests: userRateLimits.asyncApiRequests,
|
||||
syncApiRequests: userRateLimits.syncApiRequests,
|
||||
apiEndpointRequests: userRateLimits.apiEndpointRequests,
|
||||
})
|
||||
|
||||
const updatedRecord = updateResult[0]
|
||||
const actualNewRequests = this.getCountFromRecord(updatedRecord, counterType)
|
||||
|
||||
// Check if we exceeded the limit AFTER the atomic increment
|
||||
if (actualNewRequests > execLimit) {
|
||||
const resetAt = new Date(
|
||||
new Date(rateLimitRecord.windowStart).getTime() + RATE_LIMIT_WINDOW_MS
|
||||
)
|
||||
|
||||
logger.info(
|
||||
`Rate limit exceeded - request ${actualNewRequests} > limit ${execLimit} for ${
|
||||
rateLimitKey === userId ? `user ${userId}` : `organization ${rateLimitKey}`
|
||||
}`,
|
||||
{
|
||||
execLimit,
|
||||
isAsync,
|
||||
actualNewRequests,
|
||||
rateLimitKey,
|
||||
plan: subscriptionPlan,
|
||||
}
|
||||
)
|
||||
|
||||
// Update rate limited status
|
||||
await db
|
||||
.update(userRateLimits)
|
||||
.set({
|
||||
isRateLimited: true,
|
||||
rateLimitResetAt: resetAt,
|
||||
})
|
||||
.where(eq(userRateLimits.referenceId, rateLimitKey))
|
||||
|
||||
return {
|
||||
allowed: false,
|
||||
remaining: 0,
|
||||
resetAt,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
allowed: true,
|
||||
remaining: execLimit - actualNewRequests,
|
||||
resetAt: new Date(new Date(rateLimitRecord.windowStart).getTime() + RATE_LIMIT_WINDOW_MS),
|
||||
allowed: result.allowed,
|
||||
remaining: result.tokensRemaining,
|
||||
resetAt: result.resetAt,
|
||||
retryAfterMs: result.retryAfterMs,
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error checking rate limit:', error)
|
||||
// Allow execution on error to avoid blocking users
|
||||
logger.error('Rate limit storage error - failing closed (denying request)', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
userId,
|
||||
triggerType,
|
||||
isAsync,
|
||||
})
|
||||
return {
|
||||
allowed: true,
|
||||
allowed: false,
|
||||
remaining: 0,
|
||||
resetAt: new Date(Date.now() + RATE_LIMIT_WINDOW_MS),
|
||||
retryAfterMs: RATE_LIMIT_WINDOW_MS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async getRateLimitStatusWithSubscription(
|
||||
userId: string,
|
||||
subscription: SubscriptionInfo | null,
|
||||
triggerType: TriggerType = 'manual',
|
||||
isAsync = false
|
||||
): Promise<RateLimitStatus> {
|
||||
try {
|
||||
const plan = (subscription?.plan || 'free') as SubscriptionPlan
|
||||
const counterType = this.getCounterType(triggerType, isAsync)
|
||||
const config = this.getBucketConfig(plan, counterType)
|
||||
|
||||
if (triggerType === 'manual') {
|
||||
return this.createUnlimitedStatus(config)
|
||||
}
|
||||
|
||||
const rateLimitKey = this.getRateLimitKey(userId, subscription)
|
||||
const storageKey = this.buildStorageKey(rateLimitKey, counterType)
|
||||
|
||||
const status = await this.storage.getTokenStatus(storageKey, config)
|
||||
|
||||
return {
|
||||
requestsPerMinute: config.refillRate,
|
||||
maxBurst: config.maxTokens,
|
||||
remaining: Math.floor(status.tokensAvailable),
|
||||
resetAt: status.nextRefillAt,
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error getting rate limit status - returning default config', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
userId,
|
||||
triggerType,
|
||||
isAsync,
|
||||
})
|
||||
const plan = (subscription?.plan || 'free') as SubscriptionPlan
|
||||
const counterType = this.getCounterType(triggerType, isAsync)
|
||||
const config = this.getBucketConfig(plan, counterType)
|
||||
return {
|
||||
requestsPerMinute: config.refillRate,
|
||||
maxBurst: config.maxTokens,
|
||||
remaining: 0,
|
||||
resetAt: new Date(Date.now() + RATE_LIMIT_WINDOW_MS),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current rate limit status with organization awareness
|
||||
* Only applies to API executions
|
||||
*/
|
||||
async getRateLimitStatusWithSubscription(
|
||||
userId: string,
|
||||
subscription: SubscriptionInfo | null,
|
||||
triggerType: TriggerType = 'manual',
|
||||
isAsync = false
|
||||
): Promise<{ used: number; limit: number; remaining: number; resetAt: Date }> {
|
||||
try {
|
||||
if (triggerType === 'manual') {
|
||||
return {
|
||||
used: 0,
|
||||
limit: MANUAL_EXECUTION_LIMIT,
|
||||
remaining: MANUAL_EXECUTION_LIMIT,
|
||||
resetAt: new Date(Date.now() + RATE_LIMIT_WINDOW_MS),
|
||||
}
|
||||
}
|
||||
|
||||
const subscriptionPlan = (subscription?.plan || 'free') as SubscriptionPlan
|
||||
const rateLimitKey = this.getRateLimitKey(userId, subscription)
|
||||
const limit = RATE_LIMITS[subscriptionPlan]
|
||||
|
||||
const counterType = this.getCounterType(triggerType, isAsync)
|
||||
const execLimit = this.getRateLimitForCounter(limit, counterType)
|
||||
|
||||
// Try Redis first for faster status check
|
||||
const redis = getRedisClient()
|
||||
if (redis) {
|
||||
try {
|
||||
return await this.getRateLimitStatusRedis(redis, rateLimitKey, counterType, execLimit)
|
||||
} catch (error) {
|
||||
logger.warn('Redis rate limit status check failed, falling back to DB:', { error })
|
||||
// Fall through to DB implementation
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to DB implementation
|
||||
const now = new Date()
|
||||
const windowStart = new Date(now.getTime() - RATE_LIMIT_WINDOW_MS)
|
||||
|
||||
const [rateLimitRecord] = await db
|
||||
.select()
|
||||
.from(userRateLimits)
|
||||
.where(eq(userRateLimits.referenceId, rateLimitKey))
|
||||
.limit(1)
|
||||
|
||||
if (!rateLimitRecord || new Date(rateLimitRecord.windowStart) < windowStart) {
|
||||
return {
|
||||
used: 0,
|
||||
limit: execLimit,
|
||||
remaining: execLimit,
|
||||
resetAt: new Date(now.getTime() + RATE_LIMIT_WINDOW_MS),
|
||||
}
|
||||
}
|
||||
|
||||
const used = this.getCountFromRecord(rateLimitRecord, counterType)
|
||||
return {
|
||||
used,
|
||||
limit: execLimit,
|
||||
remaining: Math.max(0, execLimit - used),
|
||||
resetAt: new Date(new Date(rateLimitRecord.windowStart).getTime() + RATE_LIMIT_WINDOW_MS),
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error getting rate limit status:', error)
|
||||
const execLimit = isAsync
|
||||
? RATE_LIMITS[(subscription?.plan || 'free') as SubscriptionPlan]
|
||||
.asyncApiExecutionsPerMinute
|
||||
: RATE_LIMITS[(subscription?.plan || 'free') as SubscriptionPlan].syncApiExecutionsPerMinute
|
||||
return {
|
||||
used: 0,
|
||||
limit: execLimit,
|
||||
remaining: execLimit,
|
||||
resetAt: new Date(Date.now() + RATE_LIMIT_WINDOW_MS),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset rate limit for a user or organization
|
||||
*/
|
||||
async resetRateLimit(rateLimitKey: string): Promise<void> {
|
||||
try {
|
||||
await db.delete(userRateLimits).where(eq(userRateLimits.referenceId, rateLimitKey))
|
||||
await Promise.all([
|
||||
this.storage.resetBucket(`${rateLimitKey}:sync`),
|
||||
this.storage.resetBucket(`${rateLimitKey}:async`),
|
||||
this.storage.resetBucket(`${rateLimitKey}:api-endpoint`),
|
||||
])
|
||||
logger.info(`Reset rate limit for ${rateLimitKey}`)
|
||||
} catch (error) {
|
||||
logger.error('Error resetting rate limit:', error)
|
||||
|
||||
25
apps/sim/lib/core/rate-limiter/storage/adapter.ts
Normal file
25
apps/sim/lib/core/rate-limiter/storage/adapter.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
export interface TokenBucketConfig {
|
||||
maxTokens: number
|
||||
refillRate: number
|
||||
refillIntervalMs: number
|
||||
}
|
||||
|
||||
export interface ConsumeResult {
|
||||
allowed: boolean
|
||||
tokensRemaining: number
|
||||
resetAt: Date
|
||||
retryAfterMs?: number
|
||||
}
|
||||
|
||||
export interface TokenStatus {
|
||||
tokensAvailable: number
|
||||
maxTokens: number
|
||||
lastRefillAt: Date
|
||||
nextRefillAt: Date
|
||||
}
|
||||
|
||||
export interface RateLimitStorageAdapter {
|
||||
consumeTokens(key: string, tokens: number, config: TokenBucketConfig): Promise<ConsumeResult>
|
||||
getTokenStatus(key: string, config: TokenBucketConfig): Promise<TokenStatus>
|
||||
resetBucket(key: string): Promise<void>
|
||||
}
|
||||
136
apps/sim/lib/core/rate-limiter/storage/db-token-bucket.ts
Normal file
136
apps/sim/lib/core/rate-limiter/storage/db-token-bucket.ts
Normal file
@@ -0,0 +1,136 @@
|
||||
import { db } from '@sim/db'
|
||||
import { rateLimitBucket } from '@sim/db/schema'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import type {
|
||||
ConsumeResult,
|
||||
RateLimitStorageAdapter,
|
||||
TokenBucketConfig,
|
||||
TokenStatus,
|
||||
} from './adapter'
|
||||
|
||||
export class DbTokenBucket implements RateLimitStorageAdapter {
|
||||
async consumeTokens(
|
||||
key: string,
|
||||
requestedTokens: number,
|
||||
config: TokenBucketConfig
|
||||
): Promise<ConsumeResult> {
|
||||
const now = new Date()
|
||||
const nowMs = now.getTime()
|
||||
const nowIso = now.toISOString()
|
||||
|
||||
const result = await db
|
||||
.insert(rateLimitBucket)
|
||||
.values({
|
||||
key,
|
||||
tokens: (config.maxTokens - requestedTokens).toString(),
|
||||
lastRefillAt: now,
|
||||
updatedAt: now,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: rateLimitBucket.key,
|
||||
set: {
|
||||
tokens: sql`
|
||||
CASE
|
||||
WHEN (
|
||||
LEAST(
|
||||
${config.maxTokens}::numeric,
|
||||
${rateLimitBucket.tokens}::numeric + (
|
||||
FLOOR(
|
||||
EXTRACT(EPOCH FROM (${nowIso}::timestamp - ${rateLimitBucket.lastRefillAt})) * 1000
|
||||
/ ${config.refillIntervalMs}
|
||||
) * ${config.refillRate}
|
||||
)::numeric
|
||||
)
|
||||
) >= ${requestedTokens}::numeric
|
||||
THEN LEAST(
|
||||
${config.maxTokens}::numeric,
|
||||
${rateLimitBucket.tokens}::numeric + (
|
||||
FLOOR(
|
||||
EXTRACT(EPOCH FROM (${nowIso}::timestamp - ${rateLimitBucket.lastRefillAt})) * 1000
|
||||
/ ${config.refillIntervalMs}
|
||||
) * ${config.refillRate}
|
||||
)::numeric
|
||||
) - ${requestedTokens}::numeric
|
||||
ELSE ${rateLimitBucket.tokens}::numeric
|
||||
END
|
||||
`,
|
||||
lastRefillAt: sql`
|
||||
CASE
|
||||
WHEN FLOOR(
|
||||
EXTRACT(EPOCH FROM (${nowIso}::timestamp - ${rateLimitBucket.lastRefillAt})) * 1000
|
||||
/ ${config.refillIntervalMs}
|
||||
) > 0
|
||||
THEN ${rateLimitBucket.lastRefillAt} + (
|
||||
FLOOR(
|
||||
EXTRACT(EPOCH FROM (${nowIso}::timestamp - ${rateLimitBucket.lastRefillAt})) * 1000
|
||||
/ ${config.refillIntervalMs}
|
||||
) * ${config.refillIntervalMs} * INTERVAL '1 millisecond'
|
||||
)
|
||||
ELSE ${rateLimitBucket.lastRefillAt}
|
||||
END
|
||||
`,
|
||||
updatedAt: now,
|
||||
},
|
||||
})
|
||||
.returning({
|
||||
tokens: rateLimitBucket.tokens,
|
||||
lastRefillAt: rateLimitBucket.lastRefillAt,
|
||||
})
|
||||
|
||||
const record = result[0]
|
||||
const tokens = Number.parseFloat(record.tokens)
|
||||
const lastRefillMs = record.lastRefillAt.getTime()
|
||||
const nextRefillAt = new Date(lastRefillMs + config.refillIntervalMs)
|
||||
|
||||
const allowed = tokens >= 0
|
||||
|
||||
return {
|
||||
allowed,
|
||||
tokensRemaining: Math.max(0, tokens),
|
||||
resetAt: nextRefillAt,
|
||||
retryAfterMs: allowed ? undefined : Math.max(0, nextRefillAt.getTime() - nowMs),
|
||||
}
|
||||
}
|
||||
|
||||
async getTokenStatus(key: string, config: TokenBucketConfig): Promise<TokenStatus> {
|
||||
const now = new Date()
|
||||
|
||||
const [record] = await db
|
||||
.select({
|
||||
tokens: rateLimitBucket.tokens,
|
||||
lastRefillAt: rateLimitBucket.lastRefillAt,
|
||||
})
|
||||
.from(rateLimitBucket)
|
||||
.where(eq(rateLimitBucket.key, key))
|
||||
.limit(1)
|
||||
|
||||
if (!record) {
|
||||
return {
|
||||
tokensAvailable: config.maxTokens,
|
||||
maxTokens: config.maxTokens,
|
||||
lastRefillAt: now,
|
||||
nextRefillAt: new Date(now.getTime() + config.refillIntervalMs),
|
||||
}
|
||||
}
|
||||
|
||||
const tokens = Number.parseFloat(record.tokens)
|
||||
const elapsed = now.getTime() - record.lastRefillAt.getTime()
|
||||
const intervalsElapsed = Math.floor(elapsed / config.refillIntervalMs)
|
||||
const refillAmount = intervalsElapsed * config.refillRate
|
||||
const tokensAvailable = Math.min(config.maxTokens, tokens + refillAmount)
|
||||
const lastRefillAt = new Date(
|
||||
record.lastRefillAt.getTime() + intervalsElapsed * config.refillIntervalMs
|
||||
)
|
||||
|
||||
return {
|
||||
tokensAvailable,
|
||||
maxTokens: config.maxTokens,
|
||||
lastRefillAt,
|
||||
nextRefillAt: new Date(lastRefillAt.getTime() + config.refillIntervalMs),
|
||||
}
|
||||
}
|
||||
|
||||
async resetBucket(key: string): Promise<void> {
|
||||
await db.delete(rateLimitBucket).where(eq(rateLimitBucket.key, key))
|
||||
}
|
||||
}
|
||||
34
apps/sim/lib/core/rate-limiter/storage/factory.ts
Normal file
34
apps/sim/lib/core/rate-limiter/storage/factory.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { getRedisClient } from '@/lib/core/config/redis'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { RateLimitStorageAdapter } from './adapter'
|
||||
import { DbTokenBucket } from './db-token-bucket'
|
||||
import { RedisTokenBucket } from './redis-token-bucket'
|
||||
|
||||
const logger = createLogger('RateLimitStorage')
|
||||
|
||||
let cachedAdapter: RateLimitStorageAdapter | null = null
|
||||
|
||||
export function createStorageAdapter(): RateLimitStorageAdapter {
|
||||
if (cachedAdapter) {
|
||||
return cachedAdapter
|
||||
}
|
||||
|
||||
const redis = getRedisClient()
|
||||
if (redis) {
|
||||
logger.info('Using Redis for rate limiting')
|
||||
cachedAdapter = new RedisTokenBucket(redis)
|
||||
} else {
|
||||
logger.info('Using PostgreSQL for rate limiting')
|
||||
cachedAdapter = new DbTokenBucket()
|
||||
}
|
||||
|
||||
return cachedAdapter
|
||||
}
|
||||
|
||||
export function resetStorageAdapter(): void {
|
||||
cachedAdapter = null
|
||||
}
|
||||
|
||||
export function setStorageAdapter(adapter: RateLimitStorageAdapter): void {
|
||||
cachedAdapter = adapter
|
||||
}
|
||||
9
apps/sim/lib/core/rate-limiter/storage/index.ts
Normal file
9
apps/sim/lib/core/rate-limiter/storage/index.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
export type {
|
||||
ConsumeResult,
|
||||
RateLimitStorageAdapter,
|
||||
TokenBucketConfig,
|
||||
TokenStatus,
|
||||
} from './adapter'
|
||||
export { DbTokenBucket } from './db-token-bucket'
|
||||
export { createStorageAdapter, resetStorageAdapter, setStorageAdapter } from './factory'
|
||||
export { RedisTokenBucket } from './redis-token-bucket'
|
||||
135
apps/sim/lib/core/rate-limiter/storage/redis-token-bucket.ts
Normal file
135
apps/sim/lib/core/rate-limiter/storage/redis-token-bucket.ts
Normal file
@@ -0,0 +1,135 @@
|
||||
import type Redis from 'ioredis'
|
||||
import type {
|
||||
ConsumeResult,
|
||||
RateLimitStorageAdapter,
|
||||
TokenBucketConfig,
|
||||
TokenStatus,
|
||||
} from './adapter'
|
||||
|
||||
const CONSUME_SCRIPT = `
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local requested = tonumber(ARGV[2])
|
||||
local maxTokens = tonumber(ARGV[3])
|
||||
local refillRate = tonumber(ARGV[4])
|
||||
local refillIntervalMs = tonumber(ARGV[5])
|
||||
local ttl = tonumber(ARGV[6])
|
||||
|
||||
local bucket = redis.call('HMGET', key, 'tokens', 'lastRefillAt')
|
||||
local tokens = tonumber(bucket[1])
|
||||
local lastRefillAt = tonumber(bucket[2])
|
||||
|
||||
if tokens == nil then
|
||||
tokens = maxTokens
|
||||
lastRefillAt = now
|
||||
end
|
||||
|
||||
local elapsed = now - lastRefillAt
|
||||
local intervalsElapsed = math.floor(elapsed / refillIntervalMs)
|
||||
if intervalsElapsed > 0 then
|
||||
tokens = math.min(maxTokens, tokens + (intervalsElapsed * refillRate))
|
||||
lastRefillAt = lastRefillAt + (intervalsElapsed * refillIntervalMs)
|
||||
end
|
||||
|
||||
local allowed = 0
|
||||
if tokens >= requested then
|
||||
tokens = tokens - requested
|
||||
allowed = 1
|
||||
end
|
||||
|
||||
redis.call('HSET', key, 'tokens', tokens, 'lastRefillAt', lastRefillAt)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
|
||||
local nextRefillAt = lastRefillAt + refillIntervalMs
|
||||
|
||||
return {allowed, tokens, lastRefillAt, nextRefillAt}
|
||||
`
|
||||
|
||||
const STATUS_SCRIPT = `
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local maxTokens = tonumber(ARGV[2])
|
||||
local refillRate = tonumber(ARGV[3])
|
||||
local refillIntervalMs = tonumber(ARGV[4])
|
||||
|
||||
local bucket = redis.call('HMGET', key, 'tokens', 'lastRefillAt')
|
||||
local tokens = tonumber(bucket[1])
|
||||
local lastRefillAt = tonumber(bucket[2])
|
||||
|
||||
if tokens == nil then
|
||||
tokens = maxTokens
|
||||
lastRefillAt = now
|
||||
end
|
||||
|
||||
local elapsed = now - lastRefillAt
|
||||
local intervalsElapsed = math.floor(elapsed / refillIntervalMs)
|
||||
if intervalsElapsed > 0 then
|
||||
tokens = math.min(maxTokens, tokens + (intervalsElapsed * refillRate))
|
||||
lastRefillAt = lastRefillAt + (intervalsElapsed * refillIntervalMs)
|
||||
end
|
||||
|
||||
local nextRefillAt = lastRefillAt + refillIntervalMs
|
||||
|
||||
return {tokens, maxTokens, lastRefillAt, nextRefillAt}
|
||||
`
|
||||
|
||||
export class RedisTokenBucket implements RateLimitStorageAdapter {
|
||||
constructor(private redis: Redis) {}
|
||||
|
||||
async consumeTokens(
|
||||
key: string,
|
||||
tokens: number,
|
||||
config: TokenBucketConfig
|
||||
): Promise<ConsumeResult> {
|
||||
const now = Date.now()
|
||||
const ttl = Math.ceil((config.refillIntervalMs * 2) / 1000)
|
||||
|
||||
const result = (await this.redis.eval(
|
||||
CONSUME_SCRIPT,
|
||||
1,
|
||||
`ratelimit:tb:${key}`,
|
||||
now,
|
||||
tokens,
|
||||
config.maxTokens,
|
||||
config.refillRate,
|
||||
config.refillIntervalMs,
|
||||
ttl
|
||||
)) as [number, number, number, number]
|
||||
|
||||
const [allowed, remaining, , nextRefill] = result
|
||||
|
||||
return {
|
||||
allowed: allowed === 1,
|
||||
tokensRemaining: remaining,
|
||||
resetAt: new Date(nextRefill),
|
||||
retryAfterMs: allowed === 1 ? undefined : Math.max(0, nextRefill - now),
|
||||
}
|
||||
}
|
||||
|
||||
async getTokenStatus(key: string, config: TokenBucketConfig): Promise<TokenStatus> {
|
||||
const now = Date.now()
|
||||
|
||||
const result = (await this.redis.eval(
|
||||
STATUS_SCRIPT,
|
||||
1,
|
||||
`ratelimit:tb:${key}`,
|
||||
now,
|
||||
config.maxTokens,
|
||||
config.refillRate,
|
||||
config.refillIntervalMs
|
||||
)) as [number, number, number, number]
|
||||
|
||||
const [tokensAvailable, maxTokens, lastRefillAt, nextRefillAt] = result
|
||||
|
||||
return {
|
||||
tokensAvailable,
|
||||
maxTokens,
|
||||
lastRefillAt: new Date(lastRefillAt),
|
||||
nextRefillAt: new Date(nextRefillAt),
|
||||
}
|
||||
}
|
||||
|
||||
async resetBucket(key: string): Promise<void> {
|
||||
await this.redis.del(`ratelimit:tb:${key}`)
|
||||
}
|
||||
}
|
||||
@@ -1,56 +1,53 @@
|
||||
import type { userRateLimits } from '@sim/db/schema'
|
||||
import type { InferSelectModel } from 'drizzle-orm'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import type { TokenBucketConfig } from './storage'
|
||||
|
||||
// Database types
|
||||
export type UserRateLimit = InferSelectModel<typeof userRateLimits>
|
||||
|
||||
// Trigger types for rate limiting
|
||||
export type TriggerType = 'api' | 'webhook' | 'schedule' | 'manual' | 'chat' | 'api-endpoint'
|
||||
|
||||
// Rate limit counter types - which counter to increment in the database
|
||||
export type RateLimitCounterType = 'sync' | 'async' | 'api-endpoint'
|
||||
|
||||
// Subscription plan types
|
||||
export type SubscriptionPlan = 'free' | 'pro' | 'team' | 'enterprise'
|
||||
|
||||
// Rate limit configuration (applies to all non-manual trigger types: api, webhook, schedule, chat, api-endpoint)
|
||||
export interface RateLimitConfig {
|
||||
syncApiExecutionsPerMinute: number
|
||||
asyncApiExecutionsPerMinute: number
|
||||
apiEndpointRequestsPerMinute: number // For external API endpoints like /api/v1/logs
|
||||
sync: TokenBucketConfig
|
||||
async: TokenBucketConfig
|
||||
apiEndpoint: TokenBucketConfig
|
||||
}
|
||||
|
||||
// Rate limit window duration in milliseconds
|
||||
export const RATE_LIMIT_WINDOW_MS = Number.parseInt(env.RATE_LIMIT_WINDOW_MS) || 60000
|
||||
|
||||
// Manual execution bypass value (effectively unlimited)
|
||||
export const MANUAL_EXECUTION_LIMIT = Number.parseInt(env.MANUAL_EXECUTION_LIMIT) || 999999
|
||||
|
||||
function createBucketConfig(ratePerMinute: number, burstMultiplier = 2): TokenBucketConfig {
|
||||
return {
|
||||
maxTokens: ratePerMinute * burstMultiplier,
|
||||
refillRate: ratePerMinute,
|
||||
refillIntervalMs: RATE_LIMIT_WINDOW_MS,
|
||||
}
|
||||
}
|
||||
|
||||
export const RATE_LIMITS: Record<SubscriptionPlan, RateLimitConfig> = {
|
||||
free: {
|
||||
syncApiExecutionsPerMinute: Number.parseInt(env.RATE_LIMIT_FREE_SYNC) || 10,
|
||||
asyncApiExecutionsPerMinute: Number.parseInt(env.RATE_LIMIT_FREE_ASYNC) || 50,
|
||||
apiEndpointRequestsPerMinute: 10,
|
||||
sync: createBucketConfig(Number.parseInt(env.RATE_LIMIT_FREE_SYNC) || 10),
|
||||
async: createBucketConfig(Number.parseInt(env.RATE_LIMIT_FREE_ASYNC) || 50),
|
||||
apiEndpoint: createBucketConfig(10),
|
||||
},
|
||||
pro: {
|
||||
syncApiExecutionsPerMinute: Number.parseInt(env.RATE_LIMIT_PRO_SYNC) || 25,
|
||||
asyncApiExecutionsPerMinute: Number.parseInt(env.RATE_LIMIT_PRO_ASYNC) || 200,
|
||||
apiEndpointRequestsPerMinute: 30,
|
||||
sync: createBucketConfig(Number.parseInt(env.RATE_LIMIT_PRO_SYNC) || 25),
|
||||
async: createBucketConfig(Number.parseInt(env.RATE_LIMIT_PRO_ASYNC) || 200),
|
||||
apiEndpoint: createBucketConfig(30),
|
||||
},
|
||||
team: {
|
||||
syncApiExecutionsPerMinute: Number.parseInt(env.RATE_LIMIT_TEAM_SYNC) || 75,
|
||||
asyncApiExecutionsPerMinute: Number.parseInt(env.RATE_LIMIT_TEAM_ASYNC) || 500,
|
||||
apiEndpointRequestsPerMinute: 60,
|
||||
sync: createBucketConfig(Number.parseInt(env.RATE_LIMIT_TEAM_SYNC) || 75),
|
||||
async: createBucketConfig(Number.parseInt(env.RATE_LIMIT_TEAM_ASYNC) || 500),
|
||||
apiEndpoint: createBucketConfig(60),
|
||||
},
|
||||
enterprise: {
|
||||
syncApiExecutionsPerMinute: Number.parseInt(env.RATE_LIMIT_ENTERPRISE_SYNC) || 150,
|
||||
asyncApiExecutionsPerMinute: Number.parseInt(env.RATE_LIMIT_ENTERPRISE_ASYNC) || 1000,
|
||||
apiEndpointRequestsPerMinute: 120,
|
||||
sync: createBucketConfig(Number.parseInt(env.RATE_LIMIT_ENTERPRISE_SYNC) || 150),
|
||||
async: createBucketConfig(Number.parseInt(env.RATE_LIMIT_ENTERPRISE_ASYNC) || 1000),
|
||||
apiEndpoint: createBucketConfig(120),
|
||||
},
|
||||
}
|
||||
|
||||
// Custom error for rate limits
|
||||
export class RateLimitError extends Error {
|
||||
statusCode: number
|
||||
constructor(message: string, statusCode = 429) {
|
||||
|
||||
8
packages/db/migrations/0119_far_lethal_legion.sql
Normal file
8
packages/db/migrations/0119_far_lethal_legion.sql
Normal file
@@ -0,0 +1,8 @@
|
||||
CREATE TABLE "rate_limit_bucket" (
|
||||
"key" text PRIMARY KEY NOT NULL,
|
||||
"tokens" numeric NOT NULL,
|
||||
"last_refill_at" timestamp NOT NULL,
|
||||
"updated_at" timestamp DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
DROP TABLE "user_rate_limits" CASCADE;
|
||||
7759
packages/db/migrations/meta/0119_snapshot.json
Normal file
7759
packages/db/migrations/meta/0119_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -827,6 +827,13 @@
|
||||
"when": 1765231535125,
|
||||
"tag": "0118_tiresome_landau",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 119,
|
||||
"version": "7",
|
||||
"when": 1765271011445,
|
||||
"tag": "0119_far_lethal_legion",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -718,15 +718,11 @@ export const subscription = pgTable(
|
||||
})
|
||||
)
|
||||
|
||||
export const userRateLimits = pgTable('user_rate_limits', {
|
||||
referenceId: text('reference_id').primaryKey(), // Can be userId or organizationId for pooling
|
||||
syncApiRequests: integer('sync_api_requests').notNull().default(0), // Sync API requests counter
|
||||
asyncApiRequests: integer('async_api_requests').notNull().default(0), // Async API requests counter
|
||||
apiEndpointRequests: integer('api_endpoint_requests').notNull().default(0), // External API endpoint requests counter
|
||||
windowStart: timestamp('window_start').notNull().defaultNow(),
|
||||
lastRequestAt: timestamp('last_request_at').notNull().defaultNow(),
|
||||
isRateLimited: boolean('is_rate_limited').notNull().default(false),
|
||||
rateLimitResetAt: timestamp('rate_limit_reset_at'),
|
||||
export const rateLimitBucket = pgTable('rate_limit_bucket', {
|
||||
key: text('key').primaryKey(),
|
||||
tokens: decimal('tokens').notNull(),
|
||||
lastRefillAt: timestamp('last_refill_at').notNull(),
|
||||
updatedAt: timestamp('updated_at').notNull().defaultNow(),
|
||||
})
|
||||
|
||||
export const chat = pgTable(
|
||||
|
||||
Reference in New Issue
Block a user