mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-21 04:48:00 -05:00
Compare commits
7 Commits
v0.5.64
...
improvemen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4007c7e7e | ||
|
|
71c92788c5 | ||
|
|
4afb245fa2 | ||
|
|
8344d68ca8 | ||
|
|
a26a1a9737 | ||
|
|
689037a300 | ||
|
|
07f0c01dc4 |
364
apps/sim/app/api/copilot/headless/route.ts
Normal file
364
apps/sim/app/api/copilot/headless/route.ts
Normal file
@@ -0,0 +1,364 @@
|
||||
import { db } from '@sim/db'
|
||||
import { copilotChats, workflow as workflowTable } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { authenticateApiKeyFromHeader, updateApiKeyLastUsed } from '@/lib/api-key/service'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getCopilotModel } from '@/lib/copilot/config'
|
||||
import { SIM_AGENT_API_URL_DEFAULT, SIM_AGENT_VERSION } from '@/lib/copilot/constants'
|
||||
import { COPILOT_MODEL_IDS } from '@/lib/copilot/models'
|
||||
import {
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/request-helpers'
|
||||
import {
|
||||
createStream,
|
||||
completeStream,
|
||||
errorStream,
|
||||
updateStreamStatus,
|
||||
} from '@/lib/copilot/stream-persistence'
|
||||
import { executeToolServerSide, isServerExecutableTool } from '@/lib/copilot/tools/server/executor'
|
||||
import { getCredentialsServerTool } from '@/lib/copilot/tools/server/user/get-credentials'
|
||||
import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/persistence/utils'
|
||||
import { sanitizeForCopilot } from '@/lib/workflows/sanitization/json-sanitizer'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { tools } from '@/tools/registry'
|
||||
import { getLatestVersionTools, stripVersionSuffix } from '@/tools/utils'
|
||||
|
||||
const logger = createLogger('HeadlessCopilotAPI')
|
||||
|
||||
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
|
||||
|
||||
const HeadlessRequestSchema = z.object({
|
||||
message: z.string().min(1, 'Message is required'),
|
||||
workflowId: z.string().min(1, 'Workflow ID is required'),
|
||||
chatId: z.string().optional(),
|
||||
model: z.enum(COPILOT_MODEL_IDS).optional(),
|
||||
mode: z.enum(['agent', 'build', 'chat']).optional().default('agent'),
|
||||
timeout: z.number().optional().default(300000), // 5 minute default
|
||||
persistChanges: z.boolean().optional().default(true),
|
||||
createNewChat: z.boolean().optional().default(false),
|
||||
})
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
export const fetchCache = 'force-no-store'
|
||||
export const runtime = 'nodejs'
|
||||
|
||||
/**
|
||||
* POST /api/copilot/headless
|
||||
*
|
||||
* Execute copilot completely server-side without any client connection.
|
||||
* All tool calls are executed server-side and results are persisted directly.
|
||||
*
|
||||
* Returns the final result after all processing is complete.
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
const tracker = createRequestTracker()
|
||||
const startTime = Date.now()
|
||||
|
||||
try {
|
||||
// Authenticate via session or API key
|
||||
let userId: string | null = null
|
||||
|
||||
const session = await getSession()
|
||||
if (session?.user?.id) {
|
||||
userId = session.user.id
|
||||
} else {
|
||||
// Try API key authentication from header
|
||||
const apiKey = req.headers.get('x-api-key')
|
||||
if (apiKey) {
|
||||
const authResult = await authenticateApiKeyFromHeader(apiKey)
|
||||
if (authResult.success && authResult.userId) {
|
||||
userId = authResult.userId
|
||||
// Update last used timestamp in background
|
||||
if (authResult.keyId) {
|
||||
updateApiKeyLastUsed(authResult.keyId).catch(() => {})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
const { message, workflowId, chatId, model, mode, timeout, persistChanges, createNewChat } =
|
||||
HeadlessRequestSchema.parse(body)
|
||||
|
||||
logger.info(`[${tracker.requestId}] Headless copilot request`, {
|
||||
userId,
|
||||
workflowId,
|
||||
messageLength: message.length,
|
||||
mode,
|
||||
})
|
||||
|
||||
// Verify user has access to workflow
|
||||
const [wf] = await db
|
||||
.select({ userId: workflowTable.userId, workspaceId: workflowTable.workspaceId })
|
||||
.from(workflowTable)
|
||||
.where(eq(workflowTable.id, workflowId))
|
||||
.limit(1)
|
||||
|
||||
if (!wf) {
|
||||
return NextResponse.json({ error: 'Workflow not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// TODO: Add proper workspace access check
|
||||
if (wf.userId !== userId) {
|
||||
return NextResponse.json({ error: 'Access denied' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Load current workflow state from database
|
||||
const workflowData = await loadWorkflowFromNormalizedTables(workflowId)
|
||||
if (!workflowData) {
|
||||
return NextResponse.json({ error: 'Workflow data not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
const sanitizedWorkflow = sanitizeForCopilot({
|
||||
blocks: workflowData.blocks,
|
||||
edges: workflowData.edges,
|
||||
loops: workflowData.loops,
|
||||
parallels: workflowData.parallels,
|
||||
})
|
||||
|
||||
// Create a stream for tracking (even in headless mode)
|
||||
const streamId = crypto.randomUUID()
|
||||
const userMessageId = crypto.randomUUID()
|
||||
const assistantMessageId = crypto.randomUUID()
|
||||
|
||||
await createStream({
|
||||
streamId,
|
||||
chatId: chatId || '',
|
||||
userId,
|
||||
workflowId,
|
||||
userMessageId,
|
||||
isClientSession: false, // Key: this is headless
|
||||
})
|
||||
|
||||
await updateStreamStatus(streamId, 'streaming')
|
||||
|
||||
// Handle chat persistence
|
||||
let actualChatId = chatId
|
||||
if (createNewChat && !chatId) {
|
||||
const { provider, model: defaultModel } = getCopilotModel('chat')
|
||||
const [newChat] = await db
|
||||
.insert(copilotChats)
|
||||
.values({
|
||||
userId,
|
||||
workflowId,
|
||||
title: null,
|
||||
model: model || defaultModel,
|
||||
messages: [],
|
||||
})
|
||||
.returning()
|
||||
|
||||
if (newChat) {
|
||||
actualChatId = newChat.id
|
||||
}
|
||||
}
|
||||
|
||||
// Get credentials for tools
|
||||
let credentials: {
|
||||
oauth: Record<string, { accessToken: string; accountId: string; name: string }>
|
||||
apiKeys: string[]
|
||||
} | null = null
|
||||
|
||||
try {
|
||||
const rawCredentials = await getCredentialsServerTool.execute({ workflowId }, { userId })
|
||||
const oauthMap: Record<string, { accessToken: string; accountId: string; name: string }> = {}
|
||||
|
||||
for (const cred of rawCredentials?.oauth?.connected?.credentials || []) {
|
||||
if (cred.accessToken) {
|
||||
oauthMap[cred.provider] = {
|
||||
accessToken: cred.accessToken,
|
||||
accountId: cred.id,
|
||||
name: cred.name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
credentials = {
|
||||
oauth: oauthMap,
|
||||
apiKeys: rawCredentials?.environment?.variableNames || [],
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`[${tracker.requestId}] Failed to fetch credentials`, { error })
|
||||
}
|
||||
|
||||
// Build tool definitions
|
||||
const { createUserToolSchema } = await import('@/tools/params')
|
||||
const latestTools = getLatestVersionTools(tools)
|
||||
const integrationTools = Object.entries(latestTools).map(([toolId, toolConfig]) => {
|
||||
const userSchema = createUserToolSchema(toolConfig)
|
||||
const strippedName = stripVersionSuffix(toolId)
|
||||
return {
|
||||
name: strippedName,
|
||||
description: toolConfig.description || toolConfig.name || strippedName,
|
||||
input_schema: userSchema,
|
||||
defer_loading: true,
|
||||
}
|
||||
})
|
||||
|
||||
// Build request payload
|
||||
const defaults = getCopilotModel('chat')
|
||||
const selectedModel = model || defaults.model
|
||||
const effectiveMode = mode === 'agent' ? 'build' : mode
|
||||
|
||||
const requestPayload = {
|
||||
message,
|
||||
workflowId,
|
||||
userId,
|
||||
stream: false, // Non-streaming for headless
|
||||
model: selectedModel,
|
||||
mode: effectiveMode,
|
||||
version: SIM_AGENT_VERSION,
|
||||
messageId: userMessageId,
|
||||
...(actualChatId && { chatId: actualChatId }),
|
||||
...(integrationTools.length > 0 && { tools: integrationTools }),
|
||||
...(credentials && { credentials }),
|
||||
}
|
||||
|
||||
// Call sim agent (non-streaming)
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), timeout)
|
||||
|
||||
try {
|
||||
const response = await fetch(`${SIM_AGENT_API_URL}/api/chat-completion`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
|
||||
},
|
||||
body: JSON.stringify(requestPayload),
|
||||
signal: controller.signal,
|
||||
})
|
||||
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
logger.error(`[${tracker.requestId}] Sim agent error`, {
|
||||
status: response.status,
|
||||
error: errorText,
|
||||
})
|
||||
await errorStream(streamId, `Agent error: ${response.statusText}`)
|
||||
return NextResponse.json(
|
||||
{ error: `Agent error: ${response.statusText}` },
|
||||
{ status: response.status }
|
||||
)
|
||||
}
|
||||
|
||||
const result = await response.json()
|
||||
|
||||
// Execute tool calls server-side
|
||||
const toolResults: Record<string, { success: boolean; result?: unknown; error?: string }> = {}
|
||||
|
||||
if (result.toolCalls && Array.isArray(result.toolCalls)) {
|
||||
for (const toolCall of result.toolCalls) {
|
||||
const toolName = toolCall.name
|
||||
const toolArgs = toolCall.arguments || toolCall.input || {}
|
||||
|
||||
logger.info(`[${tracker.requestId}] Executing tool server-side`, {
|
||||
toolName,
|
||||
toolCallId: toolCall.id,
|
||||
})
|
||||
|
||||
if (!isServerExecutableTool(toolName)) {
|
||||
logger.warn(`[${tracker.requestId}] Tool not executable server-side`, { toolName })
|
||||
toolResults[toolCall.id] = {
|
||||
success: false,
|
||||
error: `Tool ${toolName} requires client-side execution`,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
const toolResult = await executeToolServerSide(
|
||||
{ name: toolName, args: toolArgs },
|
||||
{ workflowId, userId, persistChanges }
|
||||
)
|
||||
|
||||
toolResults[toolCall.id] = toolResult
|
||||
}
|
||||
}
|
||||
|
||||
// Mark stream complete
|
||||
await completeStream(streamId, { content: result.content, toolResults })
|
||||
|
||||
// Save to chat history
|
||||
if (actualChatId && persistChanges) {
|
||||
const [chat] = await db
|
||||
.select()
|
||||
.from(copilotChats)
|
||||
.where(eq(copilotChats.id, actualChatId))
|
||||
.limit(1)
|
||||
|
||||
const existingMessages = chat ? (Array.isArray(chat.messages) ? chat.messages : []) : []
|
||||
|
||||
const newMessages = [
|
||||
...existingMessages,
|
||||
{
|
||||
id: userMessageId,
|
||||
role: 'user',
|
||||
content: message,
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
{
|
||||
id: assistantMessageId,
|
||||
role: 'assistant',
|
||||
content: result.content,
|
||||
timestamp: new Date().toISOString(),
|
||||
toolCalls: Object.entries(toolResults).map(([id, r]) => ({
|
||||
id,
|
||||
success: r.success,
|
||||
})),
|
||||
},
|
||||
]
|
||||
|
||||
await db
|
||||
.update(copilotChats)
|
||||
.set({ messages: newMessages, updatedAt: new Date() })
|
||||
.where(eq(copilotChats.id, actualChatId))
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`[${tracker.requestId}] Headless copilot complete`, {
|
||||
duration,
|
||||
contentLength: result.content?.length || 0,
|
||||
toolCallsExecuted: Object.keys(toolResults).length,
|
||||
})
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
streamId,
|
||||
chatId: actualChatId,
|
||||
content: result.content,
|
||||
toolResults,
|
||||
duration,
|
||||
})
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
await errorStream(streamId, 'Request timed out')
|
||||
return NextResponse.json({ error: 'Request timed out' }, { status: 504 })
|
||||
}
|
||||
|
||||
throw error
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${tracker.requestId}] Headless copilot error`, { error })
|
||||
|
||||
if (error instanceof z.ZodError) {
|
||||
return NextResponse.json({ error: 'Invalid request', details: error.errors }, { status: 400 })
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ error: error instanceof Error ? error.message : 'Internal error' },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
237
apps/sim/app/api/copilot/stream/[streamId]/route.ts
Normal file
237
apps/sim/app/api/copilot/stream/[streamId]/route.ts
Normal file
@@ -0,0 +1,237 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
getStreamMetadata,
|
||||
getStreamEvents,
|
||||
getStreamEventCount,
|
||||
getToolCallStates,
|
||||
refreshStreamTTL,
|
||||
checkAbortSignal,
|
||||
abortStream,
|
||||
} from '@/lib/copilot/stream-persistence'
|
||||
|
||||
const logger = createLogger('StreamResumeAPI')
|
||||
|
||||
interface RouteParams {
|
||||
streamId: string
|
||||
}
|
||||
|
||||
/**
|
||||
* GET /api/copilot/stream/{streamId}
|
||||
* Subscribe to or resume a stream
|
||||
*
|
||||
* Query params:
|
||||
* - offset: Start from this event index (for resumption)
|
||||
* - mode: 'sse' (default) or 'poll'
|
||||
*/
|
||||
export async function GET(req: NextRequest, { params }: { params: Promise<RouteParams> }) {
|
||||
const { streamId } = await params
|
||||
const session = await getSession()
|
||||
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const metadata = await getStreamMetadata(streamId)
|
||||
if (!metadata) {
|
||||
return NextResponse.json({ error: 'Stream not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Verify user owns this stream
|
||||
if (metadata.userId !== session.user.id) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
const offset = parseInt(req.nextUrl.searchParams.get('offset') || '0', 10)
|
||||
const mode = req.nextUrl.searchParams.get('mode') || 'sse'
|
||||
|
||||
// Refresh TTL since someone is actively consuming
|
||||
await refreshStreamTTL(streamId)
|
||||
|
||||
// Poll mode: return current state as JSON
|
||||
if (mode === 'poll') {
|
||||
const events = await getStreamEvents(streamId, offset)
|
||||
const toolCalls = await getToolCallStates(streamId)
|
||||
const eventCount = await getStreamEventCount(streamId)
|
||||
|
||||
return NextResponse.json({
|
||||
metadata,
|
||||
events,
|
||||
toolCalls,
|
||||
totalEvents: eventCount,
|
||||
nextOffset: offset + events.length,
|
||||
})
|
||||
}
|
||||
|
||||
// SSE mode: stream events
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
const readable = new ReadableStream({
|
||||
async start(controller) {
|
||||
let closed = false
|
||||
|
||||
const safeEnqueue = (data: string) => {
|
||||
if (closed) return
|
||||
try {
|
||||
controller.enqueue(encoder.encode(data))
|
||||
} catch {
|
||||
closed = true
|
||||
}
|
||||
}
|
||||
|
||||
const safeClose = () => {
|
||||
if (closed) return
|
||||
closed = true
|
||||
try {
|
||||
controller.close()
|
||||
} catch {
|
||||
// Already closed
|
||||
}
|
||||
}
|
||||
|
||||
// Send initial connection event
|
||||
safeEnqueue(`: connected\n\n`)
|
||||
|
||||
// Send metadata
|
||||
safeEnqueue(`event: metadata\ndata: ${JSON.stringify(metadata)}\n\n`)
|
||||
|
||||
// Send tool call states
|
||||
const toolCalls = await getToolCallStates(streamId)
|
||||
if (Object.keys(toolCalls).length > 0) {
|
||||
safeEnqueue(`event: tool_states\ndata: ${JSON.stringify(toolCalls)}\n\n`)
|
||||
}
|
||||
|
||||
// Replay missed events
|
||||
const missedEvents = await getStreamEvents(streamId, offset)
|
||||
for (const event of missedEvents) {
|
||||
safeEnqueue(event)
|
||||
}
|
||||
|
||||
// If stream is complete, send done and close
|
||||
if (metadata.status === 'complete' || metadata.status === 'error' || metadata.status === 'aborted') {
|
||||
safeEnqueue(
|
||||
`event: stream_status\ndata: ${JSON.stringify({
|
||||
status: metadata.status,
|
||||
error: metadata.error,
|
||||
})}\n\n`
|
||||
)
|
||||
safeClose()
|
||||
return
|
||||
}
|
||||
|
||||
// Stream is still active - poll for new events
|
||||
let lastOffset = offset + missedEvents.length
|
||||
const pollInterval = 100 // 100ms
|
||||
const maxPollTime = 5 * 60 * 1000 // 5 minutes max
|
||||
const startTime = Date.now()
|
||||
|
||||
const poll = async () => {
|
||||
if (closed) return
|
||||
|
||||
try {
|
||||
// Check for timeout
|
||||
if (Date.now() - startTime > maxPollTime) {
|
||||
logger.info('Stream poll timeout', { streamId })
|
||||
safeEnqueue(
|
||||
`event: stream_status\ndata: ${JSON.stringify({ status: 'timeout' })}\n\n`
|
||||
)
|
||||
safeClose()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if client disconnected
|
||||
if (await checkAbortSignal(streamId)) {
|
||||
safeEnqueue(
|
||||
`event: stream_status\ndata: ${JSON.stringify({ status: 'aborted' })}\n\n`
|
||||
)
|
||||
safeClose()
|
||||
return
|
||||
}
|
||||
|
||||
// Get current metadata to check status
|
||||
const currentMeta = await getStreamMetadata(streamId)
|
||||
if (!currentMeta) {
|
||||
safeClose()
|
||||
return
|
||||
}
|
||||
|
||||
// Get new events
|
||||
const newEvents = await getStreamEvents(streamId, lastOffset)
|
||||
for (const event of newEvents) {
|
||||
safeEnqueue(event)
|
||||
}
|
||||
lastOffset += newEvents.length
|
||||
|
||||
// Refresh TTL
|
||||
await refreshStreamTTL(streamId)
|
||||
|
||||
// If complete, send status and close
|
||||
if (
|
||||
currentMeta.status === 'complete' ||
|
||||
currentMeta.status === 'error' ||
|
||||
currentMeta.status === 'aborted'
|
||||
) {
|
||||
safeEnqueue(
|
||||
`event: stream_status\ndata: ${JSON.stringify({
|
||||
status: currentMeta.status,
|
||||
error: currentMeta.error,
|
||||
})}\n\n`
|
||||
)
|
||||
safeClose()
|
||||
return
|
||||
}
|
||||
|
||||
// Continue polling
|
||||
setTimeout(poll, pollInterval)
|
||||
} catch (error) {
|
||||
logger.error('Stream poll error', { streamId, error })
|
||||
safeClose()
|
||||
}
|
||||
}
|
||||
|
||||
// Start polling
|
||||
setTimeout(poll, pollInterval)
|
||||
},
|
||||
})
|
||||
|
||||
return new Response(readable, {
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream; charset=utf-8',
|
||||
'Cache-Control': 'no-cache, no-transform',
|
||||
Connection: 'keep-alive',
|
||||
'X-Accel-Buffering': 'no',
|
||||
'X-Stream-Id': streamId,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* DELETE /api/copilot/stream/{streamId}
|
||||
* Abort a stream
|
||||
*/
|
||||
export async function DELETE(req: NextRequest, { params }: { params: Promise<RouteParams> }) {
|
||||
const { streamId } = await params
|
||||
const session = await getSession()
|
||||
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const metadata = await getStreamMetadata(streamId)
|
||||
if (!metadata) {
|
||||
return NextResponse.json({ error: 'Stream not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Verify user owns this stream
|
||||
if (metadata.userId !== session.user.id) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
await abortStream(streamId)
|
||||
|
||||
logger.info('Stream aborted by user', { streamId, userId: session.user.id })
|
||||
|
||||
return NextResponse.json({ success: true, streamId })
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ const BlockDataSchema = z.object({
|
||||
doWhileCondition: z.string().optional(),
|
||||
parallelType: z.enum(['collection', 'count']).optional(),
|
||||
type: z.string().optional(),
|
||||
canonicalModes: z.record(z.enum(['basic', 'advanced'])).optional(),
|
||||
})
|
||||
|
||||
const SubBlockStateSchema = z.object({
|
||||
|
||||
@@ -1510,7 +1510,8 @@ export function ToolCall({
|
||||
toolCall.name === 'user_memory' ||
|
||||
toolCall.name === 'edit_respond' ||
|
||||
toolCall.name === 'debug_respond' ||
|
||||
toolCall.name === 'plan_respond'
|
||||
toolCall.name === 'plan_respond' ||
|
||||
toolCall.name === 'deploy_respond'
|
||||
)
|
||||
return null
|
||||
|
||||
|
||||
@@ -57,6 +57,12 @@ export const BrowserUseBlock: BlockConfig<BrowserUseResponse> = {
|
||||
type: 'switch',
|
||||
placeholder: 'Save browser data',
|
||||
},
|
||||
{
|
||||
id: 'profile_id',
|
||||
title: 'Profile ID',
|
||||
type: 'short-input',
|
||||
placeholder: 'Enter browser profile ID (optional)',
|
||||
},
|
||||
{
|
||||
id: 'apiKey',
|
||||
title: 'API Key',
|
||||
@@ -75,6 +81,7 @@ export const BrowserUseBlock: BlockConfig<BrowserUseResponse> = {
|
||||
variables: { type: 'json', description: 'Task variables' },
|
||||
model: { type: 'string', description: 'AI model to use' },
|
||||
save_browser_data: { type: 'boolean', description: 'Save browser data' },
|
||||
profile_id: { type: 'string', description: 'Browser profile ID for persistent sessions' },
|
||||
},
|
||||
outputs: {
|
||||
id: { type: 'string', description: 'Task execution identifier' },
|
||||
|
||||
599
apps/sim/executor/execution/engine.test.ts
Normal file
599
apps/sim/executor/execution/engine.test.ts
Normal file
@@ -0,0 +1,599 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { loggerMock } from '@sim/testing'
|
||||
import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
|
||||
|
||||
vi.mock('@sim/logger', () => loggerMock)
|
||||
|
||||
vi.mock('@/lib/execution/cancellation', () => ({
|
||||
isExecutionCancelled: vi.fn(),
|
||||
isRedisCancellationEnabled: vi.fn(),
|
||||
}))
|
||||
|
||||
import { isExecutionCancelled, isRedisCancellationEnabled } from '@/lib/execution/cancellation'
|
||||
import type { DAG, DAGNode } from '@/executor/dag/builder'
|
||||
import type { EdgeManager } from '@/executor/execution/edge-manager'
|
||||
import type { NodeExecutionOrchestrator } from '@/executor/orchestrators/node'
|
||||
import type { ExecutionContext } from '@/executor/types'
|
||||
import type { SerializedBlock } from '@/serializer/types'
|
||||
import { ExecutionEngine } from './engine'
|
||||
|
||||
function createMockBlock(id: string): SerializedBlock {
|
||||
return {
|
||||
id,
|
||||
metadata: { id: 'test', name: 'Test Block' },
|
||||
position: { x: 0, y: 0 },
|
||||
config: { tool: '', params: {} },
|
||||
inputs: {},
|
||||
outputs: {},
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
function createMockNode(id: string, blockType = 'test'): DAGNode {
|
||||
return {
|
||||
id,
|
||||
block: {
|
||||
...createMockBlock(id),
|
||||
metadata: { id: blockType, name: `Block ${id}` },
|
||||
},
|
||||
outgoingEdges: new Map(),
|
||||
incomingEdges: new Set(),
|
||||
metadata: {},
|
||||
}
|
||||
}
|
||||
|
||||
function createMockContext(overrides: Partial<ExecutionContext> = {}): ExecutionContext {
|
||||
return {
|
||||
workflowId: 'test-workflow',
|
||||
workspaceId: 'test-workspace',
|
||||
executionId: 'test-execution',
|
||||
userId: 'test-user',
|
||||
blockStates: new Map(),
|
||||
executedBlocks: new Set(),
|
||||
blockLogs: [],
|
||||
loopExecutions: new Map(),
|
||||
parallelExecutions: new Map(),
|
||||
completedLoops: new Set(),
|
||||
activeExecutionPath: new Set(),
|
||||
metadata: {
|
||||
executionId: 'test-execution',
|
||||
startTime: new Date().toISOString(),
|
||||
pendingBlocks: [],
|
||||
},
|
||||
envVars: {},
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
function createMockDAG(nodes: DAGNode[]): DAG {
|
||||
const nodeMap = new Map<string, DAGNode>()
|
||||
nodes.forEach((node) => nodeMap.set(node.id, node))
|
||||
return {
|
||||
nodes: nodeMap,
|
||||
loopConfigs: new Map(),
|
||||
parallelConfigs: new Map(),
|
||||
}
|
||||
}
|
||||
|
||||
interface MockEdgeManager extends EdgeManager {
|
||||
processOutgoingEdges: ReturnType<typeof vi.fn>
|
||||
}
|
||||
|
||||
function createMockEdgeManager(
|
||||
processOutgoingEdgesImpl?: (node: DAGNode) => string[]
|
||||
): MockEdgeManager {
|
||||
const mockFn = vi.fn().mockImplementation(processOutgoingEdgesImpl || (() => []))
|
||||
return {
|
||||
processOutgoingEdges: mockFn,
|
||||
isNodeReady: vi.fn().mockReturnValue(true),
|
||||
deactivateEdgeAndDescendants: vi.fn(),
|
||||
restoreIncomingEdge: vi.fn(),
|
||||
clearDeactivatedEdges: vi.fn(),
|
||||
clearDeactivatedEdgesForNodes: vi.fn(),
|
||||
} as unknown as MockEdgeManager
|
||||
}
|
||||
|
||||
interface MockNodeOrchestrator extends NodeExecutionOrchestrator {
|
||||
executionCount: number
|
||||
}
|
||||
|
||||
function createMockNodeOrchestrator(executeDelay = 0): MockNodeOrchestrator {
|
||||
const mock = {
|
||||
executionCount: 0,
|
||||
executeNode: vi.fn().mockImplementation(async () => {
|
||||
mock.executionCount++
|
||||
if (executeDelay > 0) {
|
||||
await new Promise((resolve) => setTimeout(resolve, executeDelay))
|
||||
}
|
||||
return { nodeId: 'test', output: {}, isFinalOutput: false }
|
||||
}),
|
||||
handleNodeCompletion: vi.fn(),
|
||||
}
|
||||
return mock as unknown as MockNodeOrchestrator
|
||||
}
|
||||
|
||||
describe('ExecutionEngine', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
;(isExecutionCancelled as Mock).mockResolvedValue(false)
|
||||
;(isRedisCancellationEnabled as Mock).mockReturnValue(false)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
describe('Normal execution', () => {
|
||||
it('should execute a simple linear workflow', async () => {
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const endNode = createMockNode('end', 'function')
|
||||
startNode.outgoingEdges.set('edge1', { target: 'end' })
|
||||
endNode.incomingEdges.add('start')
|
||||
|
||||
const dag = createMockDAG([startNode, endNode])
|
||||
const context = createMockContext()
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
if (node.id === 'start') return ['end']
|
||||
return []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('start')
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(nodeOrchestrator.executionCount).toBe(2)
|
||||
})
|
||||
|
||||
it('should mark execution as successful when completed without cancellation', async () => {
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const dag = createMockDAG([startNode])
|
||||
const context = createMockContext()
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('start')
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.status).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should execute all nodes in a multi-node workflow', async () => {
|
||||
const nodes = [
|
||||
createMockNode('start', 'starter'),
|
||||
createMockNode('middle1', 'function'),
|
||||
createMockNode('middle2', 'function'),
|
||||
createMockNode('end', 'function'),
|
||||
]
|
||||
|
||||
nodes[0].outgoingEdges.set('e1', { target: 'middle1' })
|
||||
nodes[1].outgoingEdges.set('e2', { target: 'middle2' })
|
||||
nodes[2].outgoingEdges.set('e3', { target: 'end' })
|
||||
|
||||
const dag = createMockDAG(nodes)
|
||||
const context = createMockContext()
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
if (node.id === 'start') return ['middle1']
|
||||
if (node.id === 'middle1') return ['middle2']
|
||||
if (node.id === 'middle2') return ['end']
|
||||
return []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('start')
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(nodeOrchestrator.executionCount).toBe(4)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Cancellation via AbortSignal', () => {
|
||||
it('should stop execution immediately when aborted before start', async () => {
|
||||
const abortController = new AbortController()
|
||||
abortController.abort()
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const dag = createMockDAG([startNode])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('start')
|
||||
|
||||
expect(result.status).toBe('cancelled')
|
||||
expect(nodeOrchestrator.executionCount).toBe(0)
|
||||
})
|
||||
|
||||
it('should stop execution when aborted mid-workflow', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const nodes = Array.from({ length: 5 }, (_, i) => createMockNode(`node${i}`, 'function'))
|
||||
for (let i = 0; i < nodes.length - 1; i++) {
|
||||
nodes[i].outgoingEdges.set(`e${i}`, { target: `node${i + 1}` })
|
||||
}
|
||||
|
||||
const dag = createMockDAG(nodes)
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
|
||||
let callCount = 0
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
callCount++
|
||||
if (callCount === 2) abortController.abort()
|
||||
const idx = Number.parseInt(node.id.replace('node', ''))
|
||||
if (idx < 4) return [`node${idx + 1}`]
|
||||
return []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('node0')
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.status).toBe('cancelled')
|
||||
expect(nodeOrchestrator.executionCount).toBeLessThan(5)
|
||||
})
|
||||
|
||||
it('should not wait for slow executions when cancelled', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const slowNode = createMockNode('slow', 'function')
|
||||
startNode.outgoingEdges.set('edge1', { target: 'slow' })
|
||||
|
||||
const dag = createMockDAG([startNode, slowNode])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
if (node.id === 'start') return ['slow']
|
||||
return []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator(500)
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
|
||||
const executionPromise = engine.run('start')
|
||||
setTimeout(() => abortController.abort(), 50)
|
||||
|
||||
const startTime = Date.now()
|
||||
const result = await executionPromise
|
||||
const duration = Date.now() - startTime
|
||||
|
||||
expect(result.status).toBe('cancelled')
|
||||
expect(duration).toBeLessThan(400)
|
||||
})
|
||||
|
||||
it('should return cancelled status even if error thrown during cancellation', async () => {
|
||||
const abortController = new AbortController()
|
||||
abortController.abort()
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const dag = createMockDAG([startNode])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('start')
|
||||
|
||||
expect(result.status).toBe('cancelled')
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Cancellation via Redis', () => {
|
||||
it('should check Redis for cancellation when enabled', async () => {
|
||||
;(isRedisCancellationEnabled as Mock).mockReturnValue(true)
|
||||
;(isExecutionCancelled as Mock).mockResolvedValue(false)
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const dag = createMockDAG([startNode])
|
||||
const context = createMockContext()
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
await engine.run('start')
|
||||
|
||||
expect(isExecutionCancelled as Mock).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should stop execution when Redis reports cancellation', async () => {
|
||||
;(isRedisCancellationEnabled as Mock).mockReturnValue(true)
|
||||
|
||||
let checkCount = 0
|
||||
;(isExecutionCancelled as Mock).mockImplementation(async () => {
|
||||
checkCount++
|
||||
return checkCount > 1
|
||||
})
|
||||
|
||||
const nodes = Array.from({ length: 5 }, (_, i) => createMockNode(`node${i}`, 'function'))
|
||||
for (let i = 0; i < nodes.length - 1; i++) {
|
||||
nodes[i].outgoingEdges.set(`e${i}`, { target: `node${i + 1}` })
|
||||
}
|
||||
|
||||
const dag = createMockDAG(nodes)
|
||||
const context = createMockContext()
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
const idx = Number.parseInt(node.id.replace('node', ''))
|
||||
if (idx < 4) return [`node${idx + 1}`]
|
||||
return []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator(150)
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('node0')
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(result.status).toBe('cancelled')
|
||||
})
|
||||
|
||||
it('should respect cancellation check interval', async () => {
|
||||
;(isRedisCancellationEnabled as Mock).mockReturnValue(true)
|
||||
;(isExecutionCancelled as Mock).mockResolvedValue(false)
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const dag = createMockDAG([startNode])
|
||||
const context = createMockContext()
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
await engine.run('start')
|
||||
|
||||
expect((isExecutionCancelled as Mock).mock.calls.length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Loop execution with cancellation', () => {
|
||||
it('should break out of loop when cancelled mid-iteration', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const loopStartNode = createMockNode('loop-start', 'loop_sentinel')
|
||||
loopStartNode.metadata = { isSentinel: true, sentinelType: 'start', loopId: 'loop1' }
|
||||
|
||||
const loopBodyNode = createMockNode('loop-body', 'function')
|
||||
loopBodyNode.metadata = { isLoopNode: true, loopId: 'loop1' }
|
||||
|
||||
const loopEndNode = createMockNode('loop-end', 'loop_sentinel')
|
||||
loopEndNode.metadata = { isSentinel: true, sentinelType: 'end', loopId: 'loop1' }
|
||||
|
||||
loopStartNode.outgoingEdges.set('edge1', { target: 'loop-body' })
|
||||
loopBodyNode.outgoingEdges.set('edge2', { target: 'loop-end' })
|
||||
loopEndNode.outgoingEdges.set('loop_continue', {
|
||||
target: 'loop-start',
|
||||
sourceHandle: 'loop_continue',
|
||||
})
|
||||
|
||||
const dag = createMockDAG([loopStartNode, loopBodyNode, loopEndNode])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
|
||||
let iterationCount = 0
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
if (node.id === 'loop-start') return ['loop-body']
|
||||
if (node.id === 'loop-body') return ['loop-end']
|
||||
if (node.id === 'loop-end') {
|
||||
iterationCount++
|
||||
if (iterationCount === 3) abortController.abort()
|
||||
return ['loop-start']
|
||||
}
|
||||
return []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator(5)
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('loop-start')
|
||||
|
||||
expect(result.status).toBe('cancelled')
|
||||
expect(iterationCount).toBeLessThan(100)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Parallel execution with cancellation', () => {
|
||||
it('should stop queueing parallel branches when cancelled', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const parallelNodes = Array.from({ length: 10 }, (_, i) =>
|
||||
createMockNode(`parallel${i}`, 'function')
|
||||
)
|
||||
|
||||
parallelNodes.forEach((_, i) => {
|
||||
startNode.outgoingEdges.set(`edge${i}`, { target: `parallel${i}` })
|
||||
})
|
||||
|
||||
const dag = createMockDAG([startNode, ...parallelNodes])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
if (node.id === 'start') {
|
||||
return parallelNodes.map((_, i) => `parallel${i}`)
|
||||
}
|
||||
return []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator(50)
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
|
||||
const executionPromise = engine.run('start')
|
||||
setTimeout(() => abortController.abort(), 30)
|
||||
|
||||
const result = await executionPromise
|
||||
|
||||
expect(result.status).toBe('cancelled')
|
||||
expect(nodeOrchestrator.executionCount).toBeLessThan(11)
|
||||
})
|
||||
|
||||
it('should not wait for all parallel branches when cancelled', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const slowNodes = Array.from({ length: 5 }, (_, i) => createMockNode(`slow${i}`, 'function'))
|
||||
|
||||
slowNodes.forEach((_, i) => {
|
||||
startNode.outgoingEdges.set(`edge${i}`, { target: `slow${i}` })
|
||||
})
|
||||
|
||||
const dag = createMockDAG([startNode, ...slowNodes])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
if (node.id === 'start') return slowNodes.map((_, i) => `slow${i}`)
|
||||
return []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator(200)
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
|
||||
const executionPromise = engine.run('start')
|
||||
setTimeout(() => abortController.abort(), 50)
|
||||
|
||||
const startTime = Date.now()
|
||||
const result = await executionPromise
|
||||
const duration = Date.now() - startTime
|
||||
|
||||
expect(result.status).toBe('cancelled')
|
||||
expect(duration).toBeLessThan(500)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle empty DAG gracefully', async () => {
|
||||
const dag = createMockDAG([])
|
||||
const context = createMockContext()
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run()
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(nodeOrchestrator.executionCount).toBe(0)
|
||||
})
|
||||
|
||||
it('should preserve partial output when cancelled', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const endNode = createMockNode('end', 'function')
|
||||
endNode.outgoingEdges = new Map()
|
||||
|
||||
startNode.outgoingEdges.set('edge1', { target: 'end' })
|
||||
|
||||
const dag = createMockDAG([startNode, endNode])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
if (node.id === 'start') return ['end']
|
||||
return []
|
||||
})
|
||||
|
||||
const nodeOrchestrator = {
|
||||
executionCount: 0,
|
||||
executeNode: vi.fn().mockImplementation(async (_ctx: ExecutionContext, nodeId: string) => {
|
||||
if (nodeId === 'start') {
|
||||
return { nodeId: 'start', output: { startData: 'value' }, isFinalOutput: false }
|
||||
}
|
||||
abortController.abort()
|
||||
return { nodeId: 'end', output: { endData: 'value' }, isFinalOutput: true }
|
||||
}),
|
||||
handleNodeCompletion: vi.fn(),
|
||||
} as unknown as MockNodeOrchestrator
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('start')
|
||||
|
||||
expect(result.status).toBe('cancelled')
|
||||
expect(result.output).toBeDefined()
|
||||
})
|
||||
|
||||
it('should populate metadata on cancellation', async () => {
|
||||
const abortController = new AbortController()
|
||||
abortController.abort()
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const dag = createMockDAG([startNode])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('start')
|
||||
|
||||
expect(result.metadata).toBeDefined()
|
||||
expect(result.metadata.endTime).toBeDefined()
|
||||
expect(result.metadata.duration).toBeDefined()
|
||||
})
|
||||
|
||||
it('should return logs even when cancelled', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const startNode = createMockNode('start', 'starter')
|
||||
const dag = createMockDAG([startNode])
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
context.blockLogs.push({
|
||||
blockId: 'test',
|
||||
blockName: 'Test',
|
||||
blockType: 'test',
|
||||
startedAt: '',
|
||||
endedAt: '',
|
||||
durationMs: 0,
|
||||
success: true,
|
||||
})
|
||||
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
abortController.abort()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('start')
|
||||
|
||||
expect(result.logs).toBeDefined()
|
||||
expect(result.logs.length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Cancellation flag behavior', () => {
|
||||
it('should set cancelledFlag when abort signal fires', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const nodes = Array.from({ length: 3 }, (_, i) => createMockNode(`node${i}`, 'function'))
|
||||
for (let i = 0; i < nodes.length - 1; i++) {
|
||||
nodes[i].outgoingEdges.set(`e${i}`, { target: `node${i + 1}` })
|
||||
}
|
||||
|
||||
const dag = createMockDAG(nodes)
|
||||
const context = createMockContext({ abortSignal: abortController.signal })
|
||||
const edgeManager = createMockEdgeManager((node) => {
|
||||
if (node.id === 'node0') {
|
||||
abortController.abort()
|
||||
return ['node1']
|
||||
}
|
||||
return node.id === 'node1' ? ['node2'] : []
|
||||
})
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
const result = await engine.run('node0')
|
||||
|
||||
expect(result.status).toBe('cancelled')
|
||||
})
|
||||
|
||||
it('should cache Redis cancellation result', async () => {
|
||||
;(isRedisCancellationEnabled as Mock).mockReturnValue(true)
|
||||
;(isExecutionCancelled as Mock).mockResolvedValue(true)
|
||||
|
||||
const nodes = Array.from({ length: 5 }, (_, i) => createMockNode(`node${i}`, 'function'))
|
||||
const dag = createMockDAG(nodes)
|
||||
const context = createMockContext()
|
||||
const edgeManager = createMockEdgeManager()
|
||||
const nodeOrchestrator = createMockNodeOrchestrator()
|
||||
|
||||
const engine = new ExecutionEngine(context, dag, edgeManager, nodeOrchestrator)
|
||||
await engine.run('node0')
|
||||
|
||||
expect((isExecutionCancelled as Mock).mock.calls.length).toBeLessThanOrEqual(3)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -28,6 +28,8 @@ export class ExecutionEngine {
|
||||
private lastCancellationCheck = 0
|
||||
private readonly useRedisCancellation: boolean
|
||||
private readonly CANCELLATION_CHECK_INTERVAL_MS = 500
|
||||
private abortPromise: Promise<void> | null = null
|
||||
private abortResolve: (() => void) | null = null
|
||||
|
||||
constructor(
|
||||
private context: ExecutionContext,
|
||||
@@ -37,6 +39,34 @@ export class ExecutionEngine {
|
||||
) {
|
||||
this.allowResumeTriggers = this.context.metadata.resumeFromSnapshot === true
|
||||
this.useRedisCancellation = isRedisCancellationEnabled() && !!this.context.executionId
|
||||
this.initializeAbortHandler()
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up a single abort promise that can be reused throughout execution.
|
||||
* This avoids creating multiple event listeners and potential memory leaks.
|
||||
*/
|
||||
private initializeAbortHandler(): void {
|
||||
if (!this.context.abortSignal) return
|
||||
|
||||
if (this.context.abortSignal.aborted) {
|
||||
this.cancelledFlag = true
|
||||
this.abortPromise = Promise.resolve()
|
||||
return
|
||||
}
|
||||
|
||||
this.abortPromise = new Promise<void>((resolve) => {
|
||||
this.abortResolve = resolve
|
||||
})
|
||||
|
||||
this.context.abortSignal.addEventListener(
|
||||
'abort',
|
||||
() => {
|
||||
this.cancelledFlag = true
|
||||
this.abortResolve?.()
|
||||
},
|
||||
{ once: true }
|
||||
)
|
||||
}
|
||||
|
||||
private async checkCancellation(): Promise<boolean> {
|
||||
@@ -73,12 +103,15 @@ export class ExecutionEngine {
|
||||
this.initializeQueue(triggerBlockId)
|
||||
|
||||
while (this.hasWork()) {
|
||||
if ((await this.checkCancellation()) && this.executing.size === 0) {
|
||||
if (await this.checkCancellation()) {
|
||||
break
|
||||
}
|
||||
await this.processQueue()
|
||||
}
|
||||
await this.waitForAllExecutions()
|
||||
|
||||
if (!this.cancelledFlag) {
|
||||
await this.waitForAllExecutions()
|
||||
}
|
||||
|
||||
if (this.pausedBlocks.size > 0) {
|
||||
return this.buildPausedResult(startTime)
|
||||
@@ -164,11 +197,7 @@ export class ExecutionEngine {
|
||||
|
||||
private trackExecution(promise: Promise<void>): void {
|
||||
this.executing.add(promise)
|
||||
// Attach error handler to prevent unhandled rejection warnings
|
||||
// The actual error handling happens in waitForAllExecutions/waitForAnyExecution
|
||||
promise.catch(() => {
|
||||
// Error will be properly handled by Promise.all/Promise.race in wait methods
|
||||
})
|
||||
promise.catch(() => {})
|
||||
promise.finally(() => {
|
||||
this.executing.delete(promise)
|
||||
})
|
||||
@@ -176,12 +205,30 @@ export class ExecutionEngine {
|
||||
|
||||
private async waitForAnyExecution(): Promise<void> {
|
||||
if (this.executing.size > 0) {
|
||||
await Promise.race(this.executing)
|
||||
const abortPromise = this.getAbortPromise()
|
||||
if (abortPromise) {
|
||||
await Promise.race([...this.executing, abortPromise])
|
||||
} else {
|
||||
await Promise.race(this.executing)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async waitForAllExecutions(): Promise<void> {
|
||||
await Promise.all(Array.from(this.executing))
|
||||
const abortPromise = this.getAbortPromise()
|
||||
if (abortPromise) {
|
||||
await Promise.race([Promise.all(this.executing), abortPromise])
|
||||
} else {
|
||||
await Promise.all(this.executing)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the cached abort promise. This is safe to call multiple times
|
||||
* as it reuses the same promise instance created during initialization.
|
||||
*/
|
||||
private getAbortPromise(): Promise<void> | null {
|
||||
return this.abortPromise
|
||||
}
|
||||
|
||||
private async withQueueLock<T>(fn: () => Promise<T> | T): Promise<T> {
|
||||
@@ -277,7 +324,7 @@ export class ExecutionEngine {
|
||||
this.trackExecution(promise)
|
||||
}
|
||||
|
||||
if (this.executing.size > 0) {
|
||||
if (this.executing.size > 0 && !this.cancelledFlag) {
|
||||
await this.waitForAnyExecution()
|
||||
}
|
||||
}
|
||||
@@ -336,7 +383,6 @@ export class ExecutionEngine {
|
||||
|
||||
this.addMultipleToQueue(readyNodes)
|
||||
|
||||
// Check for dynamically added nodes (e.g., from parallel expansion)
|
||||
if (this.context.pendingDynamicNodes && this.context.pendingDynamicNodes.length > 0) {
|
||||
const dynamicNodes = this.context.pendingDynamicNodes
|
||||
this.context.pendingDynamicNodes = []
|
||||
|
||||
@@ -377,10 +377,7 @@ function buildManualTriggerOutput(
|
||||
return mergeFilesIntoOutput(output, workflowInput)
|
||||
}
|
||||
|
||||
function buildIntegrationTriggerOutput(
|
||||
_finalInput: unknown,
|
||||
workflowInput: unknown
|
||||
): NormalizedBlockOutput {
|
||||
function buildIntegrationTriggerOutput(workflowInput: unknown): NormalizedBlockOutput {
|
||||
return isPlainObject(workflowInput) ? (workflowInput as NormalizedBlockOutput) : {}
|
||||
}
|
||||
|
||||
@@ -430,7 +427,7 @@ export function buildStartBlockOutput(options: StartBlockOutputOptions): Normali
|
||||
return buildManualTriggerOutput(finalInput, workflowInput)
|
||||
|
||||
case StartBlockPath.EXTERNAL_TRIGGER:
|
||||
return buildIntegrationTriggerOutput(finalInput, workflowInput)
|
||||
return buildIntegrationTriggerOutput(workflowInput)
|
||||
|
||||
case StartBlockPath.LEGACY_STARTER:
|
||||
return buildLegacyStarterOutput(
|
||||
|
||||
@@ -897,6 +897,17 @@ export function useCollaborativeWorkflow() {
|
||||
// Collect all edge IDs to remove
|
||||
const edgeIdsToRemove = updates.flatMap((u) => u.affectedEdges.map((e) => e.id))
|
||||
if (edgeIdsToRemove.length > 0) {
|
||||
const edgeOperationId = crypto.randomUUID()
|
||||
addToQueue({
|
||||
id: edgeOperationId,
|
||||
operation: {
|
||||
operation: EDGES_OPERATIONS.BATCH_REMOVE_EDGES,
|
||||
target: OPERATION_TARGETS.EDGES,
|
||||
payload: { ids: edgeIdsToRemove },
|
||||
},
|
||||
workflowId: activeWorkflowId || '',
|
||||
userId: session?.user?.id || 'unknown',
|
||||
})
|
||||
useWorkflowStore.getState().batchRemoveEdges(edgeIdsToRemove)
|
||||
}
|
||||
|
||||
|
||||
@@ -98,6 +98,8 @@ export interface ApiResponse {
|
||||
*/
|
||||
export interface StreamingResponse extends ApiResponse {
|
||||
stream?: ReadableStream
|
||||
streamId?: string
|
||||
chatId?: string
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -163,9 +165,15 @@ export async function sendStreamingMessage(
|
||||
}
|
||||
}
|
||||
|
||||
// Extract stream and chat IDs from headers for resumption support
|
||||
const streamId = response.headers.get('X-Stream-Id') || undefined
|
||||
const chatId = response.headers.get('X-Chat-Id') || undefined
|
||||
|
||||
return {
|
||||
success: true,
|
||||
stream: response.body,
|
||||
streamId,
|
||||
chatId,
|
||||
}
|
||||
} catch (error) {
|
||||
// Handle AbortError gracefully - this is expected when user aborts
|
||||
|
||||
324
apps/sim/lib/copilot/render-events.ts
Normal file
324
apps/sim/lib/copilot/render-events.ts
Normal file
@@ -0,0 +1,324 @@
|
||||
/**
|
||||
* Render events are the normalized event types sent to clients.
|
||||
* These are independent of the sim agent's internal event format.
|
||||
*/
|
||||
|
||||
export type RenderEventType =
|
||||
| 'text_delta'
|
||||
| 'text_complete'
|
||||
| 'tool_pending'
|
||||
| 'tool_executing'
|
||||
| 'tool_success'
|
||||
| 'tool_error'
|
||||
| 'tool_result'
|
||||
| 'subagent_start'
|
||||
| 'subagent_text'
|
||||
| 'subagent_tool_call'
|
||||
| 'subagent_end'
|
||||
| 'thinking_start'
|
||||
| 'thinking_delta'
|
||||
| 'thinking_end'
|
||||
| 'message_start'
|
||||
| 'message_complete'
|
||||
| 'chat_id'
|
||||
| 'conversation_id'
|
||||
| 'error'
|
||||
| 'stream_status'
|
||||
|
||||
export interface BaseRenderEvent {
|
||||
type: RenderEventType
|
||||
timestamp?: number
|
||||
}
|
||||
|
||||
export interface TextDeltaEvent extends BaseRenderEvent {
|
||||
type: 'text_delta'
|
||||
content: string
|
||||
}
|
||||
|
||||
export interface TextCompleteEvent extends BaseRenderEvent {
|
||||
type: 'text_complete'
|
||||
content: string
|
||||
}
|
||||
|
||||
export interface ToolPendingEvent extends BaseRenderEvent {
|
||||
type: 'tool_pending'
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args?: Record<string, unknown>
|
||||
display?: {
|
||||
label: string
|
||||
icon?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface ToolExecutingEvent extends BaseRenderEvent {
|
||||
type: 'tool_executing'
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
}
|
||||
|
||||
export interface ToolSuccessEvent extends BaseRenderEvent {
|
||||
type: 'tool_success'
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
result?: unknown
|
||||
display?: {
|
||||
label: string
|
||||
icon?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface ToolErrorEvent extends BaseRenderEvent {
|
||||
type: 'tool_error'
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
error: string
|
||||
display?: {
|
||||
label: string
|
||||
icon?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface ToolResultEvent extends BaseRenderEvent {
|
||||
type: 'tool_result'
|
||||
toolCallId: string
|
||||
success: boolean
|
||||
result?: unknown
|
||||
error?: string
|
||||
failedDependency?: boolean
|
||||
skipped?: boolean
|
||||
}
|
||||
|
||||
export interface SubagentStartEvent extends BaseRenderEvent {
|
||||
type: 'subagent_start'
|
||||
parentToolCallId: string
|
||||
subagentName: string
|
||||
}
|
||||
|
||||
export interface SubagentTextEvent extends BaseRenderEvent {
|
||||
type: 'subagent_text'
|
||||
parentToolCallId: string
|
||||
content: string
|
||||
}
|
||||
|
||||
export interface SubagentToolCallEvent extends BaseRenderEvent {
|
||||
type: 'subagent_tool_call'
|
||||
parentToolCallId: string
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args?: Record<string, unknown>
|
||||
state: 'pending' | 'executing' | 'success' | 'error'
|
||||
result?: unknown
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface SubagentEndEvent extends BaseRenderEvent {
|
||||
type: 'subagent_end'
|
||||
parentToolCallId: string
|
||||
}
|
||||
|
||||
export interface ThinkingStartEvent extends BaseRenderEvent {
|
||||
type: 'thinking_start'
|
||||
}
|
||||
|
||||
export interface ThinkingDeltaEvent extends BaseRenderEvent {
|
||||
type: 'thinking_delta'
|
||||
content: string
|
||||
}
|
||||
|
||||
export interface ThinkingEndEvent extends BaseRenderEvent {
|
||||
type: 'thinking_end'
|
||||
}
|
||||
|
||||
export interface MessageStartEvent extends BaseRenderEvent {
|
||||
type: 'message_start'
|
||||
messageId: string
|
||||
}
|
||||
|
||||
export interface MessageCompleteEvent extends BaseRenderEvent {
|
||||
type: 'message_complete'
|
||||
messageId: string
|
||||
content?: string
|
||||
}
|
||||
|
||||
export interface ChatIdEvent extends BaseRenderEvent {
|
||||
type: 'chat_id'
|
||||
chatId: string
|
||||
}
|
||||
|
||||
export interface ConversationIdEvent extends BaseRenderEvent {
|
||||
type: 'conversation_id'
|
||||
conversationId: string
|
||||
}
|
||||
|
||||
export interface ErrorEvent extends BaseRenderEvent {
|
||||
type: 'error'
|
||||
error: string
|
||||
code?: string
|
||||
}
|
||||
|
||||
export interface StreamStatusEvent extends BaseRenderEvent {
|
||||
type: 'stream_status'
|
||||
status: 'streaming' | 'complete' | 'error' | 'aborted'
|
||||
error?: string
|
||||
}
|
||||
|
||||
export type RenderEvent =
|
||||
| TextDeltaEvent
|
||||
| TextCompleteEvent
|
||||
| ToolPendingEvent
|
||||
| ToolExecutingEvent
|
||||
| ToolSuccessEvent
|
||||
| ToolErrorEvent
|
||||
| ToolResultEvent
|
||||
| SubagentStartEvent
|
||||
| SubagentTextEvent
|
||||
| SubagentToolCallEvent
|
||||
| SubagentEndEvent
|
||||
| ThinkingStartEvent
|
||||
| ThinkingDeltaEvent
|
||||
| ThinkingEndEvent
|
||||
| MessageStartEvent
|
||||
| MessageCompleteEvent
|
||||
| ChatIdEvent
|
||||
| ConversationIdEvent
|
||||
| ErrorEvent
|
||||
| StreamStatusEvent
|
||||
|
||||
/**
|
||||
* Serialize a render event to SSE format
|
||||
*/
|
||||
export function serializeRenderEvent(event: RenderEvent): string {
|
||||
const eventWithTimestamp = {
|
||||
...event,
|
||||
timestamp: event.timestamp || Date.now(),
|
||||
}
|
||||
return `event: ${event.type}\ndata: ${JSON.stringify(eventWithTimestamp)}\n\n`
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse an SSE chunk into a render event
|
||||
*/
|
||||
export function parseRenderEvent(chunk: string): RenderEvent | null {
|
||||
// SSE format: "event: <type>\ndata: <json>\n\n"
|
||||
const lines = chunk.trim().split('\n')
|
||||
let eventType: string | null = null
|
||||
let data: string | null = null
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('event: ')) {
|
||||
eventType = line.slice(7)
|
||||
} else if (line.startsWith('data: ')) {
|
||||
data = line.slice(6)
|
||||
}
|
||||
}
|
||||
|
||||
if (!data) return null
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
// If we extracted an event type from SSE, use it; otherwise use from data
|
||||
if (eventType && !parsed.type) {
|
||||
parsed.type = eventType
|
||||
}
|
||||
return parsed as RenderEvent
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a text delta event
|
||||
*/
|
||||
export function createTextDelta(content: string): TextDeltaEvent {
|
||||
return { type: 'text_delta', content, timestamp: Date.now() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tool pending event
|
||||
*/
|
||||
export function createToolPending(
|
||||
toolCallId: string,
|
||||
toolName: string,
|
||||
args?: Record<string, unknown>,
|
||||
display?: { label: string; icon?: string }
|
||||
): ToolPendingEvent {
|
||||
return {
|
||||
type: 'tool_pending',
|
||||
toolCallId,
|
||||
toolName,
|
||||
args,
|
||||
display,
|
||||
timestamp: Date.now(),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tool executing event
|
||||
*/
|
||||
export function createToolExecuting(toolCallId: string, toolName: string): ToolExecutingEvent {
|
||||
return { type: 'tool_executing', toolCallId, toolName, timestamp: Date.now() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tool success event
|
||||
*/
|
||||
export function createToolSuccess(
|
||||
toolCallId: string,
|
||||
toolName: string,
|
||||
result?: unknown,
|
||||
display?: { label: string; icon?: string }
|
||||
): ToolSuccessEvent {
|
||||
return {
|
||||
type: 'tool_success',
|
||||
toolCallId,
|
||||
toolName,
|
||||
result,
|
||||
display,
|
||||
timestamp: Date.now(),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tool error event
|
||||
*/
|
||||
export function createToolError(
|
||||
toolCallId: string,
|
||||
toolName: string,
|
||||
error: string,
|
||||
display?: { label: string; icon?: string }
|
||||
): ToolErrorEvent {
|
||||
return {
|
||||
type: 'tool_error',
|
||||
toolCallId,
|
||||
toolName,
|
||||
error,
|
||||
display,
|
||||
timestamp: Date.now(),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a message complete event
|
||||
*/
|
||||
export function createMessageComplete(messageId: string, content?: string): MessageCompleteEvent {
|
||||
return { type: 'message_complete', messageId, content, timestamp: Date.now() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a stream status event
|
||||
*/
|
||||
export function createStreamStatus(
|
||||
status: 'streaming' | 'complete' | 'error' | 'aborted',
|
||||
error?: string
|
||||
): StreamStatusEvent {
|
||||
return { type: 'stream_status', status, error, timestamp: Date.now() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an error event
|
||||
*/
|
||||
export function createError(error: string, code?: string): ErrorEvent {
|
||||
return { type: 'error', error, code, timestamp: Date.now() }
|
||||
}
|
||||
|
||||
299
apps/sim/lib/copilot/stream-client.ts
Normal file
299
apps/sim/lib/copilot/stream-client.ts
Normal file
@@ -0,0 +1,299 @@
|
||||
'use client'
|
||||
|
||||
import { createLogger } from '@sim/logger'
|
||||
|
||||
const logger = createLogger('StreamClient')
|
||||
|
||||
export interface StreamMetadata {
|
||||
streamId: string
|
||||
chatId: string
|
||||
userId: string
|
||||
workflowId: string
|
||||
userMessageId: string
|
||||
assistantMessageId?: string
|
||||
status: 'pending' | 'streaming' | 'complete' | 'error' | 'aborted'
|
||||
isClientSession: boolean
|
||||
createdAt: number
|
||||
updatedAt: number
|
||||
completedAt?: number
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface StreamResumeResponse {
|
||||
metadata: StreamMetadata
|
||||
events: string[]
|
||||
toolCalls: Record<string, unknown>
|
||||
totalEvents: number
|
||||
nextOffset: number
|
||||
}
|
||||
|
||||
const STREAM_ID_STORAGE_KEY = 'copilot:activeStream'
|
||||
const RECONNECT_DELAY_MS = 1000
|
||||
const MAX_RECONNECT_ATTEMPTS = 5
|
||||
|
||||
/**
|
||||
* Store active stream info for potential resumption
|
||||
*/
|
||||
export function storeActiveStream(
|
||||
chatId: string,
|
||||
streamId: string,
|
||||
messageId: string
|
||||
): void {
|
||||
try {
|
||||
const data = { chatId, streamId, messageId, storedAt: Date.now() }
|
||||
sessionStorage.setItem(STREAM_ID_STORAGE_KEY, JSON.stringify(data))
|
||||
logger.info('Stored active stream for potential resumption', { streamId, chatId })
|
||||
} catch {
|
||||
// Session storage not available
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get stored active stream if one exists
|
||||
*/
|
||||
export function getStoredActiveStream(): {
|
||||
chatId: string
|
||||
streamId: string
|
||||
messageId: string
|
||||
storedAt: number
|
||||
} | null {
|
||||
try {
|
||||
const data = sessionStorage.getItem(STREAM_ID_STORAGE_KEY)
|
||||
if (!data) return null
|
||||
return JSON.parse(data)
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear stored active stream
|
||||
*/
|
||||
export function clearStoredActiveStream(): void {
|
||||
try {
|
||||
sessionStorage.removeItem(STREAM_ID_STORAGE_KEY)
|
||||
} catch {
|
||||
// Session storage not available
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a stream is still active
|
||||
*/
|
||||
export async function checkStreamStatus(streamId: string): Promise<StreamMetadata | null> {
|
||||
try {
|
||||
const response = await fetch(`/api/copilot/stream/${streamId}?mode=poll&offset=0`)
|
||||
if (!response.ok) {
|
||||
if (response.status === 404) {
|
||||
// Stream not found or expired
|
||||
return null
|
||||
}
|
||||
throw new Error(`Failed to check stream status: ${response.statusText}`)
|
||||
}
|
||||
const data: StreamResumeResponse = await response.json()
|
||||
return data.metadata
|
||||
} catch (error) {
|
||||
logger.error('Failed to check stream status', { streamId, error })
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resume a stream from a given offset using SSE
|
||||
*/
|
||||
export async function resumeStream(
|
||||
streamId: string,
|
||||
offset: number = 0
|
||||
): Promise<ReadableStream<Uint8Array> | null> {
|
||||
try {
|
||||
const response = await fetch(`/api/copilot/stream/${streamId}?mode=sse&offset=${offset}`)
|
||||
if (!response.ok || !response.body) {
|
||||
if (response.status === 404) {
|
||||
logger.info('Stream not found for resumption', { streamId })
|
||||
clearStoredActiveStream()
|
||||
return null
|
||||
}
|
||||
throw new Error(`Failed to resume stream: ${response.statusText}`)
|
||||
}
|
||||
|
||||
logger.info('Stream resumption started', { streamId, offset })
|
||||
return response.body
|
||||
} catch (error) {
|
||||
logger.error('Failed to resume stream', { streamId, error })
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Abort a stream
|
||||
*/
|
||||
export async function abortStream(streamId: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`/api/copilot/stream/${streamId}`, {
|
||||
method: 'DELETE',
|
||||
})
|
||||
if (!response.ok && response.status !== 404) {
|
||||
throw new Error(`Failed to abort stream: ${response.statusText}`)
|
||||
}
|
||||
clearStoredActiveStream()
|
||||
return true
|
||||
} catch (error) {
|
||||
logger.error('Failed to abort stream', { streamId, error })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export interface StreamSubscription {
|
||||
unsubscribe: () => void
|
||||
getStreamId: () => string
|
||||
}
|
||||
|
||||
export interface StreamEventHandler {
|
||||
onEvent: (event: { type: string; data: Record<string, unknown> }) => void
|
||||
onError?: (error: Error) => void
|
||||
onComplete?: () => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to a stream (new or resumed) and process events
|
||||
* This provides a unified interface for both initial streams and resumed streams
|
||||
*/
|
||||
export function subscribeToStream(
|
||||
streamBody: ReadableStream<Uint8Array>,
|
||||
handlers: StreamEventHandler
|
||||
): StreamSubscription {
|
||||
const reader = streamBody.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let cancelled = false
|
||||
let buffer = ''
|
||||
let streamId = ''
|
||||
|
||||
const processEvents = async () => {
|
||||
try {
|
||||
while (!cancelled) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done || cancelled) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
|
||||
// Process complete SSE messages
|
||||
const messages = buffer.split('\n\n')
|
||||
buffer = messages.pop() || ''
|
||||
|
||||
for (const message of messages) {
|
||||
if (!message.trim()) continue
|
||||
if (message.startsWith(':')) continue // SSE comment (ping)
|
||||
|
||||
// Parse SSE format
|
||||
const lines = message.split('\n')
|
||||
let eventType = 'message'
|
||||
let data: Record<string, unknown> = {}
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('event: ')) {
|
||||
eventType = line.slice(7)
|
||||
} else if (line.startsWith('data: ')) {
|
||||
try {
|
||||
data = JSON.parse(line.slice(6))
|
||||
} catch {
|
||||
data = { raw: line.slice(6) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track stream ID if provided in metadata
|
||||
if (eventType === 'metadata' && data.streamId) {
|
||||
streamId = data.streamId as string
|
||||
}
|
||||
|
||||
handlers.onEvent({ type: eventType, data })
|
||||
|
||||
// Check for terminal events
|
||||
if (eventType === 'stream_status') {
|
||||
const status = data.status as string
|
||||
if (status === 'complete' || status === 'error' || status === 'aborted') {
|
||||
if (status === 'error' && handlers.onError) {
|
||||
handlers.onError(new Error(data.error as string || 'Stream error'))
|
||||
}
|
||||
if (handlers.onComplete) {
|
||||
handlers.onComplete()
|
||||
}
|
||||
clearStoredActiveStream()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream ended without explicit status
|
||||
if (handlers.onComplete) {
|
||||
handlers.onComplete()
|
||||
}
|
||||
} catch (error) {
|
||||
if (!cancelled && handlers.onError) {
|
||||
handlers.onError(error instanceof Error ? error : new Error(String(error)))
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
// Start processing
|
||||
processEvents()
|
||||
|
||||
return {
|
||||
unsubscribe: () => {
|
||||
cancelled = true
|
||||
reader.cancel().catch(() => {})
|
||||
clearStoredActiveStream()
|
||||
},
|
||||
getStreamId: () => streamId,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to resume any active stream from session storage
|
||||
* Returns handlers if resumption is possible, null otherwise
|
||||
*/
|
||||
export async function attemptStreamResumption(): Promise<{
|
||||
stream: ReadableStream<Uint8Array>
|
||||
metadata: StreamMetadata
|
||||
offset: number
|
||||
} | null> {
|
||||
const stored = getStoredActiveStream()
|
||||
if (!stored) return null
|
||||
|
||||
// Check if stream is still valid (not too old)
|
||||
const maxAge = 5 * 60 * 1000 // 5 minutes
|
||||
if (Date.now() - stored.storedAt > maxAge) {
|
||||
clearStoredActiveStream()
|
||||
return null
|
||||
}
|
||||
|
||||
// Check stream status
|
||||
const metadata = await checkStreamStatus(stored.streamId)
|
||||
if (!metadata) {
|
||||
clearStoredActiveStream()
|
||||
return null
|
||||
}
|
||||
|
||||
// Only resume if stream is still active
|
||||
if (metadata.status !== 'streaming' && metadata.status !== 'pending') {
|
||||
clearStoredActiveStream()
|
||||
return null
|
||||
}
|
||||
|
||||
// Get the stream
|
||||
const stream = await resumeStream(stored.streamId, 0)
|
||||
if (!stream) {
|
||||
return null
|
||||
}
|
||||
|
||||
logger.info('Stream resumption possible', {
|
||||
streamId: stored.streamId,
|
||||
status: metadata.status,
|
||||
})
|
||||
|
||||
return { stream, metadata, offset: 0 }
|
||||
}
|
||||
|
||||
327
apps/sim/lib/copilot/stream-persistence.ts
Normal file
327
apps/sim/lib/copilot/stream-persistence.ts
Normal file
@@ -0,0 +1,327 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { getRedisClient } from '@/lib/core/config/redis'
|
||||
|
||||
const logger = createLogger('StreamPersistence')
|
||||
|
||||
const STREAM_PREFIX = 'copilot:stream:'
|
||||
const STREAM_TTL = 60 * 60 * 24 // 24 hours
|
||||
|
||||
export type StreamStatus = 'pending' | 'streaming' | 'complete' | 'error' | 'aborted'
|
||||
|
||||
export interface StreamMetadata {
|
||||
streamId: string
|
||||
chatId: string
|
||||
userId: string
|
||||
workflowId: string
|
||||
userMessageId: string
|
||||
assistantMessageId?: string
|
||||
status: StreamStatus
|
||||
isClientSession: boolean
|
||||
createdAt: number
|
||||
updatedAt: number
|
||||
completedAt?: number
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface ToolCallState {
|
||||
id: string
|
||||
name: string
|
||||
args: Record<string, unknown>
|
||||
state: 'pending' | 'executing' | 'success' | 'error'
|
||||
result?: unknown
|
||||
error?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize a new stream in Redis
|
||||
*/
|
||||
export async function createStream(params: {
|
||||
streamId: string
|
||||
chatId: string
|
||||
userId: string
|
||||
workflowId: string
|
||||
userMessageId: string
|
||||
isClientSession: boolean
|
||||
}): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) {
|
||||
logger.warn('Redis not available, stream will not be resumable')
|
||||
return
|
||||
}
|
||||
|
||||
const metadata: StreamMetadata = {
|
||||
...params,
|
||||
status: 'pending',
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
}
|
||||
|
||||
const key = `${STREAM_PREFIX}${params.streamId}:meta`
|
||||
await redis.set(key, JSON.stringify(metadata), 'EX', STREAM_TTL)
|
||||
|
||||
logger.info('Stream created', { streamId: params.streamId })
|
||||
}
|
||||
|
||||
/**
|
||||
* Update stream status
|
||||
*/
|
||||
export async function updateStreamStatus(
|
||||
streamId: string,
|
||||
status: StreamStatus,
|
||||
error?: string
|
||||
): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:meta`
|
||||
const data = await redis.get(key)
|
||||
if (!data) return
|
||||
|
||||
const metadata: StreamMetadata = JSON.parse(data)
|
||||
metadata.status = status
|
||||
metadata.updatedAt = Date.now()
|
||||
if (status === 'complete' || status === 'error') {
|
||||
metadata.completedAt = Date.now()
|
||||
}
|
||||
if (error) {
|
||||
metadata.error = error
|
||||
}
|
||||
|
||||
await redis.set(key, JSON.stringify(metadata), 'EX', STREAM_TTL)
|
||||
}
|
||||
|
||||
/**
|
||||
* Update stream metadata with additional fields
|
||||
*/
|
||||
export async function updateStreamMetadata(
|
||||
streamId: string,
|
||||
updates: Partial<StreamMetadata>
|
||||
): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:meta`
|
||||
const data = await redis.get(key)
|
||||
if (!data) return
|
||||
|
||||
const metadata: StreamMetadata = JSON.parse(data)
|
||||
Object.assign(metadata, updates, { updatedAt: Date.now() })
|
||||
|
||||
await redis.set(key, JSON.stringify(metadata), 'EX', STREAM_TTL)
|
||||
}
|
||||
|
||||
/**
|
||||
* Append a serialized SSE event chunk to the stream
|
||||
*/
|
||||
export async function appendChunk(streamId: string, chunk: string): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:events`
|
||||
await redis.rpush(key, chunk)
|
||||
await redis.expire(key, STREAM_TTL)
|
||||
}
|
||||
|
||||
/**
|
||||
* Append text content (for quick content retrieval without parsing events)
|
||||
*/
|
||||
export async function appendContent(streamId: string, content: string): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:content`
|
||||
await redis.append(key, content)
|
||||
await redis.expire(key, STREAM_TTL)
|
||||
}
|
||||
|
||||
/**
|
||||
* Update tool call state
|
||||
*/
|
||||
export async function updateToolCall(
|
||||
streamId: string,
|
||||
toolCallId: string,
|
||||
update: Partial<ToolCallState>
|
||||
): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:tools`
|
||||
const existing = await redis.hget(key, toolCallId)
|
||||
const current: ToolCallState = existing
|
||||
? JSON.parse(existing)
|
||||
: { id: toolCallId, name: '', args: {}, state: 'pending' }
|
||||
|
||||
const updated = { ...current, ...update }
|
||||
await redis.hset(key, toolCallId, JSON.stringify(updated))
|
||||
await redis.expire(key, STREAM_TTL)
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark stream as complete
|
||||
*/
|
||||
export async function completeStream(streamId: string, result?: unknown): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
await updateStreamStatus(streamId, 'complete')
|
||||
|
||||
if (result !== undefined) {
|
||||
const key = `${STREAM_PREFIX}${streamId}:result`
|
||||
await redis.set(key, JSON.stringify(result), 'EX', STREAM_TTL)
|
||||
}
|
||||
|
||||
logger.info('Stream completed', { streamId })
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark stream as errored
|
||||
*/
|
||||
export async function errorStream(streamId: string, error: string): Promise<void> {
|
||||
await updateStreamStatus(streamId, 'error', error)
|
||||
logger.error('Stream errored', { streamId, error })
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if stream was aborted (client requested abort)
|
||||
*/
|
||||
export async function checkAbortSignal(streamId: string): Promise<boolean> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return false
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:abort`
|
||||
const aborted = await redis.exists(key)
|
||||
return aborted === 1
|
||||
}
|
||||
|
||||
/**
|
||||
* Signal stream abort
|
||||
*/
|
||||
export async function abortStream(streamId: string): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
await redis.set(`${STREAM_PREFIX}${streamId}:abort`, '1', 'EX', STREAM_TTL)
|
||||
await updateStreamStatus(streamId, 'aborted')
|
||||
logger.info('Stream aborted', { streamId })
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh TTL on all stream keys
|
||||
*/
|
||||
export async function refreshStreamTTL(streamId: string): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
const keys = [
|
||||
`${STREAM_PREFIX}${streamId}:meta`,
|
||||
`${STREAM_PREFIX}${streamId}:events`,
|
||||
`${STREAM_PREFIX}${streamId}:content`,
|
||||
`${STREAM_PREFIX}${streamId}:tools`,
|
||||
`${STREAM_PREFIX}${streamId}:result`,
|
||||
]
|
||||
|
||||
for (const key of keys) {
|
||||
await redis.expire(key, STREAM_TTL)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get stream metadata
|
||||
*/
|
||||
export async function getStreamMetadata(streamId: string): Promise<StreamMetadata | null> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return null
|
||||
|
||||
const data = await redis.get(`${STREAM_PREFIX}${streamId}:meta`)
|
||||
return data ? JSON.parse(data) : null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get stream events from offset (for resumption)
|
||||
*/
|
||||
export async function getStreamEvents(streamId: string, fromOffset: number = 0): Promise<string[]> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return []
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:events`
|
||||
return redis.lrange(key, fromOffset, -1)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current event count (for client to know where it is)
|
||||
*/
|
||||
export async function getStreamEventCount(streamId: string): Promise<number> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return 0
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:events`
|
||||
return redis.llen(key)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all tool call states
|
||||
*/
|
||||
export async function getToolCallStates(streamId: string): Promise<Record<string, ToolCallState>> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return {}
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:tools`
|
||||
const data = await redis.hgetall(key)
|
||||
|
||||
const result: Record<string, ToolCallState> = {}
|
||||
for (const [id, json] of Object.entries(data)) {
|
||||
result[id] = JSON.parse(json)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Get accumulated content
|
||||
*/
|
||||
export async function getStreamContent(streamId: string): Promise<string> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return ''
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:content`
|
||||
return (await redis.get(key)) || ''
|
||||
}
|
||||
|
||||
/**
|
||||
* Get final result (if complete)
|
||||
*/
|
||||
export async function getStreamResult(streamId: string): Promise<unknown | null> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return null
|
||||
|
||||
const key = `${STREAM_PREFIX}${streamId}:result`
|
||||
const data = await redis.get(key)
|
||||
return data ? JSON.parse(data) : null
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if Redis is available for stream persistence
|
||||
*/
|
||||
export function isStreamPersistenceEnabled(): boolean {
|
||||
return getRedisClient() !== null
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all stream data (cleanup)
|
||||
*/
|
||||
export async function deleteStream(streamId: string): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) return
|
||||
|
||||
const keys = [
|
||||
`${STREAM_PREFIX}${streamId}:meta`,
|
||||
`${STREAM_PREFIX}${streamId}:events`,
|
||||
`${STREAM_PREFIX}${streamId}:content`,
|
||||
`${STREAM_PREFIX}${streamId}:tools`,
|
||||
`${STREAM_PREFIX}${streamId}:result`,
|
||||
`${STREAM_PREFIX}${streamId}:abort`,
|
||||
]
|
||||
|
||||
await redis.del(...keys)
|
||||
logger.info('Stream deleted', { streamId })
|
||||
}
|
||||
|
||||
419
apps/sim/lib/copilot/stream-transformer.ts
Normal file
419
apps/sim/lib/copilot/stream-transformer.ts
Normal file
@@ -0,0 +1,419 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import type { RenderEvent } from './render-events'
|
||||
|
||||
const logger = createLogger('StreamTransformer')
|
||||
|
||||
export interface TransformStreamContext {
|
||||
streamId: string
|
||||
chatId: string
|
||||
userId: string
|
||||
workflowId: string
|
||||
userMessageId: string
|
||||
assistantMessageId: string
|
||||
|
||||
/** Callback for each render event - handles both client delivery and persistence */
|
||||
onRenderEvent: (event: RenderEvent) => Promise<void>
|
||||
|
||||
/** Callback for persistence operations */
|
||||
onPersist?: (data: { type: string; [key: string]: unknown }) => Promise<void>
|
||||
|
||||
/** Check if stream should be aborted */
|
||||
isAborted: () => boolean | Promise<boolean>
|
||||
}
|
||||
|
||||
interface SimAgentEvent {
|
||||
type?: string
|
||||
event?: string
|
||||
data?: unknown
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform a sim agent SSE stream into normalized render events.
|
||||
* This function consumes the entire stream and emits events via callbacks.
|
||||
*/
|
||||
export async function transformStream(
|
||||
body: ReadableStream<Uint8Array>,
|
||||
context: TransformStreamContext
|
||||
): Promise<void> {
|
||||
const { onRenderEvent, onPersist, isAborted } = context
|
||||
|
||||
const reader = body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
// Check abort signal
|
||||
const shouldAbort = await Promise.resolve(isAborted())
|
||||
if (shouldAbort) {
|
||||
logger.info('Stream aborted by signal', { streamId: context.streamId })
|
||||
break
|
||||
}
|
||||
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
|
||||
// Process complete SSE messages (separated by double newlines)
|
||||
const messages = buffer.split('\n\n')
|
||||
buffer = messages.pop() || '' // Keep incomplete message in buffer
|
||||
|
||||
for (const message of messages) {
|
||||
if (!message.trim()) continue
|
||||
|
||||
const events = parseSimAgentMessage(message)
|
||||
for (const simEvent of events) {
|
||||
const renderEvents = transformSimAgentEvent(simEvent, context)
|
||||
for (const renderEvent of renderEvents) {
|
||||
await onRenderEvent(renderEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process any remaining buffer content
|
||||
if (buffer.trim()) {
|
||||
const events = parseSimAgentMessage(buffer)
|
||||
for (const simEvent of events) {
|
||||
const renderEvents = transformSimAgentEvent(simEvent, context)
|
||||
for (const renderEvent of renderEvents) {
|
||||
await onRenderEvent(renderEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit message complete
|
||||
await onRenderEvent({
|
||||
type: 'message_complete',
|
||||
messageId: context.assistantMessageId,
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
|
||||
// Notify persistence layer
|
||||
if (onPersist) {
|
||||
await onPersist({ type: 'message_complete', messageId: context.assistantMessageId })
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Stream transform error', { streamId: context.streamId, error })
|
||||
|
||||
await onRenderEvent({
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error.message : 'Stream processing error',
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
|
||||
throw error
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a raw SSE message into sim agent events
|
||||
*/
|
||||
function parseSimAgentMessage(message: string): SimAgentEvent[] {
|
||||
const events: SimAgentEvent[] = []
|
||||
const lines = message.split('\n')
|
||||
|
||||
let currentEvent: string | null = null
|
||||
let currentData: string[] = []
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('event: ')) {
|
||||
// If we have accumulated data, emit previous event
|
||||
if (currentData.length > 0) {
|
||||
const dataStr = currentData.join('\n')
|
||||
const parsed = tryParseJson(dataStr)
|
||||
if (parsed) {
|
||||
events.push({ ...parsed, event: currentEvent || undefined })
|
||||
}
|
||||
currentData = []
|
||||
}
|
||||
currentEvent = line.slice(7)
|
||||
} else if (line.startsWith('data: ')) {
|
||||
currentData.push(line.slice(6))
|
||||
} else if (line === '' && currentData.length > 0) {
|
||||
// Empty line signals end of event
|
||||
const dataStr = currentData.join('\n')
|
||||
const parsed = tryParseJson(dataStr)
|
||||
if (parsed) {
|
||||
events.push({ ...parsed, event: currentEvent || undefined })
|
||||
}
|
||||
currentEvent = null
|
||||
currentData = []
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remaining data
|
||||
if (currentData.length > 0) {
|
||||
const dataStr = currentData.join('\n')
|
||||
const parsed = tryParseJson(dataStr)
|
||||
if (parsed) {
|
||||
events.push({ ...parsed, event: currentEvent || undefined })
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
function tryParseJson(str: string): Record<string, unknown> | null {
|
||||
if (str === '[DONE]') return null
|
||||
try {
|
||||
return JSON.parse(str)
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform a sim agent event into one or more render events
|
||||
*/
|
||||
function transformSimAgentEvent(
|
||||
simEvent: SimAgentEvent,
|
||||
context: TransformStreamContext
|
||||
): RenderEvent[] {
|
||||
const eventType = simEvent.type || simEvent.event
|
||||
const events: RenderEvent[] = []
|
||||
const timestamp = Date.now()
|
||||
|
||||
switch (eventType) {
|
||||
// Text content events
|
||||
case 'content_block_delta':
|
||||
case 'text_delta':
|
||||
case 'delta': {
|
||||
const delta = (simEvent.delta as Record<string, unknown>) || simEvent
|
||||
const text = (delta.text as string) || (delta.content as string) || (simEvent.text as string)
|
||||
if (text) {
|
||||
events.push({ type: 'text_delta', content: text, timestamp })
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'content_block_stop':
|
||||
case 'text_complete': {
|
||||
events.push({
|
||||
type: 'text_complete',
|
||||
content: (simEvent.content as string) || '',
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
// Tool call events
|
||||
case 'tool_call':
|
||||
case 'tool_use': {
|
||||
const data = (simEvent.data as Record<string, unknown>) || simEvent
|
||||
const toolCallId = (data.id as string) || (simEvent.id as string)
|
||||
const toolName = (data.name as string) || (simEvent.name as string)
|
||||
const args = (data.arguments as Record<string, unknown>) || (data.input as Record<string, unknown>)
|
||||
|
||||
if (toolCallId && toolName) {
|
||||
events.push({
|
||||
type: 'tool_pending',
|
||||
toolCallId,
|
||||
toolName,
|
||||
args,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'tool_executing': {
|
||||
const toolCallId = (simEvent.toolCallId as string) || (simEvent.id as string)
|
||||
const toolName = (simEvent.toolName as string) || (simEvent.name as string) || ''
|
||||
|
||||
if (toolCallId) {
|
||||
events.push({
|
||||
type: 'tool_executing',
|
||||
toolCallId,
|
||||
toolName,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'tool_result': {
|
||||
const toolCallId = (simEvent.toolCallId as string) || (simEvent.id as string)
|
||||
const success = simEvent.success as boolean
|
||||
const result = simEvent.result
|
||||
const error = simEvent.error as string | undefined
|
||||
|
||||
if (toolCallId) {
|
||||
events.push({
|
||||
type: 'tool_result',
|
||||
toolCallId,
|
||||
success: success !== false,
|
||||
result,
|
||||
error,
|
||||
failedDependency: simEvent.failedDependency as boolean | undefined,
|
||||
skipped: (simEvent.result as Record<string, unknown>)?.skipped as boolean | undefined,
|
||||
timestamp,
|
||||
})
|
||||
|
||||
// Also emit success/error event for UI
|
||||
if (success !== false) {
|
||||
events.push({
|
||||
type: 'tool_success',
|
||||
toolCallId,
|
||||
toolName: (simEvent.toolName as string) || '',
|
||||
result,
|
||||
timestamp,
|
||||
})
|
||||
} else {
|
||||
events.push({
|
||||
type: 'tool_error',
|
||||
toolCallId,
|
||||
toolName: (simEvent.toolName as string) || '',
|
||||
error: error || 'Tool execution failed',
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Subagent events
|
||||
case 'subagent_start': {
|
||||
events.push({
|
||||
type: 'subagent_start',
|
||||
parentToolCallId: simEvent.parentToolCallId as string,
|
||||
subagentName: simEvent.subagentName as string,
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'subagent_text':
|
||||
case 'subagent_delta': {
|
||||
events.push({
|
||||
type: 'subagent_text',
|
||||
parentToolCallId: simEvent.parentToolCallId as string,
|
||||
content: (simEvent.content as string) || (simEvent.text as string) || '',
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'subagent_tool_call': {
|
||||
events.push({
|
||||
type: 'subagent_tool_call',
|
||||
parentToolCallId: simEvent.parentToolCallId as string,
|
||||
toolCallId: simEvent.toolCallId as string,
|
||||
toolName: simEvent.toolName as string,
|
||||
args: simEvent.args as Record<string, unknown> | undefined,
|
||||
state: (simEvent.state as 'pending' | 'executing' | 'success' | 'error') || 'pending',
|
||||
result: simEvent.result,
|
||||
error: simEvent.error as string | undefined,
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'subagent_end': {
|
||||
events.push({
|
||||
type: 'subagent_end',
|
||||
parentToolCallId: simEvent.parentToolCallId as string,
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
// Thinking events (for extended thinking models)
|
||||
case 'thinking_start':
|
||||
case 'thinking': {
|
||||
if (simEvent.type === 'thinking_start' || !simEvent.content) {
|
||||
events.push({ type: 'thinking_start', timestamp })
|
||||
}
|
||||
if (simEvent.content) {
|
||||
events.push({
|
||||
type: 'thinking_delta',
|
||||
content: simEvent.content as string,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'thinking_delta': {
|
||||
events.push({
|
||||
type: 'thinking_delta',
|
||||
content: (simEvent.content as string) || '',
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'thinking_end':
|
||||
case 'thinking_complete': {
|
||||
events.push({ type: 'thinking_end', timestamp })
|
||||
break
|
||||
}
|
||||
|
||||
// Message lifecycle events
|
||||
case 'message_start': {
|
||||
events.push({
|
||||
type: 'message_start',
|
||||
messageId: (simEvent.messageId as string) || context.assistantMessageId,
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'message_stop':
|
||||
case 'message_complete':
|
||||
case 'message_delta': {
|
||||
if (eventType === 'message_complete' || eventType === 'message_stop') {
|
||||
events.push({
|
||||
type: 'message_complete',
|
||||
messageId: (simEvent.messageId as string) || context.assistantMessageId,
|
||||
content: simEvent.content as string | undefined,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Metadata events
|
||||
case 'chat_id': {
|
||||
events.push({
|
||||
type: 'chat_id',
|
||||
chatId: simEvent.chatId as string,
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'conversation_id': {
|
||||
events.push({
|
||||
type: 'conversation_id',
|
||||
conversationId: simEvent.conversationId as string,
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
// Error events
|
||||
case 'error': {
|
||||
events.push({
|
||||
type: 'error',
|
||||
error: (simEvent.error as string) || (simEvent.message as string) || 'Unknown error',
|
||||
code: simEvent.code as string | undefined,
|
||||
timestamp,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
default: {
|
||||
// Log unhandled event types for debugging
|
||||
if (eventType && eventType !== 'ping') {
|
||||
logger.debug('Unhandled sim agent event type', { eventType, streamId: context.streamId })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
255
apps/sim/lib/copilot/tools/server/executor.ts
Normal file
255
apps/sim/lib/copilot/tools/server/executor.ts
Normal file
@@ -0,0 +1,255 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { routeExecution } from './router'
|
||||
import { saveWorkflowToNormalizedTables } from '@/lib/workflows/persistence/utils'
|
||||
|
||||
const logger = createLogger('ServerToolExecutor')
|
||||
|
||||
export interface ServerToolContext {
|
||||
workflowId: string
|
||||
userId: string
|
||||
persistChanges?: boolean
|
||||
}
|
||||
|
||||
export interface ServerToolResult {
|
||||
success: boolean
|
||||
result?: unknown
|
||||
error?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute any copilot tool completely server-side.
|
||||
* This is the central dispatcher for headless/API operation.
|
||||
*/
|
||||
export async function executeToolServerSide(
|
||||
toolCall: { name: string; args: Record<string, unknown> },
|
||||
context: ServerToolContext
|
||||
): Promise<ServerToolResult> {
|
||||
const { name, args } = toolCall
|
||||
const { workflowId, userId, persistChanges = true } = context
|
||||
|
||||
logger.info('Executing tool server-side', { name, workflowId, userId })
|
||||
|
||||
try {
|
||||
const result = await executeToolInternal(name, args, context)
|
||||
return { success: true, result }
|
||||
} catch (error) {
|
||||
logger.error('Server-side tool execution failed', {
|
||||
name,
|
||||
workflowId,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
return {
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function executeToolInternal(
|
||||
name: string,
|
||||
args: Record<string, unknown>,
|
||||
context: ServerToolContext
|
||||
): Promise<unknown> {
|
||||
const { workflowId, userId, persistChanges = true } = context
|
||||
|
||||
switch (name) {
|
||||
case 'edit_workflow': {
|
||||
// Execute edit_workflow with direct persistence
|
||||
const result = await routeExecution(
|
||||
'edit_workflow',
|
||||
{
|
||||
...args,
|
||||
workflowId,
|
||||
// Don't require currentUserWorkflow - server tool will load from DB
|
||||
},
|
||||
{ userId }
|
||||
)
|
||||
|
||||
// Persist directly to database if enabled
|
||||
if (persistChanges && result.workflowState) {
|
||||
try {
|
||||
await saveWorkflowToNormalizedTables(workflowId, result.workflowState)
|
||||
logger.info('Workflow changes persisted directly', { workflowId })
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist workflow changes', { error, workflowId })
|
||||
// Don't throw - return the result anyway
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
case 'run_workflow': {
|
||||
// Import dynamically to avoid circular dependencies
|
||||
const { executeWorkflow } = await import('@/lib/workflows/executor/execute-workflow')
|
||||
|
||||
const result = await executeWorkflow({
|
||||
workflowId,
|
||||
input: (args.workflow_input as Record<string, unknown>) || {},
|
||||
isClientSession: false,
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
case 'deploy_api':
|
||||
case 'deploy_chat':
|
||||
case 'deploy_mcp': {
|
||||
// Import dynamically
|
||||
const { deployWorkflow } = await import('@/lib/workflows/persistence/utils')
|
||||
|
||||
const deployType = name.replace('deploy_', '')
|
||||
const result = await deployWorkflow({
|
||||
workflowId,
|
||||
deployedBy: userId,
|
||||
})
|
||||
|
||||
return { ...result, deployType }
|
||||
}
|
||||
|
||||
case 'redeploy': {
|
||||
const { deployWorkflow } = await import('@/lib/workflows/persistence/utils')
|
||||
|
||||
const result = await deployWorkflow({
|
||||
workflowId,
|
||||
deployedBy: userId,
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Server tools that already exist in the router
|
||||
case 'get_blocks_and_tools':
|
||||
case 'get_blocks_metadata':
|
||||
case 'get_block_options':
|
||||
case 'get_block_config':
|
||||
case 'get_trigger_blocks':
|
||||
case 'get_workflow_console':
|
||||
case 'search_documentation':
|
||||
case 'search_online':
|
||||
case 'set_environment_variables':
|
||||
case 'get_credentials':
|
||||
case 'make_api_request':
|
||||
case 'knowledge_base': {
|
||||
return routeExecution(name, args, { userId })
|
||||
}
|
||||
|
||||
// Tools that just need workflowId context
|
||||
case 'get_user_workflow':
|
||||
case 'get_workflow_data': {
|
||||
const { loadWorkflowFromNormalizedTables } = await import(
|
||||
'@/lib/workflows/persistence/utils'
|
||||
)
|
||||
const { sanitizeForCopilot } = await import('@/lib/workflows/sanitization/json-sanitizer')
|
||||
|
||||
const workflowData = await loadWorkflowFromNormalizedTables(workflowId)
|
||||
if (!workflowData) {
|
||||
throw new Error('Workflow not found')
|
||||
}
|
||||
|
||||
const sanitized = sanitizeForCopilot({
|
||||
blocks: workflowData.blocks,
|
||||
edges: workflowData.edges,
|
||||
loops: workflowData.loops,
|
||||
parallels: workflowData.parallels,
|
||||
})
|
||||
|
||||
return { workflow: JSON.stringify(sanitized, null, 2) }
|
||||
}
|
||||
|
||||
case 'list_user_workflows': {
|
||||
const { db } = await import('@sim/db')
|
||||
const { workflow: workflowTable } = await import('@sim/db/schema')
|
||||
const { eq } = await import('drizzle-orm')
|
||||
|
||||
const workflows = await db
|
||||
.select({
|
||||
id: workflowTable.id,
|
||||
name: workflowTable.name,
|
||||
description: workflowTable.description,
|
||||
isDeployed: workflowTable.isDeployed,
|
||||
createdAt: workflowTable.createdAt,
|
||||
updatedAt: workflowTable.updatedAt,
|
||||
})
|
||||
.from(workflowTable)
|
||||
.where(eq(workflowTable.userId, userId))
|
||||
|
||||
return { workflows }
|
||||
}
|
||||
|
||||
case 'check_deployment_status': {
|
||||
const { db } = await import('@sim/db')
|
||||
const { workflow: workflowTable } = await import('@sim/db/schema')
|
||||
const { eq } = await import('drizzle-orm')
|
||||
|
||||
const [wf] = await db
|
||||
.select({
|
||||
isDeployed: workflowTable.isDeployed,
|
||||
deployedAt: workflowTable.deployedAt,
|
||||
})
|
||||
.from(workflowTable)
|
||||
.where(eq(workflowTable.id, workflowId))
|
||||
.limit(1)
|
||||
|
||||
return {
|
||||
isDeployed: wf?.isDeployed || false,
|
||||
deployedAt: wf?.deployedAt || null,
|
||||
}
|
||||
}
|
||||
|
||||
default: {
|
||||
logger.warn('Unknown tool for server-side execution', { name })
|
||||
throw new Error(`Tool ${name} is not available for server-side execution`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a tool can be executed server-side
|
||||
*/
|
||||
export function isServerExecutableTool(toolName: string): boolean {
|
||||
const serverExecutableTools = new Set([
|
||||
// Core editing tools
|
||||
'edit_workflow',
|
||||
'run_workflow',
|
||||
|
||||
// Deployment tools
|
||||
'deploy_api',
|
||||
'deploy_chat',
|
||||
'deploy_mcp',
|
||||
'redeploy',
|
||||
'check_deployment_status',
|
||||
|
||||
// Existing server tools
|
||||
'get_blocks_and_tools',
|
||||
'get_blocks_metadata',
|
||||
'get_block_options',
|
||||
'get_block_config',
|
||||
'get_trigger_blocks',
|
||||
'get_workflow_console',
|
||||
'search_documentation',
|
||||
'search_online',
|
||||
'set_environment_variables',
|
||||
'get_credentials',
|
||||
'make_api_request',
|
||||
'knowledge_base',
|
||||
|
||||
// Workflow info tools
|
||||
'get_user_workflow',
|
||||
'get_workflow_data',
|
||||
'list_user_workflows',
|
||||
])
|
||||
|
||||
return serverExecutableTools.has(toolName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of tools that require client-side execution
|
||||
*/
|
||||
export function getClientOnlyTools(): string[] {
|
||||
return [
|
||||
'navigate_ui', // Requires DOM
|
||||
'oauth_request_access', // Requires browser auth flow
|
||||
]
|
||||
}
|
||||
|
||||
@@ -54,6 +54,17 @@ export interface SimplifiedImapEmail {
|
||||
}
|
||||
|
||||
export interface ImapWebhookPayload {
|
||||
messageId: string
|
||||
subject: string
|
||||
from: string
|
||||
to: string
|
||||
cc: string
|
||||
date: string | null
|
||||
bodyText: string
|
||||
bodyHtml: string
|
||||
mailbox: string
|
||||
hasAttachments: boolean
|
||||
attachments: ImapAttachment[]
|
||||
email: SimplifiedImapEmail
|
||||
timestamp: string
|
||||
}
|
||||
@@ -613,6 +624,17 @@ async function processEmails(
|
||||
}
|
||||
|
||||
const payload: ImapWebhookPayload = {
|
||||
messageId: simplifiedEmail.messageId,
|
||||
subject: simplifiedEmail.subject,
|
||||
from: simplifiedEmail.from,
|
||||
to: simplifiedEmail.to,
|
||||
cc: simplifiedEmail.cc,
|
||||
date: simplifiedEmail.date,
|
||||
bodyText: simplifiedEmail.bodyText,
|
||||
bodyHtml: simplifiedEmail.bodyHtml,
|
||||
mailbox: simplifiedEmail.mailbox,
|
||||
hasAttachments: simplifiedEmail.hasAttachments,
|
||||
attachments: simplifiedEmail.attachments,
|
||||
email: simplifiedEmail,
|
||||
timestamp: new Date().toISOString(),
|
||||
}
|
||||
|
||||
@@ -48,6 +48,9 @@ interface RssFeed {
|
||||
}
|
||||
|
||||
export interface RssWebhookPayload {
|
||||
title?: string
|
||||
link?: string
|
||||
pubDate?: string
|
||||
item: RssItem
|
||||
feed: {
|
||||
title?: string
|
||||
@@ -349,6 +352,9 @@ async function processRssItems(
|
||||
`${webhookData.id}:${itemGuid}`,
|
||||
async () => {
|
||||
const payload: RssWebhookPayload = {
|
||||
title: item.title,
|
||||
link: item.link,
|
||||
pubDate: item.pubDate,
|
||||
item: {
|
||||
title: item.title,
|
||||
link: item.link,
|
||||
|
||||
@@ -686,6 +686,9 @@ export async function formatWebhookInput(
|
||||
if (foundWebhook.provider === 'rss') {
|
||||
if (body && typeof body === 'object' && 'item' in body) {
|
||||
return {
|
||||
title: body.title,
|
||||
link: body.link,
|
||||
pubDate: body.pubDate,
|
||||
item: body.item,
|
||||
feed: body.feed,
|
||||
timestamp: body.timestamp,
|
||||
@@ -697,6 +700,17 @@ export async function formatWebhookInput(
|
||||
if (foundWebhook.provider === 'imap') {
|
||||
if (body && typeof body === 'object' && 'email' in body) {
|
||||
return {
|
||||
messageId: body.messageId,
|
||||
subject: body.subject,
|
||||
from: body.from,
|
||||
to: body.to,
|
||||
cc: body.cc,
|
||||
date: body.date,
|
||||
bodyText: body.bodyText,
|
||||
bodyHtml: body.bodyHtml,
|
||||
mailbox: body.mailbox,
|
||||
hasAttachments: body.hasAttachments,
|
||||
attachments: body.attachments,
|
||||
email: body.email,
|
||||
timestamp: body.timestamp,
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
convertToGeminiFormat,
|
||||
convertUsageMetadata,
|
||||
createReadableStreamFromGeminiStream,
|
||||
ensureStructResponse,
|
||||
extractFunctionCallPart,
|
||||
extractTextContent,
|
||||
mapToThinkingLevel,
|
||||
@@ -104,7 +105,7 @@ async function executeToolCall(
|
||||
const duration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
const resultContent: Record<string, unknown> = result.success
|
||||
? (result.output as Record<string, unknown>)
|
||||
? ensureStructResponse(result.output)
|
||||
: { error: true, message: result.error || 'Tool execution failed', tool: toolName }
|
||||
|
||||
const toolCall: FunctionCallResponse = {
|
||||
|
||||
453
apps/sim/providers/google/utils.test.ts
Normal file
453
apps/sim/providers/google/utils.test.ts
Normal file
@@ -0,0 +1,453 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { convertToGeminiFormat, ensureStructResponse } from '@/providers/google/utils'
|
||||
import type { ProviderRequest } from '@/providers/types'
|
||||
|
||||
describe('ensureStructResponse', () => {
|
||||
describe('should return objects unchanged', () => {
|
||||
it('should return plain object unchanged', () => {
|
||||
const input = { key: 'value', nested: { a: 1 } }
|
||||
const result = ensureStructResponse(input)
|
||||
expect(result).toBe(input) // Same reference
|
||||
expect(result).toEqual({ key: 'value', nested: { a: 1 } })
|
||||
})
|
||||
|
||||
it('should return empty object unchanged', () => {
|
||||
const input = {}
|
||||
const result = ensureStructResponse(input)
|
||||
expect(result).toBe(input)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('should wrap primitive values in { value: ... }', () => {
|
||||
it('should wrap boolean true', () => {
|
||||
const result = ensureStructResponse(true)
|
||||
expect(result).toEqual({ value: true })
|
||||
expect(typeof result).toBe('object')
|
||||
})
|
||||
|
||||
it('should wrap boolean false', () => {
|
||||
const result = ensureStructResponse(false)
|
||||
expect(result).toEqual({ value: false })
|
||||
expect(typeof result).toBe('object')
|
||||
})
|
||||
|
||||
it('should wrap string', () => {
|
||||
const result = ensureStructResponse('success')
|
||||
expect(result).toEqual({ value: 'success' })
|
||||
expect(typeof result).toBe('object')
|
||||
})
|
||||
|
||||
it('should wrap empty string', () => {
|
||||
const result = ensureStructResponse('')
|
||||
expect(result).toEqual({ value: '' })
|
||||
expect(typeof result).toBe('object')
|
||||
})
|
||||
|
||||
it('should wrap number', () => {
|
||||
const result = ensureStructResponse(42)
|
||||
expect(result).toEqual({ value: 42 })
|
||||
expect(typeof result).toBe('object')
|
||||
})
|
||||
|
||||
it('should wrap zero', () => {
|
||||
const result = ensureStructResponse(0)
|
||||
expect(result).toEqual({ value: 0 })
|
||||
expect(typeof result).toBe('object')
|
||||
})
|
||||
|
||||
it('should wrap null', () => {
|
||||
const result = ensureStructResponse(null)
|
||||
expect(result).toEqual({ value: null })
|
||||
expect(typeof result).toBe('object')
|
||||
})
|
||||
|
||||
it('should wrap undefined', () => {
|
||||
const result = ensureStructResponse(undefined)
|
||||
expect(result).toEqual({ value: undefined })
|
||||
expect(typeof result).toBe('object')
|
||||
})
|
||||
})
|
||||
|
||||
describe('should wrap arrays in { value: ... }', () => {
|
||||
it('should wrap array of strings', () => {
|
||||
const result = ensureStructResponse(['a', 'b', 'c'])
|
||||
expect(result).toEqual({ value: ['a', 'b', 'c'] })
|
||||
expect(typeof result).toBe('object')
|
||||
expect(Array.isArray(result)).toBe(false)
|
||||
})
|
||||
|
||||
it('should wrap array of objects', () => {
|
||||
const result = ensureStructResponse([{ id: 1 }, { id: 2 }])
|
||||
expect(result).toEqual({ value: [{ id: 1 }, { id: 2 }] })
|
||||
expect(typeof result).toBe('object')
|
||||
expect(Array.isArray(result)).toBe(false)
|
||||
})
|
||||
|
||||
it('should wrap empty array', () => {
|
||||
const result = ensureStructResponse([])
|
||||
expect(result).toEqual({ value: [] })
|
||||
expect(typeof result).toBe('object')
|
||||
expect(Array.isArray(result)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle nested objects correctly', () => {
|
||||
const input = { a: { b: { c: 1 } }, d: [1, 2, 3] }
|
||||
const result = ensureStructResponse(input)
|
||||
expect(result).toBe(input) // Same reference, unchanged
|
||||
})
|
||||
|
||||
it('should handle object with array property correctly', () => {
|
||||
const input = { items: ['a', 'b'], count: 2 }
|
||||
const result = ensureStructResponse(input)
|
||||
expect(result).toBe(input) // Same reference, unchanged
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertToGeminiFormat', () => {
|
||||
describe('tool message handling', () => {
|
||||
it('should convert tool message with object response correctly', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: { name: 'get_weather', arguments: '{"city": "London"}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'get_weather',
|
||||
tool_call_id: 'call_123',
|
||||
content: '{"temperature": 20, "condition": "sunny"}',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
expect(toolResponseContent).toBeDefined()
|
||||
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
expect(functionResponse?.response).toEqual({ temperature: 20, condition: 'sunny' })
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
})
|
||||
|
||||
it('should wrap boolean true response in an object for Gemini compatibility', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Check if user exists' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_456',
|
||||
type: 'function',
|
||||
function: { name: 'user_exists', arguments: '{"userId": "123"}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'user_exists',
|
||||
tool_call_id: 'call_456',
|
||||
content: 'true', // Boolean true as JSON string
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
expect(toolResponseContent).toBeDefined()
|
||||
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
expect(functionResponse?.response).not.toBe(true)
|
||||
expect(functionResponse?.response).toEqual({ value: true })
|
||||
})
|
||||
|
||||
it('should wrap boolean false response in an object for Gemini compatibility', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Check if user exists' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_789',
|
||||
type: 'function',
|
||||
function: { name: 'user_exists', arguments: '{"userId": "999"}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'user_exists',
|
||||
tool_call_id: 'call_789',
|
||||
content: 'false', // Boolean false as JSON string
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
expect(functionResponse?.response).toEqual({ value: false })
|
||||
})
|
||||
|
||||
it('should wrap string response in an object for Gemini compatibility', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Get status' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_str',
|
||||
type: 'function',
|
||||
function: { name: 'get_status', arguments: '{}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'get_status',
|
||||
tool_call_id: 'call_str',
|
||||
content: '"success"', // String as JSON
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
expect(functionResponse?.response).toEqual({ value: 'success' })
|
||||
})
|
||||
|
||||
it('should wrap number response in an object for Gemini compatibility', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Get count' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_num',
|
||||
type: 'function',
|
||||
function: { name: 'get_count', arguments: '{}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'get_count',
|
||||
tool_call_id: 'call_num',
|
||||
content: '42', // Number as JSON
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
expect(functionResponse?.response).toEqual({ value: 42 })
|
||||
})
|
||||
|
||||
it('should wrap null response in an object for Gemini compatibility', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Get data' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_null',
|
||||
type: 'function',
|
||||
function: { name: 'get_data', arguments: '{}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'get_data',
|
||||
tool_call_id: 'call_null',
|
||||
content: 'null', // null as JSON
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
expect(functionResponse?.response).toEqual({ value: null })
|
||||
})
|
||||
|
||||
it('should keep array response as-is since arrays are valid Struct values', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Get items' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_arr',
|
||||
type: 'function',
|
||||
function: { name: 'get_items', arguments: '{}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'get_items',
|
||||
tool_call_id: 'call_arr',
|
||||
content: '["item1", "item2"]', // Array as JSON
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
expect(functionResponse?.response).toEqual({ value: ['item1', 'item2'] })
|
||||
})
|
||||
|
||||
it('should handle invalid JSON by wrapping in output object', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Get data' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_invalid',
|
||||
type: 'function',
|
||||
function: { name: 'get_data', arguments: '{}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'get_data',
|
||||
tool_call_id: 'call_invalid',
|
||||
content: 'not valid json {',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
expect(functionResponse?.response).toEqual({ output: 'not valid json {' })
|
||||
})
|
||||
|
||||
it('should handle empty content by wrapping in output object', () => {
|
||||
const request: ProviderRequest = {
|
||||
model: 'gemini-2.5-flash',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Do something' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_empty',
|
||||
type: 'function',
|
||||
function: { name: 'do_action', arguments: '{}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
name: 'do_action',
|
||||
tool_call_id: 'call_empty',
|
||||
content: '', // Empty content - falls back to default '{}'
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertToGeminiFormat(request)
|
||||
|
||||
const toolResponseContent = result.contents.find(
|
||||
(c) => c.parts?.[0] && 'functionResponse' in c.parts[0]
|
||||
)
|
||||
const functionResponse = (toolResponseContent?.parts?.[0] as { functionResponse?: unknown })
|
||||
?.functionResponse as { response?: unknown }
|
||||
|
||||
expect(typeof functionResponse?.response).toBe('object')
|
||||
// Empty string is not valid JSON, so it falls back to { output: "" }
|
||||
expect(functionResponse?.response).toEqual({ output: '' })
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -18,6 +18,22 @@ import { trackForcedToolUsage } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('GoogleUtils')
|
||||
|
||||
/**
|
||||
* Ensures a value is a valid object for Gemini's functionResponse.response field.
|
||||
* Gemini's API requires functionResponse.response to be a google.protobuf.Struct,
|
||||
* which must be an object with string keys. Primitive values (boolean, string,
|
||||
* number, null) and arrays are wrapped in { value: ... }.
|
||||
*
|
||||
* @param value - The value to ensure is a Struct-compatible object
|
||||
* @returns A Record<string, unknown> suitable for functionResponse.response
|
||||
*/
|
||||
export function ensureStructResponse(value: unknown): Record<string, unknown> {
|
||||
if (typeof value === 'object' && value !== null && !Array.isArray(value)) {
|
||||
return value as Record<string, unknown>
|
||||
}
|
||||
return { value }
|
||||
}
|
||||
|
||||
/**
|
||||
* Usage metadata for Google Gemini responses
|
||||
*/
|
||||
@@ -180,7 +196,8 @@ export function convertToGeminiFormat(request: ProviderRequest): {
|
||||
}
|
||||
let responseData: Record<string, unknown>
|
||||
try {
|
||||
responseData = JSON.parse(message.content ?? '{}')
|
||||
const parsed = JSON.parse(message.content ?? '{}')
|
||||
responseData = ensureStructResponse(parsed)
|
||||
} catch {
|
||||
responseData = { output: message.content }
|
||||
}
|
||||
|
||||
@@ -337,10 +337,11 @@ async function handleBlockOperationTx(
|
||||
const currentData = currentBlock?.data || {}
|
||||
|
||||
// Update data with parentId and extent
|
||||
const { parentId: _removedParentId, extent: _removedExtent, ...restData } = currentData
|
||||
const updatedData = isRemovingFromParent
|
||||
? {} // Clear data entirely when removing from parent
|
||||
? restData
|
||||
: {
|
||||
...currentData,
|
||||
...restData,
|
||||
...(payload.parentId ? { parentId: payload.parentId } : {}),
|
||||
...(payload.extent ? { extent: payload.extent } : {}),
|
||||
}
|
||||
@@ -828,10 +829,11 @@ async function handleBlocksOperationTx(
|
||||
|
||||
const currentData = currentBlock?.data || {}
|
||||
|
||||
const { parentId: _removedParentId, extent: _removedExtent, ...restData } = currentData
|
||||
const updatedData = isRemovingFromParent
|
||||
? {}
|
||||
? restData
|
||||
: {
|
||||
...currentData,
|
||||
...restData,
|
||||
...(parentId ? { parentId, extent: 'parent' } : {}),
|
||||
}
|
||||
|
||||
|
||||
@@ -2708,6 +2708,16 @@ export const useCopilotStore = create<CopilotStore>()(
|
||||
})
|
||||
|
||||
if (result.success && result.stream) {
|
||||
// Store stream ID for potential resumption on disconnect
|
||||
if (result.streamId) {
|
||||
const { storeActiveStream } = await import('@/lib/copilot/stream-client')
|
||||
storeActiveStream(
|
||||
result.chatId || currentChat?.id || '',
|
||||
result.streamId,
|
||||
streamingMessage.id
|
||||
)
|
||||
}
|
||||
|
||||
await get().handleStreamingResponse(
|
||||
result.stream,
|
||||
streamingMessage.id,
|
||||
@@ -2715,6 +2725,12 @@ export const useCopilotStore = create<CopilotStore>()(
|
||||
userMessage.id
|
||||
)
|
||||
set({ chatsLastLoadedAt: null, chatsLoadedForWorkflow: null })
|
||||
|
||||
// Clear stream storage on successful completion
|
||||
if (result.streamId) {
|
||||
const { clearStoredActiveStream } = await import('@/lib/copilot/stream-client')
|
||||
clearStoredActiveStream()
|
||||
}
|
||||
} else {
|
||||
if (result.error === 'Request was aborted') {
|
||||
return
|
||||
@@ -3853,6 +3869,68 @@ export const useCopilotStore = create<CopilotStore>()(
|
||||
return autoAllowedTools.includes(toolId)
|
||||
},
|
||||
|
||||
// Stream resumption
|
||||
attemptStreamResumption: async () => {
|
||||
const { isSendingMessage } = get()
|
||||
if (isSendingMessage) {
|
||||
logger.info('[Stream] Cannot attempt resumption while already sending')
|
||||
return false
|
||||
}
|
||||
|
||||
try {
|
||||
const { attemptStreamResumption, clearStoredActiveStream } = await import(
|
||||
'@/lib/copilot/stream-client'
|
||||
)
|
||||
|
||||
const resumption = await attemptStreamResumption()
|
||||
if (!resumption) {
|
||||
return false
|
||||
}
|
||||
|
||||
const { stream, metadata } = resumption
|
||||
|
||||
logger.info('[Stream] Resuming stream', {
|
||||
streamId: metadata.streamId,
|
||||
chatId: metadata.chatId,
|
||||
})
|
||||
|
||||
// Find or create the assistant message for this stream
|
||||
const { messages } = get()
|
||||
let assistantMessageId = metadata.assistantMessageId
|
||||
|
||||
// If we don't have the assistant message, create a placeholder
|
||||
if (!assistantMessageId || !messages.find((m) => m.id === assistantMessageId)) {
|
||||
assistantMessageId = crypto.randomUUID()
|
||||
const streamingMessage: CopilotMessage = {
|
||||
id: assistantMessageId,
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
timestamp: new Date().toISOString(),
|
||||
isStreaming: true,
|
||||
contentBlocks: [],
|
||||
}
|
||||
set((state) => ({
|
||||
messages: [...state.messages, streamingMessage],
|
||||
isSendingMessage: true,
|
||||
}))
|
||||
}
|
||||
|
||||
// Process the resumed stream
|
||||
await get().handleStreamingResponse(
|
||||
stream,
|
||||
assistantMessageId,
|
||||
true, // This is a continuation
|
||||
metadata.userMessageId
|
||||
)
|
||||
|
||||
clearStoredActiveStream()
|
||||
return true
|
||||
} catch (error) {
|
||||
logger.error('[Stream] Resumption failed', { error })
|
||||
return false
|
||||
}
|
||||
},
|
||||
|
||||
// Message queue actions
|
||||
addToQueue: (message, options) => {
|
||||
const queuedMessage: import('./types').QueuedMessage = {
|
||||
|
||||
@@ -63,6 +63,8 @@ export interface CopilotMessage {
|
||||
fileAttachments?: MessageFileAttachment[]
|
||||
contexts?: ChatContext[]
|
||||
errorType?: 'usage_limit' | 'unauthorized' | 'forbidden' | 'rate_limit' | 'upgrade_required'
|
||||
/** Whether this message is currently being streamed */
|
||||
isStreaming?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -235,6 +237,9 @@ export interface CopilotActions {
|
||||
removeAutoAllowedTool: (toolId: string) => Promise<void>
|
||||
isToolAutoAllowed: (toolId: string) => boolean
|
||||
|
||||
// Stream resumption
|
||||
attemptStreamResumption: () => Promise<boolean>
|
||||
|
||||
// Message queue actions
|
||||
addToQueue: (
|
||||
message: string,
|
||||
|
||||
@@ -1,11 +1,214 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import type { BrowserUseRunTaskParams, BrowserUseRunTaskResponse } from '@/tools/browser_use/types'
|
||||
import type { ToolConfig } from '@/tools/types'
|
||||
import type { ToolConfig, ToolResponse } from '@/tools/types'
|
||||
|
||||
const logger = createLogger('BrowserUseTool')
|
||||
|
||||
const POLL_INTERVAL_MS = 5000 // 5 seconds between polls
|
||||
const MAX_POLL_TIME_MS = 180000 // 3 minutes maximum polling time
|
||||
const POLL_INTERVAL_MS = 5000
|
||||
const MAX_POLL_TIME_MS = 180000
|
||||
const MAX_CONSECUTIVE_ERRORS = 3
|
||||
|
||||
async function createSessionWithProfile(
|
||||
profileId: string,
|
||||
apiKey: string
|
||||
): Promise<{ sessionId: string } | { error: string }> {
|
||||
try {
|
||||
const response = await fetch('https://api.browser-use.com/api/v2/sessions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Browser-Use-API-Key': apiKey,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
profileId: profileId.trim(),
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
logger.error(`Failed to create session with profile: ${errorText}`)
|
||||
return { error: `Failed to create session with profile: ${response.statusText}` }
|
||||
}
|
||||
|
||||
const data = (await response.json()) as { id: string }
|
||||
logger.info(`Created session ${data.id} with profile ${profileId}`)
|
||||
return { sessionId: data.id }
|
||||
} catch (error: any) {
|
||||
logger.error('Error creating session with profile:', error)
|
||||
return { error: `Error creating session: ${error.message}` }
|
||||
}
|
||||
}
|
||||
|
||||
async function stopSession(sessionId: string, apiKey: string): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(`https://api.browser-use.com/api/v2/sessions/${sessionId}`, {
|
||||
method: 'PATCH',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Browser-Use-API-Key': apiKey,
|
||||
},
|
||||
body: JSON.stringify({ action: 'stop' }),
|
||||
})
|
||||
|
||||
if (response.ok) {
|
||||
logger.info(`Stopped session ${sessionId}`)
|
||||
} else {
|
||||
logger.warn(`Failed to stop session ${sessionId}: ${response.statusText}`)
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.warn(`Error stopping session ${sessionId}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
function buildRequestBody(
|
||||
params: BrowserUseRunTaskParams,
|
||||
sessionId?: string
|
||||
): Record<string, any> {
|
||||
const requestBody: Record<string, any> = {
|
||||
task: params.task,
|
||||
}
|
||||
|
||||
if (sessionId) {
|
||||
requestBody.sessionId = sessionId
|
||||
logger.info(`Using session ${sessionId} for task`)
|
||||
}
|
||||
|
||||
if (params.variables) {
|
||||
let secrets: Record<string, string> = {}
|
||||
|
||||
if (Array.isArray(params.variables)) {
|
||||
logger.info('Converting variables array to dictionary format')
|
||||
params.variables.forEach((row: any) => {
|
||||
if (row.cells?.Key && row.cells.Value !== undefined) {
|
||||
secrets[row.cells.Key] = row.cells.Value
|
||||
logger.info(`Added secret for key: ${row.cells.Key}`)
|
||||
} else if (row.Key && row.Value !== undefined) {
|
||||
secrets[row.Key] = row.Value
|
||||
logger.info(`Added secret for key: ${row.Key}`)
|
||||
}
|
||||
})
|
||||
} else if (typeof params.variables === 'object' && params.variables !== null) {
|
||||
logger.info('Using variables object directly')
|
||||
secrets = params.variables
|
||||
}
|
||||
|
||||
if (Object.keys(secrets).length > 0) {
|
||||
logger.info(`Found ${Object.keys(secrets).length} secrets to include`)
|
||||
requestBody.secrets = secrets
|
||||
} else {
|
||||
logger.warn('No usable secrets found in variables')
|
||||
}
|
||||
}
|
||||
|
||||
if (params.model) {
|
||||
requestBody.llm_model = params.model
|
||||
}
|
||||
|
||||
if (params.save_browser_data) {
|
||||
requestBody.save_browser_data = params.save_browser_data
|
||||
}
|
||||
|
||||
requestBody.use_adblock = true
|
||||
requestBody.highlight_elements = true
|
||||
|
||||
return requestBody
|
||||
}
|
||||
|
||||
async function fetchTaskStatus(
|
||||
taskId: string,
|
||||
apiKey: string
|
||||
): Promise<{ ok: true; data: any } | { ok: false; error: string }> {
|
||||
try {
|
||||
const response = await fetch(`https://api.browser-use.com/api/v2/tasks/${taskId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'X-Browser-Use-API-Key': apiKey,
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
return { ok: false, error: `HTTP ${response.status}: ${response.statusText}` }
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return { ok: true, data }
|
||||
} catch (error: any) {
|
||||
return { ok: false, error: error.message || 'Network error' }
|
||||
}
|
||||
}
|
||||
|
||||
async function pollForCompletion(
|
||||
taskId: string,
|
||||
apiKey: string
|
||||
): Promise<{ success: boolean; output: any; steps: any[]; error?: string }> {
|
||||
let liveUrlLogged = false
|
||||
let consecutiveErrors = 0
|
||||
const startTime = Date.now()
|
||||
|
||||
while (Date.now() - startTime < MAX_POLL_TIME_MS) {
|
||||
const result = await fetchTaskStatus(taskId, apiKey)
|
||||
|
||||
if (!result.ok) {
|
||||
consecutiveErrors++
|
||||
logger.warn(
|
||||
`Error polling task ${taskId} (attempt ${consecutiveErrors}/${MAX_CONSECUTIVE_ERRORS}): ${result.error}`
|
||||
)
|
||||
|
||||
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
|
||||
logger.error(`Max consecutive errors reached for task ${taskId}`)
|
||||
return {
|
||||
success: false,
|
||||
output: null,
|
||||
steps: [],
|
||||
error: `Failed to poll task status after ${MAX_CONSECUTIVE_ERRORS} attempts: ${result.error}`,
|
||||
}
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, POLL_INTERVAL_MS))
|
||||
continue
|
||||
}
|
||||
|
||||
consecutiveErrors = 0
|
||||
const taskData = result.data
|
||||
const status = taskData.status
|
||||
|
||||
logger.info(`BrowserUse task ${taskId} status: ${status}`)
|
||||
|
||||
if (['finished', 'failed', 'stopped'].includes(status)) {
|
||||
return {
|
||||
success: status === 'finished',
|
||||
output: taskData.output ?? null,
|
||||
steps: taskData.steps || [],
|
||||
}
|
||||
}
|
||||
|
||||
if (!liveUrlLogged && taskData.live_url) {
|
||||
logger.info(`BrowserUse task ${taskId} live URL: ${taskData.live_url}`)
|
||||
liveUrlLogged = true
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, POLL_INTERVAL_MS))
|
||||
}
|
||||
|
||||
const finalResult = await fetchTaskStatus(taskId, apiKey)
|
||||
if (finalResult.ok && ['finished', 'failed', 'stopped'].includes(finalResult.data.status)) {
|
||||
return {
|
||||
success: finalResult.data.status === 'finished',
|
||||
output: finalResult.data.output ?? null,
|
||||
steps: finalResult.data.steps || [],
|
||||
}
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`Task ${taskId} did not complete within the maximum polling time (${MAX_POLL_TIME_MS / 1000}s)`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
output: null,
|
||||
steps: [],
|
||||
error: `Task did not complete within the maximum polling time (${MAX_POLL_TIME_MS / 1000}s)`,
|
||||
}
|
||||
}
|
||||
|
||||
export const runTaskTool: ToolConfig<BrowserUseRunTaskParams, BrowserUseRunTaskResponse> = {
|
||||
id: 'browser_use_run_task',
|
||||
@@ -44,7 +247,14 @@ export const runTaskTool: ToolConfig<BrowserUseRunTaskParams, BrowserUseRunTaskR
|
||||
visibility: 'user-only',
|
||||
description: 'API key for BrowserUse API',
|
||||
},
|
||||
profile_id: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'user-only',
|
||||
description: 'Browser profile ID for persistent sessions (cookies, login state)',
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
url: 'https://api.browser-use.com/api/v2/tasks',
|
||||
method: 'POST',
|
||||
@@ -52,155 +262,94 @@ export const runTaskTool: ToolConfig<BrowserUseRunTaskParams, BrowserUseRunTaskR
|
||||
'Content-Type': 'application/json',
|
||||
'X-Browser-Use-API-Key': params.apiKey,
|
||||
}),
|
||||
body: (params) => {
|
||||
const requestBody: Record<string, any> = {
|
||||
task: params.task,
|
||||
}
|
||||
|
||||
if (params.variables) {
|
||||
let secrets: Record<string, string> = {}
|
||||
|
||||
if (Array.isArray(params.variables)) {
|
||||
logger.info('Converting variables array to dictionary format')
|
||||
params.variables.forEach((row) => {
|
||||
if (row.cells?.Key && row.cells.Value !== undefined) {
|
||||
secrets[row.cells.Key] = row.cells.Value
|
||||
logger.info(`Added secret for key: ${row.cells.Key}`)
|
||||
} else if (row.Key && row.Value !== undefined) {
|
||||
secrets[row.Key] = row.Value
|
||||
logger.info(`Added secret for key: ${row.Key}`)
|
||||
}
|
||||
})
|
||||
} else if (typeof params.variables === 'object' && params.variables !== null) {
|
||||
logger.info('Using variables object directly')
|
||||
secrets = params.variables
|
||||
}
|
||||
|
||||
if (Object.keys(secrets).length > 0) {
|
||||
logger.info(`Found ${Object.keys(secrets).length} secrets to include`)
|
||||
requestBody.secrets = secrets
|
||||
} else {
|
||||
logger.warn('No usable secrets found in variables')
|
||||
}
|
||||
}
|
||||
|
||||
if (params.model) {
|
||||
requestBody.llm_model = params.model
|
||||
}
|
||||
|
||||
if (params.save_browser_data) {
|
||||
requestBody.save_browser_data = params.save_browser_data
|
||||
}
|
||||
|
||||
requestBody.use_adblock = true
|
||||
requestBody.highlight_elements = true
|
||||
|
||||
return requestBody
|
||||
},
|
||||
},
|
||||
|
||||
transformResponse: async (response: Response) => {
|
||||
const data = (await response.json()) as { id: string }
|
||||
return {
|
||||
success: true,
|
||||
output: {
|
||||
id: data.id,
|
||||
success: true,
|
||||
output: null,
|
||||
steps: [],
|
||||
},
|
||||
}
|
||||
},
|
||||
directExecution: async (params: BrowserUseRunTaskParams): Promise<ToolResponse> => {
|
||||
let sessionId: string | undefined
|
||||
|
||||
postProcess: async (result, params) => {
|
||||
if (!result.success) {
|
||||
return result
|
||||
if (params.profile_id) {
|
||||
logger.info(`Creating session with profile ID: ${params.profile_id}`)
|
||||
const sessionResult = await createSessionWithProfile(params.profile_id, params.apiKey)
|
||||
if ('error' in sessionResult) {
|
||||
return {
|
||||
success: false,
|
||||
output: {
|
||||
id: null,
|
||||
success: false,
|
||||
output: null,
|
||||
steps: [],
|
||||
},
|
||||
error: sessionResult.error,
|
||||
}
|
||||
}
|
||||
sessionId = sessionResult.sessionId
|
||||
}
|
||||
|
||||
const taskId = result.output.id
|
||||
let liveUrlLogged = false
|
||||
const requestBody = buildRequestBody(params, sessionId)
|
||||
logger.info('Creating BrowserUse task', { hasSession: !!sessionId })
|
||||
|
||||
try {
|
||||
const initialTaskResponse = await fetch(
|
||||
`https://api.browser-use.com/api/v2/tasks/${taskId}`,
|
||||
{
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'X-Browser-Use-API-Key': params.apiKey,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if (initialTaskResponse.ok) {
|
||||
const initialTaskData = await initialTaskResponse.json()
|
||||
if (initialTaskData.live_url) {
|
||||
logger.info(
|
||||
`BrowserUse task ${taskId} launched with live URL: ${initialTaskData.live_url}`
|
||||
)
|
||||
liveUrlLogged = true
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to get initial task details for ${taskId}:`, error)
|
||||
}
|
||||
|
||||
let elapsedTime = 0
|
||||
|
||||
while (elapsedTime < MAX_POLL_TIME_MS) {
|
||||
try {
|
||||
const statusResponse = await fetch(`https://api.browser-use.com/api/v2/tasks/${taskId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'X-Browser-Use-API-Key': params.apiKey,
|
||||
},
|
||||
})
|
||||
|
||||
if (!statusResponse.ok) {
|
||||
throw new Error(`Failed to get task status: ${statusResponse.statusText}`)
|
||||
}
|
||||
|
||||
const taskData = await statusResponse.json()
|
||||
const status = taskData.status
|
||||
|
||||
logger.info(`BrowserUse task ${taskId} status: ${status}`)
|
||||
|
||||
if (['finished', 'failed', 'stopped'].includes(status)) {
|
||||
result.output = {
|
||||
id: taskId,
|
||||
success: status === 'finished',
|
||||
output: taskData.output ?? null,
|
||||
steps: taskData.steps || [],
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
if (!liveUrlLogged && status === 'running' && taskData.live_url) {
|
||||
logger.info(`BrowserUse task ${taskId} running with live URL: ${taskData.live_url}`)
|
||||
liveUrlLogged = true
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, POLL_INTERVAL_MS))
|
||||
elapsedTime += POLL_INTERVAL_MS
|
||||
} catch (error: any) {
|
||||
logger.error('Error polling for task status:', {
|
||||
message: error.message || 'Unknown error',
|
||||
taskId,
|
||||
})
|
||||
const response = await fetch('https://api.browser-use.com/api/v2/tasks', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Browser-Use-API-Key': params.apiKey,
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
logger.error(`Failed to create task: ${errorText}`)
|
||||
return {
|
||||
...result,
|
||||
error: `Error polling for task status: ${error.message || 'Unknown error'}`,
|
||||
success: false,
|
||||
output: {
|
||||
id: null,
|
||||
success: false,
|
||||
output: null,
|
||||
steps: [],
|
||||
},
|
||||
error: `Failed to create task: ${response.statusText}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`Task ${taskId} did not complete within the maximum polling time (${MAX_POLL_TIME_MS / 1000}s)`
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
error: `Task did not complete within the maximum polling time (${MAX_POLL_TIME_MS / 1000}s)`,
|
||||
const data = (await response.json()) as { id: string }
|
||||
const taskId = data.id
|
||||
logger.info(`Created BrowserUse task: ${taskId}`)
|
||||
|
||||
const result = await pollForCompletion(taskId, params.apiKey)
|
||||
|
||||
if (sessionId) {
|
||||
await stopSession(sessionId, params.apiKey)
|
||||
}
|
||||
|
||||
return {
|
||||
success: result.success && !result.error,
|
||||
output: {
|
||||
id: taskId,
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
steps: result.steps,
|
||||
},
|
||||
error: result.error,
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Error creating BrowserUse task:', error)
|
||||
|
||||
if (sessionId) {
|
||||
await stopSession(sessionId, params.apiKey)
|
||||
}
|
||||
|
||||
return {
|
||||
success: false,
|
||||
output: {
|
||||
id: null,
|
||||
success: false,
|
||||
output: null,
|
||||
steps: [],
|
||||
},
|
||||
error: `Error creating task: ${error.message}`,
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ export interface BrowserUseRunTaskParams {
|
||||
variables?: Record<string, string>
|
||||
model?: string
|
||||
save_browser_data?: boolean
|
||||
profile_id?: string
|
||||
}
|
||||
|
||||
export interface BrowserUseTaskStep {
|
||||
|
||||
Reference in New Issue
Block a user