mirror of
https://github.com/simstudioai/sim.git
synced 2026-03-15 03:00:33 -04:00
feat(hosted key): Add exa hosted key (#3221)
* feat(hosted keys): Implement serper hosted key * Handle required fields correctly for hosted keys * Add rate limiting (3 tries, exponential backoff) * Add custom pricing, switch to exa as first hosted key * Add telemetry * Consolidate byok type definitions * Add warning comment if default calculation is used * Record usage to user stats table * Fix unit tests, use cost property * Include more metadata in cost output * Fix disabled tests * Fix spacing * Fix lint * Move knowledge cost restructuring away from generic block handler * Migrate knowledge unit tests * Lint * Fix broken tests * Add user based hosted key throttling * Refactor hosted key handling. Add optimistic handling of throttling for custom throttle rules. * Remove research as hosted key. Recommend BYOK if throtttling occurs * Make adding api keys adjustable via env vars * Remove vestigial fields from research * Make billing actor id required for throttling * Switch to round robin for api key distribution * Add helper method for adding hosted key cost * Strip leading double underscores to avoid breaking change * Lint fix * Remove falsy check in favor for explicit null check * Add more detailed metrics for different throttling types * Fix _costDollars field * Handle hosted agent tool calls * Fail loudly if cost field isn't found * Remove any type * Fix type error * Fix lint * Fix usage log double logging data * Fix test --------- Co-authored-by: Theodore Li <teddy@zenobiapay.com>
This commit is contained in:
@@ -13,7 +13,7 @@ import { getUserEntityPermissions, getWorkspaceById } from '@/lib/workspaces/per
|
||||
|
||||
const logger = createLogger('WorkspaceBYOKKeysAPI')
|
||||
|
||||
const VALID_PROVIDERS = ['openai', 'anthropic', 'google', 'mistral'] as const
|
||||
const VALID_PROVIDERS = ['openai', 'anthropic', 'google', 'mistral', 'exa'] as const
|
||||
|
||||
const UpsertKeySchema = z.object({
|
||||
providerId: z.enum(VALID_PROVIDERS),
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
buildCanonicalIndex,
|
||||
evaluateSubBlockCondition,
|
||||
isSubBlockFeatureEnabled,
|
||||
isSubBlockHiddenByHostedKey,
|
||||
isSubBlockVisibleForMode,
|
||||
} from '@/lib/workflows/subblocks/visibility'
|
||||
import type { BlockConfig, SubBlockConfig, SubBlockType } from '@/blocks/types'
|
||||
@@ -108,6 +109,9 @@ export function useEditorSubblockLayout(
|
||||
// Check required feature if specified - declarative feature gating
|
||||
if (!isSubBlockFeatureEnabled(block)) return false
|
||||
|
||||
// Hide tool API key fields when hosted
|
||||
if (isSubBlockHiddenByHostedKey(block)) return false
|
||||
|
||||
// Special handling for trigger-config type (legacy trigger configuration UI)
|
||||
if (block.type === ('trigger-config' as SubBlockType)) {
|
||||
const isPureTriggerBlock = config?.triggers?.enabled && config.category === 'triggers'
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
evaluateSubBlockCondition,
|
||||
hasAdvancedValues,
|
||||
isSubBlockFeatureEnabled,
|
||||
isSubBlockHiddenByHostedKey,
|
||||
isSubBlockVisibleForMode,
|
||||
resolveDependencyValue,
|
||||
} from '@/lib/workflows/subblocks/visibility'
|
||||
@@ -977,6 +978,7 @@ export const WorkflowBlock = memo(function WorkflowBlock({
|
||||
if (block.hidden) return false
|
||||
if (block.hideFromPreview) return false
|
||||
if (!isSubBlockFeatureEnabled(block)) return false
|
||||
if (isSubBlockHiddenByHostedKey(block)) return false
|
||||
|
||||
const isPureTriggerBlock = config?.triggers?.enabled && config.category === 'triggers'
|
||||
|
||||
|
||||
@@ -13,15 +13,15 @@ import {
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
} from '@/components/emcn'
|
||||
import { AnthropicIcon, GeminiIcon, MistralIcon, OpenAIIcon } from '@/components/icons'
|
||||
import { AnthropicIcon, ExaAIIcon, GeminiIcon, MistralIcon, OpenAIIcon } from '@/components/icons'
|
||||
import { Skeleton } from '@/components/ui'
|
||||
import {
|
||||
type BYOKKey,
|
||||
type BYOKProviderId,
|
||||
useBYOKKeys,
|
||||
useDeleteBYOKKey,
|
||||
useUpsertBYOKKey,
|
||||
} from '@/hooks/queries/byok-keys'
|
||||
import type { BYOKProviderId } from '@/tools/types'
|
||||
|
||||
const logger = createLogger('BYOKSettings')
|
||||
|
||||
@@ -60,6 +60,13 @@ const PROVIDERS: {
|
||||
description: 'LLM calls and Knowledge Base OCR',
|
||||
placeholder: 'Enter your API key',
|
||||
},
|
||||
{
|
||||
id: 'exa',
|
||||
name: 'Exa',
|
||||
icon: ExaAIIcon,
|
||||
description: 'AI-powered search and research',
|
||||
placeholder: 'Enter your Exa API key',
|
||||
},
|
||||
]
|
||||
|
||||
function BYOKKeySkeleton() {
|
||||
|
||||
@@ -309,7 +309,7 @@ export const ExaBlock: BlockConfig<ExaResponse> = {
|
||||
value: () => 'exa-research',
|
||||
condition: { field: 'operation', value: 'exa_research' },
|
||||
},
|
||||
// API Key (common)
|
||||
// API Key — hidden when hosted for operations with hosted key support
|
||||
{
|
||||
id: 'apiKey',
|
||||
title: 'API Key',
|
||||
@@ -317,6 +317,18 @@ export const ExaBlock: BlockConfig<ExaResponse> = {
|
||||
placeholder: 'Enter your Exa API key',
|
||||
password: true,
|
||||
required: true,
|
||||
hideWhenHosted: true,
|
||||
condition: { field: 'operation', value: 'exa_research', not: true },
|
||||
},
|
||||
// API Key — always visible for research (no hosted key support)
|
||||
{
|
||||
id: 'apiKey',
|
||||
title: 'API Key',
|
||||
type: 'short-input',
|
||||
placeholder: 'Enter your Exa API key',
|
||||
password: true,
|
||||
required: true,
|
||||
condition: { field: 'operation', value: 'exa_research' },
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
|
||||
@@ -253,6 +253,7 @@ export interface SubBlockConfig {
|
||||
hidden?: boolean
|
||||
hideFromPreview?: boolean // Hide this subblock from the workflow block preview
|
||||
requiresFeature?: string // Environment variable name that must be truthy for this subblock to be visible
|
||||
hideWhenHosted?: boolean // Hide this subblock when running on hosted sim
|
||||
description?: string
|
||||
tooltip?: string // Tooltip text displayed via info icon next to the title
|
||||
value?: (params: Record<string, any>) => string
|
||||
|
||||
@@ -147,219 +147,4 @@ describe('GenericBlockHandler', () => {
|
||||
'Block execution of Some Custom Tool failed with no error message'
|
||||
)
|
||||
})
|
||||
|
||||
describe('Knowledge block cost tracking', () => {
|
||||
beforeEach(() => {
|
||||
// Set up knowledge block mock
|
||||
mockBlock = {
|
||||
...mockBlock,
|
||||
config: { tool: 'knowledge_search', params: {} },
|
||||
}
|
||||
|
||||
mockTool = {
|
||||
...mockTool,
|
||||
id: 'knowledge_search',
|
||||
name: 'Knowledge Search',
|
||||
}
|
||||
|
||||
mockGetTool.mockImplementation((toolId) => {
|
||||
if (toolId === 'knowledge_search') {
|
||||
return mockTool
|
||||
}
|
||||
return undefined
|
||||
})
|
||||
})
|
||||
|
||||
it.concurrent(
|
||||
'should extract and restructure cost information from knowledge tools',
|
||||
async () => {
|
||||
const inputs = { query: 'test query' }
|
||||
const mockToolResponse = {
|
||||
success: true,
|
||||
output: {
|
||||
results: [],
|
||||
query: 'test query',
|
||||
totalResults: 0,
|
||||
cost: {
|
||||
input: 0.00001042,
|
||||
output: 0,
|
||||
total: 0.00001042,
|
||||
tokens: {
|
||||
input: 521,
|
||||
output: 0,
|
||||
total: 521,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
pricing: {
|
||||
input: 0.02,
|
||||
output: 0,
|
||||
updatedAt: '2025-07-10',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockExecuteTool.mockResolvedValue(mockToolResponse)
|
||||
|
||||
const result = await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
// Verify cost information is restructured correctly for enhanced logging
|
||||
expect(result).toEqual({
|
||||
results: [],
|
||||
query: 'test query',
|
||||
totalResults: 0,
|
||||
cost: {
|
||||
input: 0.00001042,
|
||||
output: 0,
|
||||
total: 0.00001042,
|
||||
},
|
||||
tokens: {
|
||||
input: 521,
|
||||
output: 0,
|
||||
total: 521,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
it.concurrent('should handle knowledge_upload_chunk cost information', async () => {
|
||||
// Update to upload_chunk tool
|
||||
mockBlock.config.tool = 'knowledge_upload_chunk'
|
||||
mockTool.id = 'knowledge_upload_chunk'
|
||||
mockTool.name = 'Knowledge Upload Chunk'
|
||||
|
||||
mockGetTool.mockImplementation((toolId) => {
|
||||
if (toolId === 'knowledge_upload_chunk') {
|
||||
return mockTool
|
||||
}
|
||||
return undefined
|
||||
})
|
||||
|
||||
const inputs = { content: 'test content' }
|
||||
const mockToolResponse = {
|
||||
success: true,
|
||||
output: {
|
||||
data: {
|
||||
id: 'chunk-123',
|
||||
content: 'test content',
|
||||
chunkIndex: 0,
|
||||
},
|
||||
message: 'Successfully uploaded chunk',
|
||||
documentId: 'doc-123',
|
||||
cost: {
|
||||
input: 0.00000521,
|
||||
output: 0,
|
||||
total: 0.00000521,
|
||||
tokens: {
|
||||
input: 260,
|
||||
output: 0,
|
||||
total: 260,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
pricing: {
|
||||
input: 0.02,
|
||||
output: 0,
|
||||
updatedAt: '2025-07-10',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockExecuteTool.mockResolvedValue(mockToolResponse)
|
||||
|
||||
const result = await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
// Verify cost information is restructured correctly
|
||||
expect(result).toEqual({
|
||||
data: {
|
||||
id: 'chunk-123',
|
||||
content: 'test content',
|
||||
chunkIndex: 0,
|
||||
},
|
||||
message: 'Successfully uploaded chunk',
|
||||
documentId: 'doc-123',
|
||||
cost: {
|
||||
input: 0.00000521,
|
||||
output: 0,
|
||||
total: 0.00000521,
|
||||
},
|
||||
tokens: {
|
||||
input: 260,
|
||||
output: 0,
|
||||
total: 260,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass through output unchanged for knowledge tools without cost info', async () => {
|
||||
const inputs = { query: 'test query' }
|
||||
const mockToolResponse = {
|
||||
success: true,
|
||||
output: {
|
||||
results: [],
|
||||
query: 'test query',
|
||||
totalResults: 0,
|
||||
// No cost information
|
||||
},
|
||||
}
|
||||
|
||||
mockExecuteTool.mockResolvedValue(mockToolResponse)
|
||||
|
||||
const result = await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
// Should return original output without cost transformation
|
||||
expect(result).toEqual({
|
||||
results: [],
|
||||
query: 'test query',
|
||||
totalResults: 0,
|
||||
})
|
||||
})
|
||||
|
||||
it.concurrent(
|
||||
'should process cost info for all tools (universal cost extraction)',
|
||||
async () => {
|
||||
mockBlock.config.tool = 'some_other_tool'
|
||||
mockTool.id = 'some_other_tool'
|
||||
|
||||
mockGetTool.mockImplementation((toolId) => {
|
||||
if (toolId === 'some_other_tool') {
|
||||
return mockTool
|
||||
}
|
||||
return undefined
|
||||
})
|
||||
|
||||
const inputs = { param: 'value' }
|
||||
const mockToolResponse = {
|
||||
success: true,
|
||||
output: {
|
||||
result: 'success',
|
||||
cost: {
|
||||
input: 0.001,
|
||||
output: 0.002,
|
||||
total: 0.003,
|
||||
tokens: { input: 100, output: 50, total: 150 },
|
||||
model: 'some-model',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockExecuteTool.mockResolvedValue(mockToolResponse)
|
||||
|
||||
const result = await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
expect(result).toEqual({
|
||||
result: 'success',
|
||||
cost: {
|
||||
input: 0.001,
|
||||
output: 0.002,
|
||||
total: 0.003,
|
||||
},
|
||||
tokens: { input: 100, output: 50, total: 150 },
|
||||
model: 'some-model',
|
||||
})
|
||||
}
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -98,27 +98,7 @@ export class GenericBlockHandler implements BlockHandler {
|
||||
throw error
|
||||
}
|
||||
|
||||
const output = result.output
|
||||
let cost = null
|
||||
|
||||
if (output?.cost) {
|
||||
cost = output.cost
|
||||
}
|
||||
|
||||
if (cost) {
|
||||
return {
|
||||
...output,
|
||||
cost: {
|
||||
input: cost.input,
|
||||
output: cost.output,
|
||||
total: cost.total,
|
||||
},
|
||||
tokens: cost.tokens,
|
||||
model: cost.model,
|
||||
}
|
||||
}
|
||||
|
||||
return output
|
||||
return result.output
|
||||
} catch (error: any) {
|
||||
if (!error.message || error.message === 'undefined (undefined)') {
|
||||
let errorMessage = `Block execution of ${tool?.name || block.config.tool} failed`
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { API_ENDPOINTS } from '@/stores/constants'
|
||||
import type { BYOKProviderId } from '@/tools/types'
|
||||
|
||||
const logger = createLogger('BYOKKeysQueries')
|
||||
|
||||
export type BYOKProviderId = 'openai' | 'anthropic' | 'google' | 'mistral'
|
||||
|
||||
export interface BYOKKey {
|
||||
id: string
|
||||
providerId: BYOKProviderId
|
||||
|
||||
@@ -7,11 +7,10 @@ import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { decryptSecret } from '@/lib/core/security/encryption'
|
||||
import { getHostedModels } from '@/providers/models'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
import type { BYOKProviderId } from '@/tools/types'
|
||||
|
||||
const logger = createLogger('BYOKKeys')
|
||||
|
||||
export type BYOKProviderId = 'openai' | 'anthropic' | 'google' | 'mistral'
|
||||
|
||||
export interface BYOKKeyResult {
|
||||
apiKey: string
|
||||
isBYOK: true
|
||||
|
||||
@@ -22,12 +22,13 @@ export type UsageLogSource = 'workflow' | 'wand' | 'copilot' | 'mcp_copilot'
|
||||
export interface ModelUsageMetadata {
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
toolCost?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Metadata for 'fixed' category charges (currently empty, extensible)
|
||||
* Metadata for 'fixed' category charges (e.g., tool cost breakdown)
|
||||
*/
|
||||
export type FixedUsageMetadata = Record<string, never>
|
||||
export type FixedUsageMetadata = Record<string, unknown>
|
||||
|
||||
/**
|
||||
* Union type for all metadata types
|
||||
@@ -44,6 +45,7 @@ export interface LogModelUsageParams {
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
cost: number
|
||||
toolCost?: number
|
||||
workspaceId?: string
|
||||
workflowId?: string
|
||||
executionId?: string
|
||||
@@ -60,6 +62,8 @@ export interface LogFixedUsageParams {
|
||||
workspaceId?: string
|
||||
workflowId?: string
|
||||
executionId?: string
|
||||
/** Optional metadata (e.g., tool cost breakdown from API) */
|
||||
metadata?: FixedUsageMetadata
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -74,6 +78,7 @@ export async function logModelUsage(params: LogModelUsageParams): Promise<void>
|
||||
const metadata: ModelUsageMetadata = {
|
||||
inputTokens: params.inputTokens,
|
||||
outputTokens: params.outputTokens,
|
||||
...(params.toolCost != null && params.toolCost > 0 && { toolCost: params.toolCost }),
|
||||
}
|
||||
|
||||
await db.insert(usageLog).values({
|
||||
@@ -119,7 +124,7 @@ export async function logFixedUsage(params: LogFixedUsageParams): Promise<void>
|
||||
category: 'fixed',
|
||||
source: params.source,
|
||||
description: params.description,
|
||||
metadata: null,
|
||||
metadata: params.metadata ?? null,
|
||||
cost: params.cost.toString(),
|
||||
workspaceId: params.workspaceId ?? null,
|
||||
workflowId: params.workflowId ?? null,
|
||||
@@ -155,6 +160,7 @@ export interface LogWorkflowUsageBatchParams {
|
||||
{
|
||||
total: number
|
||||
tokens: { input: number; output: number }
|
||||
toolCost?: number
|
||||
}
|
||||
>
|
||||
}
|
||||
@@ -207,6 +213,8 @@ export async function logWorkflowUsageBatch(params: LogWorkflowUsageBatchParams)
|
||||
metadata: {
|
||||
inputTokens: modelData.tokens.input,
|
||||
outputTokens: modelData.tokens.output,
|
||||
...(modelData.toolCost != null &&
|
||||
modelData.toolCost > 0 && { toolCost: modelData.toolCost }),
|
||||
},
|
||||
cost: modelData.total.toString(),
|
||||
workspaceId: params.workspaceId ?? null,
|
||||
|
||||
@@ -6,19 +6,18 @@ import type {
|
||||
ToolCallResult,
|
||||
ToolCallState,
|
||||
} from '@/lib/copilot/orchestrator/types'
|
||||
import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { getEffectiveDecryptedEnv } from '@/lib/environment/utils'
|
||||
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
|
||||
import { resolveEnvVarReferences } from '@/executor/utils/reference-validation'
|
||||
import { executeTool } from '@/tools'
|
||||
import type { ToolConfig } from '@/tools/types'
|
||||
import { resolveToolId } from '@/tools/utils'
|
||||
|
||||
export async function executeIntegrationToolDirect(
|
||||
toolCall: ToolCallState,
|
||||
toolConfig: {
|
||||
oauth?: { required?: boolean; provider?: string }
|
||||
params?: { apiKey?: { required?: boolean } }
|
||||
},
|
||||
toolConfig: ToolConfig,
|
||||
context: ExecutionContext
|
||||
): Promise<ToolCallResult> {
|
||||
const { userId, workflowId } = context
|
||||
@@ -74,7 +73,8 @@ export async function executeIntegrationToolDirect(
|
||||
executionParams.accessToken = accessToken
|
||||
}
|
||||
|
||||
if (toolConfig.params?.apiKey?.required && !executionParams.apiKey) {
|
||||
const hasHostedKeySupport = isHosted && !!toolConfig.hosting
|
||||
if (toolConfig.params?.apiKey?.required && !executionParams.apiKey && !hasHostedKeySupport) {
|
||||
return {
|
||||
success: false,
|
||||
error: `API key not provided for ${toolName}. Use {{YOUR_API_KEY_ENV_VAR}} to reference your environment variable.`,
|
||||
@@ -83,6 +83,7 @@ export async function executeIntegrationToolDirect(
|
||||
|
||||
executionParams._context = {
|
||||
workflowId,
|
||||
workspaceId,
|
||||
userId,
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,521 @@
|
||||
import { loggerMock } from '@sim/testing'
|
||||
import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
|
||||
import type {
|
||||
ConsumeResult,
|
||||
RateLimitStorageAdapter,
|
||||
TokenStatus,
|
||||
} from '@/lib/core/rate-limiter/storage'
|
||||
import { HostedKeyRateLimiter } from './hosted-key-rate-limiter'
|
||||
import type { CustomRateLimit, PerRequestRateLimit } from './types'
|
||||
|
||||
vi.mock('@sim/logger', () => loggerMock)
|
||||
|
||||
interface MockAdapter {
|
||||
consumeTokens: Mock
|
||||
getTokenStatus: Mock
|
||||
resetBucket: Mock
|
||||
}
|
||||
|
||||
const createMockAdapter = (): MockAdapter => ({
|
||||
consumeTokens: vi.fn(),
|
||||
getTokenStatus: vi.fn(),
|
||||
resetBucket: vi.fn(),
|
||||
})
|
||||
|
||||
describe('HostedKeyRateLimiter', () => {
|
||||
const testProvider = 'exa'
|
||||
const envKeyPrefix = 'EXA_API_KEY'
|
||||
let mockAdapter: MockAdapter
|
||||
let rateLimiter: HostedKeyRateLimiter
|
||||
let originalEnv: NodeJS.ProcessEnv
|
||||
|
||||
const perRequestRateLimit: PerRequestRateLimit = {
|
||||
mode: 'per_request',
|
||||
requestsPerMinute: 10,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAdapter = createMockAdapter()
|
||||
rateLimiter = new HostedKeyRateLimiter(mockAdapter as RateLimitStorageAdapter)
|
||||
|
||||
originalEnv = { ...process.env }
|
||||
process.env.EXA_API_KEY_COUNT = '3'
|
||||
process.env.EXA_API_KEY_1 = 'test-key-1'
|
||||
process.env.EXA_API_KEY_2 = 'test-key-2'
|
||||
process.env.EXA_API_KEY_3 = 'test-key-3'
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv
|
||||
})
|
||||
|
||||
describe('acquireKey', () => {
|
||||
it('should return error when no keys are configured', async () => {
|
||||
const allowedResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 9,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(allowedResult)
|
||||
|
||||
process.env.EXA_API_KEY_COUNT = undefined
|
||||
process.env.EXA_API_KEY_1 = undefined
|
||||
process.env.EXA_API_KEY_2 = undefined
|
||||
process.env.EXA_API_KEY_3 = undefined
|
||||
|
||||
const result = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-1'
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.error).toContain('No hosted keys configured')
|
||||
})
|
||||
|
||||
it('should rate limit billing actor when they exceed their limit', async () => {
|
||||
const rateLimitedResult: ConsumeResult = {
|
||||
allowed: false,
|
||||
tokensRemaining: 0,
|
||||
resetAt: new Date(Date.now() + 30000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(rateLimitedResult)
|
||||
|
||||
const result = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-123'
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.billingActorRateLimited).toBe(true)
|
||||
expect(result.retryAfterMs).toBeDefined()
|
||||
expect(result.error).toContain('Rate limit exceeded')
|
||||
})
|
||||
|
||||
it('should allow billing actor within their rate limit', async () => {
|
||||
const allowedResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 9,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(allowedResult)
|
||||
|
||||
const result = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-123'
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.billingActorRateLimited).toBeUndefined()
|
||||
expect(result.key).toBe('test-key-1')
|
||||
})
|
||||
|
||||
it('should distribute requests across keys round-robin style', async () => {
|
||||
const allowedResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 9,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(allowedResult)
|
||||
|
||||
const r1 = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-1'
|
||||
)
|
||||
const r2 = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-2'
|
||||
)
|
||||
const r3 = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-3'
|
||||
)
|
||||
const r4 = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-4'
|
||||
)
|
||||
|
||||
expect(r1.keyIndex).toBe(0)
|
||||
expect(r2.keyIndex).toBe(1)
|
||||
expect(r3.keyIndex).toBe(2)
|
||||
expect(r4.keyIndex).toBe(0) // Wraps back
|
||||
})
|
||||
|
||||
it('should handle partial key availability', async () => {
|
||||
const allowedResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 9,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(allowedResult)
|
||||
|
||||
process.env.EXA_API_KEY_2 = undefined
|
||||
|
||||
const result = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-1'
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.key).toBe('test-key-1')
|
||||
expect(result.envVarName).toBe('EXA_API_KEY_1')
|
||||
|
||||
const r2 = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
perRequestRateLimit,
|
||||
'workspace-2'
|
||||
)
|
||||
expect(r2.keyIndex).toBe(2) // Skips missing key 1
|
||||
expect(r2.envVarName).toBe('EXA_API_KEY_3')
|
||||
})
|
||||
})
|
||||
|
||||
describe('acquireKey with custom rate limit', () => {
|
||||
const customRateLimit: CustomRateLimit = {
|
||||
mode: 'custom',
|
||||
requestsPerMinute: 5,
|
||||
dimensions: [
|
||||
{
|
||||
name: 'tokens',
|
||||
limitPerMinute: 1000,
|
||||
extractUsage: (_params, response) => (response.tokenCount as number) ?? 0,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
it('should enforce requestsPerMinute for custom mode', async () => {
|
||||
const rateLimitedResult: ConsumeResult = {
|
||||
allowed: false,
|
||||
tokensRemaining: 0,
|
||||
resetAt: new Date(Date.now() + 30000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(rateLimitedResult)
|
||||
|
||||
const result = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
customRateLimit,
|
||||
'workspace-1'
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.billingActorRateLimited).toBe(true)
|
||||
expect(result.error).toContain('Rate limit exceeded')
|
||||
})
|
||||
|
||||
it('should allow request when actor request limit and dimensions have budget', async () => {
|
||||
const allowedConsume: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 4,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(allowedConsume)
|
||||
|
||||
const budgetAvailable: TokenStatus = {
|
||||
tokensAvailable: 500,
|
||||
maxTokens: 2000,
|
||||
lastRefillAt: new Date(),
|
||||
nextRefillAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.getTokenStatus.mockResolvedValue(budgetAvailable)
|
||||
|
||||
const result = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
customRateLimit,
|
||||
'workspace-1'
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.key).toBe('test-key-1')
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledTimes(1)
|
||||
expect(mockAdapter.getTokenStatus).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should block request when a dimension is depleted', async () => {
|
||||
const allowedConsume: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 4,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(allowedConsume)
|
||||
|
||||
const depleted: TokenStatus = {
|
||||
tokensAvailable: 0,
|
||||
maxTokens: 2000,
|
||||
lastRefillAt: new Date(),
|
||||
nextRefillAt: new Date(Date.now() + 45000),
|
||||
}
|
||||
mockAdapter.getTokenStatus.mockResolvedValue(depleted)
|
||||
|
||||
const result = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
customRateLimit,
|
||||
'workspace-1'
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.billingActorRateLimited).toBe(true)
|
||||
expect(result.error).toContain('tokens')
|
||||
})
|
||||
|
||||
it('should pre-check all dimensions and block on first depleted one', async () => {
|
||||
const multiDimensionConfig: CustomRateLimit = {
|
||||
mode: 'custom',
|
||||
requestsPerMinute: 10,
|
||||
dimensions: [
|
||||
{
|
||||
name: 'tokens',
|
||||
limitPerMinute: 1000,
|
||||
extractUsage: (_p, r) => (r.tokenCount as number) ?? 0,
|
||||
},
|
||||
{
|
||||
name: 'search_units',
|
||||
limitPerMinute: 50,
|
||||
extractUsage: (_p, r) => (r.searchUnits as number) ?? 0,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const allowedConsume: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 9,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(allowedConsume)
|
||||
|
||||
const tokensBudget: TokenStatus = {
|
||||
tokensAvailable: 500,
|
||||
maxTokens: 2000,
|
||||
lastRefillAt: new Date(),
|
||||
nextRefillAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
const searchUnitsDepleted: TokenStatus = {
|
||||
tokensAvailable: 0,
|
||||
maxTokens: 100,
|
||||
lastRefillAt: new Date(),
|
||||
nextRefillAt: new Date(Date.now() + 30000),
|
||||
}
|
||||
mockAdapter.getTokenStatus
|
||||
.mockResolvedValueOnce(tokensBudget)
|
||||
.mockResolvedValueOnce(searchUnitsDepleted)
|
||||
|
||||
const result = await rateLimiter.acquireKey(
|
||||
testProvider,
|
||||
envKeyPrefix,
|
||||
multiDimensionConfig,
|
||||
'workspace-1'
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.billingActorRateLimited).toBe(true)
|
||||
expect(result.error).toContain('search_units')
|
||||
})
|
||||
})
|
||||
|
||||
describe('reportUsage', () => {
|
||||
const customConfig: CustomRateLimit = {
|
||||
mode: 'custom',
|
||||
requestsPerMinute: 5,
|
||||
dimensions: [
|
||||
{
|
||||
name: 'tokens',
|
||||
limitPerMinute: 1000,
|
||||
extractUsage: (_params, response) => (response.tokenCount as number) ?? 0,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
it('should consume actual tokens from dimension bucket after execution', async () => {
|
||||
const consumeResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 850,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(consumeResult)
|
||||
|
||||
const result = await rateLimiter.reportUsage(
|
||||
testProvider,
|
||||
'workspace-1',
|
||||
customConfig,
|
||||
{},
|
||||
{ tokenCount: 150 }
|
||||
)
|
||||
|
||||
expect(result.dimensions).toHaveLength(1)
|
||||
expect(result.dimensions[0].name).toBe('tokens')
|
||||
expect(result.dimensions[0].consumed).toBe(150)
|
||||
expect(result.dimensions[0].allowed).toBe(true)
|
||||
expect(result.dimensions[0].tokensRemaining).toBe(850)
|
||||
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledWith(
|
||||
'hosted:exa:actor:workspace-1:tokens',
|
||||
150,
|
||||
expect.objectContaining({ maxTokens: 2000, refillRate: 1000 })
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle overdrawn bucket gracefully (optimistic concurrency)', async () => {
|
||||
const overdrawnResult: ConsumeResult = {
|
||||
allowed: false,
|
||||
tokensRemaining: 0,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(overdrawnResult)
|
||||
|
||||
const result = await rateLimiter.reportUsage(
|
||||
testProvider,
|
||||
'workspace-1',
|
||||
customConfig,
|
||||
{},
|
||||
{ tokenCount: 500 }
|
||||
)
|
||||
|
||||
expect(result.dimensions[0].allowed).toBe(false)
|
||||
expect(result.dimensions[0].consumed).toBe(500)
|
||||
})
|
||||
|
||||
it('should skip consumption when extractUsage returns 0', async () => {
|
||||
const result = await rateLimiter.reportUsage(
|
||||
testProvider,
|
||||
'workspace-1',
|
||||
customConfig,
|
||||
{},
|
||||
{ tokenCount: 0 }
|
||||
)
|
||||
|
||||
expect(result.dimensions).toHaveLength(1)
|
||||
expect(result.dimensions[0].consumed).toBe(0)
|
||||
expect(mockAdapter.consumeTokens).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle multiple dimensions independently', async () => {
|
||||
const multiConfig: CustomRateLimit = {
|
||||
mode: 'custom',
|
||||
requestsPerMinute: 10,
|
||||
dimensions: [
|
||||
{
|
||||
name: 'tokens',
|
||||
limitPerMinute: 1000,
|
||||
extractUsage: (_p, r) => (r.tokenCount as number) ?? 0,
|
||||
},
|
||||
{
|
||||
name: 'search_units',
|
||||
limitPerMinute: 50,
|
||||
extractUsage: (_p, r) => (r.searchUnits as number) ?? 0,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const tokensConsumed: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 800,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
const searchConsumed: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 47,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens
|
||||
.mockResolvedValueOnce(tokensConsumed)
|
||||
.mockResolvedValueOnce(searchConsumed)
|
||||
|
||||
const result = await rateLimiter.reportUsage(
|
||||
testProvider,
|
||||
'workspace-1',
|
||||
multiConfig,
|
||||
{},
|
||||
{ tokenCount: 200, searchUnits: 3 }
|
||||
)
|
||||
|
||||
expect(result.dimensions).toHaveLength(2)
|
||||
expect(result.dimensions[0]).toEqual({
|
||||
name: 'tokens',
|
||||
consumed: 200,
|
||||
allowed: true,
|
||||
tokensRemaining: 800,
|
||||
})
|
||||
expect(result.dimensions[1]).toEqual({
|
||||
name: 'search_units',
|
||||
consumed: 3,
|
||||
allowed: true,
|
||||
tokensRemaining: 47,
|
||||
})
|
||||
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should continue with remaining dimensions if extractUsage throws', async () => {
|
||||
const throwingConfig: CustomRateLimit = {
|
||||
mode: 'custom',
|
||||
requestsPerMinute: 10,
|
||||
dimensions: [
|
||||
{
|
||||
name: 'broken',
|
||||
limitPerMinute: 100,
|
||||
extractUsage: () => {
|
||||
throw new Error('extraction failed')
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'tokens',
|
||||
limitPerMinute: 1000,
|
||||
extractUsage: (_p, r) => (r.tokenCount as number) ?? 0,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const consumeResult: ConsumeResult = {
|
||||
allowed: true,
|
||||
tokensRemaining: 900,
|
||||
resetAt: new Date(Date.now() + 60000),
|
||||
}
|
||||
mockAdapter.consumeTokens.mockResolvedValue(consumeResult)
|
||||
|
||||
const result = await rateLimiter.reportUsage(
|
||||
testProvider,
|
||||
'workspace-1',
|
||||
throwingConfig,
|
||||
{},
|
||||
{ tokenCount: 100 }
|
||||
)
|
||||
|
||||
expect(result.dimensions).toHaveLength(1)
|
||||
expect(result.dimensions[0].name).toBe('tokens')
|
||||
expect(mockAdapter.consumeTokens).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle storage errors gracefully', async () => {
|
||||
mockAdapter.consumeTokens.mockRejectedValue(new Error('db connection lost'))
|
||||
|
||||
const result = await rateLimiter.reportUsage(
|
||||
testProvider,
|
||||
'workspace-1',
|
||||
customConfig,
|
||||
{},
|
||||
{ tokenCount: 100 }
|
||||
)
|
||||
|
||||
expect(result.dimensions).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,349 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import {
|
||||
createStorageAdapter,
|
||||
type RateLimitStorageAdapter,
|
||||
type TokenBucketConfig,
|
||||
} from '@/lib/core/rate-limiter/storage'
|
||||
import {
|
||||
type AcquireKeyResult,
|
||||
type CustomRateLimit,
|
||||
DEFAULT_BURST_MULTIPLIER,
|
||||
DEFAULT_WINDOW_MS,
|
||||
type HostedKeyRateLimitConfig,
|
||||
type ReportUsageResult,
|
||||
toTokenBucketConfig,
|
||||
} from './types'
|
||||
|
||||
const logger = createLogger('HostedKeyRateLimiter')
|
||||
|
||||
/**
|
||||
* Resolves env var names for a numbered key prefix using a `{PREFIX}_COUNT` env var.
|
||||
* E.g. with `EXA_API_KEY_COUNT=5`, returns `['EXA_API_KEY_1', ..., 'EXA_API_KEY_5']`.
|
||||
*/
|
||||
function resolveEnvKeys(prefix: string): string[] {
|
||||
const count = Number.parseInt(process.env[`${prefix}_COUNT`] || '0', 10)
|
||||
const names: string[] = []
|
||||
for (let i = 1; i <= count; i++) {
|
||||
names.push(`${prefix}_${i}`)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
/** Dimension name for per-billing-actor request rate limiting */
|
||||
const ACTOR_REQUESTS_DIMENSION = 'actor_requests'
|
||||
|
||||
/**
|
||||
* Information about an available hosted key
|
||||
*/
|
||||
interface AvailableKey {
|
||||
key: string
|
||||
keyIndex: number
|
||||
envVarName: string
|
||||
}
|
||||
|
||||
/**
|
||||
* HostedKeyRateLimiter provides:
|
||||
* 1. Per-billing-actor rate limiting (enforced - blocks actors who exceed their limit)
|
||||
* 2. Round-robin key selection (distributes requests evenly across keys)
|
||||
* 3. Post-execution dimension usage tracking for custom rate limits
|
||||
*
|
||||
* The billing actor is typically a workspace ID, meaning rate limits are shared
|
||||
* across all users within the same workspace.
|
||||
*/
|
||||
export class HostedKeyRateLimiter {
|
||||
private storage: RateLimitStorageAdapter
|
||||
/** Round-robin counter per provider for even key distribution */
|
||||
private roundRobinCounters = new Map<string, number>()
|
||||
|
||||
constructor(storage?: RateLimitStorageAdapter) {
|
||||
this.storage = storage ?? createStorageAdapter()
|
||||
}
|
||||
|
||||
private buildActorStorageKey(provider: string, billingActorId: string): string {
|
||||
return `hosted:${provider}:actor:${billingActorId}:${ACTOR_REQUESTS_DIMENSION}`
|
||||
}
|
||||
|
||||
private buildDimensionStorageKey(
|
||||
provider: string,
|
||||
billingActorId: string,
|
||||
dimensionName: string
|
||||
): string {
|
||||
return `hosted:${provider}:actor:${billingActorId}:${dimensionName}`
|
||||
}
|
||||
|
||||
private getAvailableKeys(envKeys: string[]): AvailableKey[] {
|
||||
const keys: AvailableKey[] = []
|
||||
for (let i = 0; i < envKeys.length; i++) {
|
||||
const envVarName = envKeys[i]
|
||||
const key = process.env[envVarName]
|
||||
if (key) {
|
||||
keys.push({ key, keyIndex: i, envVarName })
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a token bucket config for the per-billing-actor request rate limit.
|
||||
* Works for both `per_request` and `custom` modes since both define `requestsPerMinute`.
|
||||
*/
|
||||
private getActorRateLimitConfig(config: HostedKeyRateLimitConfig): TokenBucketConfig | null {
|
||||
if (!config.requestsPerMinute) return null
|
||||
return toTokenBucketConfig(
|
||||
config.requestsPerMinute,
|
||||
config.burstMultiplier ?? DEFAULT_BURST_MULTIPLIER,
|
||||
DEFAULT_WINDOW_MS
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check and consume billing actor request rate limit. Returns null if allowed, or retry info if blocked.
|
||||
*/
|
||||
private async checkActorRateLimit(
|
||||
provider: string,
|
||||
billingActorId: string,
|
||||
config: HostedKeyRateLimitConfig
|
||||
): Promise<{ rateLimited: true; retryAfterMs: number } | null> {
|
||||
const bucketConfig = this.getActorRateLimitConfig(config)
|
||||
if (!bucketConfig) return null
|
||||
|
||||
const storageKey = this.buildActorStorageKey(provider, billingActorId)
|
||||
|
||||
try {
|
||||
const result = await this.storage.consumeTokens(storageKey, 1, bucketConfig)
|
||||
if (!result.allowed) {
|
||||
const retryAfterMs = Math.max(0, result.resetAt.getTime() - Date.now())
|
||||
logger.info(`Billing actor ${billingActorId} rate limited for ${provider}`, {
|
||||
provider,
|
||||
billingActorId,
|
||||
retryAfterMs,
|
||||
tokensRemaining: result.tokensRemaining,
|
||||
})
|
||||
return { rateLimited: true, retryAfterMs }
|
||||
}
|
||||
return null
|
||||
} catch (error) {
|
||||
logger.error(`Error checking billing actor rate limit for ${provider}`, {
|
||||
error,
|
||||
billingActorId,
|
||||
})
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pre-check that the billing actor has available budget in all custom dimensions.
|
||||
* Does NOT consume tokens -- just verifies the actor isn't already depleted.
|
||||
* Returns retry info for the most restrictive exhausted dimension, or null if all pass.
|
||||
*/
|
||||
private async preCheckDimensions(
|
||||
provider: string,
|
||||
billingActorId: string,
|
||||
config: CustomRateLimit
|
||||
): Promise<{ rateLimited: true; retryAfterMs: number; dimension: string } | null> {
|
||||
for (const dimension of config.dimensions) {
|
||||
const storageKey = this.buildDimensionStorageKey(provider, billingActorId, dimension.name)
|
||||
const bucketConfig = toTokenBucketConfig(
|
||||
dimension.limitPerMinute,
|
||||
dimension.burstMultiplier ?? DEFAULT_BURST_MULTIPLIER,
|
||||
DEFAULT_WINDOW_MS
|
||||
)
|
||||
|
||||
try {
|
||||
const status = await this.storage.getTokenStatus(storageKey, bucketConfig)
|
||||
if (status.tokensAvailable < 1) {
|
||||
const retryAfterMs = Math.max(0, status.nextRefillAt.getTime() - Date.now())
|
||||
logger.info(
|
||||
`Billing actor ${billingActorId} exhausted dimension ${dimension.name} for ${provider}`,
|
||||
{
|
||||
provider,
|
||||
billingActorId,
|
||||
dimension: dimension.name,
|
||||
tokensAvailable: status.tokensAvailable,
|
||||
retryAfterMs,
|
||||
}
|
||||
)
|
||||
return { rateLimited: true, retryAfterMs, dimension: dimension.name }
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error pre-checking dimension ${dimension.name} for ${provider}`, {
|
||||
error,
|
||||
billingActorId,
|
||||
})
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Acquire an available key via round-robin selection.
|
||||
*
|
||||
* For both modes:
|
||||
* 1. Per-billing-actor request rate limiting (enforced): blocks actors who exceed their request limit
|
||||
* 2. Round-robin key selection: cycles through available keys for even distribution
|
||||
*
|
||||
* For `custom` mode additionally:
|
||||
* 3. Pre-checks dimension budgets: blocks if any dimension is already depleted
|
||||
*
|
||||
* @param envKeyPrefix - Env var prefix (e.g. 'EXA_API_KEY'). Keys resolved via `{prefix}_COUNT`.
|
||||
* @param billingActorId - The billing actor (typically workspace ID) to rate limit against
|
||||
*/
|
||||
async acquireKey(
|
||||
provider: string,
|
||||
envKeyPrefix: string,
|
||||
config: HostedKeyRateLimitConfig,
|
||||
billingActorId: string
|
||||
): Promise<AcquireKeyResult> {
|
||||
if (config.requestsPerMinute) {
|
||||
const rateLimitResult = await this.checkActorRateLimit(provider, billingActorId, config)
|
||||
if (rateLimitResult) {
|
||||
return {
|
||||
success: false,
|
||||
billingActorRateLimited: true,
|
||||
retryAfterMs: rateLimitResult.retryAfterMs,
|
||||
error: `Rate limit exceeded. Please wait ${Math.ceil(rateLimitResult.retryAfterMs / 1000)} seconds. If you're getting throttled frequently, consider adding your own API key under Settings > BYOK to avoid shared rate limits.`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (config.mode === 'custom' && config.dimensions.length > 0) {
|
||||
const dimensionResult = await this.preCheckDimensions(provider, billingActorId, config)
|
||||
if (dimensionResult) {
|
||||
return {
|
||||
success: false,
|
||||
billingActorRateLimited: true,
|
||||
retryAfterMs: dimensionResult.retryAfterMs,
|
||||
error: `Rate limit exceeded for ${dimensionResult.dimension}. Please wait ${Math.ceil(dimensionResult.retryAfterMs / 1000)} seconds. If you're getting throttled frequently, consider adding your own API key under Settings > BYOK to avoid shared rate limits.`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const envKeys = resolveEnvKeys(envKeyPrefix)
|
||||
const availableKeys = this.getAvailableKeys(envKeys)
|
||||
|
||||
if (availableKeys.length === 0) {
|
||||
logger.warn(`No hosted keys configured for provider ${provider}`)
|
||||
return {
|
||||
success: false,
|
||||
error: `No hosted keys configured for ${provider}`,
|
||||
}
|
||||
}
|
||||
|
||||
const counter = this.roundRobinCounters.get(provider) ?? 0
|
||||
const selected = availableKeys[counter % availableKeys.length]
|
||||
this.roundRobinCounters.set(provider, counter + 1)
|
||||
|
||||
logger.debug(`Selected hosted key for ${provider}`, {
|
||||
provider,
|
||||
keyIndex: selected.keyIndex,
|
||||
envVarName: selected.envVarName,
|
||||
})
|
||||
|
||||
return {
|
||||
success: true,
|
||||
key: selected.key,
|
||||
keyIndex: selected.keyIndex,
|
||||
envVarName: selected.envVarName,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Report actual usage after successful tool execution (custom mode only).
|
||||
* Calls `extractUsage` on each dimension and consumes the actual token count.
|
||||
* This is the "post-execution" phase of the optimistic two-phase approach.
|
||||
*/
|
||||
async reportUsage(
|
||||
provider: string,
|
||||
billingActorId: string,
|
||||
config: CustomRateLimit,
|
||||
params: Record<string, unknown>,
|
||||
response: Record<string, unknown>
|
||||
): Promise<ReportUsageResult> {
|
||||
const results: ReportUsageResult['dimensions'] = []
|
||||
|
||||
for (const dimension of config.dimensions) {
|
||||
let usage: number
|
||||
try {
|
||||
usage = dimension.extractUsage(params, response)
|
||||
} catch (error) {
|
||||
logger.error(`Failed to extract usage for dimension ${dimension.name}`, {
|
||||
provider,
|
||||
billingActorId,
|
||||
error,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if (usage <= 0) {
|
||||
results.push({
|
||||
name: dimension.name,
|
||||
consumed: 0,
|
||||
allowed: true,
|
||||
tokensRemaining: 0,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
const storageKey = this.buildDimensionStorageKey(provider, billingActorId, dimension.name)
|
||||
const bucketConfig = toTokenBucketConfig(
|
||||
dimension.limitPerMinute,
|
||||
dimension.burstMultiplier ?? DEFAULT_BURST_MULTIPLIER,
|
||||
DEFAULT_WINDOW_MS
|
||||
)
|
||||
|
||||
try {
|
||||
const consumeResult = await this.storage.consumeTokens(storageKey, usage, bucketConfig)
|
||||
|
||||
results.push({
|
||||
name: dimension.name,
|
||||
consumed: usage,
|
||||
allowed: consumeResult.allowed,
|
||||
tokensRemaining: consumeResult.tokensRemaining,
|
||||
})
|
||||
|
||||
if (!consumeResult.allowed) {
|
||||
logger.warn(
|
||||
`Dimension ${dimension.name} overdrawn for ${provider} (optimistic concurrency)`,
|
||||
{ provider, billingActorId, usage, tokensRemaining: consumeResult.tokensRemaining }
|
||||
)
|
||||
}
|
||||
|
||||
logger.debug(`Consumed ${usage} from dimension ${dimension.name} for ${provider}`, {
|
||||
provider,
|
||||
billingActorId,
|
||||
usage,
|
||||
allowed: consumeResult.allowed,
|
||||
tokensRemaining: consumeResult.tokensRemaining,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`Failed to consume tokens for dimension ${dimension.name}`, {
|
||||
provider,
|
||||
billingActorId,
|
||||
usage,
|
||||
error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return { dimensions: results }
|
||||
}
|
||||
}
|
||||
|
||||
let cachedInstance: HostedKeyRateLimiter | null = null
|
||||
|
||||
/**
|
||||
* Get the singleton HostedKeyRateLimiter instance
|
||||
*/
|
||||
export function getHostedKeyRateLimiter(): HostedKeyRateLimiter {
|
||||
if (!cachedInstance) {
|
||||
cachedInstance = new HostedKeyRateLimiter()
|
||||
}
|
||||
return cachedInstance
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset the cached rate limiter (for testing)
|
||||
*/
|
||||
export function resetHostedKeyRateLimiter(): void {
|
||||
cachedInstance = null
|
||||
}
|
||||
17
apps/sim/lib/core/rate-limiter/hosted-key/index.ts
Normal file
17
apps/sim/lib/core/rate-limiter/hosted-key/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
export {
|
||||
getHostedKeyRateLimiter,
|
||||
HostedKeyRateLimiter,
|
||||
resetHostedKeyRateLimiter,
|
||||
} from './hosted-key-rate-limiter'
|
||||
export {
|
||||
type AcquireKeyResult,
|
||||
type CustomRateLimit,
|
||||
DEFAULT_BURST_MULTIPLIER,
|
||||
DEFAULT_WINDOW_MS,
|
||||
type HostedKeyRateLimitConfig,
|
||||
type HostedKeyRateLimitMode,
|
||||
type PerRequestRateLimit,
|
||||
type RateLimitDimension,
|
||||
type ReportUsageResult,
|
||||
toTokenBucketConfig,
|
||||
} from './types'
|
||||
108
apps/sim/lib/core/rate-limiter/hosted-key/types.ts
Normal file
108
apps/sim/lib/core/rate-limiter/hosted-key/types.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
import type { TokenBucketConfig } from '@/lib/core/rate-limiter/storage'
|
||||
|
||||
export type HostedKeyRateLimitMode = 'per_request' | 'custom'
|
||||
|
||||
/**
|
||||
* Simple per-request rate limit configuration.
|
||||
* Enforces per-billing-actor rate limiting and distributes requests across keys.
|
||||
*/
|
||||
export interface PerRequestRateLimit {
|
||||
mode: 'per_request'
|
||||
/** Maximum requests per minute per billing actor (enforced - blocks if exceeded) */
|
||||
requestsPerMinute: number
|
||||
/** Burst multiplier for token bucket max capacity. Default: 2 */
|
||||
burstMultiplier?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom rate limit with multiple dimensions (e.g., tokens, search units).
|
||||
* Allows tracking different usage metrics independently.
|
||||
*/
|
||||
export interface CustomRateLimit {
|
||||
mode: 'custom'
|
||||
/** Maximum requests per minute per billing actor (enforced - blocks if exceeded) */
|
||||
requestsPerMinute: number
|
||||
/** Multiple dimensions to track */
|
||||
dimensions: RateLimitDimension[]
|
||||
/** Burst multiplier for token bucket max capacity. Default: 2 */
|
||||
burstMultiplier?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* A single dimension for custom rate limiting.
|
||||
* Each dimension has its own token bucket.
|
||||
*/
|
||||
export interface RateLimitDimension {
|
||||
/** Dimension name (e.g., 'tokens', 'search_units') - used in storage key */
|
||||
name: string
|
||||
/** Limit per minute for this dimension */
|
||||
limitPerMinute: number
|
||||
/** Burst multiplier for token bucket max capacity. Default: 2 */
|
||||
burstMultiplier?: number
|
||||
/**
|
||||
* Extract usage amount from request params and response.
|
||||
* Called after successful execution to consume the actual usage.
|
||||
*/
|
||||
extractUsage: (params: Record<string, unknown>, response: Record<string, unknown>) => number
|
||||
}
|
||||
|
||||
/** Union of all hosted key rate limit configuration types */
|
||||
export type HostedKeyRateLimitConfig = PerRequestRateLimit | CustomRateLimit
|
||||
|
||||
/**
|
||||
* Result from acquiring a key from the hosted key rate limiter
|
||||
*/
|
||||
export interface AcquireKeyResult {
|
||||
/** Whether a key was successfully acquired */
|
||||
success: boolean
|
||||
/** The API key value (if success=true) */
|
||||
key?: string
|
||||
/** Index of the key in the envKeys array */
|
||||
keyIndex?: number
|
||||
/** Environment variable name of the selected key */
|
||||
envVarName?: string
|
||||
/** Error message if no key available */
|
||||
error?: string
|
||||
/** Whether the billing actor was rate limited (exceeded their limit) */
|
||||
billingActorRateLimited?: boolean
|
||||
/** Milliseconds until the billing actor's rate limit resets (if billingActorRateLimited=true) */
|
||||
retryAfterMs?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Result from reporting post-execution usage for custom dimensions
|
||||
*/
|
||||
export interface ReportUsageResult {
|
||||
/** Per-dimension consumption results */
|
||||
dimensions: {
|
||||
name: string
|
||||
consumed: number
|
||||
allowed: boolean
|
||||
tokensRemaining: number
|
||||
}[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert rate limit config to token bucket config for a dimension
|
||||
*/
|
||||
export function toTokenBucketConfig(
|
||||
limitPerMinute: number,
|
||||
burstMultiplier = 2,
|
||||
windowMs = 60000
|
||||
): TokenBucketConfig {
|
||||
return {
|
||||
maxTokens: limitPerMinute * burstMultiplier,
|
||||
refillRate: limitPerMinute,
|
||||
refillIntervalMs: windowMs,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Default rate limit window in milliseconds (1 minute)
|
||||
*/
|
||||
export const DEFAULT_WINDOW_MS = 60000
|
||||
|
||||
/**
|
||||
* Default burst multiplier
|
||||
*/
|
||||
export const DEFAULT_BURST_MULTIPLIER = 2
|
||||
@@ -1,3 +1,18 @@
|
||||
export {
|
||||
type AcquireKeyResult,
|
||||
type CustomRateLimit,
|
||||
DEFAULT_BURST_MULTIPLIER,
|
||||
DEFAULT_WINDOW_MS,
|
||||
getHostedKeyRateLimiter,
|
||||
type HostedKeyRateLimitConfig,
|
||||
HostedKeyRateLimiter,
|
||||
type HostedKeyRateLimitMode,
|
||||
type PerRequestRateLimit,
|
||||
type RateLimitDimension,
|
||||
type ReportUsageResult,
|
||||
resetHostedKeyRateLimiter,
|
||||
toTokenBucketConfig,
|
||||
} from './hosted-key'
|
||||
export type { RateLimitResult, RateLimitStatus } from './rate-limiter'
|
||||
export { RateLimiter } from './rate-limiter'
|
||||
export type { RateLimitStorageAdapter, TokenBucketConfig } from './storage'
|
||||
|
||||
@@ -51,7 +51,7 @@ export class DbTokenBucket implements RateLimitStorageAdapter {
|
||||
) * ${config.refillRate}
|
||||
)::numeric
|
||||
) - ${requestedTokens}::numeric
|
||||
ELSE ${rateLimitBucket.tokens}::numeric
|
||||
ELSE -1
|
||||
END
|
||||
`,
|
||||
lastRefillAt: sql`
|
||||
|
||||
@@ -934,6 +934,55 @@ export const PlatformEvents = {
|
||||
})
|
||||
},
|
||||
|
||||
/**
|
||||
* Track when a rate limit error is surfaced to the end user (not retried/absorbed).
|
||||
* Fires for both billing-actor limits and exhausted upstream retries.
|
||||
*/
|
||||
userThrottled: (attrs: {
|
||||
toolId: string
|
||||
reason: 'billing_actor_limit' | 'upstream_retries_exhausted'
|
||||
provider?: string
|
||||
retryAfterMs?: number
|
||||
userId?: string
|
||||
workspaceId?: string
|
||||
workflowId?: string
|
||||
}) => {
|
||||
trackPlatformEvent('platform.user.throttled', {
|
||||
'tool.id': attrs.toolId,
|
||||
'throttle.reason': attrs.reason,
|
||||
...(attrs.provider && { 'provider.id': attrs.provider }),
|
||||
...(attrs.retryAfterMs != null && { 'rate_limit.retry_after_ms': attrs.retryAfterMs }),
|
||||
...(attrs.userId && { 'user.id': attrs.userId }),
|
||||
...(attrs.workspaceId && { 'workspace.id': attrs.workspaceId }),
|
||||
...(attrs.workflowId && { 'workflow.id': attrs.workflowId }),
|
||||
})
|
||||
},
|
||||
|
||||
/**
|
||||
* Track hosted key rate limited by upstream provider (429 from the external API)
|
||||
*/
|
||||
hostedKeyRateLimited: (attrs: {
|
||||
toolId: string
|
||||
envVarName: string
|
||||
attempt: number
|
||||
maxRetries: number
|
||||
delayMs: number
|
||||
userId?: string
|
||||
workspaceId?: string
|
||||
workflowId?: string
|
||||
}) => {
|
||||
trackPlatformEvent('platform.hosted_key.rate_limited', {
|
||||
'tool.id': attrs.toolId,
|
||||
'hosted_key.env_var': attrs.envVarName,
|
||||
'rate_limit.attempt': attrs.attempt,
|
||||
'rate_limit.max_retries': attrs.maxRetries,
|
||||
'rate_limit.delay_ms': attrs.delayMs,
|
||||
...(attrs.userId && { 'user.id': attrs.userId }),
|
||||
...(attrs.workspaceId && { 'workspace.id': attrs.workspaceId }),
|
||||
...(attrs.workflowId && { 'workflow.id': attrs.workflowId }),
|
||||
})
|
||||
},
|
||||
|
||||
/**
|
||||
* Track chat deployed (workflow deployed as chat interface)
|
||||
*/
|
||||
|
||||
@@ -181,6 +181,7 @@ export class ExecutionLogger implements IExecutionLoggerService {
|
||||
input: number
|
||||
output: number
|
||||
total: number
|
||||
toolCost?: number
|
||||
tokens: { input: number; output: number; total: number }
|
||||
}
|
||||
>
|
||||
@@ -507,6 +508,7 @@ export class ExecutionLogger implements IExecutionLoggerService {
|
||||
input: number
|
||||
output: number
|
||||
total: number
|
||||
toolCost?: number
|
||||
tokens: { input: number; output: number; total: number }
|
||||
}
|
||||
>
|
||||
|
||||
@@ -95,6 +95,7 @@ export function calculateCostSummary(traceSpans: any[]): {
|
||||
input: number
|
||||
output: number
|
||||
total: number
|
||||
toolCost?: number
|
||||
tokens: { input: number; output: number; total: number }
|
||||
}
|
||||
>
|
||||
@@ -143,6 +144,7 @@ export function calculateCostSummary(traceSpans: any[]): {
|
||||
input: number
|
||||
output: number
|
||||
total: number
|
||||
toolCost?: number
|
||||
tokens: { input: number; output: number; total: number }
|
||||
}
|
||||
> = {}
|
||||
@@ -171,6 +173,10 @@ export function calculateCostSummary(traceSpans: any[]): {
|
||||
models[model].tokens.input += span.tokens?.input ?? span.tokens?.prompt ?? 0
|
||||
models[model].tokens.output += span.tokens?.output ?? span.tokens?.completion ?? 0
|
||||
models[model].tokens.total += span.tokens?.total || 0
|
||||
|
||||
if (span.cost.toolCost) {
|
||||
models[model].toolCost = (models[model].toolCost || 0) + span.cost.toolCost
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { getEnv, isTruthy } from '@/lib/core/config/env'
|
||||
import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import type { SubBlockConfig } from '@/blocks/types'
|
||||
|
||||
export type CanonicalMode = 'basic' | 'advanced'
|
||||
@@ -287,3 +288,12 @@ export function isSubBlockFeatureEnabled(subBlock: SubBlockConfig): boolean {
|
||||
if (!subBlock.requiresFeature) return true
|
||||
return isTruthy(getEnv(subBlock.requiresFeature))
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a subblock should be hidden because we're running on hosted Sim.
|
||||
* Used for tool API key fields that should be hidden when Sim provides hosted keys.
|
||||
*/
|
||||
export function isSubBlockHiddenByHostedKey(subBlock: SubBlockConfig): boolean {
|
||||
if (!subBlock.hideWhenHosted) return false
|
||||
return isHosted
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
@@ -490,7 +491,7 @@ export async function executeAnthropicProviderRequest(
|
||||
}
|
||||
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...messages]
|
||||
let iterationCount = 0
|
||||
let hasUsedForcedTool = false
|
||||
@@ -609,7 +610,7 @@ export async function executeAnthropicProviderRequest(
|
||||
})
|
||||
|
||||
let resultContent: unknown
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -783,10 +784,12 @@ export async function executeAnthropicProviderRequest(
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(request.model, usage.input_tokens, usage.output_tokens)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
@@ -829,6 +832,7 @@ export async function executeAnthropicProviderRequest(
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
toolCost: undefined as number | undefined,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
@@ -901,7 +905,7 @@ export async function executeAnthropicProviderRequest(
|
||||
}
|
||||
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...messages]
|
||||
let iterationCount = 0
|
||||
let hasUsedForcedTool = false
|
||||
@@ -1022,7 +1026,7 @@ export async function executeAnthropicProviderRequest(
|
||||
})
|
||||
|
||||
let resultContent: unknown
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -1208,10 +1212,12 @@ export async function executeAnthropicProviderRequest(
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(request.model, usage.input_tokens, usage.output_tokens)
|
||||
const tc2 = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: cost.input + streamCost.input,
|
||||
output: cost.output + streamCost.output,
|
||||
total: cost.total + streamCost.total,
|
||||
toolCost: tc2 || undefined,
|
||||
total: cost.total + streamCost.total + tc2,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
@@ -1254,6 +1260,7 @@ export async function executeAnthropicProviderRequest(
|
||||
cost: {
|
||||
input: cost.input,
|
||||
output: cost.output,
|
||||
toolCost: undefined as number | undefined,
|
||||
total: cost.total,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -35,6 +35,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
@@ -499,10 +500,12 @@ async function executeChatCompletionsRequest(
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
@@ -815,10 +816,12 @@ export const bedrockProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(request.model, usage.inputTokens, usage.outputTokens)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: cost.input + streamCost.input,
|
||||
output: cost.output + streamCost.output,
|
||||
total: cost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: cost.total + streamCost.total + tc,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
@@ -861,6 +864,7 @@ export const bedrockProvider: ProviderConfig = {
|
||||
cost: {
|
||||
input: cost.input,
|
||||
output: cost.output,
|
||||
toolCost: undefined as number | undefined,
|
||||
total: cost.total,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -195,7 +196,7 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
@@ -313,7 +314,7 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
duration: duration,
|
||||
})
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -472,10 +473,12 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
@@ -508,6 +511,7 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
toolCost: undefined as number | undefined,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -205,7 +206,7 @@ export const deepseekProvider: ProviderConfig = {
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
let hasUsedForcedTool = false
|
||||
@@ -325,7 +326,7 @@ export const deepseekProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -471,10 +472,12 @@ export const deepseekProvider: ProviderConfig = {
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}
|
||||
),
|
||||
@@ -508,6 +511,7 @@ export const deepseekProvider: ProviderConfig = {
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
toolCost: undefined as number | undefined,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -31,6 +31,7 @@ import {
|
||||
isDeepResearchModel,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import type { ExecutionState, GeminiProviderType, GeminiUsage } from './types'
|
||||
@@ -1163,10 +1164,12 @@ export async function executeGeminiRequest(
|
||||
usage.promptTokenCount,
|
||||
usage.candidatesTokenCount
|
||||
)
|
||||
const tc = sumToolCosts(state.toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
pricing: streamCost.pricing,
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -201,7 +202,7 @@ export const groqProvider: ProviderConfig = {
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
let modelTime = firstResponseTime
|
||||
@@ -303,7 +304,7 @@ export const groqProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -426,10 +427,12 @@ export const groqProvider: ProviderConfig = {
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
@@ -462,6 +465,7 @@ export const groqProvider: ProviderConfig = {
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
toolCost: undefined as number | undefined,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
calculateCost,
|
||||
generateStructuredOutputInstructions,
|
||||
shouldBillModelUsage,
|
||||
sumToolCosts,
|
||||
supportsReasoningEffort,
|
||||
supportsTemperature,
|
||||
supportsThinking,
|
||||
@@ -162,5 +163,11 @@ export async function executeProviderRequest(
|
||||
}
|
||||
}
|
||||
|
||||
const toolCost = sumToolCosts(response.toolResults)
|
||||
if (toolCost > 0 && response.cost) {
|
||||
response.cost.toolCost = toolCost
|
||||
response.cost.total += toolCost
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -258,7 +259,7 @@ export const mistralProvider: ProviderConfig = {
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
@@ -366,7 +367,7 @@ export const mistralProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -482,10 +483,12 @@ export const mistralProvider: ProviderConfig = {
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
|
||||
@@ -13,7 +13,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { ProviderError } from '@/providers/types'
|
||||
import { calculateCost, prepareToolExecution } from '@/providers/utils'
|
||||
import { calculateCost, prepareToolExecution, sumToolCosts } from '@/providers/utils'
|
||||
import { useProvidersStore } from '@/stores/providers'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
@@ -271,7 +271,7 @@ export const ollamaProvider: ProviderConfig = {
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
@@ -377,7 +377,7 @@ export const ollamaProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -486,10 +486,12 @@ export const ollamaProvider: ProviderConfig = {
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -405,7 +406,7 @@ export async function executeResponsesProviderRequest(
|
||||
}
|
||||
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
let iterationCount = 0
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
@@ -512,7 +513,7 @@ export async function executeResponsesProviderRequest(
|
||||
})
|
||||
|
||||
let resultContent: Record<string, unknown>
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output as Record<string, unknown>
|
||||
} else {
|
||||
@@ -728,10 +729,12 @@ export async function executeResponsesProviderRequest(
|
||||
usage?.promptTokens || 0,
|
||||
usage?.completionTokens || 0
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
generateSchemaInstructions,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
@@ -478,10 +479,12 @@ export const openRouterProvider: ProviderConfig = {
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
|
||||
@@ -79,7 +79,7 @@ export interface ProviderResponse {
|
||||
total?: number
|
||||
}
|
||||
toolCalls?: FunctionCallResponse[]
|
||||
toolResults?: any[]
|
||||
toolResults?: Record<string, unknown>[]
|
||||
timing?: {
|
||||
startTime: string
|
||||
endTime: string
|
||||
@@ -93,6 +93,7 @@ export interface ProviderResponse {
|
||||
cost?: {
|
||||
input: number
|
||||
output: number
|
||||
toolCost?: number
|
||||
total: number
|
||||
pricing: ModelPricing
|
||||
}
|
||||
|
||||
@@ -1405,6 +1405,7 @@ describe('prepareToolExecution', () => {
|
||||
workspaceId: 'ws-456',
|
||||
chatId: 'chat-789',
|
||||
userId: 'user-abc',
|
||||
skipFixedUsageLog: true,
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -650,6 +650,20 @@ export function calculateCost(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sums the `cost.total` from each tool result returned during a provider tool loop.
|
||||
* Tool results may carry a `cost` object injected by `applyHostedKeyCostToResult`.
|
||||
*/
|
||||
export function sumToolCosts(toolResults?: Record<string, unknown>[]): number {
|
||||
if (!toolResults?.length) return 0
|
||||
let total = 0
|
||||
for (const tr of toolResults) {
|
||||
const cost = tr?.cost as Record<string, unknown> | undefined
|
||||
if (cost?.total && typeof cost.total === 'number') total += cost.total
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
export function getModelPricing(modelId: string): any {
|
||||
const embeddingPricing = getEmbeddingModelPricing(modelId)
|
||||
if (embeddingPricing) {
|
||||
@@ -1140,6 +1154,7 @@ export function prepareToolExecution(
|
||||
? { isDeployedContext: request.isDeployedContext }
|
||||
: {}),
|
||||
...(request.callChain ? { callChain: request.callChain } : {}),
|
||||
skipFixedUsageLog: true,
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { createReadableStreamFromVLLMStream } from '@/providers/vllm/utils'
|
||||
@@ -315,7 +316,7 @@ export const vllmProvider: ProviderConfig = {
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
@@ -428,7 +429,7 @@ export const vllmProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -553,10 +554,12 @@ export const vllmProvider: ProviderConfig = {
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sumToolCosts,
|
||||
} from '@/providers/utils'
|
||||
import {
|
||||
checkForForcedToolUsage,
|
||||
@@ -215,7 +216,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const toolResults: Record<string, unknown>[] = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
@@ -331,7 +332,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
duration: duration,
|
||||
})
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
if (result.success && result.output) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
@@ -509,10 +510,12 @@ export const xAIProvider: ProviderConfig = {
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
const tc = sumToolCosts(toolResults)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
toolCost: tc || undefined,
|
||||
total: accumulatedCost.total + streamCost.total + tc,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
@@ -545,6 +548,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
toolCost: undefined as number | undefined,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
isCanonicalPair,
|
||||
isNonEmptyValue,
|
||||
isSubBlockFeatureEnabled,
|
||||
isSubBlockHiddenByHostedKey,
|
||||
resolveCanonicalMode,
|
||||
} from '@/lib/workflows/subblocks/visibility'
|
||||
import { getBlock } from '@/blocks'
|
||||
@@ -48,6 +49,7 @@ function shouldSerializeSubBlock(
|
||||
canonicalModeOverrides?: CanonicalModeOverrides
|
||||
): boolean {
|
||||
if (!isSubBlockFeatureEnabled(subBlockConfig)) return false
|
||||
if (isSubBlockHiddenByHostedKey(subBlockConfig)) return false
|
||||
|
||||
if (subBlockConfig.mode === 'trigger') {
|
||||
if (!isTriggerContext && !isTriggerCategory) return false
|
||||
|
||||
@@ -27,6 +27,25 @@ export const answerTool: ToolConfig<ExaAnswerParams, ExaAnswerResponse> = {
|
||||
description: 'Exa AI API Key',
|
||||
},
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'EXA_API_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
byokProviderId: 'exa',
|
||||
pricing: {
|
||||
type: 'custom',
|
||||
getCost: (_params, output) => {
|
||||
const costDollars = output.__costDollars as { total?: number } | undefined
|
||||
if (costDollars?.total == null) {
|
||||
throw new Error('Exa answer response missing costDollars field')
|
||||
}
|
||||
return { cost: costDollars.total, metadata: { costDollars } }
|
||||
},
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request',
|
||||
requestsPerMinute: 5,
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
url: 'https://api.exa.ai/answer',
|
||||
@@ -61,6 +80,7 @@ export const answerTool: ToolConfig<ExaAnswerParams, ExaAnswerResponse> = {
|
||||
url: citation.url,
|
||||
text: citation.text || '',
|
||||
})) || [],
|
||||
__costDollars: data.costDollars,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
@@ -76,6 +76,25 @@ export const findSimilarLinksTool: ToolConfig<
|
||||
description: 'Exa AI API Key',
|
||||
},
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'EXA_API_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
byokProviderId: 'exa',
|
||||
pricing: {
|
||||
type: 'custom',
|
||||
getCost: (_params, output) => {
|
||||
const costDollars = output.__costDollars as { total?: number } | undefined
|
||||
if (costDollars?.total == null) {
|
||||
throw new Error('Exa find_similar_links response missing costDollars field')
|
||||
}
|
||||
return { cost: costDollars.total, metadata: { costDollars } }
|
||||
},
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request',
|
||||
requestsPerMinute: 10,
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
url: 'https://api.exa.ai/findSimilar',
|
||||
@@ -140,6 +159,7 @@ export const findSimilarLinksTool: ToolConfig<
|
||||
highlights: result.highlights,
|
||||
score: result.score || 0,
|
||||
})),
|
||||
__costDollars: data.costDollars,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
@@ -61,6 +61,25 @@ export const getContentsTool: ToolConfig<ExaGetContentsParams, ExaGetContentsRes
|
||||
description: 'Exa AI API Key',
|
||||
},
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'EXA_API_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
byokProviderId: 'exa',
|
||||
pricing: {
|
||||
type: 'custom',
|
||||
getCost: (_params, output) => {
|
||||
const costDollars = output.__costDollars as { total?: number } | undefined
|
||||
if (costDollars?.total == null) {
|
||||
throw new Error('Exa get_contents response missing costDollars field')
|
||||
}
|
||||
return { cost: costDollars.total, metadata: { costDollars } }
|
||||
},
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request',
|
||||
requestsPerMinute: 10,
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
url: 'https://api.exa.ai/contents',
|
||||
@@ -132,6 +151,7 @@ export const getContentsTool: ToolConfig<ExaGetContentsParams, ExaGetContentsRes
|
||||
summary: result.summary || '',
|
||||
highlights: result.highlights,
|
||||
})),
|
||||
__costDollars: data.costDollars,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
@@ -86,6 +86,25 @@ export const searchTool: ToolConfig<ExaSearchParams, ExaSearchResponse> = {
|
||||
description: 'Exa AI API Key',
|
||||
},
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'EXA_API_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
byokProviderId: 'exa',
|
||||
pricing: {
|
||||
type: 'custom',
|
||||
getCost: (_params, output) => {
|
||||
const costDollars = output.__costDollars as { total?: number } | undefined
|
||||
if (costDollars?.total == null) {
|
||||
throw new Error('Exa search response missing costDollars field')
|
||||
}
|
||||
return { cost: costDollars.total, metadata: { costDollars } }
|
||||
},
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request',
|
||||
requestsPerMinute: 5,
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
url: 'https://api.exa.ai/search',
|
||||
@@ -167,6 +186,7 @@ export const searchTool: ToolConfig<ExaSearchParams, ExaSearchResponse> = {
|
||||
highlights: result.highlights,
|
||||
score: result.score,
|
||||
})),
|
||||
__costDollars: data.costDollars,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
@@ -6,6 +6,11 @@ export interface ExaBaseParams {
|
||||
apiKey: string
|
||||
}
|
||||
|
||||
/** Cost breakdown returned by Exa API responses */
|
||||
export interface ExaCostDollars {
|
||||
total: number
|
||||
}
|
||||
|
||||
// Search tool types
|
||||
export interface ExaSearchParams extends ExaBaseParams {
|
||||
query: string
|
||||
@@ -50,6 +55,7 @@ export interface ExaSearchResult {
|
||||
export interface ExaSearchResponse extends ToolResponse {
|
||||
output: {
|
||||
results: ExaSearchResult[]
|
||||
__costDollars?: ExaCostDollars
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,6 +84,7 @@ export interface ExaGetContentsResult {
|
||||
export interface ExaGetContentsResponse extends ToolResponse {
|
||||
output: {
|
||||
results: ExaGetContentsResult[]
|
||||
__costDollars?: ExaCostDollars
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,6 +127,7 @@ export interface ExaSimilarLink {
|
||||
export interface ExaFindSimilarLinksResponse extends ToolResponse {
|
||||
output: {
|
||||
similarLinks: ExaSimilarLink[]
|
||||
__costDollars?: ExaCostDollars
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,6 +145,7 @@ export interface ExaAnswerResponse extends ToolResponse {
|
||||
url: string
|
||||
text: string
|
||||
}[]
|
||||
__costDollars?: ExaCostDollars
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,52 +15,85 @@ import {
|
||||
} from '@sim/testing'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// Mock custom tools query - must be hoisted before imports
|
||||
vi.mock('@/hooks/queries/custom-tools', () => ({
|
||||
getCustomTool: (toolId: string) => {
|
||||
if (toolId === 'custom-tool-123') {
|
||||
return {
|
||||
id: 'custom-tool-123',
|
||||
title: 'Custom Weather Tool',
|
||||
code: 'return { result: "Weather data" }',
|
||||
schema: {
|
||||
function: {
|
||||
description: 'Get weather information',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string', description: 'City name' },
|
||||
unit: { type: 'string', description: 'Unit (metric/imperial)' },
|
||||
},
|
||||
required: ['location'],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
// Hoisted mock state - these are available to vi.mock factories
|
||||
const { mockIsHosted, mockEnv, mockGetBYOKKey, mockLogFixedUsage, mockRateLimiterFns } = vi.hoisted(
|
||||
() => ({
|
||||
mockIsHosted: { value: false },
|
||||
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
|
||||
mockGetBYOKKey: vi.fn(),
|
||||
mockLogFixedUsage: vi.fn(),
|
||||
mockRateLimiterFns: {
|
||||
acquireKey: vi.fn(),
|
||||
preConsumeCapacity: vi.fn(),
|
||||
consumeCapacity: vi.fn(),
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
// Mock feature flags
|
||||
vi.mock('@/lib/core/config/feature-flags', () => ({
|
||||
get isHosted() {
|
||||
return mockIsHosted.value
|
||||
},
|
||||
getCustomTools: () => [
|
||||
{
|
||||
id: 'custom-tool-123',
|
||||
title: 'Custom Weather Tool',
|
||||
code: 'return { result: "Weather data" }',
|
||||
schema: {
|
||||
function: {
|
||||
description: 'Get weather information',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string', description: 'City name' },
|
||||
unit: { type: 'string', description: 'Unit (metric/imperial)' },
|
||||
},
|
||||
required: ['location'],
|
||||
isProd: false,
|
||||
isDev: true,
|
||||
isTest: true,
|
||||
}))
|
||||
|
||||
// Mock env config to control hosted key availability
|
||||
vi.mock('@/lib/core/config/env', () => ({
|
||||
env: new Proxy({} as Record<string, string | undefined>, {
|
||||
get: (_target, prop: string) => mockEnv[prop],
|
||||
}),
|
||||
getEnv: (key: string) => mockEnv[key],
|
||||
isTruthy: (val: unknown) => val === true || val === 'true' || val === '1',
|
||||
isFalsy: (val: unknown) => val === false || val === 'false' || val === '0',
|
||||
}))
|
||||
|
||||
// Mock getBYOKKey
|
||||
vi.mock('@/lib/api-key/byok', () => ({
|
||||
getBYOKKey: (...args: unknown[]) => mockGetBYOKKey(...args),
|
||||
}))
|
||||
|
||||
// Mock logFixedUsage for billing
|
||||
vi.mock('@/lib/billing/core/usage-log', () => ({
|
||||
logFixedUsage: (...args: unknown[]) => mockLogFixedUsage(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/core/rate-limiter/hosted-key', () => ({
|
||||
getHostedKeyRateLimiter: () => mockRateLimiterFns,
|
||||
}))
|
||||
|
||||
// Mock custom tools - define mock data inside factory function
|
||||
vi.mock('@/hooks/queries/custom-tools', () => {
|
||||
const mockCustomTool = {
|
||||
id: 'custom-tool-123',
|
||||
title: 'Custom Weather Tool',
|
||||
code: 'return { result: "Weather data" }',
|
||||
schema: {
|
||||
function: {
|
||||
description: 'Get weather information',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string', description: 'City name' },
|
||||
unit: { type: 'string', description: 'Unit (metric/imperial)' },
|
||||
},
|
||||
required: ['location'],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}))
|
||||
}
|
||||
return {
|
||||
getCustomTool: (toolId: string) => {
|
||||
if (toolId === 'custom-tool-123') {
|
||||
return mockCustomTool
|
||||
}
|
||||
return undefined
|
||||
},
|
||||
getCustomTools: () => [mockCustomTool],
|
||||
}
|
||||
})
|
||||
|
||||
import { executeTool } from '@/tools/index'
|
||||
import { tools } from '@/tools/registry'
|
||||
@@ -1186,3 +1219,712 @@ describe('MCP Tool Execution', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Hosted Key Injection', () => {
|
||||
let cleanupEnvVars: () => void
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000'
|
||||
cleanupEnvVars = setupEnvVars({ NEXT_PUBLIC_APP_URL: 'http://localhost:3000' })
|
||||
vi.clearAllMocks()
|
||||
mockGetBYOKKey.mockReset()
|
||||
mockLogFixedUsage.mockReset()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
cleanupEnvVars()
|
||||
})
|
||||
|
||||
it('should not inject hosted key when tool has no hosting config', async () => {
|
||||
const mockTool = {
|
||||
id: 'test_no_hosting',
|
||||
name: 'Test No Hosting',
|
||||
description: 'A test tool without hosting config',
|
||||
version: '1.0.0',
|
||||
params: {},
|
||||
request: {
|
||||
url: '/api/test/endpoint',
|
||||
method: 'POST' as const,
|
||||
headers: () => ({ 'Content-Type': 'application/json' }),
|
||||
},
|
||||
transformResponse: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
output: { result: 'success' },
|
||||
}),
|
||||
}
|
||||
|
||||
const originalTools = { ...tools }
|
||||
;(tools as any).test_no_hosting = mockTool
|
||||
|
||||
global.fetch = Object.assign(
|
||||
vi.fn().mockImplementation(async () => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})),
|
||||
{ preconnect: vi.fn() }
|
||||
) as typeof fetch
|
||||
|
||||
const mockContext = createToolExecutionContext()
|
||||
await executeTool('test_no_hosting', {}, false, mockContext)
|
||||
|
||||
// BYOK should not be called since there's no hosting config
|
||||
expect(mockGetBYOKKey).not.toHaveBeenCalled()
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
|
||||
it('should check BYOK key first when tool has hosting config', async () => {
|
||||
// Note: isHosted is mocked to false by default, so hosted key injection won't happen
|
||||
// This test verifies the flow when isHosted would be true
|
||||
const mockTool = {
|
||||
id: 'test_with_hosting',
|
||||
name: 'Test With Hosting',
|
||||
description: 'A test tool with hosting config',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: true },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_API',
|
||||
apiKeyParam: 'apiKey',
|
||||
byokProviderId: 'exa',
|
||||
pricing: {
|
||||
type: 'per_request' as const,
|
||||
cost: 0.005,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/endpoint',
|
||||
method: 'POST' as const,
|
||||
headers: (params: any) => ({
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': params.apiKey,
|
||||
}),
|
||||
},
|
||||
transformResponse: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
output: { result: 'success' },
|
||||
}),
|
||||
}
|
||||
|
||||
const originalTools = { ...tools }
|
||||
;(tools as any).test_with_hosting = mockTool
|
||||
|
||||
// Mock BYOK returning a key
|
||||
mockGetBYOKKey.mockResolvedValue({ apiKey: 'byok-test-key', isBYOK: true })
|
||||
|
||||
global.fetch = Object.assign(
|
||||
vi.fn().mockImplementation(async () => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})),
|
||||
{ preconnect: vi.fn() }
|
||||
) as typeof fetch
|
||||
|
||||
const mockContext = createToolExecutionContext()
|
||||
await executeTool('test_with_hosting', {}, false, mockContext)
|
||||
|
||||
// With isHosted=false, BYOK won't be called - this is expected behavior
|
||||
// The test documents the current behavior
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
|
||||
it('should use per_request pricing model correctly', async () => {
|
||||
const mockTool = {
|
||||
id: 'test_per_request_pricing',
|
||||
name: 'Test Per Request Pricing',
|
||||
description: 'A test tool with per_request pricing',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: true },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_API',
|
||||
apiKeyParam: 'apiKey',
|
||||
byokProviderId: 'exa',
|
||||
pricing: {
|
||||
type: 'per_request' as const,
|
||||
cost: 0.005,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/endpoint',
|
||||
method: 'POST' as const,
|
||||
headers: (params: any) => ({
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': params.apiKey,
|
||||
}),
|
||||
},
|
||||
transformResponse: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
output: { result: 'success' },
|
||||
}),
|
||||
}
|
||||
|
||||
// Verify pricing config structure
|
||||
expect(mockTool.hosting.pricing.type).toBe('per_request')
|
||||
expect(mockTool.hosting.pricing.cost).toBe(0.005)
|
||||
})
|
||||
|
||||
it('should use custom pricing model correctly', async () => {
|
||||
const mockGetCost = vi.fn().mockReturnValue({ cost: 0.01, metadata: { breakdown: 'test' } })
|
||||
|
||||
const mockTool = {
|
||||
id: 'test_custom_pricing',
|
||||
name: 'Test Custom Pricing',
|
||||
description: 'A test tool with custom pricing',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: true },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_API',
|
||||
apiKeyParam: 'apiKey',
|
||||
byokProviderId: 'exa',
|
||||
pricing: {
|
||||
type: 'custom' as const,
|
||||
getCost: mockGetCost,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/endpoint',
|
||||
method: 'POST' as const,
|
||||
headers: (params: any) => ({
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': params.apiKey,
|
||||
}),
|
||||
},
|
||||
transformResponse: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
output: { result: 'success', costDollars: { total: 0.01 } },
|
||||
}),
|
||||
}
|
||||
|
||||
// Verify pricing config structure
|
||||
expect(mockTool.hosting.pricing.type).toBe('custom')
|
||||
expect(typeof mockTool.hosting.pricing.getCost).toBe('function')
|
||||
|
||||
// Test getCost returns expected value
|
||||
const result = mockTool.hosting.pricing.getCost({}, { costDollars: { total: 0.01 } })
|
||||
expect(result).toEqual({ cost: 0.01, metadata: { breakdown: 'test' } })
|
||||
})
|
||||
|
||||
it('should handle custom pricing returning a number', async () => {
|
||||
const mockGetCost = vi.fn().mockReturnValue(0.005)
|
||||
|
||||
const mockTool = {
|
||||
id: 'test_custom_pricing_number',
|
||||
name: 'Test Custom Pricing Number',
|
||||
description: 'A test tool with custom pricing returning number',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: true },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_API',
|
||||
apiKeyParam: 'apiKey',
|
||||
byokProviderId: 'exa',
|
||||
pricing: {
|
||||
type: 'custom' as const,
|
||||
getCost: mockGetCost,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/endpoint',
|
||||
method: 'POST' as const,
|
||||
headers: (params: any) => ({
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': params.apiKey,
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
// Test getCost returns a number
|
||||
const result = mockTool.hosting.pricing.getCost({}, {})
|
||||
expect(result).toBe(0.005)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Rate Limiting and Retry Logic', () => {
|
||||
let cleanupEnvVars: () => void
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000'
|
||||
cleanupEnvVars = setupEnvVars({
|
||||
NEXT_PUBLIC_APP_URL: 'http://localhost:3000',
|
||||
})
|
||||
vi.clearAllMocks()
|
||||
mockIsHosted.value = true
|
||||
mockEnv.TEST_HOSTED_KEY = 'test-hosted-api-key'
|
||||
mockGetBYOKKey.mockResolvedValue(null)
|
||||
// Set up throttler mock defaults
|
||||
mockRateLimiterFns.acquireKey.mockResolvedValue({
|
||||
success: true,
|
||||
key: 'mock-hosted-key',
|
||||
keyIndex: 0,
|
||||
envVarName: 'TEST_HOSTED_KEY',
|
||||
})
|
||||
mockRateLimiterFns.preConsumeCapacity.mockResolvedValue(true)
|
||||
mockRateLimiterFns.consumeCapacity.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
cleanupEnvVars()
|
||||
mockIsHosted.value = false
|
||||
mockEnv.TEST_HOSTED_KEY = undefined
|
||||
})
|
||||
|
||||
it('should retry on 429 rate limit errors with exponential backoff', async () => {
|
||||
let attemptCount = 0
|
||||
|
||||
const mockTool = {
|
||||
id: 'test_rate_limit',
|
||||
name: 'Test Rate Limit',
|
||||
description: 'A test tool for rate limiting',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: false },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_HOSTED_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
pricing: {
|
||||
type: 'per_request' as const,
|
||||
cost: 0.001,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/rate-limit',
|
||||
method: 'POST' as const,
|
||||
headers: () => ({ 'Content-Type': 'application/json' }),
|
||||
},
|
||||
transformResponse: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
output: { result: 'success' },
|
||||
}),
|
||||
}
|
||||
|
||||
const originalTools = { ...tools }
|
||||
;(tools as any).test_rate_limit = mockTool
|
||||
|
||||
global.fetch = Object.assign(
|
||||
vi.fn().mockImplementation(async () => {
|
||||
attemptCount++
|
||||
if (attemptCount < 3) {
|
||||
// Return a proper 429 response - the code extracts error, attaches status, and throws
|
||||
return {
|
||||
ok: false,
|
||||
status: 429,
|
||||
statusText: 'Too Many Requests',
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ error: 'Rate limited' }),
|
||||
text: () => Promise.resolve('Rate limited'),
|
||||
}
|
||||
}
|
||||
return {
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
}
|
||||
}),
|
||||
{ preconnect: vi.fn() }
|
||||
) as typeof fetch
|
||||
|
||||
const mockContext = createToolExecutionContext()
|
||||
const result = await executeTool('test_rate_limit', {}, false, mockContext)
|
||||
|
||||
// Should succeed after retries
|
||||
expect(result.success).toBe(true)
|
||||
// Should have made 3 attempts (2 failures + 1 success)
|
||||
expect(attemptCount).toBe(3)
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
|
||||
it('should fail after max retries on persistent rate limiting', async () => {
|
||||
const mockTool = {
|
||||
id: 'test_persistent_rate_limit',
|
||||
name: 'Test Persistent Rate Limit',
|
||||
description: 'A test tool for persistent rate limiting',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: false },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_HOSTED_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
pricing: {
|
||||
type: 'per_request' as const,
|
||||
cost: 0.001,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/persistent-rate-limit',
|
||||
method: 'POST' as const,
|
||||
headers: () => ({ 'Content-Type': 'application/json' }),
|
||||
},
|
||||
}
|
||||
|
||||
const originalTools = { ...tools }
|
||||
;(tools as any).test_persistent_rate_limit = mockTool
|
||||
|
||||
global.fetch = Object.assign(
|
||||
vi.fn().mockImplementation(async () => {
|
||||
// Always return 429 to test max retries exhaustion
|
||||
return {
|
||||
ok: false,
|
||||
status: 429,
|
||||
statusText: 'Too Many Requests',
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ error: 'Rate limited' }),
|
||||
text: () => Promise.resolve('Rate limited'),
|
||||
}
|
||||
}),
|
||||
{ preconnect: vi.fn() }
|
||||
) as typeof fetch
|
||||
|
||||
const mockContext = createToolExecutionContext()
|
||||
const result = await executeTool('test_persistent_rate_limit', {}, false, mockContext)
|
||||
|
||||
// Should fail after all retries exhausted
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.error).toContain('Rate limited')
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
|
||||
it('should not retry on non-rate-limit errors', async () => {
|
||||
let attemptCount = 0
|
||||
|
||||
const mockTool = {
|
||||
id: 'test_no_retry',
|
||||
name: 'Test No Retry',
|
||||
description: 'A test tool that should not retry',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: false },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_HOSTED_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
pricing: {
|
||||
type: 'per_request' as const,
|
||||
cost: 0.001,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/no-retry',
|
||||
method: 'POST' as const,
|
||||
headers: () => ({ 'Content-Type': 'application/json' }),
|
||||
},
|
||||
}
|
||||
|
||||
const originalTools = { ...tools }
|
||||
;(tools as any).test_no_retry = mockTool
|
||||
|
||||
global.fetch = Object.assign(
|
||||
vi.fn().mockImplementation(async () => {
|
||||
attemptCount++
|
||||
// Return a 400 response - should not trigger retry logic
|
||||
return {
|
||||
ok: false,
|
||||
status: 400,
|
||||
statusText: 'Bad Request',
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ error: 'Bad request' }),
|
||||
text: () => Promise.resolve('Bad request'),
|
||||
}
|
||||
}),
|
||||
{ preconnect: vi.fn() }
|
||||
) as typeof fetch
|
||||
|
||||
const mockContext = createToolExecutionContext()
|
||||
const result = await executeTool('test_no_retry', {}, false, mockContext)
|
||||
|
||||
// Should fail immediately without retries
|
||||
expect(result.success).toBe(false)
|
||||
expect(attemptCount).toBe(1)
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Cost Field Handling', () => {
|
||||
let cleanupEnvVars: () => void
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000'
|
||||
cleanupEnvVars = setupEnvVars({
|
||||
NEXT_PUBLIC_APP_URL: 'http://localhost:3000',
|
||||
})
|
||||
vi.clearAllMocks()
|
||||
mockIsHosted.value = true
|
||||
mockEnv.TEST_HOSTED_KEY = 'test-hosted-api-key'
|
||||
mockGetBYOKKey.mockResolvedValue(null)
|
||||
mockLogFixedUsage.mockResolvedValue(undefined)
|
||||
// Set up throttler mock defaults
|
||||
mockRateLimiterFns.acquireKey.mockResolvedValue({
|
||||
success: true,
|
||||
key: 'mock-hosted-key',
|
||||
keyIndex: 0,
|
||||
envVarName: 'TEST_HOSTED_KEY',
|
||||
})
|
||||
mockRateLimiterFns.preConsumeCapacity.mockResolvedValue(true)
|
||||
mockRateLimiterFns.consumeCapacity.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
cleanupEnvVars()
|
||||
mockIsHosted.value = false
|
||||
mockEnv.TEST_HOSTED_KEY = undefined
|
||||
})
|
||||
|
||||
it('should add cost to output when using hosted key with per_request pricing', async () => {
|
||||
const mockTool = {
|
||||
id: 'test_cost_per_request',
|
||||
name: 'Test Cost Per Request',
|
||||
description: 'A test tool with per_request pricing',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: false },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_HOSTED_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
pricing: {
|
||||
type: 'per_request' as const,
|
||||
cost: 0.005,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/cost',
|
||||
method: 'POST' as const,
|
||||
headers: () => ({ 'Content-Type': 'application/json' }),
|
||||
},
|
||||
transformResponse: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
output: { result: 'success' },
|
||||
}),
|
||||
}
|
||||
|
||||
const originalTools = { ...tools }
|
||||
;(tools as any).test_cost_per_request = mockTool
|
||||
|
||||
global.fetch = Object.assign(
|
||||
vi.fn().mockImplementation(async () => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})),
|
||||
{ preconnect: vi.fn() }
|
||||
) as typeof fetch
|
||||
|
||||
const mockContext = createToolExecutionContext({
|
||||
userId: 'user-123',
|
||||
} as any)
|
||||
const result = await executeTool('test_cost_per_request', {}, false, mockContext)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
// Note: In test environment, hosted key injection may not work due to env mocking complexity.
|
||||
// The cost calculation logic is tested via the pricing model tests above.
|
||||
// This test verifies the tool execution flow when hosted key IS available (by checking output structure).
|
||||
if (result.output.cost) {
|
||||
expect(result.output.cost.total).toBe(0.005)
|
||||
// Should have logged usage
|
||||
expect(mockLogFixedUsage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
userId: 'user-123',
|
||||
cost: 0.005,
|
||||
description: 'tool:test_cost_per_request',
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
|
||||
it('should not add cost when not using hosted key', async () => {
|
||||
mockIsHosted.value = false
|
||||
|
||||
const mockTool = {
|
||||
id: 'test_no_hosted_cost',
|
||||
name: 'Test No Hosted Cost',
|
||||
description: 'A test tool without hosted key',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: true },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_HOSTED_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
pricing: {
|
||||
type: 'per_request' as const,
|
||||
cost: 0.005,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/no-hosted',
|
||||
method: 'POST' as const,
|
||||
headers: () => ({ 'Content-Type': 'application/json' }),
|
||||
},
|
||||
transformResponse: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
output: { result: 'success' },
|
||||
}),
|
||||
}
|
||||
|
||||
const originalTools = { ...tools }
|
||||
;(tools as any).test_no_hosted_cost = mockTool
|
||||
|
||||
global.fetch = Object.assign(
|
||||
vi.fn().mockImplementation(async () => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})),
|
||||
{ preconnect: vi.fn() }
|
||||
) as typeof fetch
|
||||
|
||||
const mockContext = createToolExecutionContext()
|
||||
// Pass user's own API key
|
||||
const result = await executeTool(
|
||||
'test_no_hosted_cost',
|
||||
{ apiKey: 'user-api-key' },
|
||||
false,
|
||||
mockContext
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
// Should not have cost since user provided their own key
|
||||
expect(result.output.cost).toBeUndefined()
|
||||
// Should not have logged usage
|
||||
expect(mockLogFixedUsage).not.toHaveBeenCalled()
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
|
||||
it('should use custom pricing getCost function', async () => {
|
||||
const mockGetCost = vi.fn().mockReturnValue({
|
||||
cost: 0.015,
|
||||
metadata: { mode: 'advanced', results: 10 },
|
||||
})
|
||||
|
||||
const mockTool = {
|
||||
id: 'test_custom_pricing_cost',
|
||||
name: 'Test Custom Pricing Cost',
|
||||
description: 'A test tool with custom pricing',
|
||||
version: '1.0.0',
|
||||
params: {
|
||||
apiKey: { type: 'string', required: false },
|
||||
mode: { type: 'string', required: false },
|
||||
},
|
||||
hosting: {
|
||||
envKeyPrefix: 'TEST_HOSTED_KEY',
|
||||
apiKeyParam: 'apiKey',
|
||||
pricing: {
|
||||
type: 'custom' as const,
|
||||
getCost: mockGetCost,
|
||||
},
|
||||
rateLimit: {
|
||||
mode: 'per_request' as const,
|
||||
requestsPerMinute: 100,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
url: '/api/test/custom-pricing',
|
||||
method: 'POST' as const,
|
||||
headers: () => ({ 'Content-Type': 'application/json' }),
|
||||
},
|
||||
transformResponse: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
output: { result: 'success', results: 10 },
|
||||
}),
|
||||
}
|
||||
|
||||
const originalTools = { ...tools }
|
||||
;(tools as any).test_custom_pricing_cost = mockTool
|
||||
|
||||
global.fetch = Object.assign(
|
||||
vi.fn().mockImplementation(async () => ({
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers(),
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})),
|
||||
{ preconnect: vi.fn() }
|
||||
) as typeof fetch
|
||||
|
||||
const mockContext = createToolExecutionContext({
|
||||
userId: 'user-123',
|
||||
} as any)
|
||||
const result = await executeTool(
|
||||
'test_custom_pricing_cost',
|
||||
{ mode: 'advanced' },
|
||||
false,
|
||||
mockContext
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output.cost).toBeDefined()
|
||||
expect(result.output.cost.total).toBe(0.015)
|
||||
|
||||
// getCost should have been called with params and output
|
||||
expect(mockGetCost).toHaveBeenCalled()
|
||||
|
||||
// Should have logged usage with metadata
|
||||
expect(mockLogFixedUsage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
cost: 0.015,
|
||||
metadata: { mode: 'advanced', results: 10 },
|
||||
})
|
||||
)
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { getBYOKKey } from '@/lib/api-key/byok'
|
||||
import { generateInternalToken } from '@/lib/auth/internal'
|
||||
import { logFixedUsage } from '@/lib/billing/core/usage-log'
|
||||
import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { DEFAULT_EXECUTION_TIMEOUT_MS } from '@/lib/core/execution-limits'
|
||||
import { getHostedKeyRateLimiter } from '@/lib/core/rate-limiter'
|
||||
import {
|
||||
secureFetchWithPinnedIP,
|
||||
validateUrlWithDNS,
|
||||
} from '@/lib/core/security/input-validation.server'
|
||||
import { PlatformEvents } from '@/lib/core/telemetry'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { getBaseUrl, getInternalApiBaseUrl } from '@/lib/core/utils/urls'
|
||||
import { SIM_VIA_HEADER, serializeCallChain } from '@/lib/execution/call-chain'
|
||||
@@ -14,7 +19,14 @@ import { resolveSkillContent } from '@/executor/handlers/agent/skills-resolver'
|
||||
import type { ExecutionContext } from '@/executor/types'
|
||||
import type { ErrorInfo } from '@/tools/error-extractors'
|
||||
import { extractErrorMessage } from '@/tools/error-extractors'
|
||||
import type { OAuthTokenPayload, ToolConfig, ToolResponse, ToolRetryConfig } from '@/tools/types'
|
||||
import type {
|
||||
BYOKProviderId,
|
||||
OAuthTokenPayload,
|
||||
ToolConfig,
|
||||
ToolHostingPricing,
|
||||
ToolResponse,
|
||||
ToolRetryConfig,
|
||||
} from '@/tools/types'
|
||||
import {
|
||||
formatRequestParams,
|
||||
getTool,
|
||||
@@ -24,6 +36,365 @@ import {
|
||||
|
||||
const logger = createLogger('Tools')
|
||||
|
||||
/** Result from hosted key injection */
|
||||
interface HostedKeyInjectionResult {
|
||||
isUsingHostedKey: boolean
|
||||
envVarName?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Inject hosted API key if tool supports it and user didn't provide one.
|
||||
* Checks BYOK workspace keys first, then uses the HostedKeyRateLimiter for round-robin key selection.
|
||||
* Returns whether a hosted (billable) key was injected and which env var it came from.
|
||||
*/
|
||||
async function injectHostedKeyIfNeeded(
|
||||
tool: ToolConfig,
|
||||
params: Record<string, unknown>,
|
||||
executionContext: ExecutionContext | undefined,
|
||||
requestId: string
|
||||
): Promise<HostedKeyInjectionResult> {
|
||||
if (!tool.hosting) return { isUsingHostedKey: false }
|
||||
if (!isHosted) return { isUsingHostedKey: false }
|
||||
|
||||
const { envKeyPrefix, apiKeyParam, byokProviderId, rateLimit } = tool.hosting
|
||||
|
||||
// Derive workspace/user/workflow IDs from executionContext or params._context
|
||||
const ctx = params._context as Record<string, unknown> | undefined
|
||||
const workspaceId = executionContext?.workspaceId || (ctx?.workspaceId as string | undefined)
|
||||
const userId = executionContext?.userId || (ctx?.userId as string | undefined)
|
||||
const workflowId = executionContext?.workflowId || (ctx?.workflowId as string | undefined)
|
||||
|
||||
// Check BYOK workspace key first
|
||||
if (byokProviderId && workspaceId) {
|
||||
try {
|
||||
const byokResult = await getBYOKKey(workspaceId, byokProviderId as BYOKProviderId)
|
||||
if (byokResult) {
|
||||
params[apiKeyParam] = byokResult.apiKey
|
||||
logger.info(`[${requestId}] Using BYOK key for ${tool.id}`)
|
||||
return { isUsingHostedKey: false } // Don't bill - user's own key
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to get BYOK key for ${tool.id}:`, error)
|
||||
// Fall through to hosted key
|
||||
}
|
||||
}
|
||||
|
||||
const rateLimiter = getHostedKeyRateLimiter()
|
||||
const provider = byokProviderId || tool.id
|
||||
const billingActorId = workspaceId
|
||||
|
||||
if (!billingActorId) {
|
||||
logger.error(`[${requestId}] No workspace ID available for hosted key rate limiting`)
|
||||
return { isUsingHostedKey: false }
|
||||
}
|
||||
|
||||
const acquireResult = await rateLimiter.acquireKey(
|
||||
provider,
|
||||
envKeyPrefix,
|
||||
rateLimit,
|
||||
billingActorId
|
||||
)
|
||||
|
||||
if (!acquireResult.success && acquireResult.billingActorRateLimited) {
|
||||
logger.warn(`[${requestId}] Billing actor ${billingActorId} rate limited for ${tool.id}`, {
|
||||
provider,
|
||||
retryAfterMs: acquireResult.retryAfterMs,
|
||||
})
|
||||
|
||||
PlatformEvents.userThrottled({
|
||||
toolId: tool.id,
|
||||
reason: 'billing_actor_limit',
|
||||
provider,
|
||||
retryAfterMs: acquireResult.retryAfterMs ?? 0,
|
||||
userId,
|
||||
workspaceId,
|
||||
workflowId,
|
||||
})
|
||||
|
||||
const error = new Error(acquireResult.error || `Rate limit exceeded for ${tool.id}`)
|
||||
;(error as any).status = 429
|
||||
;(error as any).retryAfterMs = acquireResult.retryAfterMs
|
||||
throw error
|
||||
}
|
||||
|
||||
// Handle no keys configured (503)
|
||||
if (!acquireResult.success) {
|
||||
logger.error(`[${requestId}] No hosted keys configured for ${tool.id}: ${acquireResult.error}`)
|
||||
const error = new Error(acquireResult.error || `No hosted keys configured for ${tool.id}`)
|
||||
;(error as any).status = 503
|
||||
throw error
|
||||
}
|
||||
|
||||
params[apiKeyParam] = acquireResult.key
|
||||
logger.info(`[${requestId}] Using hosted key for ${tool.id} (${acquireResult.envVarName})`, {
|
||||
keyIndex: acquireResult.keyIndex,
|
||||
provider,
|
||||
})
|
||||
|
||||
return {
|
||||
isUsingHostedKey: true,
|
||||
envVarName: acquireResult.envVarName,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an error is a rate limit (throttling) error
|
||||
*/
|
||||
function isRateLimitError(error: unknown): boolean {
|
||||
if (error && typeof error === 'object') {
|
||||
const status = (error as { status?: number }).status
|
||||
// 429 = Too Many Requests, 503 = Service Unavailable (sometimes used for rate limiting)
|
||||
if (status === 429 || status === 503) return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/** Context for retry with rate limit tracking */
|
||||
interface RetryContext {
|
||||
requestId: string
|
||||
toolId: string
|
||||
envVarName: string
|
||||
executionContext?: ExecutionContext
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a function with exponential backoff retry for rate limiting errors.
|
||||
* Only used for hosted key requests. Tracks rate limit events via telemetry.
|
||||
*/
|
||||
async function executeWithRetry<T>(
|
||||
fn: () => Promise<T>,
|
||||
context: RetryContext,
|
||||
maxRetries = 3,
|
||||
baseDelayMs = 1000
|
||||
): Promise<T> {
|
||||
const { requestId, toolId, envVarName, executionContext } = context
|
||||
let lastError: unknown
|
||||
|
||||
for (let attempt = 0; attempt <= maxRetries; attempt++) {
|
||||
try {
|
||||
return await fn()
|
||||
} catch (error) {
|
||||
lastError = error
|
||||
|
||||
if (!isRateLimitError(error) || attempt === maxRetries) {
|
||||
if (isRateLimitError(error) && attempt === maxRetries) {
|
||||
PlatformEvents.userThrottled({
|
||||
toolId,
|
||||
reason: 'upstream_retries_exhausted',
|
||||
userId: executionContext?.userId,
|
||||
workspaceId: executionContext?.workspaceId,
|
||||
workflowId: executionContext?.workflowId,
|
||||
})
|
||||
}
|
||||
throw error
|
||||
}
|
||||
|
||||
const delayMs = baseDelayMs * 2 ** attempt
|
||||
|
||||
// Track throttling event via telemetry
|
||||
PlatformEvents.hostedKeyRateLimited({
|
||||
toolId,
|
||||
envVarName,
|
||||
attempt: attempt + 1,
|
||||
maxRetries,
|
||||
delayMs,
|
||||
userId: executionContext?.userId,
|
||||
workspaceId: executionContext?.workspaceId,
|
||||
workflowId: executionContext?.workflowId,
|
||||
})
|
||||
|
||||
logger.warn(
|
||||
`[${requestId}] Rate limited for ${toolId} (${envVarName}), retrying in ${delayMs}ms (attempt ${attempt + 1}/${maxRetries})`
|
||||
)
|
||||
await new Promise((resolve) => setTimeout(resolve, delayMs))
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError
|
||||
}
|
||||
|
||||
/** Result from cost calculation */
|
||||
interface ToolCostResult {
|
||||
cost: number
|
||||
metadata?: Record<string, unknown>
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate cost based on pricing model
|
||||
*/
|
||||
function calculateToolCost(
|
||||
pricing: ToolHostingPricing,
|
||||
params: Record<string, unknown>,
|
||||
response: Record<string, unknown>
|
||||
): ToolCostResult {
|
||||
switch (pricing.type) {
|
||||
case 'per_request':
|
||||
return { cost: pricing.cost }
|
||||
|
||||
case 'custom': {
|
||||
const result = pricing.getCost(params, response)
|
||||
if (typeof result === 'number') {
|
||||
return { cost: result }
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
default: {
|
||||
const exhaustiveCheck: never = pricing
|
||||
throw new Error(`Unknown pricing type: ${(exhaustiveCheck as ToolHostingPricing).type}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface HostedKeyCostResult {
|
||||
cost: number
|
||||
metadata?: Record<string, unknown>
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate and log hosted key cost for a tool execution.
|
||||
* Logs to usageLog for audit trail and returns cost + metadata for output.
|
||||
*/
|
||||
async function processHostedKeyCost(
|
||||
tool: ToolConfig,
|
||||
params: Record<string, unknown>,
|
||||
response: Record<string, unknown>,
|
||||
executionContext: ExecutionContext | undefined,
|
||||
requestId: string
|
||||
): Promise<HostedKeyCostResult> {
|
||||
if (!tool.hosting?.pricing) {
|
||||
return { cost: 0 }
|
||||
}
|
||||
|
||||
const { cost, metadata } = calculateToolCost(tool.hosting.pricing, params, response)
|
||||
|
||||
if (cost <= 0) return { cost: 0 }
|
||||
|
||||
const ctx = params._context as Record<string, unknown> | undefined
|
||||
const userId = executionContext?.userId || (ctx?.userId as string | undefined)
|
||||
const wsId = executionContext?.workspaceId || (ctx?.workspaceId as string | undefined)
|
||||
const wfId = executionContext?.workflowId || (ctx?.workflowId as string | undefined)
|
||||
|
||||
if (!userId) return { cost, metadata }
|
||||
|
||||
const skipLog = !!ctx?.skipFixedUsageLog
|
||||
if (!skipLog) {
|
||||
try {
|
||||
await logFixedUsage({
|
||||
userId,
|
||||
source: 'workflow',
|
||||
description: `tool:${tool.id}`,
|
||||
cost,
|
||||
workspaceId: wsId,
|
||||
workflowId: wfId,
|
||||
executionId: executionContext?.executionId,
|
||||
metadata,
|
||||
})
|
||||
logger.debug(
|
||||
`[${requestId}] Logged hosted key cost for ${tool.id}: $${cost}`,
|
||||
metadata ? { metadata } : {}
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to log hosted key usage for ${tool.id}:`, error)
|
||||
}
|
||||
} else {
|
||||
logger.debug(
|
||||
`[${requestId}] Skipping fixed usage log for ${tool.id} (cost will be tracked via provider tool loop)`
|
||||
)
|
||||
}
|
||||
|
||||
return { cost, metadata }
|
||||
}
|
||||
|
||||
/**
|
||||
* Report custom dimension usage after successful hosted-key tool execution.
|
||||
* Only applies to tools with `custom` rate limit mode. Fires and logs;
|
||||
* failures here do not block the response since execution already succeeded.
|
||||
*/
|
||||
async function reportCustomDimensionUsage(
|
||||
tool: ToolConfig,
|
||||
params: Record<string, unknown>,
|
||||
response: Record<string, unknown>,
|
||||
executionContext: ExecutionContext | undefined,
|
||||
requestId: string
|
||||
): Promise<void> {
|
||||
if (tool.hosting?.rateLimit.mode !== 'custom') return
|
||||
const ctx = params._context as Record<string, unknown> | undefined
|
||||
const billingActorId = executionContext?.workspaceId || (ctx?.workspaceId as string | undefined)
|
||||
if (!billingActorId) return
|
||||
|
||||
const rateLimiter = getHostedKeyRateLimiter()
|
||||
const provider = tool.hosting.byokProviderId || tool.id
|
||||
|
||||
try {
|
||||
const result = await rateLimiter.reportUsage(
|
||||
provider,
|
||||
billingActorId,
|
||||
tool.hosting.rateLimit,
|
||||
params,
|
||||
response
|
||||
)
|
||||
|
||||
for (const dim of result.dimensions) {
|
||||
if (!dim.allowed) {
|
||||
logger.warn(`[${requestId}] Dimension ${dim.name} overdrawn after ${tool.id} execution`, {
|
||||
consumed: dim.consumed,
|
||||
tokensRemaining: dim.tokensRemaining,
|
||||
})
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to report custom dimension usage for ${tool.id}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Strips internal fields (keys starting with `__`) from tool output before
|
||||
* returning to users. The double-underscore prefix is reserved for transient
|
||||
* data (e.g. `__costDollars`) and will never collide with legitimate API
|
||||
* fields like `_id`.
|
||||
*/
|
||||
function stripInternalFields(output: Record<string, unknown>): Record<string, unknown> {
|
||||
const result: Record<string, unknown> = {}
|
||||
for (const [key, value] of Object.entries(output)) {
|
||||
if (!key.startsWith('__')) {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply post-execution hosted-key cost tracking to a successful tool result.
|
||||
* Reports custom dimension usage, calculates cost, and merges it into the output.
|
||||
*/
|
||||
async function applyHostedKeyCostToResult(
|
||||
finalResult: ToolResponse,
|
||||
tool: ToolConfig,
|
||||
params: Record<string, unknown>,
|
||||
executionContext: ExecutionContext | undefined,
|
||||
requestId: string
|
||||
): Promise<void> {
|
||||
await reportCustomDimensionUsage(tool, params, finalResult.output, executionContext, requestId)
|
||||
|
||||
const { cost: hostedKeyCost, metadata } = await processHostedKeyCost(
|
||||
tool,
|
||||
params,
|
||||
finalResult.output,
|
||||
executionContext,
|
||||
requestId
|
||||
)
|
||||
if (hostedKeyCost > 0) {
|
||||
finalResult.output = {
|
||||
...finalResult.output,
|
||||
cost: {
|
||||
...metadata,
|
||||
total: hostedKeyCost,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes a tool ID by stripping resource ID suffix (UUID/tableId).
|
||||
* Workflow tools: 'workflow_executor_<uuid>' -> 'workflow_executor'
|
||||
@@ -299,6 +670,15 @@ export async function executeTool(
|
||||
throw new Error(`Tool not found: ${toolId}`)
|
||||
}
|
||||
|
||||
// Inject hosted API key if tool supports it and user didn't provide one
|
||||
const hostedKeyInfo = await injectHostedKeyIfNeeded(
|
||||
tool,
|
||||
contextParams,
|
||||
executionContext,
|
||||
requestId
|
||||
)
|
||||
|
||||
// If we have a credential parameter, fetch the access token
|
||||
if (contextParams.oauthCredential) {
|
||||
contextParams.credential = contextParams.oauthCredential
|
||||
}
|
||||
@@ -419,8 +799,22 @@ export async function executeTool(
|
||||
const endTime = new Date()
|
||||
const endTimeISO = endTime.toISOString()
|
||||
const duration = endTime.getTime() - startTime.getTime()
|
||||
|
||||
if (hostedKeyInfo.isUsingHostedKey && finalResult.success) {
|
||||
await applyHostedKeyCostToResult(
|
||||
finalResult,
|
||||
tool,
|
||||
contextParams,
|
||||
executionContext,
|
||||
requestId
|
||||
)
|
||||
}
|
||||
|
||||
const strippedOutput = stripInternalFields(finalResult.output || {})
|
||||
|
||||
return {
|
||||
...finalResult,
|
||||
output: strippedOutput,
|
||||
timing: {
|
||||
startTime: startTimeISO,
|
||||
endTime: endTimeISO,
|
||||
@@ -430,7 +824,15 @@ export async function executeTool(
|
||||
}
|
||||
|
||||
// Execute the tool request directly (internal routes use regular fetch, external use SSRF-protected fetch)
|
||||
const result = await executeToolRequest(toolId, tool, contextParams)
|
||||
// Wrap with retry logic for hosted keys to handle rate limiting due to higher usage
|
||||
const result = hostedKeyInfo.isUsingHostedKey
|
||||
? await executeWithRetry(() => executeToolRequest(toolId, tool, contextParams), {
|
||||
requestId,
|
||||
toolId,
|
||||
envVarName: hostedKeyInfo.envVarName!,
|
||||
executionContext,
|
||||
})
|
||||
: await executeToolRequest(toolId, tool, contextParams)
|
||||
|
||||
// Apply post-processing if available and not skipped
|
||||
let finalResult = result
|
||||
@@ -452,8 +854,22 @@ export async function executeTool(
|
||||
const endTime = new Date()
|
||||
const endTimeISO = endTime.toISOString()
|
||||
const duration = endTime.getTime() - startTime.getTime()
|
||||
|
||||
if (hostedKeyInfo.isUsingHostedKey && finalResult.success) {
|
||||
await applyHostedKeyCostToResult(
|
||||
finalResult,
|
||||
tool,
|
||||
contextParams,
|
||||
executionContext,
|
||||
requestId
|
||||
)
|
||||
}
|
||||
|
||||
const strippedOutput = stripInternalFields(finalResult.output || {})
|
||||
|
||||
return {
|
||||
...finalResult,
|
||||
output: strippedOutput,
|
||||
timing: {
|
||||
startTime: startTimeISO,
|
||||
endTime: endTimeISO,
|
||||
|
||||
202
apps/sim/tools/knowledge/knowledge.test.ts
Normal file
202
apps/sim/tools/knowledge/knowledge.test.ts
Normal file
@@ -0,0 +1,202 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*
|
||||
* Knowledge Tools Unit Tests
|
||||
*
|
||||
* Tests for knowledge_search and knowledge_upload_chunk tools,
|
||||
* specifically the cost restructuring in transformResponse.
|
||||
*/
|
||||
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { knowledgeSearchTool } from '@/tools/knowledge/search'
|
||||
import { knowledgeUploadChunkTool } from '@/tools/knowledge/upload_chunk'
|
||||
|
||||
/**
|
||||
* Creates a mock Response object for testing transformResponse
|
||||
*/
|
||||
function createMockResponse(data: unknown): Response {
|
||||
return {
|
||||
json: async () => data,
|
||||
ok: true,
|
||||
status: 200,
|
||||
} as Response
|
||||
}
|
||||
|
||||
describe('Knowledge Tools', () => {
|
||||
describe('knowledgeSearchTool', () => {
|
||||
describe('transformResponse', () => {
|
||||
it('should restructure cost information for logging', async () => {
|
||||
const apiResponse = {
|
||||
data: {
|
||||
results: [{ content: 'test result', similarity: 0.95 }],
|
||||
query: 'test query',
|
||||
totalResults: 1,
|
||||
cost: {
|
||||
input: 0.00001042,
|
||||
output: 0,
|
||||
total: 0.00001042,
|
||||
tokens: {
|
||||
prompt: 521,
|
||||
completion: 0,
|
||||
total: 521,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
pricing: {
|
||||
input: 0.02,
|
||||
output: 0,
|
||||
updatedAt: '2025-07-10',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const result = await knowledgeSearchTool.transformResponse!(createMockResponse(apiResponse))
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output).toEqual({
|
||||
results: [{ content: 'test result', similarity: 0.95 }],
|
||||
query: 'test query',
|
||||
totalResults: 1,
|
||||
cost: {
|
||||
input: 0.00001042,
|
||||
output: 0,
|
||||
total: 0.00001042,
|
||||
},
|
||||
tokens: {
|
||||
prompt: 521,
|
||||
completion: 0,
|
||||
total: 521,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle response without cost information', async () => {
|
||||
const apiResponse = {
|
||||
data: {
|
||||
results: [],
|
||||
query: 'test query',
|
||||
totalResults: 0,
|
||||
},
|
||||
}
|
||||
|
||||
const result = await knowledgeSearchTool.transformResponse!(createMockResponse(apiResponse))
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output).toEqual({
|
||||
results: [],
|
||||
query: 'test query',
|
||||
totalResults: 0,
|
||||
})
|
||||
expect(result.output.cost).toBeUndefined()
|
||||
expect(result.output.tokens).toBeUndefined()
|
||||
expect(result.output.model).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle response with partial cost information', async () => {
|
||||
const apiResponse = {
|
||||
data: {
|
||||
results: [],
|
||||
query: 'test query',
|
||||
totalResults: 0,
|
||||
cost: {
|
||||
input: 0.001,
|
||||
output: 0,
|
||||
total: 0.001,
|
||||
// No tokens or model
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const result = await knowledgeSearchTool.transformResponse!(createMockResponse(apiResponse))
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output.cost).toEqual({
|
||||
input: 0.001,
|
||||
output: 0,
|
||||
total: 0.001,
|
||||
})
|
||||
expect(result.output.tokens).toBeUndefined()
|
||||
expect(result.output.model).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('knowledgeUploadChunkTool', () => {
|
||||
describe('transformResponse', () => {
|
||||
it('should restructure cost information for logging', async () => {
|
||||
const apiResponse = {
|
||||
data: {
|
||||
id: 'chunk-123',
|
||||
chunkIndex: 0,
|
||||
content: 'test content',
|
||||
contentLength: 12,
|
||||
tokenCount: 3,
|
||||
enabled: true,
|
||||
documentId: 'doc-456',
|
||||
documentName: 'Test Document',
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
updatedAt: '2025-01-01T00:00:00Z',
|
||||
cost: {
|
||||
input: 0.00000521,
|
||||
output: 0,
|
||||
total: 0.00000521,
|
||||
tokens: {
|
||||
prompt: 260,
|
||||
completion: 0,
|
||||
total: 260,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
pricing: {
|
||||
input: 0.02,
|
||||
output: 0,
|
||||
updatedAt: '2025-07-10',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const result = await knowledgeUploadChunkTool.transformResponse!(
|
||||
createMockResponse(apiResponse)
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output.cost).toEqual({
|
||||
input: 0.00000521,
|
||||
output: 0,
|
||||
total: 0.00000521,
|
||||
})
|
||||
expect(result.output.tokens).toEqual({
|
||||
prompt: 260,
|
||||
completion: 0,
|
||||
total: 260,
|
||||
})
|
||||
expect(result.output.model).toBe('text-embedding-3-small')
|
||||
expect(result.output.data.chunkId).toBe('chunk-123')
|
||||
expect(result.output.documentId).toBe('doc-456')
|
||||
})
|
||||
|
||||
it('should handle response without cost information', async () => {
|
||||
const apiResponse = {
|
||||
data: {
|
||||
id: 'chunk-123',
|
||||
chunkIndex: 0,
|
||||
content: 'test content',
|
||||
documentId: 'doc-456',
|
||||
documentName: 'Test Document',
|
||||
},
|
||||
}
|
||||
|
||||
const result = await knowledgeUploadChunkTool.transformResponse!(
|
||||
createMockResponse(apiResponse)
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.output.cost).toBeUndefined()
|
||||
expect(result.output.tokens).toBeUndefined()
|
||||
expect(result.output.model).toBeUndefined()
|
||||
expect(result.output.data.chunkId).toBe('chunk-123')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -80,13 +80,24 @@ export const knowledgeSearchTool: ToolConfig<any, KnowledgeSearchResponse> = {
|
||||
const result = await response.json()
|
||||
const data = result.data || result
|
||||
|
||||
// Restructure cost: extract tokens/model to top level for logging
|
||||
let costFields: Record<string, unknown> = {}
|
||||
if (data.cost && typeof data.cost === 'object') {
|
||||
const { tokens, model, input, output: outputCost, total } = data.cost
|
||||
costFields = {
|
||||
cost: { input, output: outputCost, total },
|
||||
...(tokens && { tokens }),
|
||||
...(model && { model }),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
success: true,
|
||||
output: {
|
||||
results: data.results || [],
|
||||
query: data.query,
|
||||
totalResults: data.totalResults || 0,
|
||||
cost: data.cost,
|
||||
...costFields,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
@@ -52,6 +52,17 @@ export const knowledgeUploadChunkTool: ToolConfig<any, KnowledgeUploadChunkRespo
|
||||
const result = await response.json()
|
||||
const data = result.data || result
|
||||
|
||||
// Restructure cost: extract tokens/model to top level for logging
|
||||
let costFields: Record<string, unknown> = {}
|
||||
if (data.cost && typeof data.cost === 'object') {
|
||||
const { tokens, model, input, output: outputCost, total } = data.cost
|
||||
costFields = {
|
||||
cost: { input, output: outputCost, total },
|
||||
...(tokens && { tokens }),
|
||||
...(model && { model }),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
success: true,
|
||||
output: {
|
||||
@@ -68,7 +79,7 @@ export const knowledgeUploadChunkTool: ToolConfig<any, KnowledgeUploadChunkRespo
|
||||
},
|
||||
documentId: data.documentId,
|
||||
documentName: data.documentName,
|
||||
cost: data.cost,
|
||||
...costFields,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
type CanonicalModeOverrides,
|
||||
evaluateSubBlockCondition,
|
||||
isCanonicalPair,
|
||||
isSubBlockHiddenByHostedKey,
|
||||
resolveCanonicalMode,
|
||||
type SubBlockCondition,
|
||||
} from '@/lib/workflows/subblocks/visibility'
|
||||
@@ -319,6 +320,10 @@ export function getToolParametersConfig(
|
||||
)
|
||||
|
||||
if (subBlock) {
|
||||
if (isSubBlockHiddenByHostedKey(subBlock)) {
|
||||
toolParam.visibility = 'hidden'
|
||||
}
|
||||
|
||||
toolParam.uiComponent = {
|
||||
type: subBlock.type,
|
||||
options: subBlock.options as Option[] | undefined,
|
||||
@@ -933,6 +938,9 @@ export function getSubBlocksForToolInput(
|
||||
// Skip trigger-mode-only subblocks
|
||||
if (sb.mode === 'trigger') continue
|
||||
|
||||
// Hide tool API key fields when running on hosted Sim
|
||||
if (isSubBlockHiddenByHostedKey(sb)) continue
|
||||
|
||||
// Determine the effective param ID (canonical or subblock id)
|
||||
const effectiveParamId = sb.canonicalParamId || sb.id
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import type { HostedKeyRateLimitConfig } from '@/lib/core/rate-limiter'
|
||||
import type { OAuthService } from '@/lib/oauth'
|
||||
|
||||
export type BYOKProviderId = 'openai' | 'anthropic' | 'google' | 'mistral' | 'exa'
|
||||
|
||||
export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD'
|
||||
|
||||
/**
|
||||
@@ -147,12 +150,18 @@ export interface ToolConfig<P = any, R = any> {
|
||||
* Maps param IDs to their enrichment configuration.
|
||||
*/
|
||||
schemaEnrichment?: Record<string, SchemaEnrichmentConfig>
|
||||
|
||||
/**
|
||||
* Optional tool-level enrichment that modifies description and all parameters.
|
||||
* Use when multiple params depend on a single runtime value.
|
||||
*/
|
||||
toolEnrichment?: ToolEnrichmentConfig
|
||||
|
||||
/**
|
||||
* Hosted API key configuration for this tool.
|
||||
* When configured, the tool can use Sim's hosted API keys if user doesn't provide their own.
|
||||
* Usage is billed according to the pricing config.
|
||||
*/
|
||||
hosting?: ToolHostingConfig<P>
|
||||
}
|
||||
|
||||
export interface TableRow {
|
||||
@@ -222,3 +231,72 @@ export interface ToolEnrichmentConfig {
|
||||
}
|
||||
} | null>
|
||||
}
|
||||
|
||||
/**
|
||||
* Pricing models for hosted API key usage
|
||||
*/
|
||||
/** Flat fee per API call (e.g., Serper search) */
|
||||
export interface PerRequestPricing {
|
||||
type: 'per_request'
|
||||
/** Cost per request in dollars */
|
||||
cost: number
|
||||
}
|
||||
|
||||
/** Result from custom pricing calculation */
|
||||
export interface CustomPricingResult {
|
||||
/** Cost in dollars */
|
||||
cost: number
|
||||
/** Optional metadata about the cost calculation (e.g., breakdown from API) */
|
||||
metadata?: Record<string, unknown>
|
||||
}
|
||||
|
||||
/** Custom pricing calculated from params and response (e.g., Exa with different modes/result counts) */
|
||||
export interface CustomPricing<P = Record<string, unknown>> {
|
||||
type: 'custom'
|
||||
/** Calculate cost based on request params and response output. Fields starting with _ are internal. */
|
||||
getCost: (params: P, output: Record<string, unknown>) => number | CustomPricingResult
|
||||
}
|
||||
|
||||
/** Union of all pricing models */
|
||||
export type ToolHostingPricing<P = Record<string, unknown>> = PerRequestPricing | CustomPricing<P>
|
||||
|
||||
/**
|
||||
* Configuration for hosted API key support.
|
||||
* When configured, the tool can use Sim's hosted API keys if user doesn't provide their own.
|
||||
*
|
||||
* ### Hosted key env var convention
|
||||
*
|
||||
* Keys follow a numbered naming convention driven by a count env var:
|
||||
*
|
||||
* 1. Set `{envKeyPrefix}_COUNT` to the number of keys available.
|
||||
* 2. Provide each key as `{envKeyPrefix}_1`, `{envKeyPrefix}_2`, ..., `{envKeyPrefix}_N`.
|
||||
*
|
||||
* **Example** — for `envKeyPrefix: 'EXA_API_KEY'` with 5 keys:
|
||||
* ```
|
||||
* EXA_API_KEY_COUNT=5
|
||||
* EXA_API_KEY_1=sk-...
|
||||
* EXA_API_KEY_2=sk-...
|
||||
* EXA_API_KEY_3=sk-...
|
||||
* EXA_API_KEY_4=sk-...
|
||||
* EXA_API_KEY_5=sk-...
|
||||
* ```
|
||||
*
|
||||
* Adding more keys only requires updating the count and adding the new env var —
|
||||
* no code changes needed.
|
||||
*/
|
||||
export interface ToolHostingConfig<P = Record<string, unknown>> {
|
||||
/**
|
||||
* Env var name prefix for hosted keys.
|
||||
* At runtime, `{envKeyPrefix}_COUNT` is read to determine how many keys exist,
|
||||
* then `{envKeyPrefix}_1` through `{envKeyPrefix}_N` are resolved.
|
||||
*/
|
||||
envKeyPrefix: string
|
||||
/** The parameter name that receives the API key */
|
||||
apiKeyParam: string
|
||||
/** BYOK provider ID for workspace key lookup */
|
||||
byokProviderId?: BYOKProviderId
|
||||
/** Pricing when using hosted key */
|
||||
pricing: ToolHostingPricing<P>
|
||||
/** Hosted key rate limit configuration (required for hosted key distribution) */
|
||||
rateLimit: HostedKeyRateLimitConfig
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user