fix(billing): atomize usage_log and userStats writes via central recordUsage (#3767)

* fix(billing): atomize usage_log and userStats writes via central recordUsage()

* fix(billing): address PR review — re-throw errors, guard reserved keys, handle zero-cost counters

* chore(lint): fix formatting in hubspot list_lists.ts from staging

* fix(billing): tighten early-return guard to handle empty additionalStats object

* lint

* chore(billing): remove implementation-decision comments
This commit is contained in:
Waleed
2026-03-25 13:41:27 -07:00
committed by GitHub
parent 54a862d5b0
commit f94be08950
12 changed files with 205 additions and 376 deletions

View File

@@ -1,10 +1,8 @@
import { db } from '@sim/db'
import { userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { eq, sql } from 'drizzle-orm'
import { sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { logModelUsage } from '@/lib/billing/core/usage-log'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { checkInternalApiKey } from '@/lib/copilot/utils'
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
@@ -87,55 +85,41 @@ export async function POST(req: NextRequest) {
source,
})
// Check if user stats record exists (same as ExecutionLogger)
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))
if (userStatsRecords.length === 0) {
logger.error(
`[${requestId}] User stats record not found - should be created during onboarding`,
{
userId,
}
)
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
}
const totalTokens = inputTokens + outputTokens
const updateFields: Record<string, unknown> = {
totalCost: sql`total_cost + ${cost}`,
currentPeriodCost: sql`current_period_cost + ${cost}`,
const additionalStats: Record<string, ReturnType<typeof sql>> = {
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
currentPeriodCopilotCost: sql`current_period_copilot_cost + ${cost}`,
totalCopilotCalls: sql`total_copilot_calls + 1`,
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
lastActive: new Date(),
}
if (isMcp) {
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
updateFields.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
additionalStats.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
additionalStats.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
additionalStats.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
}
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
await recordUsage({
userId,
entries: [
{
category: 'model',
source,
description: model,
cost,
metadata: { inputTokens, outputTokens },
},
],
additionalStats,
})
logger.info(`[${requestId}] Updated user stats record`, {
logger.info(`[${requestId}] Recorded usage`, {
userId,
addedCost: cost,
source,
})
// Log usage for complete audit trail with the original source for visibility
await logModelUsage({
userId,
source,
model,
inputTokens,
outputTokens,
cost,
})
// Check if user has hit overage threshold and bill incrementally
await checkAndBillOverageThreshold(userId)

View File

@@ -1,11 +1,11 @@
import { db } from '@sim/db'
import { userStats, workflow } from '@sim/db/schema'
import { workflow } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getBYOKKey } from '@/lib/api-key/byok'
import { getSession } from '@/lib/auth'
import { logModelUsage } from '@/lib/billing/core/usage-log'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { env } from '@/lib/core/config/env'
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
@@ -134,23 +134,20 @@ async function updateUserStatsForWand(
costToStore = modelCost * costMultiplier
}
await db
.update(userStats)
.set({
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
totalCost: sql`total_cost + ${costToStore}`,
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
lastActive: new Date(),
})
.where(eq(userStats.userId, userId))
await logModelUsage({
await recordUsage({
userId,
source: 'wand',
model: modelName,
inputTokens: promptTokens,
outputTokens: completionTokens,
cost: costToStore,
entries: [
{
category: 'model',
source: 'wand',
description: modelName,
cost: costToStore,
metadata: { inputTokens: promptTokens, outputTokens: completionTokens },
},
],
additionalStats: {
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
},
})
await checkAndBillOverageThreshold(userId)
@@ -341,7 +338,7 @@ export async function POST(req: NextRequest) {
let finalUsage: any = null
let usageRecorded = false
const recordUsage = async () => {
const flushUsage = async () => {
if (usageRecorded || !finalUsage) {
return
}
@@ -360,7 +357,7 @@ export async function POST(req: NextRequest) {
if (done) {
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
await recordUsage()
await flushUsage()
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
controller.close()
break
@@ -390,7 +387,7 @@ export async function POST(req: NextRequest) {
if (data === '[DONE]') {
logger.info(`[${requestId}] Received [DONE] signal`)
await recordUsage()
await flushUsage()
controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
@@ -468,7 +465,7 @@ export async function POST(req: NextRequest) {
})
try {
await recordUsage()
await flushUsage()
} catch (usageError) {
logger.warn(`[${requestId}] Failed to record usage after stream error`, usageError)
}

View File

@@ -1,7 +1,7 @@
import { db } from '@sim/db'
import { usageLog } from '@sim/db/schema'
import { usageLog, userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, desc, eq, gte, lte, sql } from 'drizzle-orm'
import { and, desc, eq, gte, lte, type SQL, sql } from 'drizzle-orm'
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
const logger = createLogger('UsageLog')
@@ -32,224 +32,121 @@ export interface ModelUsageMetadata {
}
/**
* Metadata for 'fixed' category charges (e.g., tool cost breakdown)
* Union type for all usage log metadata types
*/
export type FixedUsageMetadata = Record<string, unknown>
export type UsageLogMetadata = ModelUsageMetadata | Record<string, unknown> | null
/**
* Union type for all metadata types
* A single usage entry to be recorded in the usage_log table.
*/
export type UsageLogMetadata = ModelUsageMetadata | FixedUsageMetadata | null
/**
* Parameters for logging model usage (token-based charges)
*/
export interface LogModelUsageParams {
userId: string
source: UsageLogSource
model: string
inputTokens: number
outputTokens: number
cost: number
toolCost?: number
workspaceId?: string
workflowId?: string
executionId?: string
}
/**
* Parameters for logging fixed charges (flat fees)
*/
export interface LogFixedUsageParams {
userId: string
export interface UsageEntry {
category: UsageLogCategory
source: UsageLogSource
description: string
cost: number
workspaceId?: string
workflowId?: string
executionId?: string
/** Optional metadata (e.g., tool cost breakdown from API) */
metadata?: FixedUsageMetadata
metadata?: UsageLogMetadata
}
/**
* Log a model usage charge (token-based)
* Parameters for the central recordUsage function.
* This is the single entry point for all billing mutations.
*/
export async function logModelUsage(params: LogModelUsageParams): Promise<void> {
if (!isBillingEnabled || params.cost <= 0) {
return
}
try {
const metadata: ModelUsageMetadata = {
inputTokens: params.inputTokens,
outputTokens: params.outputTokens,
...(params.toolCost != null && params.toolCost > 0 && { toolCost: params.toolCost }),
}
await db.insert(usageLog).values({
id: crypto.randomUUID(),
userId: params.userId,
category: 'model',
source: params.source,
description: params.model,
metadata,
cost: params.cost.toString(),
workspaceId: params.workspaceId ?? null,
workflowId: params.workflowId ?? null,
executionId: params.executionId ?? null,
})
logger.debug('Logged model usage', {
userId: params.userId,
source: params.source,
model: params.model,
cost: params.cost,
})
} catch (error) {
logger.error('Failed to log model usage', {
error: error instanceof Error ? error.message : String(error),
params,
})
// Don't throw - usage logging should not break the main flow
}
}
/**
* Log a fixed charge (flat fee like base execution charge or search)
*/
export async function logFixedUsage(params: LogFixedUsageParams): Promise<void> {
if (!isBillingEnabled || params.cost <= 0) {
return
}
try {
await db.insert(usageLog).values({
id: crypto.randomUUID(),
userId: params.userId,
category: 'fixed',
source: params.source,
description: params.description,
metadata: params.metadata ?? null,
cost: params.cost.toString(),
workspaceId: params.workspaceId ?? null,
workflowId: params.workflowId ?? null,
executionId: params.executionId ?? null,
})
logger.debug('Logged fixed usage', {
userId: params.userId,
source: params.source,
description: params.description,
cost: params.cost,
})
} catch (error) {
logger.error('Failed to log fixed usage', {
error: error instanceof Error ? error.message : String(error),
params,
})
// Don't throw - usage logging should not break the main flow
}
}
/**
* Parameters for batch logging workflow usage
*/
export interface LogWorkflowUsageBatchParams {
export interface RecordUsageParams {
/** The user being charged */
userId: string
/** One or more usage_log entries to record. Total cost is derived from these. */
entries: UsageEntry[]
/** Workspace context */
workspaceId?: string
workflowId: string
/** Workflow context */
workflowId?: string
/** Execution context */
executionId?: string
baseExecutionCharge?: number
models?: Record<
string,
{
total: number
tokens: { input: number; output: number }
toolCost?: number
}
>
/** Source-specific counter increments (e.g. totalCopilotCalls, totalManualExecutions) */
additionalStats?: Record<string, SQL>
}
/**
* Log all workflow usage entries in a single batch insert (performance optimized)
* Records usage in a single atomic transaction.
*
* Inserts all entries into usage_log and updates userStats counters
* (totalCost, currentPeriodCost, lastActive) within one Postgres transaction.
* The total cost added to userStats is derived from summing entry costs,
* ensuring usage_log and currentPeriodCost can never drift apart.
*
* If billing is disabled, total cost is zero, or no entries have positive cost,
* this function returns early without writing anything.
*/
export async function logWorkflowUsageBatch(params: LogWorkflowUsageBatchParams): Promise<void> {
export async function recordUsage(params: RecordUsageParams): Promise<void> {
if (!isBillingEnabled) {
return
}
const entries: Array<{
id: string
userId: string
category: 'model' | 'fixed'
source: 'workflow'
description: string
metadata: ModelUsageMetadata | null
cost: string
workspaceId: string | null
workflowId: string | null
executionId: string | null
}> = []
const { userId, entries, workspaceId, workflowId, executionId, additionalStats } = params
if (params.baseExecutionCharge && params.baseExecutionCharge > 0) {
entries.push({
id: crypto.randomUUID(),
userId: params.userId,
category: 'fixed',
source: 'workflow',
description: 'execution_fee',
metadata: null,
cost: params.baseExecutionCharge.toString(),
workspaceId: params.workspaceId ?? null,
workflowId: params.workflowId,
executionId: params.executionId ?? null,
})
}
const validEntries = entries.filter((e) => e.cost > 0)
const totalCost = validEntries.reduce((sum, e) => sum + e.cost, 0)
if (params.models) {
for (const [modelName, modelData] of Object.entries(params.models)) {
if (modelData.total > 0) {
entries.push({
id: crypto.randomUUID(),
userId: params.userId,
category: 'model',
source: 'workflow',
description: modelName,
metadata: {
inputTokens: modelData.tokens.input,
outputTokens: modelData.tokens.output,
...(modelData.toolCost != null &&
modelData.toolCost > 0 && { toolCost: modelData.toolCost }),
},
cost: modelData.total.toString(),
workspaceId: params.workspaceId ?? null,
workflowId: params.workflowId,
executionId: params.executionId ?? null,
})
}
}
}
if (entries.length === 0) {
if (
validEntries.length === 0 &&
(!additionalStats || Object.keys(additionalStats).length === 0)
) {
return
}
try {
await db.insert(usageLog).values(entries)
const RESERVED_KEYS = new Set(['totalCost', 'currentPeriodCost', 'lastActive'])
const safeStats = additionalStats
? Object.fromEntries(Object.entries(additionalStats).filter(([k]) => !RESERVED_KEYS.has(k)))
: undefined
logger.debug('Logged workflow usage batch', {
userId: params.userId,
workflowId: params.workflowId,
entryCount: entries.length,
})
} catch (error) {
logger.error('Failed to log workflow usage batch', {
error: error instanceof Error ? error.message : String(error),
params,
})
// Don't throw - usage logging should not break the main flow
}
await db.transaction(async (tx) => {
if (validEntries.length > 0) {
await tx.insert(usageLog).values(
validEntries.map((entry) => ({
id: crypto.randomUUID(),
userId,
category: entry.category,
source: entry.source,
description: entry.description,
metadata: entry.metadata ?? null,
cost: entry.cost.toString(),
workspaceId: workspaceId ?? null,
workflowId: workflowId ?? null,
executionId: executionId ?? null,
}))
)
}
const updateFields: Record<string, SQL | Date> = {
lastActive: new Date(),
...(totalCost > 0 && {
totalCost: sql`total_cost + ${totalCost}`,
currentPeriodCost: sql`current_period_cost + ${totalCost}`,
}),
...safeStats,
}
const result = await tx
.update(userStats)
.set(updateFields)
.where(eq(userStats.userId, userId))
.returning({ userId: userStats.userId })
if (result.length === 0) {
logger.warn('recordUsage: userStats row not found, transaction will roll back', {
userId,
totalCost,
})
throw new Error(`userStats row not found for userId: ${userId}`)
}
})
logger.debug('Recorded usage', {
userId,
totalCost,
entryCount: validEntries.length,
sources: [...new Set(validEntries.map((e) => e.source))],
})
}
/**

View File

@@ -29,7 +29,7 @@ vi.mock('@/lib/billing/core/usage', () => ({
}))
vi.mock('@/lib/billing/core/usage-log', () => ({
logWorkflowUsageBatch: vi.fn(() => Promise.resolve()),
recordUsage: vi.fn(() => Promise.resolve()),
}))
vi.mock('@/lib/billing/threshold-billing', () => ({

View File

@@ -16,7 +16,7 @@ import {
getOrgUsageLimit,
maybeSendUsageThresholdEmail,
} from '@/lib/billing/core/usage'
import { logWorkflowUsageBatch } from '@/lib/billing/core/usage-log'
import { type ModelUsageMetadata, recordUsage } from '@/lib/billing/core/usage-log'
import { isOrgPlan } from '@/lib/billing/plan-helpers'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
@@ -37,6 +37,17 @@ import type {
} from '@/lib/logs/types'
import type { SerializableExecutionState } from '@/executor/execution/types'
/** Maps execution trigger types to their corresponding userStats counter columns */
const TRIGGER_COUNTER_MAP: Record<string, { key: string; column: string }> = {
manual: { key: 'totalManualExecutions', column: 'total_manual_executions' },
api: { key: 'totalApiCalls', column: 'total_api_calls' },
webhook: { key: 'totalWebhookTriggers', column: 'total_webhook_triggers' },
schedule: { key: 'totalScheduledExecutions', column: 'total_scheduled_executions' },
chat: { key: 'totalChatExecutions', column: 'total_chat_executions' },
mcp: { key: 'totalMcpExecutions', column: 'total_mcp_executions' },
a2a: { key: 'totalA2aExecutions', column: 'total_a2a_executions' },
} as const
export interface ToolCall {
name: string
duration: number // in milliseconds
@@ -634,66 +645,58 @@ export class ExecutionLogger implements IExecutionLoggerService {
return
}
const costToStore = costSummary.totalCost
const entries: Array<{
category: 'model' | 'fixed'
source: 'workflow'
description: string
cost: number
metadata?: ModelUsageMetadata | null
}> = []
const existing = await db.select().from(userStats).where(eq(userStats.userId, userId))
if (existing.length === 0) {
logger.error('User stats record not found - should be created during onboarding', {
userId,
trigger,
if (costSummary.baseExecutionCharge > 0) {
entries.push({
category: 'fixed',
source: 'workflow',
description: 'execution_fee',
cost: costSummary.baseExecutionCharge,
})
return
}
// All costs go to currentPeriodCost - credits are applied at end of billing cycle
const updateFields: any = {
if (costSummary.models) {
for (const [modelName, modelData] of Object.entries(costSummary.models)) {
if (modelData.total > 0) {
entries.push({
category: 'model',
source: 'workflow',
description: modelName,
cost: modelData.total,
metadata: {
inputTokens: modelData.tokens.input,
outputTokens: modelData.tokens.output,
...(modelData.toolCost != null &&
modelData.toolCost > 0 && { toolCost: modelData.toolCost }),
},
})
}
}
}
const additionalStats: Record<string, ReturnType<typeof sql>> = {
totalTokensUsed: sql`total_tokens_used + ${costSummary.totalTokens}`,
totalCost: sql`total_cost + ${costToStore}`,
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
lastActive: new Date(),
}
switch (trigger) {
case 'manual':
updateFields.totalManualExecutions = sql`total_manual_executions + 1`
break
case 'api':
updateFields.totalApiCalls = sql`total_api_calls + 1`
break
case 'webhook':
updateFields.totalWebhookTriggers = sql`total_webhook_triggers + 1`
break
case 'schedule':
updateFields.totalScheduledExecutions = sql`total_scheduled_executions + 1`
break
case 'chat':
updateFields.totalChatExecutions = sql`total_chat_executions + 1`
break
case 'mcp':
updateFields.totalMcpExecutions = sql`total_mcp_executions + 1`
break
case 'a2a':
updateFields.totalA2aExecutions = sql`total_a2a_executions + 1`
break
const triggerCounter = TRIGGER_COUNTER_MAP[trigger]
if (triggerCounter) {
additionalStats[triggerCounter.key] = sql`${sql.raw(triggerCounter.column)} + 1`
}
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
logger.debug('Updated user stats record with cost data', {
userId,
trigger,
addedCost: costToStore,
addedTokens: costSummary.totalTokens,
})
// Log usage entries for auditing (batch insert for performance)
await logWorkflowUsageBatch({
await recordUsage({
userId,
entries,
workspaceId: workflowRecord.workspaceId ?? undefined,
workflowId,
executionId,
baseExecutionCharge: costSummary.baseExecutionCharge,
models: costSummary.models,
additionalStats,
})
// Check if user has hit overage threshold and bill incrementally

View File

@@ -1401,7 +1401,6 @@ describe('prepareToolExecution', () => {
workspaceId: 'ws-456',
chatId: 'chat-789',
userId: 'user-abc',
skipFixedUsageLog: true,
})
})

View File

@@ -1147,7 +1147,6 @@ export function prepareToolExecution(
? { isDeployedContext: request.isDeployedContext }
: {}),
...(request.callChain ? { callChain: request.callChain } : {}),
skipFixedUsageLog: true,
},
}
: {}),

View File

@@ -1,6 +1,10 @@
import { createLogger } from '@sim/logger'
import type { HubSpotListListsParams, HubSpotListListsResponse } from '@/tools/hubspot/types'
import { LISTS_ARRAY_OUTPUT, METADATA_OUTPUT_PROPERTIES, PAGING_OUTPUT } from '@/tools/hubspot/types'
import {
LISTS_ARRAY_OUTPUT,
METADATA_OUTPUT_PROPERTIES,
PAGING_OUTPUT,
} from '@/tools/hubspot/types'
import type { ToolConfig } from '@/tools/types'
const logger = createLogger('HubSpotListLists')
@@ -99,7 +103,11 @@ export const hubspotListListsTool: ToolConfig<HubSpotListListsParams, HubSpotLis
description: 'Response metadata',
properties: {
...METADATA_OUTPUT_PROPERTIES,
total: { type: 'number', description: 'Total number of lists matching the query', optional: true },
total: {
type: 'number',
description: 'Total number of lists matching the query',
optional: true,
},
},
},
success: { type: 'boolean', description: 'Operation success status' },

View File

@@ -16,19 +16,16 @@ import {
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
// Hoisted mock state - these are available to vi.mock factories
const { mockIsHosted, mockEnv, mockGetBYOKKey, mockLogFixedUsage, mockRateLimiterFns } = vi.hoisted(
() => ({
mockIsHosted: { value: false },
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
mockGetBYOKKey: vi.fn(),
mockLogFixedUsage: vi.fn(),
mockRateLimiterFns: {
acquireKey: vi.fn(),
preConsumeCapacity: vi.fn(),
consumeCapacity: vi.fn(),
},
})
)
const { mockIsHosted, mockEnv, mockGetBYOKKey, mockRateLimiterFns } = vi.hoisted(() => ({
mockIsHosted: { value: false },
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
mockGetBYOKKey: vi.fn(),
mockRateLimiterFns: {
acquireKey: vi.fn(),
preConsumeCapacity: vi.fn(),
consumeCapacity: vi.fn(),
},
}))
// Mock feature flags
vi.mock('@/lib/core/config/feature-flags', () => ({
@@ -55,10 +52,7 @@ vi.mock('@/lib/api-key/byok', () => ({
getBYOKKey: (...args: unknown[]) => mockGetBYOKKey(...args),
}))
// Mock logFixedUsage for billing
vi.mock('@/lib/billing/core/usage-log', () => ({
logFixedUsage: (...args: unknown[]) => mockLogFixedUsage(...args),
}))
vi.mock('@/lib/billing/core/usage-log', () => ({}))
vi.mock('@/lib/core/rate-limiter/hosted-key', () => ({
getHostedKeyRateLimiter: () => mockRateLimiterFns,
@@ -1364,7 +1358,6 @@ describe('Hosted Key Injection', () => {
cleanupEnvVars = setupEnvVars({ NEXT_PUBLIC_APP_URL: 'http://localhost:3000' })
vi.clearAllMocks()
mockGetBYOKKey.mockReset()
mockLogFixedUsage.mockReset()
})
afterEach(() => {
@@ -2022,7 +2015,6 @@ describe('Cost Field Handling', () => {
mockIsHosted.value = true
mockEnv.TEST_HOSTED_KEY = 'test-hosted-api-key'
mockGetBYOKKey.mockResolvedValue(null)
mockLogFixedUsage.mockResolvedValue(undefined)
// Set up throttler mock defaults
mockRateLimiterFns.acquireKey.mockResolvedValue({
success: true,
@@ -2097,14 +2089,6 @@ describe('Cost Field Handling', () => {
// This test verifies the tool execution flow when hosted key IS available (by checking output structure).
if (result.output.cost) {
expect(result.output.cost.total).toBe(0.005)
// Should have logged usage
expect(mockLogFixedUsage).toHaveBeenCalledWith(
expect.objectContaining({
userId: 'user-123',
cost: 0.005,
description: 'tool:test_cost_per_request',
})
)
}
Object.assign(tools, originalTools)
@@ -2169,8 +2153,6 @@ describe('Cost Field Handling', () => {
expect(result.success).toBe(true)
// Should not have cost since user provided their own key
expect(result.output.cost).toBeUndefined()
// Should not have logged usage
expect(mockLogFixedUsage).not.toHaveBeenCalled()
Object.assign(tools, originalTools)
})
@@ -2243,14 +2225,6 @@ describe('Cost Field Handling', () => {
// getCost should have been called with params and output
expect(mockGetCost).toHaveBeenCalled()
// Should have logged usage with metadata
expect(mockLogFixedUsage).toHaveBeenCalledWith(
expect.objectContaining({
cost: 0.015,
metadata: { mode: 'advanced', results: 10 },
})
)
Object.assign(tools, originalTools)
})
})

View File

@@ -1,7 +1,6 @@
import { createLogger } from '@sim/logger'
import { getBYOKKey } from '@/lib/api-key/byok'
import { generateInternalToken } from '@/lib/auth/internal'
import { logFixedUsage } from '@/lib/billing/core/usage-log'
import { isHosted } from '@/lib/core/config/feature-flags'
import { DEFAULT_EXECUTION_TIMEOUT_MS } from '@/lib/core/execution-limits'
import { getHostedKeyRateLimiter } from '@/lib/core/rate-limiter'
@@ -285,31 +284,10 @@ async function processHostedKeyCost(
if (!userId) return { cost, metadata }
const skipLog = !!ctx?.skipFixedUsageLog || !!tool.hosting?.skipFixedUsageLog
if (!skipLog) {
try {
await logFixedUsage({
userId,
source: 'workflow',
description: `tool:${tool.id}`,
cost,
workspaceId: wsId,
workflowId: wfId,
executionId: executionContext?.executionId,
metadata,
})
logger.debug(
`[${requestId}] Logged hosted key cost for ${tool.id}: $${cost}`,
metadata ? { metadata } : {}
)
} catch (error) {
logger.error(`[${requestId}] Failed to log hosted key usage for ${tool.id}:`, error)
}
} else {
logger.debug(
`[${requestId}] Skipping fixed usage log for ${tool.id} (cost will be tracked via provider tool loop)`
)
}
logger.debug(
`[${requestId}] Hosted key cost for ${tool.id}: $${cost}`,
metadata ? { metadata } : {}
)
return { cost, metadata }
}
@@ -388,13 +366,6 @@ async function applyHostedKeyCostToResult(
): Promise<void> {
await reportCustomDimensionUsage(tool, params, finalResult.output, executionContext, requestId)
if (tool.hosting?.skipFixedUsageLog) {
const ctx = params._context as Record<string, unknown> | undefined
if (ctx) {
ctx.skipFixedUsageLog = true
}
}
const { cost: hostedKeyCost, metadata } = await processHostedKeyCost(
tool,
params,

View File

@@ -152,7 +152,6 @@ export const chatTool: ToolConfig<PerplexityChatParams, PerplexityChatResponse>
mode: 'per_request',
requestsPerMinute: 20,
},
skipFixedUsageLog: true,
},
request: {

View File

@@ -312,6 +312,4 @@ export interface ToolHostingConfig<P = Record<string, unknown>> {
pricing: ToolHostingPricing<P>
/** Hosted key rate limit configuration (required for hosted key distribution) */
rateLimit: HostedKeyRateLimitConfig
/** When true, skip the fixed usage log entry (useful for tools that log custom dimensions instead) */
skipFixedUsageLog?: boolean
}