mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-06 20:55:23 -05:00
Add copilot mcp tracking
This commit is contained in:
@@ -18,6 +18,7 @@ const UpdateCostSchema = z.object({
|
||||
model: z.string().min(1, 'Model is required'),
|
||||
inputTokens: z.number().min(0).default(0),
|
||||
outputTokens: z.number().min(0).default(0),
|
||||
source: z.enum(['copilot', 'mcp_copilot']).default('copilot'),
|
||||
})
|
||||
|
||||
/**
|
||||
@@ -75,12 +76,14 @@ export async function POST(req: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
const { userId, cost, model, inputTokens, outputTokens } = validation.data
|
||||
const { userId, cost, model, inputTokens, outputTokens, source } = validation.data
|
||||
const isMcp = source === 'mcp_copilot'
|
||||
|
||||
logger.info(`[${requestId}] Processing cost update`, {
|
||||
userId,
|
||||
cost,
|
||||
model,
|
||||
source,
|
||||
})
|
||||
|
||||
// Check if user stats record exists (same as ExecutionLogger)
|
||||
@@ -96,7 +99,7 @@ export async function POST(req: NextRequest) {
|
||||
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
|
||||
}
|
||||
|
||||
const updateFields = {
|
||||
const updateFields: Record<string, unknown> = {
|
||||
totalCost: sql`total_cost + ${cost}`,
|
||||
currentPeriodCost: sql`current_period_cost + ${cost}`,
|
||||
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
|
||||
@@ -105,17 +108,24 @@ export async function POST(req: NextRequest) {
|
||||
lastActive: new Date(),
|
||||
}
|
||||
|
||||
// Also increment MCP-specific counters when source is mcp_copilot
|
||||
if (isMcp) {
|
||||
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
|
||||
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
|
||||
}
|
||||
|
||||
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
|
||||
|
||||
logger.info(`[${requestId}] Updated user stats record`, {
|
||||
userId,
|
||||
addedCost: cost,
|
||||
source,
|
||||
})
|
||||
|
||||
// Log usage for complete audit trail
|
||||
await logModelUsage({
|
||||
userId,
|
||||
source: 'copilot',
|
||||
source: isMcp ? 'mcp_copilot' : 'copilot',
|
||||
model,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
|
||||
@@ -10,9 +10,13 @@ import {
|
||||
type ListToolsResult,
|
||||
type RequestId,
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
import { db } from '@sim/db'
|
||||
import { userStats } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { checkHybridAuth } from '@/lib/auth/hybrid'
|
||||
import { authenticateApiKeyFromHeader, updateApiKeyLastUsed } from '@/lib/api-key/service'
|
||||
import { checkServerSideUsageLimits } from '@/lib/billing/calculations/usage-monitor'
|
||||
import { getCopilotModel } from '@/lib/copilot/config'
|
||||
import { SIM_AGENT_VERSION } from '@/lib/copilot/constants'
|
||||
import { orchestrateCopilotStream } from '@/lib/copilot/orchestrator'
|
||||
@@ -97,11 +101,28 @@ export async function GET() {
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const auth = await checkHybridAuth(request, { requireWorkflowId: false })
|
||||
if (!auth.success || !auth.userId) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
// API-key-only auth — MCP clients must provide x-api-key header
|
||||
const apiKeyHeader = request.headers.get('x-api-key')
|
||||
if (!apiKeyHeader) {
|
||||
return NextResponse.json(
|
||||
createError(0, -32000, 'API key required. Set the x-api-key header with a valid Sim API key.'),
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
const authResult = await authenticateApiKeyFromHeader(apiKeyHeader)
|
||||
if (!authResult.success || !authResult.userId) {
|
||||
return NextResponse.json(
|
||||
createError(0, -32000, authResult.error || 'Invalid API key'),
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
// Fire-and-forget last-used update
|
||||
updateApiKeyLastUsed(authResult.keyId!)
|
||||
|
||||
const userId = authResult.userId
|
||||
|
||||
const body = (await request.json()) as JSONRPCMessage
|
||||
|
||||
if (isJSONRPCNotification(body)) {
|
||||
@@ -117,6 +138,17 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
const { id, method, params } = body
|
||||
|
||||
// Pre-flight usage limit check for tool calls
|
||||
if (method === 'tools/call') {
|
||||
const usageCheck = await checkServerSideUsageLimits(userId)
|
||||
if (usageCheck.isExceeded) {
|
||||
return NextResponse.json(
|
||||
createError(id, -32000, `Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}`),
|
||||
{ status: 402 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch (method) {
|
||||
case 'initialize': {
|
||||
const result: InitializeResult = {
|
||||
@@ -131,12 +163,16 @@ export async function POST(request: NextRequest) {
|
||||
return NextResponse.json(createResponse(id, {}))
|
||||
case 'tools/list':
|
||||
return handleToolsList(id)
|
||||
case 'tools/call':
|
||||
return handleToolsCall(
|
||||
case 'tools/call': {
|
||||
const response = await handleToolsCall(
|
||||
id,
|
||||
params as { name: string; arguments?: Record<string, unknown> },
|
||||
auth.userId
|
||||
userId
|
||||
)
|
||||
// Track MCP copilot call (fire-and-forget)
|
||||
trackMcpCopilotCall(userId)
|
||||
return response
|
||||
}
|
||||
default:
|
||||
return NextResponse.json(
|
||||
createError(id, ErrorCode.MethodNotFound, `Method not found: ${method}`),
|
||||
@@ -151,6 +187,22 @@ export async function POST(request: NextRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment MCP copilot call counter in userStats (fire-and-forget).
|
||||
*/
|
||||
function trackMcpCopilotCall(userId: string): void {
|
||||
db.update(userStats)
|
||||
.set({
|
||||
totalMcpCopilotCalls: sql`total_mcp_copilot_calls + 1`,
|
||||
lastActive: new Date(),
|
||||
})
|
||||
.where(eq(userStats.userId, userId))
|
||||
.then(() => {})
|
||||
.catch((error) => {
|
||||
logger.error('Failed to track MCP copilot call', { error, userId })
|
||||
})
|
||||
}
|
||||
|
||||
async function handleToolsList(id: RequestId): Promise<NextResponse> {
|
||||
const directTools = DIRECT_TOOL_DEFS.map((tool) => ({
|
||||
name: tool.name,
|
||||
@@ -351,6 +403,7 @@ async function handleSubagentToolCall(
|
||||
context,
|
||||
model,
|
||||
headless: true,
|
||||
source: 'mcp_copilot',
|
||||
},
|
||||
{
|
||||
userId,
|
||||
|
||||
@@ -14,7 +14,7 @@ export type UsageLogCategory = 'model' | 'fixed'
|
||||
/**
|
||||
* Usage log source types
|
||||
*/
|
||||
export type UsageLogSource = 'workflow' | 'wand' | 'copilot'
|
||||
export type UsageLogSource = 'workflow' | 'wand' | 'copilot' | 'mcp_copilot'
|
||||
|
||||
/**
|
||||
* Metadata for 'model' category charges
|
||||
|
||||
4
packages/db/migrations/0153_complete_arclight.sql
Normal file
4
packages/db/migrations/0153_complete_arclight.sql
Normal file
@@ -0,0 +1,4 @@
|
||||
ALTER TYPE "public"."usage_log_source" ADD VALUE 'mcp_copilot';--> statement-breakpoint
|
||||
ALTER TABLE "user_stats" ADD COLUMN "total_mcp_copilot_calls" integer DEFAULT 0 NOT NULL;--> statement-breakpoint
|
||||
ALTER TABLE "user_stats" ADD COLUMN "total_mcp_copilot_cost" numeric DEFAULT '0' NOT NULL;--> statement-breakpoint
|
||||
ALTER TABLE "user_stats" ADD COLUMN "current_period_mcp_copilot_cost" numeric DEFAULT '0' NOT NULL;
|
||||
11139
packages/db/migrations/meta/0153_snapshot.json
Normal file
11139
packages/db/migrations/meta/0153_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1065,6 +1065,13 @@
|
||||
"when": 1770336289511,
|
||||
"tag": "0152_parallel_frog_thor",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 153,
|
||||
"version": "7",
|
||||
"when": 1770410282842,
|
||||
"tag": "0153_complete_arclight",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -715,6 +715,10 @@ export const userStats = pgTable('user_stats', {
|
||||
lastPeriodCopilotCost: decimal('last_period_copilot_cost').default('0'),
|
||||
totalCopilotTokens: integer('total_copilot_tokens').notNull().default(0),
|
||||
totalCopilotCalls: integer('total_copilot_calls').notNull().default(0),
|
||||
// MCP Copilot usage tracking
|
||||
totalMcpCopilotCalls: integer('total_mcp_copilot_calls').notNull().default(0),
|
||||
totalMcpCopilotCost: decimal('total_mcp_copilot_cost').notNull().default('0'),
|
||||
currentPeriodMcpCopilotCost: decimal('current_period_mcp_copilot_cost').notNull().default('0'),
|
||||
// Storage tracking (for free/pro users)
|
||||
storageUsedBytes: bigint('storage_used_bytes', { mode: 'number' }).notNull().default(0),
|
||||
lastActive: timestamp('last_active').notNull().defaultNow(),
|
||||
@@ -1968,7 +1972,12 @@ export const a2aPushNotificationConfig = pgTable(
|
||||
)
|
||||
|
||||
export const usageLogCategoryEnum = pgEnum('usage_log_category', ['model', 'fixed'])
|
||||
export const usageLogSourceEnum = pgEnum('usage_log_source', ['workflow', 'wand', 'copilot'])
|
||||
export const usageLogSourceEnum = pgEnum('usage_log_source', [
|
||||
'workflow',
|
||||
'wand',
|
||||
'copilot',
|
||||
'mcp_copilot',
|
||||
])
|
||||
|
||||
export const usageLog = pgTable(
|
||||
'usage_log',
|
||||
|
||||
Reference in New Issue
Block a user