This commit is contained in:
Siddharth Ganesan
2026-02-06 14:25:51 -08:00
parent f63ed61bc8
commit d1a2d661c9

View File

@@ -13,6 +13,7 @@ import {
import { db } from '@sim/db'
import { userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { randomUUID } from 'node:crypto'
import { eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { authenticateApiKeyFromHeader, updateApiKeyLastUsed } from '@/lib/api-key/service'
@@ -107,17 +108,81 @@ function readHeader(headers: HeaderMap | undefined, name: string): string | unde
class NextResponseCapture {
private _status = 200
private _headers = new Headers()
private _chunks: Buffer[] = []
private _controller: ReadableStreamDefaultController<Uint8Array> | null = null
private _pendingChunks: Uint8Array[] = []
private _closeHandlers: Array<() => void> = []
private _errorHandlers: Array<(error: Error) => void> = []
private _headersWritten = false
private _ended = false
private _headersPromise: Promise<void>
private _resolveHeaders: (() => void) | null = null
private _endedPromise: Promise<void>
private _resolveEnded: (() => void) | null = null
readonly readable: ReadableStream<Uint8Array>
constructor() {
this._headersPromise = new Promise<void>((resolve) => {
this._resolveHeaders = resolve
})
this._endedPromise = new Promise<void>((resolve) => {
this._resolveEnded = resolve
})
this.readable = new ReadableStream<Uint8Array>({
start: (controller) => {
this._controller = controller
if (this._pendingChunks.length > 0) {
for (const chunk of this._pendingChunks) {
controller.enqueue(chunk)
}
this._pendingChunks = []
}
},
cancel: () => {
this._ended = true
this._resolveEnded?.()
this.triggerCloseHandlers()
},
})
}
private markHeadersWritten(): void {
if (this._headersWritten) return
this._headersWritten = true
this._resolveHeaders?.()
}
private triggerCloseHandlers(): void {
for (const handler of this._closeHandlers) {
try {
handler()
} catch (error) {
this.triggerErrorHandlers(error instanceof Error ? error : new Error(String(error)))
}
}
}
private triggerErrorHandlers(error: Error): void {
for (const errorHandler of this._errorHandlers) {
errorHandler(error)
}
}
private normalizeChunk(chunk: unknown): Uint8Array | null {
if (typeof chunk === 'string') {
return new TextEncoder().encode(chunk)
}
if (chunk instanceof Uint8Array) {
return chunk
}
if (chunk === undefined || chunk === null) {
return null
}
return new TextEncoder().encode(String(chunk))
}
writeHead(status: number, headers?: Record<string, string | number | string[]>): this {
@@ -133,52 +198,66 @@ class NextResponseCapture {
})
}
this.markHeadersWritten()
return this
}
flushHeaders(): this {
this.markHeadersWritten()
return this
}
write(chunk: unknown): boolean {
if (typeof chunk === 'string') {
this._chunks.push(Buffer.from(chunk))
return true
}
const normalized = this.normalizeChunk(chunk)
if (!normalized) return true
if (chunk instanceof Uint8Array) {
this._chunks.push(Buffer.from(chunk))
return true
}
this.markHeadersWritten()
if (chunk !== undefined && chunk !== null) {
this._chunks.push(Buffer.from(String(chunk)))
if (this._controller) {
try {
this._controller.enqueue(normalized)
} catch (error) {
this.triggerErrorHandlers(error instanceof Error ? error : new Error(String(error)))
}
} else {
this._pendingChunks.push(normalized)
}
return true
}
end(chunk?: unknown): this {
if (chunk !== undefined) {
this.write(chunk)
}
if (chunk !== undefined) this.write(chunk)
this.markHeadersWritten()
if (this._ended) return this
this._ended = true
this._resolveEnded?.()
this._closeHandlers.forEach((handler) => {
if (this._controller) {
try {
handler()
this._controller.close()
} catch (error) {
this._errorHandlers.forEach((errorHandler) => {
errorHandler(error instanceof Error ? error : new Error(String(error)))
})
this.triggerErrorHandlers(error instanceof Error ? error : new Error(String(error)))
}
})
}
this.triggerCloseHandlers()
return this
}
async waitForHeaders(timeoutMs = 30000): Promise<void> {
if (this._headersWritten) return
await Promise.race([
this._headersPromise,
new Promise<void>((resolve) => {
setTimeout(resolve, timeoutMs)
}),
])
}
async waitForEnd(timeoutMs = 30000): Promise<void> {
if (this._ended) return
@@ -203,15 +282,7 @@ class NextResponseCapture {
}
toNextResponse(): NextResponse {
if (this._chunks.length === 0) {
return new NextResponse(null, {
status: this._status,
headers: this._headers,
})
}
const body = Buffer.concat(this._chunks)
return new NextResponse(body, {
return new NextResponse(this.readable, {
status: this._status,
headers: this._headers,
})
@@ -320,7 +391,6 @@ async function handleMcpRequestWithSdk(
})
const responseCapture = new NextResponseCapture()
const requestAdapter = {
method: request.method,
headers: normalizeRequestHeaders(request),
@@ -330,6 +400,7 @@ async function handleMcpRequestWithSdk(
try {
await transport.handleRequest(requestAdapter as any, responseCapture as any, parsedBody)
await responseCapture.waitForHeaders()
await responseCapture.waitForEnd()
return responseCapture.toNextResponse()
} finally {
@@ -368,6 +439,11 @@ export async function POST(request: NextRequest) {
}
}
export async function DELETE(request: NextRequest) {
void request
return NextResponse.json(createError(0, -32000, 'Method not allowed.'), { status: 405 })
}
/**
* Increment MCP copilot call counter in userStats (fire-and-forget).
*/
@@ -412,7 +488,7 @@ async function handleDirectToolCall(
const execContext = await prepareExecutionContext(userId, (args.workflowId as string) || '')
const toolCall = {
id: crypto.randomUUID(),
id: randomUUID(),
name: toolDef.toolId,
status: 'pending' as const,
params: args as Record<string, any>,
@@ -480,7 +556,7 @@ async function handleBuildToolCall(
}
}
const chatId = crypto.randomUUID()
const chatId = randomUUID()
const requestPayload = {
message: requestText,
@@ -489,7 +565,7 @@ async function handleBuildToolCall(
model,
mode: 'agent',
commands: ['fast'],
messageId: crypto.randomUUID(),
messageId: randomUUID(),
version: SIM_AGENT_VERSION,
headless: true,
chatId,