mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-07 13:15:06 -05:00
Add copilot mcp tracking
This commit is contained in:
@@ -18,6 +18,7 @@ const UpdateCostSchema = z.object({
|
||||
model: z.string().min(1, 'Model is required'),
|
||||
inputTokens: z.number().min(0).default(0),
|
||||
outputTokens: z.number().min(0).default(0),
|
||||
source: z.enum(['copilot', 'mcp_copilot']).default('copilot'),
|
||||
})
|
||||
|
||||
/**
|
||||
@@ -75,12 +76,14 @@ export async function POST(req: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
const { userId, cost, model, inputTokens, outputTokens } = validation.data
|
||||
const { userId, cost, model, inputTokens, outputTokens, source } = validation.data
|
||||
const isMcp = source === 'mcp_copilot'
|
||||
|
||||
logger.info(`[${requestId}] Processing cost update`, {
|
||||
userId,
|
||||
cost,
|
||||
model,
|
||||
source,
|
||||
})
|
||||
|
||||
// Check if user stats record exists (same as ExecutionLogger)
|
||||
@@ -96,7 +99,7 @@ export async function POST(req: NextRequest) {
|
||||
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
|
||||
}
|
||||
|
||||
const updateFields = {
|
||||
const updateFields: Record<string, unknown> = {
|
||||
totalCost: sql`total_cost + ${cost}`,
|
||||
currentPeriodCost: sql`current_period_cost + ${cost}`,
|
||||
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
|
||||
@@ -105,17 +108,24 @@ export async function POST(req: NextRequest) {
|
||||
lastActive: new Date(),
|
||||
}
|
||||
|
||||
// Also increment MCP-specific counters when source is mcp_copilot
|
||||
if (isMcp) {
|
||||
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
|
||||
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
|
||||
}
|
||||
|
||||
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
|
||||
|
||||
logger.info(`[${requestId}] Updated user stats record`, {
|
||||
userId,
|
||||
addedCost: cost,
|
||||
source,
|
||||
})
|
||||
|
||||
// Log usage for complete audit trail
|
||||
await logModelUsage({
|
||||
userId,
|
||||
source: 'copilot',
|
||||
source: isMcp ? 'mcp_copilot' : 'copilot',
|
||||
model,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
|
||||
@@ -10,9 +10,13 @@ import {
|
||||
type ListToolsResult,
|
||||
type RequestId,
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
import { db } from '@sim/db'
|
||||
import { userStats } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { checkHybridAuth } from '@/lib/auth/hybrid'
|
||||
import { authenticateApiKeyFromHeader, updateApiKeyLastUsed } from '@/lib/api-key/service'
|
||||
import { checkServerSideUsageLimits } from '@/lib/billing/calculations/usage-monitor'
|
||||
import { getCopilotModel } from '@/lib/copilot/config'
|
||||
import { SIM_AGENT_VERSION } from '@/lib/copilot/constants'
|
||||
import { orchestrateCopilotStream } from '@/lib/copilot/orchestrator'
|
||||
@@ -97,11 +101,28 @@ export async function GET() {
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const auth = await checkHybridAuth(request, { requireWorkflowId: false })
|
||||
if (!auth.success || !auth.userId) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
// API-key-only auth — MCP clients must provide x-api-key header
|
||||
const apiKeyHeader = request.headers.get('x-api-key')
|
||||
if (!apiKeyHeader) {
|
||||
return NextResponse.json(
|
||||
createError(0, -32000, 'API key required. Set the x-api-key header with a valid Sim API key.'),
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
const authResult = await authenticateApiKeyFromHeader(apiKeyHeader)
|
||||
if (!authResult.success || !authResult.userId) {
|
||||
return NextResponse.json(
|
||||
createError(0, -32000, authResult.error || 'Invalid API key'),
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
// Fire-and-forget last-used update
|
||||
updateApiKeyLastUsed(authResult.keyId!)
|
||||
|
||||
const userId = authResult.userId
|
||||
|
||||
const body = (await request.json()) as JSONRPCMessage
|
||||
|
||||
if (isJSONRPCNotification(body)) {
|
||||
@@ -117,6 +138,17 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
const { id, method, params } = body
|
||||
|
||||
// Pre-flight usage limit check for tool calls
|
||||
if (method === 'tools/call') {
|
||||
const usageCheck = await checkServerSideUsageLimits(userId)
|
||||
if (usageCheck.isExceeded) {
|
||||
return NextResponse.json(
|
||||
createError(id, -32000, `Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}`),
|
||||
{ status: 402 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch (method) {
|
||||
case 'initialize': {
|
||||
const result: InitializeResult = {
|
||||
@@ -131,12 +163,16 @@ export async function POST(request: NextRequest) {
|
||||
return NextResponse.json(createResponse(id, {}))
|
||||
case 'tools/list':
|
||||
return handleToolsList(id)
|
||||
case 'tools/call':
|
||||
return handleToolsCall(
|
||||
case 'tools/call': {
|
||||
const response = await handleToolsCall(
|
||||
id,
|
||||
params as { name: string; arguments?: Record<string, unknown> },
|
||||
auth.userId
|
||||
userId
|
||||
)
|
||||
// Track MCP copilot call (fire-and-forget)
|
||||
trackMcpCopilotCall(userId)
|
||||
return response
|
||||
}
|
||||
default:
|
||||
return NextResponse.json(
|
||||
createError(id, ErrorCode.MethodNotFound, `Method not found: ${method}`),
|
||||
@@ -151,6 +187,22 @@ export async function POST(request: NextRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment MCP copilot call counter in userStats (fire-and-forget).
|
||||
*/
|
||||
function trackMcpCopilotCall(userId: string): void {
|
||||
db.update(userStats)
|
||||
.set({
|
||||
totalMcpCopilotCalls: sql`total_mcp_copilot_calls + 1`,
|
||||
lastActive: new Date(),
|
||||
})
|
||||
.where(eq(userStats.userId, userId))
|
||||
.then(() => {})
|
||||
.catch((error) => {
|
||||
logger.error('Failed to track MCP copilot call', { error, userId })
|
||||
})
|
||||
}
|
||||
|
||||
async function handleToolsList(id: RequestId): Promise<NextResponse> {
|
||||
const directTools = DIRECT_TOOL_DEFS.map((tool) => ({
|
||||
name: tool.name,
|
||||
@@ -351,6 +403,7 @@ async function handleSubagentToolCall(
|
||||
context,
|
||||
model,
|
||||
headless: true,
|
||||
source: 'mcp_copilot',
|
||||
},
|
||||
{
|
||||
userId,
|
||||
|
||||
@@ -14,7 +14,7 @@ export type UsageLogCategory = 'model' | 'fixed'
|
||||
/**
|
||||
* Usage log source types
|
||||
*/
|
||||
export type UsageLogSource = 'workflow' | 'wand' | 'copilot'
|
||||
export type UsageLogSource = 'workflow' | 'wand' | 'copilot' | 'mcp_copilot'
|
||||
|
||||
/**
|
||||
* Metadata for 'model' category charges
|
||||
|
||||
Reference in New Issue
Block a user