mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
fix(billing): atomize usage_log and userStats writes via central recordUsage (#3767)
* fix(billing): atomize usage_log and userStats writes via central recordUsage() * fix(billing): address PR review — re-throw errors, guard reserved keys, handle zero-cost counters * chore(lint): fix formatting in hubspot list_lists.ts from staging * fix(billing): tighten early-return guard to handle empty additionalStats object * lint * chore(billing): remove implementation-decision comments
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
import { db } from '@sim/db'
|
||||
import { userStats } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { logModelUsage } from '@/lib/billing/core/usage-log'
|
||||
import { recordUsage } from '@/lib/billing/core/usage-log'
|
||||
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
|
||||
import { checkInternalApiKey } from '@/lib/copilot/utils'
|
||||
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
|
||||
@@ -87,55 +85,41 @@ export async function POST(req: NextRequest) {
|
||||
source,
|
||||
})
|
||||
|
||||
// Check if user stats record exists (same as ExecutionLogger)
|
||||
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))
|
||||
|
||||
if (userStatsRecords.length === 0) {
|
||||
logger.error(
|
||||
`[${requestId}] User stats record not found - should be created during onboarding`,
|
||||
{
|
||||
userId,
|
||||
}
|
||||
)
|
||||
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
|
||||
}
|
||||
|
||||
const totalTokens = inputTokens + outputTokens
|
||||
|
||||
const updateFields: Record<string, unknown> = {
|
||||
totalCost: sql`total_cost + ${cost}`,
|
||||
currentPeriodCost: sql`current_period_cost + ${cost}`,
|
||||
const additionalStats: Record<string, ReturnType<typeof sql>> = {
|
||||
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
|
||||
currentPeriodCopilotCost: sql`current_period_copilot_cost + ${cost}`,
|
||||
totalCopilotCalls: sql`total_copilot_calls + 1`,
|
||||
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
|
||||
lastActive: new Date(),
|
||||
}
|
||||
|
||||
if (isMcp) {
|
||||
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
|
||||
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
|
||||
updateFields.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
|
||||
additionalStats.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
|
||||
additionalStats.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
|
||||
additionalStats.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
|
||||
}
|
||||
|
||||
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
|
||||
await recordUsage({
|
||||
userId,
|
||||
entries: [
|
||||
{
|
||||
category: 'model',
|
||||
source,
|
||||
description: model,
|
||||
cost,
|
||||
metadata: { inputTokens, outputTokens },
|
||||
},
|
||||
],
|
||||
additionalStats,
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Updated user stats record`, {
|
||||
logger.info(`[${requestId}] Recorded usage`, {
|
||||
userId,
|
||||
addedCost: cost,
|
||||
source,
|
||||
})
|
||||
|
||||
// Log usage for complete audit trail with the original source for visibility
|
||||
await logModelUsage({
|
||||
userId,
|
||||
source,
|
||||
model,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cost,
|
||||
})
|
||||
|
||||
// Check if user has hit overage threshold and bill incrementally
|
||||
await checkAndBillOverageThreshold(userId)
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { db } from '@sim/db'
|
||||
import { userStats, workflow } from '@sim/db/schema'
|
||||
import { workflow } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getBYOKKey } from '@/lib/api-key/byok'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { logModelUsage } from '@/lib/billing/core/usage-log'
|
||||
import { recordUsage } from '@/lib/billing/core/usage-log'
|
||||
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
|
||||
@@ -134,23 +134,20 @@ async function updateUserStatsForWand(
|
||||
costToStore = modelCost * costMultiplier
|
||||
}
|
||||
|
||||
await db
|
||||
.update(userStats)
|
||||
.set({
|
||||
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
|
||||
totalCost: sql`total_cost + ${costToStore}`,
|
||||
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
|
||||
lastActive: new Date(),
|
||||
})
|
||||
.where(eq(userStats.userId, userId))
|
||||
|
||||
await logModelUsage({
|
||||
await recordUsage({
|
||||
userId,
|
||||
source: 'wand',
|
||||
model: modelName,
|
||||
inputTokens: promptTokens,
|
||||
outputTokens: completionTokens,
|
||||
cost: costToStore,
|
||||
entries: [
|
||||
{
|
||||
category: 'model',
|
||||
source: 'wand',
|
||||
description: modelName,
|
||||
cost: costToStore,
|
||||
metadata: { inputTokens: promptTokens, outputTokens: completionTokens },
|
||||
},
|
||||
],
|
||||
additionalStats: {
|
||||
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
|
||||
},
|
||||
})
|
||||
|
||||
await checkAndBillOverageThreshold(userId)
|
||||
@@ -341,7 +338,7 @@ export async function POST(req: NextRequest) {
|
||||
let finalUsage: any = null
|
||||
let usageRecorded = false
|
||||
|
||||
const recordUsage = async () => {
|
||||
const flushUsage = async () => {
|
||||
if (usageRecorded || !finalUsage) {
|
||||
return
|
||||
}
|
||||
@@ -360,7 +357,7 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
if (done) {
|
||||
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
|
||||
await recordUsage()
|
||||
await flushUsage()
|
||||
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
|
||||
controller.close()
|
||||
break
|
||||
@@ -390,7 +387,7 @@ export async function POST(req: NextRequest) {
|
||||
if (data === '[DONE]') {
|
||||
logger.info(`[${requestId}] Received [DONE] signal`)
|
||||
|
||||
await recordUsage()
|
||||
await flushUsage()
|
||||
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
|
||||
@@ -468,7 +465,7 @@ export async function POST(req: NextRequest) {
|
||||
})
|
||||
|
||||
try {
|
||||
await recordUsage()
|
||||
await flushUsage()
|
||||
} catch (usageError) {
|
||||
logger.warn(`[${requestId}] Failed to record usage after stream error`, usageError)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { db } from '@sim/db'
|
||||
import { usageLog } from '@sim/db/schema'
|
||||
import { usageLog, userStats } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { and, desc, eq, gte, lte, sql } from 'drizzle-orm'
|
||||
import { and, desc, eq, gte, lte, type SQL, sql } from 'drizzle-orm'
|
||||
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
|
||||
|
||||
const logger = createLogger('UsageLog')
|
||||
@@ -32,224 +32,121 @@ export interface ModelUsageMetadata {
|
||||
}
|
||||
|
||||
/**
|
||||
* Metadata for 'fixed' category charges (e.g., tool cost breakdown)
|
||||
* Union type for all usage log metadata types
|
||||
*/
|
||||
export type FixedUsageMetadata = Record<string, unknown>
|
||||
export type UsageLogMetadata = ModelUsageMetadata | Record<string, unknown> | null
|
||||
|
||||
/**
|
||||
* Union type for all metadata types
|
||||
* A single usage entry to be recorded in the usage_log table.
|
||||
*/
|
||||
export type UsageLogMetadata = ModelUsageMetadata | FixedUsageMetadata | null
|
||||
|
||||
/**
|
||||
* Parameters for logging model usage (token-based charges)
|
||||
*/
|
||||
export interface LogModelUsageParams {
|
||||
userId: string
|
||||
source: UsageLogSource
|
||||
model: string
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
cost: number
|
||||
toolCost?: number
|
||||
workspaceId?: string
|
||||
workflowId?: string
|
||||
executionId?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for logging fixed charges (flat fees)
|
||||
*/
|
||||
export interface LogFixedUsageParams {
|
||||
userId: string
|
||||
export interface UsageEntry {
|
||||
category: UsageLogCategory
|
||||
source: UsageLogSource
|
||||
description: string
|
||||
cost: number
|
||||
workspaceId?: string
|
||||
workflowId?: string
|
||||
executionId?: string
|
||||
/** Optional metadata (e.g., tool cost breakdown from API) */
|
||||
metadata?: FixedUsageMetadata
|
||||
metadata?: UsageLogMetadata
|
||||
}
|
||||
|
||||
/**
|
||||
* Log a model usage charge (token-based)
|
||||
* Parameters for the central recordUsage function.
|
||||
* This is the single entry point for all billing mutations.
|
||||
*/
|
||||
export async function logModelUsage(params: LogModelUsageParams): Promise<void> {
|
||||
if (!isBillingEnabled || params.cost <= 0) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const metadata: ModelUsageMetadata = {
|
||||
inputTokens: params.inputTokens,
|
||||
outputTokens: params.outputTokens,
|
||||
...(params.toolCost != null && params.toolCost > 0 && { toolCost: params.toolCost }),
|
||||
}
|
||||
|
||||
await db.insert(usageLog).values({
|
||||
id: crypto.randomUUID(),
|
||||
userId: params.userId,
|
||||
category: 'model',
|
||||
source: params.source,
|
||||
description: params.model,
|
||||
metadata,
|
||||
cost: params.cost.toString(),
|
||||
workspaceId: params.workspaceId ?? null,
|
||||
workflowId: params.workflowId ?? null,
|
||||
executionId: params.executionId ?? null,
|
||||
})
|
||||
|
||||
logger.debug('Logged model usage', {
|
||||
userId: params.userId,
|
||||
source: params.source,
|
||||
model: params.model,
|
||||
cost: params.cost,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to log model usage', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
params,
|
||||
})
|
||||
// Don't throw - usage logging should not break the main flow
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Log a fixed charge (flat fee like base execution charge or search)
|
||||
*/
|
||||
export async function logFixedUsage(params: LogFixedUsageParams): Promise<void> {
|
||||
if (!isBillingEnabled || params.cost <= 0) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await db.insert(usageLog).values({
|
||||
id: crypto.randomUUID(),
|
||||
userId: params.userId,
|
||||
category: 'fixed',
|
||||
source: params.source,
|
||||
description: params.description,
|
||||
metadata: params.metadata ?? null,
|
||||
cost: params.cost.toString(),
|
||||
workspaceId: params.workspaceId ?? null,
|
||||
workflowId: params.workflowId ?? null,
|
||||
executionId: params.executionId ?? null,
|
||||
})
|
||||
|
||||
logger.debug('Logged fixed usage', {
|
||||
userId: params.userId,
|
||||
source: params.source,
|
||||
description: params.description,
|
||||
cost: params.cost,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to log fixed usage', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
params,
|
||||
})
|
||||
// Don't throw - usage logging should not break the main flow
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for batch logging workflow usage
|
||||
*/
|
||||
export interface LogWorkflowUsageBatchParams {
|
||||
export interface RecordUsageParams {
|
||||
/** The user being charged */
|
||||
userId: string
|
||||
/** One or more usage_log entries to record. Total cost is derived from these. */
|
||||
entries: UsageEntry[]
|
||||
/** Workspace context */
|
||||
workspaceId?: string
|
||||
workflowId: string
|
||||
/** Workflow context */
|
||||
workflowId?: string
|
||||
/** Execution context */
|
||||
executionId?: string
|
||||
baseExecutionCharge?: number
|
||||
models?: Record<
|
||||
string,
|
||||
{
|
||||
total: number
|
||||
tokens: { input: number; output: number }
|
||||
toolCost?: number
|
||||
}
|
||||
>
|
||||
/** Source-specific counter increments (e.g. totalCopilotCalls, totalManualExecutions) */
|
||||
additionalStats?: Record<string, SQL>
|
||||
}
|
||||
|
||||
/**
|
||||
* Log all workflow usage entries in a single batch insert (performance optimized)
|
||||
* Records usage in a single atomic transaction.
|
||||
*
|
||||
* Inserts all entries into usage_log and updates userStats counters
|
||||
* (totalCost, currentPeriodCost, lastActive) within one Postgres transaction.
|
||||
* The total cost added to userStats is derived from summing entry costs,
|
||||
* ensuring usage_log and currentPeriodCost can never drift apart.
|
||||
*
|
||||
* If billing is disabled, total cost is zero, or no entries have positive cost,
|
||||
* this function returns early without writing anything.
|
||||
*/
|
||||
export async function logWorkflowUsageBatch(params: LogWorkflowUsageBatchParams): Promise<void> {
|
||||
export async function recordUsage(params: RecordUsageParams): Promise<void> {
|
||||
if (!isBillingEnabled) {
|
||||
return
|
||||
}
|
||||
|
||||
const entries: Array<{
|
||||
id: string
|
||||
userId: string
|
||||
category: 'model' | 'fixed'
|
||||
source: 'workflow'
|
||||
description: string
|
||||
metadata: ModelUsageMetadata | null
|
||||
cost: string
|
||||
workspaceId: string | null
|
||||
workflowId: string | null
|
||||
executionId: string | null
|
||||
}> = []
|
||||
const { userId, entries, workspaceId, workflowId, executionId, additionalStats } = params
|
||||
|
||||
if (params.baseExecutionCharge && params.baseExecutionCharge > 0) {
|
||||
entries.push({
|
||||
id: crypto.randomUUID(),
|
||||
userId: params.userId,
|
||||
category: 'fixed',
|
||||
source: 'workflow',
|
||||
description: 'execution_fee',
|
||||
metadata: null,
|
||||
cost: params.baseExecutionCharge.toString(),
|
||||
workspaceId: params.workspaceId ?? null,
|
||||
workflowId: params.workflowId,
|
||||
executionId: params.executionId ?? null,
|
||||
})
|
||||
}
|
||||
const validEntries = entries.filter((e) => e.cost > 0)
|
||||
const totalCost = validEntries.reduce((sum, e) => sum + e.cost, 0)
|
||||
|
||||
if (params.models) {
|
||||
for (const [modelName, modelData] of Object.entries(params.models)) {
|
||||
if (modelData.total > 0) {
|
||||
entries.push({
|
||||
id: crypto.randomUUID(),
|
||||
userId: params.userId,
|
||||
category: 'model',
|
||||
source: 'workflow',
|
||||
description: modelName,
|
||||
metadata: {
|
||||
inputTokens: modelData.tokens.input,
|
||||
outputTokens: modelData.tokens.output,
|
||||
...(modelData.toolCost != null &&
|
||||
modelData.toolCost > 0 && { toolCost: modelData.toolCost }),
|
||||
},
|
||||
cost: modelData.total.toString(),
|
||||
workspaceId: params.workspaceId ?? null,
|
||||
workflowId: params.workflowId,
|
||||
executionId: params.executionId ?? null,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (entries.length === 0) {
|
||||
if (
|
||||
validEntries.length === 0 &&
|
||||
(!additionalStats || Object.keys(additionalStats).length === 0)
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await db.insert(usageLog).values(entries)
|
||||
const RESERVED_KEYS = new Set(['totalCost', 'currentPeriodCost', 'lastActive'])
|
||||
const safeStats = additionalStats
|
||||
? Object.fromEntries(Object.entries(additionalStats).filter(([k]) => !RESERVED_KEYS.has(k)))
|
||||
: undefined
|
||||
|
||||
logger.debug('Logged workflow usage batch', {
|
||||
userId: params.userId,
|
||||
workflowId: params.workflowId,
|
||||
entryCount: entries.length,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to log workflow usage batch', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
params,
|
||||
})
|
||||
// Don't throw - usage logging should not break the main flow
|
||||
}
|
||||
await db.transaction(async (tx) => {
|
||||
if (validEntries.length > 0) {
|
||||
await tx.insert(usageLog).values(
|
||||
validEntries.map((entry) => ({
|
||||
id: crypto.randomUUID(),
|
||||
userId,
|
||||
category: entry.category,
|
||||
source: entry.source,
|
||||
description: entry.description,
|
||||
metadata: entry.metadata ?? null,
|
||||
cost: entry.cost.toString(),
|
||||
workspaceId: workspaceId ?? null,
|
||||
workflowId: workflowId ?? null,
|
||||
executionId: executionId ?? null,
|
||||
}))
|
||||
)
|
||||
}
|
||||
|
||||
const updateFields: Record<string, SQL | Date> = {
|
||||
lastActive: new Date(),
|
||||
...(totalCost > 0 && {
|
||||
totalCost: sql`total_cost + ${totalCost}`,
|
||||
currentPeriodCost: sql`current_period_cost + ${totalCost}`,
|
||||
}),
|
||||
...safeStats,
|
||||
}
|
||||
|
||||
const result = await tx
|
||||
.update(userStats)
|
||||
.set(updateFields)
|
||||
.where(eq(userStats.userId, userId))
|
||||
.returning({ userId: userStats.userId })
|
||||
|
||||
if (result.length === 0) {
|
||||
logger.warn('recordUsage: userStats row not found, transaction will roll back', {
|
||||
userId,
|
||||
totalCost,
|
||||
})
|
||||
throw new Error(`userStats row not found for userId: ${userId}`)
|
||||
}
|
||||
})
|
||||
|
||||
logger.debug('Recorded usage', {
|
||||
userId,
|
||||
totalCost,
|
||||
entryCount: validEntries.length,
|
||||
sources: [...new Set(validEntries.map((e) => e.source))],
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -29,7 +29,7 @@ vi.mock('@/lib/billing/core/usage', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/billing/core/usage-log', () => ({
|
||||
logWorkflowUsageBatch: vi.fn(() => Promise.resolve()),
|
||||
recordUsage: vi.fn(() => Promise.resolve()),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/billing/threshold-billing', () => ({
|
||||
|
||||
@@ -16,7 +16,7 @@ import {
|
||||
getOrgUsageLimit,
|
||||
maybeSendUsageThresholdEmail,
|
||||
} from '@/lib/billing/core/usage'
|
||||
import { logWorkflowUsageBatch } from '@/lib/billing/core/usage-log'
|
||||
import { type ModelUsageMetadata, recordUsage } from '@/lib/billing/core/usage-log'
|
||||
import { isOrgPlan } from '@/lib/billing/plan-helpers'
|
||||
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
|
||||
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
|
||||
@@ -37,6 +37,17 @@ import type {
|
||||
} from '@/lib/logs/types'
|
||||
import type { SerializableExecutionState } from '@/executor/execution/types'
|
||||
|
||||
/** Maps execution trigger types to their corresponding userStats counter columns */
|
||||
const TRIGGER_COUNTER_MAP: Record<string, { key: string; column: string }> = {
|
||||
manual: { key: 'totalManualExecutions', column: 'total_manual_executions' },
|
||||
api: { key: 'totalApiCalls', column: 'total_api_calls' },
|
||||
webhook: { key: 'totalWebhookTriggers', column: 'total_webhook_triggers' },
|
||||
schedule: { key: 'totalScheduledExecutions', column: 'total_scheduled_executions' },
|
||||
chat: { key: 'totalChatExecutions', column: 'total_chat_executions' },
|
||||
mcp: { key: 'totalMcpExecutions', column: 'total_mcp_executions' },
|
||||
a2a: { key: 'totalA2aExecutions', column: 'total_a2a_executions' },
|
||||
} as const
|
||||
|
||||
export interface ToolCall {
|
||||
name: string
|
||||
duration: number // in milliseconds
|
||||
@@ -634,66 +645,58 @@ export class ExecutionLogger implements IExecutionLoggerService {
|
||||
return
|
||||
}
|
||||
|
||||
const costToStore = costSummary.totalCost
|
||||
const entries: Array<{
|
||||
category: 'model' | 'fixed'
|
||||
source: 'workflow'
|
||||
description: string
|
||||
cost: number
|
||||
metadata?: ModelUsageMetadata | null
|
||||
}> = []
|
||||
|
||||
const existing = await db.select().from(userStats).where(eq(userStats.userId, userId))
|
||||
if (existing.length === 0) {
|
||||
logger.error('User stats record not found - should be created during onboarding', {
|
||||
userId,
|
||||
trigger,
|
||||
if (costSummary.baseExecutionCharge > 0) {
|
||||
entries.push({
|
||||
category: 'fixed',
|
||||
source: 'workflow',
|
||||
description: 'execution_fee',
|
||||
cost: costSummary.baseExecutionCharge,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// All costs go to currentPeriodCost - credits are applied at end of billing cycle
|
||||
const updateFields: any = {
|
||||
if (costSummary.models) {
|
||||
for (const [modelName, modelData] of Object.entries(costSummary.models)) {
|
||||
if (modelData.total > 0) {
|
||||
entries.push({
|
||||
category: 'model',
|
||||
source: 'workflow',
|
||||
description: modelName,
|
||||
cost: modelData.total,
|
||||
metadata: {
|
||||
inputTokens: modelData.tokens.input,
|
||||
outputTokens: modelData.tokens.output,
|
||||
...(modelData.toolCost != null &&
|
||||
modelData.toolCost > 0 && { toolCost: modelData.toolCost }),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const additionalStats: Record<string, ReturnType<typeof sql>> = {
|
||||
totalTokensUsed: sql`total_tokens_used + ${costSummary.totalTokens}`,
|
||||
totalCost: sql`total_cost + ${costToStore}`,
|
||||
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
|
||||
lastActive: new Date(),
|
||||
}
|
||||
|
||||
switch (trigger) {
|
||||
case 'manual':
|
||||
updateFields.totalManualExecutions = sql`total_manual_executions + 1`
|
||||
break
|
||||
case 'api':
|
||||
updateFields.totalApiCalls = sql`total_api_calls + 1`
|
||||
break
|
||||
case 'webhook':
|
||||
updateFields.totalWebhookTriggers = sql`total_webhook_triggers + 1`
|
||||
break
|
||||
case 'schedule':
|
||||
updateFields.totalScheduledExecutions = sql`total_scheduled_executions + 1`
|
||||
break
|
||||
case 'chat':
|
||||
updateFields.totalChatExecutions = sql`total_chat_executions + 1`
|
||||
break
|
||||
case 'mcp':
|
||||
updateFields.totalMcpExecutions = sql`total_mcp_executions + 1`
|
||||
break
|
||||
case 'a2a':
|
||||
updateFields.totalA2aExecutions = sql`total_a2a_executions + 1`
|
||||
break
|
||||
const triggerCounter = TRIGGER_COUNTER_MAP[trigger]
|
||||
if (triggerCounter) {
|
||||
additionalStats[triggerCounter.key] = sql`${sql.raw(triggerCounter.column)} + 1`
|
||||
}
|
||||
|
||||
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
|
||||
|
||||
logger.debug('Updated user stats record with cost data', {
|
||||
userId,
|
||||
trigger,
|
||||
addedCost: costToStore,
|
||||
addedTokens: costSummary.totalTokens,
|
||||
})
|
||||
|
||||
// Log usage entries for auditing (batch insert for performance)
|
||||
await logWorkflowUsageBatch({
|
||||
await recordUsage({
|
||||
userId,
|
||||
entries,
|
||||
workspaceId: workflowRecord.workspaceId ?? undefined,
|
||||
workflowId,
|
||||
executionId,
|
||||
baseExecutionCharge: costSummary.baseExecutionCharge,
|
||||
models: costSummary.models,
|
||||
additionalStats,
|
||||
})
|
||||
|
||||
// Check if user has hit overage threshold and bill incrementally
|
||||
|
||||
@@ -1401,7 +1401,6 @@ describe('prepareToolExecution', () => {
|
||||
workspaceId: 'ws-456',
|
||||
chatId: 'chat-789',
|
||||
userId: 'user-abc',
|
||||
skipFixedUsageLog: true,
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1147,7 +1147,6 @@ export function prepareToolExecution(
|
||||
? { isDeployedContext: request.isDeployedContext }
|
||||
: {}),
|
||||
...(request.callChain ? { callChain: request.callChain } : {}),
|
||||
skipFixedUsageLog: true,
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import type { HubSpotListListsParams, HubSpotListListsResponse } from '@/tools/hubspot/types'
|
||||
import { LISTS_ARRAY_OUTPUT, METADATA_OUTPUT_PROPERTIES, PAGING_OUTPUT } from '@/tools/hubspot/types'
|
||||
import {
|
||||
LISTS_ARRAY_OUTPUT,
|
||||
METADATA_OUTPUT_PROPERTIES,
|
||||
PAGING_OUTPUT,
|
||||
} from '@/tools/hubspot/types'
|
||||
import type { ToolConfig } from '@/tools/types'
|
||||
|
||||
const logger = createLogger('HubSpotListLists')
|
||||
@@ -99,7 +103,11 @@ export const hubspotListListsTool: ToolConfig<HubSpotListListsParams, HubSpotLis
|
||||
description: 'Response metadata',
|
||||
properties: {
|
||||
...METADATA_OUTPUT_PROPERTIES,
|
||||
total: { type: 'number', description: 'Total number of lists matching the query', optional: true },
|
||||
total: {
|
||||
type: 'number',
|
||||
description: 'Total number of lists matching the query',
|
||||
optional: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
success: { type: 'boolean', description: 'Operation success status' },
|
||||
|
||||
@@ -16,19 +16,16 @@ import {
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// Hoisted mock state - these are available to vi.mock factories
|
||||
const { mockIsHosted, mockEnv, mockGetBYOKKey, mockLogFixedUsage, mockRateLimiterFns } = vi.hoisted(
|
||||
() => ({
|
||||
mockIsHosted: { value: false },
|
||||
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
|
||||
mockGetBYOKKey: vi.fn(),
|
||||
mockLogFixedUsage: vi.fn(),
|
||||
mockRateLimiterFns: {
|
||||
acquireKey: vi.fn(),
|
||||
preConsumeCapacity: vi.fn(),
|
||||
consumeCapacity: vi.fn(),
|
||||
},
|
||||
})
|
||||
)
|
||||
const { mockIsHosted, mockEnv, mockGetBYOKKey, mockRateLimiterFns } = vi.hoisted(() => ({
|
||||
mockIsHosted: { value: false },
|
||||
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
|
||||
mockGetBYOKKey: vi.fn(),
|
||||
mockRateLimiterFns: {
|
||||
acquireKey: vi.fn(),
|
||||
preConsumeCapacity: vi.fn(),
|
||||
consumeCapacity: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock feature flags
|
||||
vi.mock('@/lib/core/config/feature-flags', () => ({
|
||||
@@ -55,10 +52,7 @@ vi.mock('@/lib/api-key/byok', () => ({
|
||||
getBYOKKey: (...args: unknown[]) => mockGetBYOKKey(...args),
|
||||
}))
|
||||
|
||||
// Mock logFixedUsage for billing
|
||||
vi.mock('@/lib/billing/core/usage-log', () => ({
|
||||
logFixedUsage: (...args: unknown[]) => mockLogFixedUsage(...args),
|
||||
}))
|
||||
vi.mock('@/lib/billing/core/usage-log', () => ({}))
|
||||
|
||||
vi.mock('@/lib/core/rate-limiter/hosted-key', () => ({
|
||||
getHostedKeyRateLimiter: () => mockRateLimiterFns,
|
||||
@@ -1364,7 +1358,6 @@ describe('Hosted Key Injection', () => {
|
||||
cleanupEnvVars = setupEnvVars({ NEXT_PUBLIC_APP_URL: 'http://localhost:3000' })
|
||||
vi.clearAllMocks()
|
||||
mockGetBYOKKey.mockReset()
|
||||
mockLogFixedUsage.mockReset()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
@@ -2022,7 +2015,6 @@ describe('Cost Field Handling', () => {
|
||||
mockIsHosted.value = true
|
||||
mockEnv.TEST_HOSTED_KEY = 'test-hosted-api-key'
|
||||
mockGetBYOKKey.mockResolvedValue(null)
|
||||
mockLogFixedUsage.mockResolvedValue(undefined)
|
||||
// Set up throttler mock defaults
|
||||
mockRateLimiterFns.acquireKey.mockResolvedValue({
|
||||
success: true,
|
||||
@@ -2097,14 +2089,6 @@ describe('Cost Field Handling', () => {
|
||||
// This test verifies the tool execution flow when hosted key IS available (by checking output structure).
|
||||
if (result.output.cost) {
|
||||
expect(result.output.cost.total).toBe(0.005)
|
||||
// Should have logged usage
|
||||
expect(mockLogFixedUsage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
userId: 'user-123',
|
||||
cost: 0.005,
|
||||
description: 'tool:test_cost_per_request',
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
@@ -2169,8 +2153,6 @@ describe('Cost Field Handling', () => {
|
||||
expect(result.success).toBe(true)
|
||||
// Should not have cost since user provided their own key
|
||||
expect(result.output.cost).toBeUndefined()
|
||||
// Should not have logged usage
|
||||
expect(mockLogFixedUsage).not.toHaveBeenCalled()
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
@@ -2243,14 +2225,6 @@ describe('Cost Field Handling', () => {
|
||||
// getCost should have been called with params and output
|
||||
expect(mockGetCost).toHaveBeenCalled()
|
||||
|
||||
// Should have logged usage with metadata
|
||||
expect(mockLogFixedUsage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
cost: 0.015,
|
||||
metadata: { mode: 'advanced', results: 10 },
|
||||
})
|
||||
)
|
||||
|
||||
Object.assign(tools, originalTools)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { getBYOKKey } from '@/lib/api-key/byok'
|
||||
import { generateInternalToken } from '@/lib/auth/internal'
|
||||
import { logFixedUsage } from '@/lib/billing/core/usage-log'
|
||||
import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { DEFAULT_EXECUTION_TIMEOUT_MS } from '@/lib/core/execution-limits'
|
||||
import { getHostedKeyRateLimiter } from '@/lib/core/rate-limiter'
|
||||
@@ -285,31 +284,10 @@ async function processHostedKeyCost(
|
||||
|
||||
if (!userId) return { cost, metadata }
|
||||
|
||||
const skipLog = !!ctx?.skipFixedUsageLog || !!tool.hosting?.skipFixedUsageLog
|
||||
if (!skipLog) {
|
||||
try {
|
||||
await logFixedUsage({
|
||||
userId,
|
||||
source: 'workflow',
|
||||
description: `tool:${tool.id}`,
|
||||
cost,
|
||||
workspaceId: wsId,
|
||||
workflowId: wfId,
|
||||
executionId: executionContext?.executionId,
|
||||
metadata,
|
||||
})
|
||||
logger.debug(
|
||||
`[${requestId}] Logged hosted key cost for ${tool.id}: $${cost}`,
|
||||
metadata ? { metadata } : {}
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to log hosted key usage for ${tool.id}:`, error)
|
||||
}
|
||||
} else {
|
||||
logger.debug(
|
||||
`[${requestId}] Skipping fixed usage log for ${tool.id} (cost will be tracked via provider tool loop)`
|
||||
)
|
||||
}
|
||||
logger.debug(
|
||||
`[${requestId}] Hosted key cost for ${tool.id}: $${cost}`,
|
||||
metadata ? { metadata } : {}
|
||||
)
|
||||
|
||||
return { cost, metadata }
|
||||
}
|
||||
@@ -388,13 +366,6 @@ async function applyHostedKeyCostToResult(
|
||||
): Promise<void> {
|
||||
await reportCustomDimensionUsage(tool, params, finalResult.output, executionContext, requestId)
|
||||
|
||||
if (tool.hosting?.skipFixedUsageLog) {
|
||||
const ctx = params._context as Record<string, unknown> | undefined
|
||||
if (ctx) {
|
||||
ctx.skipFixedUsageLog = true
|
||||
}
|
||||
}
|
||||
|
||||
const { cost: hostedKeyCost, metadata } = await processHostedKeyCost(
|
||||
tool,
|
||||
params,
|
||||
|
||||
@@ -152,7 +152,6 @@ export const chatTool: ToolConfig<PerplexityChatParams, PerplexityChatResponse>
|
||||
mode: 'per_request',
|
||||
requestsPerMinute: 20,
|
||||
},
|
||||
skipFixedUsageLog: true,
|
||||
},
|
||||
|
||||
request: {
|
||||
|
||||
@@ -312,6 +312,4 @@ export interface ToolHostingConfig<P = Record<string, unknown>> {
|
||||
pricing: ToolHostingPricing<P>
|
||||
/** Hosted key rate limit configuration (required for hosted key distribution) */
|
||||
rateLimit: HostedKeyRateLimitConfig
|
||||
/** When true, skip the fixed usage log entry (useful for tools that log custom dimensions instead) */
|
||||
skipFixedUsageLog?: boolean
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user