Add copilot mcp tracking

This commit is contained in:
Siddharth Ganesan
2026-02-06 12:38:16 -08:00
parent 92efd817d2
commit a7341cdcd3
7 changed files with 11235 additions and 13 deletions

View File

@@ -18,6 +18,7 @@ const UpdateCostSchema = z.object({
model: z.string().min(1, 'Model is required'),
inputTokens: z.number().min(0).default(0),
outputTokens: z.number().min(0).default(0),
source: z.enum(['copilot', 'mcp_copilot']).default('copilot'),
})
/**
@@ -75,12 +76,14 @@ export async function POST(req: NextRequest) {
)
}
const { userId, cost, model, inputTokens, outputTokens } = validation.data
const { userId, cost, model, inputTokens, outputTokens, source } = validation.data
const isMcp = source === 'mcp_copilot'
logger.info(`[${requestId}] Processing cost update`, {
userId,
cost,
model,
source,
})
// Check if user stats record exists (same as ExecutionLogger)
@@ -96,7 +99,7 @@ export async function POST(req: NextRequest) {
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
}
const updateFields = {
const updateFields: Record<string, unknown> = {
totalCost: sql`total_cost + ${cost}`,
currentPeriodCost: sql`current_period_cost + ${cost}`,
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
@@ -105,17 +108,24 @@ export async function POST(req: NextRequest) {
lastActive: new Date(),
}
// Also increment MCP-specific counters when source is mcp_copilot
if (isMcp) {
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
}
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
logger.info(`[${requestId}] Updated user stats record`, {
userId,
addedCost: cost,
source,
})
// Log usage for complete audit trail
await logModelUsage({
userId,
source: 'copilot',
source: isMcp ? 'mcp_copilot' : 'copilot',
model,
inputTokens,
outputTokens,

View File

@@ -10,9 +10,13 @@ import {
type ListToolsResult,
type RequestId,
} from '@modelcontextprotocol/sdk/types.js'
import { db } from '@sim/db'
import { userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { checkHybridAuth } from '@/lib/auth/hybrid'
import { authenticateApiKeyFromHeader, updateApiKeyLastUsed } from '@/lib/api-key/service'
import { checkServerSideUsageLimits } from '@/lib/billing/calculations/usage-monitor'
import { getCopilotModel } from '@/lib/copilot/config'
import { SIM_AGENT_VERSION } from '@/lib/copilot/constants'
import { orchestrateCopilotStream } from '@/lib/copilot/orchestrator'
@@ -97,11 +101,28 @@ export async function GET() {
export async function POST(request: NextRequest) {
try {
const auth = await checkHybridAuth(request, { requireWorkflowId: false })
if (!auth.success || !auth.userId) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
// API-key-only auth — MCP clients must provide x-api-key header
const apiKeyHeader = request.headers.get('x-api-key')
if (!apiKeyHeader) {
return NextResponse.json(
createError(0, -32000, 'API key required. Set the x-api-key header with a valid Sim API key.'),
{ status: 401 }
)
}
const authResult = await authenticateApiKeyFromHeader(apiKeyHeader)
if (!authResult.success || !authResult.userId) {
return NextResponse.json(
createError(0, -32000, authResult.error || 'Invalid API key'),
{ status: 401 }
)
}
// Fire-and-forget last-used update
updateApiKeyLastUsed(authResult.keyId!)
const userId = authResult.userId
const body = (await request.json()) as JSONRPCMessage
if (isJSONRPCNotification(body)) {
@@ -117,6 +138,17 @@ export async function POST(request: NextRequest) {
const { id, method, params } = body
// Pre-flight usage limit check for tool calls
if (method === 'tools/call') {
const usageCheck = await checkServerSideUsageLimits(userId)
if (usageCheck.isExceeded) {
return NextResponse.json(
createError(id, -32000, `Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}`),
{ status: 402 }
)
}
}
switch (method) {
case 'initialize': {
const result: InitializeResult = {
@@ -131,12 +163,16 @@ export async function POST(request: NextRequest) {
return NextResponse.json(createResponse(id, {}))
case 'tools/list':
return handleToolsList(id)
case 'tools/call':
return handleToolsCall(
case 'tools/call': {
const response = await handleToolsCall(
id,
params as { name: string; arguments?: Record<string, unknown> },
auth.userId
userId
)
// Track MCP copilot call (fire-and-forget)
trackMcpCopilotCall(userId)
return response
}
default:
return NextResponse.json(
createError(id, ErrorCode.MethodNotFound, `Method not found: ${method}`),
@@ -151,6 +187,22 @@ export async function POST(request: NextRequest) {
}
}
/**
* Increment MCP copilot call counter in userStats (fire-and-forget).
*/
function trackMcpCopilotCall(userId: string): void {
db.update(userStats)
.set({
totalMcpCopilotCalls: sql`total_mcp_copilot_calls + 1`,
lastActive: new Date(),
})
.where(eq(userStats.userId, userId))
.then(() => {})
.catch((error) => {
logger.error('Failed to track MCP copilot call', { error, userId })
})
}
async function handleToolsList(id: RequestId): Promise<NextResponse> {
const directTools = DIRECT_TOOL_DEFS.map((tool) => ({
name: tool.name,
@@ -351,6 +403,7 @@ async function handleSubagentToolCall(
context,
model,
headless: true,
source: 'mcp_copilot',
},
{
userId,

View File

@@ -14,7 +14,7 @@ export type UsageLogCategory = 'model' | 'fixed'
/**
* Usage log source types
*/
export type UsageLogSource = 'workflow' | 'wand' | 'copilot'
export type UsageLogSource = 'workflow' | 'wand' | 'copilot' | 'mcp_copilot'
/**
* Metadata for 'model' category charges

View File

@@ -0,0 +1,4 @@
ALTER TYPE "public"."usage_log_source" ADD VALUE 'mcp_copilot';--> statement-breakpoint
ALTER TABLE "user_stats" ADD COLUMN "total_mcp_copilot_calls" integer DEFAULT 0 NOT NULL;--> statement-breakpoint
ALTER TABLE "user_stats" ADD COLUMN "total_mcp_copilot_cost" numeric DEFAULT '0' NOT NULL;--> statement-breakpoint
ALTER TABLE "user_stats" ADD COLUMN "current_period_mcp_copilot_cost" numeric DEFAULT '0' NOT NULL;

File diff suppressed because it is too large Load Diff

View File

@@ -1065,6 +1065,13 @@
"when": 1770336289511,
"tag": "0152_parallel_frog_thor",
"breakpoints": true
},
{
"idx": 153,
"version": "7",
"when": 1770410282842,
"tag": "0153_complete_arclight",
"breakpoints": true
}
]
}
}

View File

@@ -715,6 +715,10 @@ export const userStats = pgTable('user_stats', {
lastPeriodCopilotCost: decimal('last_period_copilot_cost').default('0'),
totalCopilotTokens: integer('total_copilot_tokens').notNull().default(0),
totalCopilotCalls: integer('total_copilot_calls').notNull().default(0),
// MCP Copilot usage tracking
totalMcpCopilotCalls: integer('total_mcp_copilot_calls').notNull().default(0),
totalMcpCopilotCost: decimal('total_mcp_copilot_cost').notNull().default('0'),
currentPeriodMcpCopilotCost: decimal('current_period_mcp_copilot_cost').notNull().default('0'),
// Storage tracking (for free/pro users)
storageUsedBytes: bigint('storage_used_bytes', { mode: 'number' }).notNull().default(0),
lastActive: timestamp('last_active').notNull().defaultNow(),
@@ -1968,7 +1972,12 @@ export const a2aPushNotificationConfig = pgTable(
)
export const usageLogCategoryEnum = pgEnum('usage_log_category', ['model', 'fixed'])
export const usageLogSourceEnum = pgEnum('usage_log_source', ['workflow', 'wand', 'copilot'])
export const usageLogSourceEnum = pgEnum('usage_log_source', [
'workflow',
'wand',
'copilot',
'mcp_copilot',
])
export const usageLog = pgTable(
'usage_log',