mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
Compare commits
35 Commits
feat/agent
...
v0.6.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c78c870fda | ||
|
|
0c80438ede | ||
|
|
41a7d247ea | ||
|
|
092525e8aa | ||
|
|
19442f19e2 | ||
|
|
1731a4d7f0 | ||
|
|
9fcd02fd3b | ||
|
|
ff7b5b528c | ||
|
|
30f2d1a0fc | ||
|
|
4bd0731871 | ||
|
|
4f3bc37fe4 | ||
|
|
84d6fdc423 | ||
|
|
4c12914d35 | ||
|
|
e9bdc57616 | ||
|
|
36612ae42a | ||
|
|
1c2c2c65d4 | ||
|
|
ecd3536a72 | ||
|
|
8c0a2e04b1 | ||
|
|
6586c5ce40 | ||
|
|
3ce947566d | ||
|
|
70c36cb7aa | ||
|
|
f1ec5fe824 | ||
|
|
e07e3c34cc | ||
|
|
0d2e6ff31d | ||
|
|
4fd0989264 | ||
|
|
67f8a687f6 | ||
|
|
af592349d3 | ||
|
|
0d86ea01f0 | ||
|
|
115f04e989 | ||
|
|
34d92fae89 | ||
|
|
67aa4bb332 | ||
|
|
15ace5e63f | ||
|
|
fdca73679d | ||
|
|
da46a387c9 | ||
|
|
b7e377ec4b |
@@ -1,6 +1,11 @@
|
||||
import { NextResponse } from 'next/server'
|
||||
import { abortActiveStream } from '@/lib/copilot/chat-streaming'
|
||||
import { getLatestRunForStream } from '@/lib/copilot/async-runs/repository'
|
||||
import { abortActiveStream, waitForPendingChatStream } from '@/lib/copilot/chat-streaming'
|
||||
import { SIM_AGENT_API_URL } from '@/lib/copilot/constants'
|
||||
import { authenticateCopilotRequestSessionOnly } from '@/lib/copilot/request-helpers'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
|
||||
const GO_EXPLICIT_ABORT_TIMEOUT_MS = 3000
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const { userId: authenticatedUserId, isAuthenticated } =
|
||||
@@ -12,11 +17,48 @@ export async function POST(request: Request) {
|
||||
|
||||
const body = await request.json().catch(() => ({}))
|
||||
const streamId = typeof body.streamId === 'string' ? body.streamId : ''
|
||||
let chatId = typeof body.chatId === 'string' ? body.chatId : ''
|
||||
|
||||
if (!streamId) {
|
||||
return NextResponse.json({ error: 'streamId is required' }, { status: 400 })
|
||||
}
|
||||
|
||||
const aborted = abortActiveStream(streamId)
|
||||
if (!chatId) {
|
||||
const run = await getLatestRunForStream(streamId, authenticatedUserId).catch(() => null)
|
||||
if (run?.chatId) {
|
||||
chatId = run.chatId
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const headers: Record<string, string> = { 'Content-Type': 'application/json' }
|
||||
if (env.COPILOT_API_KEY) {
|
||||
headers['x-api-key'] = env.COPILOT_API_KEY
|
||||
}
|
||||
const controller = new AbortController()
|
||||
const timeout = setTimeout(() => controller.abort(), GO_EXPLICIT_ABORT_TIMEOUT_MS)
|
||||
const response = await fetch(`${SIM_AGENT_API_URL}/api/streams/explicit-abort`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
signal: controller.signal,
|
||||
body: JSON.stringify({
|
||||
messageId: streamId,
|
||||
userId: authenticatedUserId,
|
||||
...(chatId ? { chatId } : {}),
|
||||
}),
|
||||
}).finally(() => clearTimeout(timeout))
|
||||
if (!response.ok) {
|
||||
throw new Error(`Explicit abort marker request failed: ${response.status}`)
|
||||
}
|
||||
} catch {
|
||||
// best effort: local abort should still proceed even if Go marker fails
|
||||
}
|
||||
|
||||
const aborted = await abortActiveStream(streamId)
|
||||
if (chatId) {
|
||||
await waitForPendingChatStream(chatId, GO_EXPLICIT_ABORT_TIMEOUT_MS + 1000, streamId).catch(
|
||||
() => false
|
||||
)
|
||||
}
|
||||
return NextResponse.json({ aborted })
|
||||
}
|
||||
|
||||
@@ -8,7 +8,9 @@ import { getSession } from '@/lib/auth'
|
||||
import { getAccessibleCopilotChat, resolveOrCreateChat } from '@/lib/copilot/chat-lifecycle'
|
||||
import { buildCopilotRequestPayload } from '@/lib/copilot/chat-payload'
|
||||
import {
|
||||
acquirePendingChatStream,
|
||||
createSSEStream,
|
||||
releasePendingChatStream,
|
||||
requestChatTitle,
|
||||
SSE_RESPONSE_HEADERS,
|
||||
} from '@/lib/copilot/chat-streaming'
|
||||
@@ -16,6 +18,7 @@ import { COPILOT_REQUEST_MODES } from '@/lib/copilot/models'
|
||||
import { orchestrateCopilotStream } from '@/lib/copilot/orchestrator'
|
||||
import { getStreamMeta, readStreamEvents } from '@/lib/copilot/orchestrator/stream/buffer'
|
||||
import type { OrchestratorResult } from '@/lib/copilot/orchestrator/types'
|
||||
import { resolveActiveResourceContext } from '@/lib/copilot/process-contents'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createBadRequestResponse,
|
||||
@@ -44,6 +47,13 @@ const FileAttachmentSchema = z.object({
|
||||
size: z.number(),
|
||||
})
|
||||
|
||||
const ResourceAttachmentSchema = z.object({
|
||||
type: z.enum(['workflow', 'table', 'file', 'knowledgebase']),
|
||||
id: z.string().min(1),
|
||||
title: z.string().optional(),
|
||||
active: z.boolean().optional(),
|
||||
})
|
||||
|
||||
const ChatMessageSchema = z.object({
|
||||
message: z.string().min(1, 'Message is required'),
|
||||
userMessageId: z.string().optional(),
|
||||
@@ -58,6 +68,7 @@ const ChatMessageSchema = z.object({
|
||||
stream: z.boolean().optional().default(true),
|
||||
implicitFeedback: z.string().optional(),
|
||||
fileAttachments: z.array(FileAttachmentSchema).optional(),
|
||||
resourceAttachments: z.array(ResourceAttachmentSchema).optional(),
|
||||
provider: z.string().optional(),
|
||||
contexts: z
|
||||
.array(
|
||||
@@ -98,6 +109,10 @@ const ChatMessageSchema = z.object({
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
const tracker = createRequestTracker()
|
||||
let actualChatId: string | undefined
|
||||
let pendingChatStreamAcquired = false
|
||||
let pendingChatStreamHandedOff = false
|
||||
let pendingChatStreamID: string | undefined
|
||||
|
||||
try {
|
||||
// Get session to access user information including name
|
||||
@@ -124,6 +139,7 @@ export async function POST(req: NextRequest) {
|
||||
stream,
|
||||
implicitFeedback,
|
||||
fileAttachments,
|
||||
resourceAttachments,
|
||||
provider,
|
||||
contexts,
|
||||
commands,
|
||||
@@ -189,7 +205,7 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
let currentChat: any = null
|
||||
let conversationHistory: any[] = []
|
||||
let actualChatId = chatId
|
||||
actualChatId = chatId
|
||||
const selectedModel = model || 'claude-opus-4-6'
|
||||
|
||||
if (chatId || createNewChat) {
|
||||
@@ -241,6 +257,39 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
Array.isArray(resourceAttachments) &&
|
||||
resourceAttachments.length > 0 &&
|
||||
resolvedWorkspaceId
|
||||
) {
|
||||
const results = await Promise.allSettled(
|
||||
resourceAttachments.map(async (r) => {
|
||||
const ctx = await resolveActiveResourceContext(
|
||||
r.type,
|
||||
r.id,
|
||||
resolvedWorkspaceId!,
|
||||
authenticatedUserId,
|
||||
actualChatId
|
||||
)
|
||||
if (!ctx) return null
|
||||
return {
|
||||
...ctx,
|
||||
tag: r.active ? '@active_tab' : '@open_tab',
|
||||
}
|
||||
})
|
||||
)
|
||||
for (const result of results) {
|
||||
if (result.status === 'fulfilled' && result.value) {
|
||||
agentContexts.push(result.value)
|
||||
} else if (result.status === 'rejected') {
|
||||
logger.error(
|
||||
`[${tracker.requestId}] Failed to resolve resource attachment`,
|
||||
result.reason
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const effectiveMode = mode === 'agent' ? 'build' : mode
|
||||
|
||||
const userPermission = resolvedWorkspaceId
|
||||
@@ -291,6 +340,21 @@ export async function POST(req: NextRequest) {
|
||||
})
|
||||
} catch {}
|
||||
|
||||
if (stream && actualChatId) {
|
||||
const acquired = await acquirePendingChatStream(actualChatId, userMessageIdToUse)
|
||||
if (!acquired) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error:
|
||||
'A response is already in progress for this chat. Wait for it to finish or use Stop.',
|
||||
},
|
||||
{ status: 409 }
|
||||
)
|
||||
}
|
||||
pendingChatStreamAcquired = true
|
||||
pendingChatStreamID = userMessageIdToUse
|
||||
}
|
||||
|
||||
if (actualChatId) {
|
||||
const userMsg = {
|
||||
id: userMessageIdToUse,
|
||||
@@ -337,6 +401,7 @@ export async function POST(req: NextRequest) {
|
||||
titleProvider: provider,
|
||||
requestId: tracker.requestId,
|
||||
workspaceId: resolvedWorkspaceId,
|
||||
pendingChatStreamAlreadyRegistered: Boolean(actualChatId && stream),
|
||||
orchestrateOptions: {
|
||||
userId: authenticatedUserId,
|
||||
workflowId,
|
||||
@@ -348,6 +413,7 @@ export async function POST(req: NextRequest) {
|
||||
interactive: true,
|
||||
onComplete: async (result: OrchestratorResult) => {
|
||||
if (!actualChatId) return
|
||||
if (!result.success) return
|
||||
|
||||
const assistantMessage: Record<string, unknown> = {
|
||||
id: crypto.randomUUID(),
|
||||
@@ -423,6 +489,7 @@ export async function POST(req: NextRequest) {
|
||||
},
|
||||
},
|
||||
})
|
||||
pendingChatStreamHandedOff = true
|
||||
|
||||
return new Response(sseStream, { headers: SSE_RESPONSE_HEADERS })
|
||||
}
|
||||
@@ -528,6 +595,14 @@ export async function POST(req: NextRequest) {
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
if (
|
||||
actualChatId &&
|
||||
pendingChatStreamAcquired &&
|
||||
!pendingChatStreamHandedOff &&
|
||||
pendingChatStreamID
|
||||
) {
|
||||
await releasePendingChatStream(actualChatId, pendingChatStreamID).catch(() => {})
|
||||
}
|
||||
const duration = tracker.getDuration()
|
||||
|
||||
if (error instanceof z.ZodError) {
|
||||
|
||||
@@ -8,9 +8,9 @@ import { getSession } from '@/lib/auth'
|
||||
import { resolveOrCreateChat } from '@/lib/copilot/chat-lifecycle'
|
||||
import { buildCopilotRequestPayload } from '@/lib/copilot/chat-payload'
|
||||
import {
|
||||
acquirePendingChatStream,
|
||||
createSSEStream,
|
||||
SSE_RESPONSE_HEADERS,
|
||||
waitForPendingChatStream,
|
||||
} from '@/lib/copilot/chat-streaming'
|
||||
import type { OrchestratorResult } from '@/lib/copilot/orchestrator/types'
|
||||
import { processContextsServer, resolveActiveResourceContext } from '@/lib/copilot/process-contents'
|
||||
@@ -253,7 +253,16 @@ export async function POST(req: NextRequest) {
|
||||
)
|
||||
|
||||
if (actualChatId) {
|
||||
await waitForPendingChatStream(actualChatId)
|
||||
const acquired = await acquirePendingChatStream(actualChatId, userMessageId)
|
||||
if (!acquired) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error:
|
||||
'A response is already in progress for this chat. Wait for it to finish or use Stop.',
|
||||
},
|
||||
{ status: 409 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const executionId = crypto.randomUUID()
|
||||
@@ -271,6 +280,7 @@ export async function POST(req: NextRequest) {
|
||||
titleModel: 'claude-opus-4-6',
|
||||
requestId: tracker.requestId,
|
||||
workspaceId,
|
||||
pendingChatStreamAlreadyRegistered: Boolean(actualChatId),
|
||||
orchestrateOptions: {
|
||||
userId: authenticatedUserId,
|
||||
workspaceId,
|
||||
@@ -282,6 +292,7 @@ export async function POST(req: NextRequest) {
|
||||
interactive: true,
|
||||
onComplete: async (result: OrchestratorResult) => {
|
||||
if (!actualChatId) return
|
||||
if (!result.success) return
|
||||
|
||||
const assistantMessage: Record<string, unknown> = {
|
||||
id: crypto.randomUUID(),
|
||||
|
||||
@@ -323,8 +323,8 @@ export function useChat(
|
||||
reader: ReadableStreamDefaultReader<Uint8Array>,
|
||||
assistantId: string,
|
||||
expectedGen?: number
|
||||
) => Promise<void>
|
||||
>(async () => {})
|
||||
) => Promise<boolean>
|
||||
>(async () => false)
|
||||
const finalizeRef = useRef<(options?: { error?: boolean }) => void>(() => {})
|
||||
|
||||
const abortControllerRef = useRef<AbortController | null>(null)
|
||||
@@ -415,6 +415,8 @@ export function useChat(
|
||||
setIsReconnecting(false)
|
||||
setResources([])
|
||||
setActiveResourceId(null)
|
||||
setStreamingFile(null)
|
||||
streamingFileRef.current = null
|
||||
setMessageQueue([])
|
||||
}, [initialChatId, queryClient])
|
||||
|
||||
@@ -433,6 +435,8 @@ export function useChat(
|
||||
setIsReconnecting(false)
|
||||
setResources([])
|
||||
setActiveResourceId(null)
|
||||
setStreamingFile(null)
|
||||
streamingFileRef.current = null
|
||||
setMessageQueue([])
|
||||
}, [isHomePage])
|
||||
|
||||
@@ -441,12 +445,6 @@ export function useChat(
|
||||
|
||||
const activeStreamId = chatHistory.activeStreamId
|
||||
const snapshot = chatHistory.streamSnapshot
|
||||
|
||||
if (activeStreamId && !snapshot && !sendingRef.current) {
|
||||
queryClient.invalidateQueries({ queryKey: taskKeys.detail(chatHistory.id) })
|
||||
return
|
||||
}
|
||||
|
||||
appliedChatIdRef.current = chatHistory.id
|
||||
const mappedMessages = chatHistory.messages.map(mapStoredMessage)
|
||||
const shouldPreserveActiveStreamingMessage =
|
||||
@@ -497,7 +495,6 @@ export function useChat(
|
||||
}
|
||||
|
||||
if (activeStreamId && !sendingRef.current) {
|
||||
abortControllerRef.current?.abort()
|
||||
const gen = ++streamGenRef.current
|
||||
const abortController = new AbortController()
|
||||
abortControllerRef.current = abortController
|
||||
@@ -508,6 +505,7 @@ export function useChat(
|
||||
const assistantId = crypto.randomUUID()
|
||||
|
||||
const reconnect = async () => {
|
||||
let reconnectFailed = false
|
||||
try {
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
@@ -515,14 +513,8 @@ export function useChat(
|
||||
const streamStatus = snapshot?.status ?? ''
|
||||
|
||||
if (batchEvents.length === 0 && streamStatus === 'unknown') {
|
||||
const cid = chatIdRef.current
|
||||
if (cid) {
|
||||
fetch(stopPathRef.current, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ chatId: cid, streamId: activeStreamId, content: '' }),
|
||||
}).catch(() => {})
|
||||
}
|
||||
reconnectFailed = true
|
||||
setError(RECONNECT_TAIL_ERROR)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -550,6 +542,7 @@ export function useChat(
|
||||
{ signal: abortController.signal }
|
||||
)
|
||||
if (!sseRes.ok || !sseRes.body) {
|
||||
reconnectFailed = true
|
||||
logger.warn('SSE tail reconnect returned no readable body', {
|
||||
status: sseRes.status,
|
||||
streamId: activeStreamId,
|
||||
@@ -565,6 +558,7 @@ export function useChat(
|
||||
}
|
||||
} catch (err) {
|
||||
if (!(err instanceof Error && err.name === 'AbortError')) {
|
||||
reconnectFailed = true
|
||||
logger.warn('SSE tail failed during reconnect', err)
|
||||
setError(RECONNECT_TAIL_ERROR)
|
||||
}
|
||||
@@ -575,13 +569,21 @@ export function useChat(
|
||||
},
|
||||
})
|
||||
|
||||
await processSSEStreamRef.current(combinedStream.getReader(), assistantId, gen)
|
||||
const hadStreamError = await processSSEStreamRef.current(
|
||||
combinedStream.getReader(),
|
||||
assistantId,
|
||||
gen
|
||||
)
|
||||
if (hadStreamError) {
|
||||
reconnectFailed = true
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === 'AbortError') return
|
||||
reconnectFailed = true
|
||||
} finally {
|
||||
setIsReconnecting(false)
|
||||
if (streamGenRef.current === gen) {
|
||||
finalizeRef.current()
|
||||
finalizeRef.current(reconnectFailed ? { error: true } : undefined)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -619,7 +621,34 @@ export function useChat(
|
||||
return b
|
||||
}
|
||||
|
||||
const appendInlineErrorTag = (tag: string) => {
|
||||
if (runningText.includes(tag)) return
|
||||
const tb = ensureTextBlock()
|
||||
const prefix = runningText.length > 0 && !runningText.endsWith('\n') ? '\n' : ''
|
||||
tb.content = `${tb.content ?? ''}${prefix}${tag}`
|
||||
if (activeSubagent) tb.subagent = activeSubagent
|
||||
runningText += `${prefix}${tag}`
|
||||
streamingContentRef.current = runningText
|
||||
flush()
|
||||
}
|
||||
|
||||
const buildInlineErrorTag = (payload: SSEPayload) => {
|
||||
const data = getPayloadData(payload) as Record<string, unknown> | undefined
|
||||
const message =
|
||||
(data?.displayMessage as string | undefined) ||
|
||||
payload.error ||
|
||||
'An unexpected error occurred'
|
||||
const provider = (data?.provider as string | undefined) || undefined
|
||||
const code = (data?.code as string | undefined) || undefined
|
||||
return `<mothership-error>${JSON.stringify({
|
||||
message,
|
||||
...(code ? { code } : {}),
|
||||
...(provider ? { provider } : {}),
|
||||
})}</mothership-error>`
|
||||
}
|
||||
|
||||
const isStale = () => expectedGen !== undefined && streamGenRef.current !== expectedGen
|
||||
let sawStreamError = false
|
||||
|
||||
const flush = () => {
|
||||
if (isStale()) return
|
||||
@@ -644,12 +673,9 @@ export function useChat(
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
if (isStale()) {
|
||||
reader.cancel().catch(() => {})
|
||||
break
|
||||
}
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
if (isStale()) continue
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
@@ -1113,21 +1139,20 @@ export function useChat(
|
||||
break
|
||||
}
|
||||
case 'error': {
|
||||
sawStreamError = true
|
||||
setError(parsed.error || 'An error occurred')
|
||||
appendInlineErrorTag(buildInlineErrorTag(parsed))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isStale()) {
|
||||
reader.cancel().catch(() => {})
|
||||
break
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (streamReaderRef.current === reader) {
|
||||
streamReaderRef.current = null
|
||||
}
|
||||
}
|
||||
return sawStreamError
|
||||
},
|
||||
[workspaceId, queryClient, addResource, removeResource]
|
||||
)
|
||||
@@ -1354,7 +1379,10 @@ export function useChat(
|
||||
|
||||
if (!response.body) throw new Error('No response body')
|
||||
|
||||
await processSSEStream(response.body.getReader(), assistantId, gen)
|
||||
const hadStreamError = await processSSEStream(response.body.getReader(), assistantId, gen)
|
||||
if (streamGenRef.current === gen) {
|
||||
finalize(hadStreamError ? { error: true } : undefined)
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === 'AbortError') return
|
||||
setError(err instanceof Error ? err.message : 'Failed to send message')
|
||||
@@ -1363,9 +1391,6 @@ export function useChat(
|
||||
}
|
||||
return
|
||||
}
|
||||
if (streamGenRef.current === gen) {
|
||||
finalize()
|
||||
}
|
||||
},
|
||||
[workspaceId, queryClient, processSSEStream, finalize]
|
||||
)
|
||||
@@ -1387,6 +1412,25 @@ export function useChat(
|
||||
sendingRef.current = false
|
||||
setIsSending(false)
|
||||
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) => {
|
||||
if (!msg.contentBlocks?.some((b) => b.toolCall?.status === 'executing')) return msg
|
||||
const updated = msg.contentBlocks!.map((block) => {
|
||||
if (block.toolCall?.status !== 'executing') return block
|
||||
return {
|
||||
...block,
|
||||
toolCall: {
|
||||
...block.toolCall,
|
||||
status: 'cancelled' as const,
|
||||
displayTitle: 'Stopped by user',
|
||||
},
|
||||
}
|
||||
})
|
||||
updated.push({ type: 'stopped' as const })
|
||||
return { ...msg, contentBlocks: updated }
|
||||
})
|
||||
)
|
||||
|
||||
if (sid) {
|
||||
fetch('/api/copilot/chat/abort', {
|
||||
method: 'POST',
|
||||
@@ -1410,25 +1454,6 @@ export function useChat(
|
||||
streamingFileRef.current = null
|
||||
setResources((rs) => rs.filter((resource) => resource.id !== 'streaming-file'))
|
||||
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) => {
|
||||
if (!msg.contentBlocks?.some((b) => b.toolCall?.status === 'executing')) return msg
|
||||
const updated = msg.contentBlocks!.map((block) => {
|
||||
if (block.toolCall?.status !== 'executing') return block
|
||||
return {
|
||||
...block,
|
||||
toolCall: {
|
||||
...block.toolCall,
|
||||
status: 'cancelled' as const,
|
||||
displayTitle: 'Stopped by user',
|
||||
},
|
||||
}
|
||||
})
|
||||
updated.push({ type: 'stopped' as const })
|
||||
return { ...msg, contentBlocks: updated }
|
||||
})
|
||||
)
|
||||
|
||||
const execState = useExecutionStore.getState()
|
||||
const consoleStore = useTerminalConsoleStore.getState()
|
||||
for (const [workflowId, wfExec] of execState.workflowExecutions) {
|
||||
@@ -1500,7 +1525,6 @@ export function useChat(
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
streamReaderRef.current?.cancel().catch(() => {})
|
||||
streamReaderRef.current = null
|
||||
abortControllerRef.current = null
|
||||
streamGenRef.current++
|
||||
|
||||
@@ -86,6 +86,20 @@ export async function getLatestRunForExecution(executionId: string) {
|
||||
return run ?? null
|
||||
}
|
||||
|
||||
export async function getLatestRunForStream(streamId: string, userId?: string) {
|
||||
const conditions = userId
|
||||
? and(eq(copilotRuns.streamId, streamId), eq(copilotRuns.userId, userId))
|
||||
: eq(copilotRuns.streamId, streamId)
|
||||
const [run] = await db
|
||||
.select()
|
||||
.from(copilotRuns)
|
||||
.where(conditions)
|
||||
.orderBy(desc(copilotRuns.startedAt))
|
||||
.limit(1)
|
||||
|
||||
return run ?? null
|
||||
}
|
||||
|
||||
export async function getRunSegment(runId: string) {
|
||||
const [run] = await db.select().from(copilotRuns).where(eq(copilotRuns.id, runId)).limit(1)
|
||||
return run ?? null
|
||||
@@ -121,6 +135,20 @@ export async function upsertAsyncToolCall(input: {
|
||||
status?: CopilotAsyncToolStatus
|
||||
}) {
|
||||
const existing = await getAsyncToolCall(input.toolCallId)
|
||||
const incomingStatus = input.status ?? 'pending'
|
||||
if (
|
||||
existing &&
|
||||
(isTerminalAsyncStatus(existing.status) || isDeliveredAsyncStatus(existing.status)) &&
|
||||
!isTerminalAsyncStatus(incomingStatus) &&
|
||||
!isDeliveredAsyncStatus(incomingStatus)
|
||||
) {
|
||||
logger.info('Ignoring async tool upsert that would downgrade terminal state', {
|
||||
toolCallId: input.toolCallId,
|
||||
existingStatus: existing.status,
|
||||
incomingStatus,
|
||||
})
|
||||
return existing
|
||||
}
|
||||
const effectiveRunId = input.runId ?? existing?.runId ?? null
|
||||
if (!effectiveRunId) {
|
||||
logger.warn('upsertAsyncToolCall missing runId and no existing row', {
|
||||
@@ -140,7 +168,7 @@ export async function upsertAsyncToolCall(input: {
|
||||
toolCallId: input.toolCallId,
|
||||
toolName: input.toolName,
|
||||
args: input.args ?? {},
|
||||
status: input.status ?? 'pending',
|
||||
status: incomingStatus,
|
||||
updatedAt: now,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
@@ -150,7 +178,7 @@ export async function upsertAsyncToolCall(input: {
|
||||
checkpointId: input.checkpointId ?? null,
|
||||
toolName: input.toolName,
|
||||
args: input.args ?? {},
|
||||
status: input.status ?? 'pending',
|
||||
status: incomingStatus,
|
||||
updatedAt: now,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -8,14 +8,19 @@ import type { OrchestrateStreamOptions } from '@/lib/copilot/orchestrator'
|
||||
import { orchestrateCopilotStream } from '@/lib/copilot/orchestrator'
|
||||
import {
|
||||
createStreamEventWriter,
|
||||
getStreamMeta,
|
||||
resetStreamBuffer,
|
||||
setStreamMeta,
|
||||
} from '@/lib/copilot/orchestrator/stream/buffer'
|
||||
import { taskPubSub } from '@/lib/copilot/task-events'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { acquireLock, getRedisClient, releaseLock } from '@/lib/core/config/redis'
|
||||
import { SSE_HEADERS } from '@/lib/core/utils/sse'
|
||||
|
||||
const logger = createLogger('CopilotChatStreaming')
|
||||
const CHAT_STREAM_LOCK_TTL_SECONDS = 2 * 60 * 60
|
||||
const STREAM_ABORT_TTL_SECONDS = 10 * 60
|
||||
const STREAM_ABORT_POLL_MS = 1000
|
||||
|
||||
// Registry of in-flight Sim→Go streams so the explicit abort endpoint can
|
||||
// reach them. Keyed by streamId, cleaned up when the stream completes.
|
||||
@@ -48,25 +53,138 @@ function resolvePendingChatStream(chatId: string, streamId: string): void {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Abort any in-flight stream on `chatId` and wait for it to fully settle
|
||||
* (including onComplete and Go-side persistence). Returns immediately if
|
||||
* no stream is active. Gives up after `timeoutMs`.
|
||||
*/
|
||||
export async function waitForPendingChatStream(chatId: string, timeoutMs = 5_000): Promise<void> {
|
||||
const entry = pendingChatStreams.get(chatId)
|
||||
if (!entry) return
|
||||
|
||||
// Force-abort the previous stream so we don't passively wait for it to
|
||||
// finish naturally (which could take tens of seconds for a subagent).
|
||||
abortActiveStream(entry.streamId)
|
||||
|
||||
await Promise.race([entry.promise, new Promise<void>((r) => setTimeout(r, timeoutMs))])
|
||||
function getChatStreamLockKey(chatId: string): string {
|
||||
return `copilot:chat-stream-lock:${chatId}`
|
||||
}
|
||||
|
||||
export function abortActiveStream(streamId: string): boolean {
|
||||
function getStreamAbortKey(streamId: string): string {
|
||||
return `copilot:stream-abort:${streamId}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for any in-flight stream on `chatId` to settle without force-aborting it.
|
||||
* Returns true when no stream is active (or it settles in time), false on timeout.
|
||||
*/
|
||||
export async function waitForPendingChatStream(
|
||||
chatId: string,
|
||||
timeoutMs = 5_000,
|
||||
expectedStreamId?: string
|
||||
): Promise<boolean> {
|
||||
const redis = getRedisClient()
|
||||
const deadline = Date.now() + timeoutMs
|
||||
|
||||
for (;;) {
|
||||
const entry = pendingChatStreams.get(chatId)
|
||||
const localPending = !!entry && (!expectedStreamId || entry.streamId === expectedStreamId)
|
||||
|
||||
if (redis) {
|
||||
try {
|
||||
const ownerStreamId = await redis.get(getChatStreamLockKey(chatId))
|
||||
const lockReleased =
|
||||
!ownerStreamId || (expectedStreamId !== undefined && ownerStreamId !== expectedStreamId)
|
||||
if (!localPending && lockReleased) {
|
||||
return true
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Failed to check distributed chat stream lock while waiting', {
|
||||
chatId,
|
||||
expectedStreamId,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
}
|
||||
} else if (!localPending) {
|
||||
return true
|
||||
}
|
||||
|
||||
if (Date.now() >= deadline) return false
|
||||
await new Promise((resolve) => setTimeout(resolve, 200))
|
||||
}
|
||||
}
|
||||
|
||||
export async function releasePendingChatStream(chatId: string, streamId: string): Promise<void> {
|
||||
const redis = getRedisClient()
|
||||
if (redis) {
|
||||
await releaseLock(getChatStreamLockKey(chatId), streamId).catch(() => false)
|
||||
}
|
||||
resolvePendingChatStream(chatId, streamId)
|
||||
}
|
||||
|
||||
export async function acquirePendingChatStream(
|
||||
chatId: string,
|
||||
streamId: string,
|
||||
timeoutMs = 5_000
|
||||
): Promise<boolean> {
|
||||
const redis = getRedisClient()
|
||||
if (redis) {
|
||||
const deadline = Date.now() + timeoutMs
|
||||
for (;;) {
|
||||
try {
|
||||
const acquired = await acquireLock(
|
||||
getChatStreamLockKey(chatId),
|
||||
streamId,
|
||||
CHAT_STREAM_LOCK_TTL_SECONDS
|
||||
)
|
||||
if (acquired) {
|
||||
registerPendingChatStream(chatId, streamId)
|
||||
return true
|
||||
}
|
||||
if (!pendingChatStreams.has(chatId)) {
|
||||
const ownerStreamId = await redis.get(getChatStreamLockKey(chatId))
|
||||
if (ownerStreamId) {
|
||||
const ownerMeta = await getStreamMeta(ownerStreamId)
|
||||
const ownerTerminal =
|
||||
ownerMeta?.status === 'complete' ||
|
||||
ownerMeta?.status === 'error' ||
|
||||
ownerMeta?.status === 'cancelled'
|
||||
if (ownerTerminal) {
|
||||
await releaseLock(getChatStreamLockKey(chatId), ownerStreamId).catch(() => false)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Distributed chat stream lock failed; retrying distributed coordination', {
|
||||
chatId,
|
||||
streamId,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
}
|
||||
if (Date.now() >= deadline) return false
|
||||
await new Promise((resolve) => setTimeout(resolve, 200))
|
||||
}
|
||||
}
|
||||
|
||||
for (;;) {
|
||||
const existing = pendingChatStreams.get(chatId)
|
||||
if (!existing) {
|
||||
registerPendingChatStream(chatId, streamId)
|
||||
return true
|
||||
}
|
||||
|
||||
const settled = await Promise.race([
|
||||
existing.promise.then(() => true),
|
||||
new Promise<boolean>((r) => setTimeout(() => r(false), timeoutMs)),
|
||||
])
|
||||
if (!settled) return false
|
||||
}
|
||||
}
|
||||
|
||||
export async function abortActiveStream(streamId: string): Promise<boolean> {
|
||||
const redis = getRedisClient()
|
||||
let published = false
|
||||
if (redis) {
|
||||
try {
|
||||
await redis.set(getStreamAbortKey(streamId), '1', 'EX', STREAM_ABORT_TTL_SECONDS)
|
||||
published = true
|
||||
} catch (error) {
|
||||
logger.warn('Failed to publish distributed stream abort', {
|
||||
streamId,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
}
|
||||
}
|
||||
const controller = activeStreams.get(streamId)
|
||||
if (!controller) return false
|
||||
if (!controller) return published
|
||||
controller.abort()
|
||||
activeStreams.delete(streamId)
|
||||
return true
|
||||
@@ -135,6 +253,7 @@ export interface StreamingOrchestrationParams {
|
||||
requestId: string
|
||||
workspaceId?: string
|
||||
orchestrateOptions: Omit<OrchestrateStreamOptions, 'onEvent'>
|
||||
pendingChatStreamAlreadyRegistered?: boolean
|
||||
}
|
||||
|
||||
export function createSSEStream(params: StreamingOrchestrationParams): ReadableStream {
|
||||
@@ -153,6 +272,7 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
|
||||
requestId,
|
||||
workspaceId,
|
||||
orchestrateOptions,
|
||||
pendingChatStreamAlreadyRegistered = false,
|
||||
} = params
|
||||
|
||||
let eventWriter: ReturnType<typeof createStreamEventWriter> | null = null
|
||||
@@ -160,7 +280,7 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
|
||||
const abortController = new AbortController()
|
||||
activeStreams.set(streamId, abortController)
|
||||
|
||||
if (chatId) {
|
||||
if (chatId && !pendingChatStreamAlreadyRegistered) {
|
||||
registerPendingChatStream(chatId, streamId)
|
||||
}
|
||||
|
||||
@@ -191,14 +311,47 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
|
||||
eventWriter = createStreamEventWriter(streamId)
|
||||
|
||||
let localSeq = 0
|
||||
let abortPoller: ReturnType<typeof setInterval> | null = null
|
||||
|
||||
const redis = getRedisClient()
|
||||
if (redis) {
|
||||
abortPoller = setInterval(() => {
|
||||
void (async () => {
|
||||
try {
|
||||
const shouldAbort = await redis.get(getStreamAbortKey(streamId))
|
||||
if (shouldAbort && !abortController.signal.aborted) {
|
||||
abortController.abort()
|
||||
await redis.del(getStreamAbortKey(streamId))
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to poll distributed stream abort`, {
|
||||
streamId,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
}
|
||||
})()
|
||||
}, STREAM_ABORT_POLL_MS)
|
||||
}
|
||||
|
||||
const pushEvent = async (event: Record<string, any>) => {
|
||||
if (!eventWriter) return
|
||||
|
||||
const eventId = ++localSeq
|
||||
|
||||
// Enqueue to client stream FIRST for minimal latency.
|
||||
// Redis persistence happens after so the client never waits on I/O.
|
||||
try {
|
||||
await eventWriter.write(event)
|
||||
if (FLUSH_EVENT_TYPES.has(event.type)) {
|
||||
await eventWriter.flush()
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to persist stream event`, {
|
||||
eventType: event.type,
|
||||
eventId,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
// Keep the live SSE stream going even if durable buffering hiccups.
|
||||
}
|
||||
|
||||
try {
|
||||
if (!clientDisconnected) {
|
||||
controller.enqueue(
|
||||
@@ -208,16 +361,16 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
|
||||
} catch {
|
||||
clientDisconnected = true
|
||||
}
|
||||
}
|
||||
|
||||
const pushEventBestEffort = async (event: Record<string, any>) => {
|
||||
try {
|
||||
await eventWriter.write(event)
|
||||
if (FLUSH_EVENT_TYPES.has(event.type)) {
|
||||
await eventWriter.flush()
|
||||
}
|
||||
} catch {
|
||||
if (clientDisconnected) {
|
||||
await eventWriter.flush().catch(() => {})
|
||||
}
|
||||
await pushEvent(event)
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to push event`, {
|
||||
eventType: event.type,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,7 +437,7 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
|
||||
logger.error(`[${requestId}] Orchestration returned failure`, {
|
||||
error: errorMessage,
|
||||
})
|
||||
await pushEvent({
|
||||
await pushEventBestEffort({
|
||||
type: 'error',
|
||||
error: errorMessage,
|
||||
data: {
|
||||
@@ -324,7 +477,7 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
|
||||
}
|
||||
logger.error(`[${requestId}] Orchestration error:`, error)
|
||||
const errorMessage = error instanceof Error ? error.message : 'Stream error'
|
||||
await pushEvent({
|
||||
await pushEventBestEffort({
|
||||
type: 'error',
|
||||
error: errorMessage,
|
||||
data: {
|
||||
@@ -345,10 +498,19 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
|
||||
}).catch(() => {})
|
||||
} finally {
|
||||
clearInterval(keepaliveInterval)
|
||||
if (abortPoller) {
|
||||
clearInterval(abortPoller)
|
||||
}
|
||||
activeStreams.delete(streamId)
|
||||
if (chatId) {
|
||||
if (redis) {
|
||||
await releaseLock(getChatStreamLockKey(chatId), streamId).catch(() => false)
|
||||
}
|
||||
resolvePendingChatStream(chatId, streamId)
|
||||
}
|
||||
if (redis) {
|
||||
await redis.del(getStreamAbortKey(streamId)).catch(() => {})
|
||||
}
|
||||
try {
|
||||
controller.close()
|
||||
} catch {
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import type { OrchestratorOptions } from './types'
|
||||
|
||||
const {
|
||||
prepareExecutionContext,
|
||||
getEffectiveDecryptedEnv,
|
||||
runStreamLoop,
|
||||
claimCompletedAsyncToolCall,
|
||||
getAsyncToolCall,
|
||||
getAsyncToolCalls,
|
||||
markAsyncToolDelivered,
|
||||
releaseCompletedAsyncToolClaim,
|
||||
updateRunStatus,
|
||||
} = vi.hoisted(() => ({
|
||||
prepareExecutionContext: vi.fn(),
|
||||
getEffectiveDecryptedEnv: vi.fn(),
|
||||
runStreamLoop: vi.fn(),
|
||||
claimCompletedAsyncToolCall: vi.fn(),
|
||||
getAsyncToolCall: vi.fn(),
|
||||
getAsyncToolCalls: vi.fn(),
|
||||
markAsyncToolDelivered: vi.fn(),
|
||||
releaseCompletedAsyncToolClaim: vi.fn(),
|
||||
updateRunStatus: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/copilot/orchestrator/tool-executor', () => ({
|
||||
prepareExecutionContext,
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/environment/utils', () => ({
|
||||
getEffectiveDecryptedEnv,
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/copilot/async-runs/repository', () => ({
|
||||
claimCompletedAsyncToolCall,
|
||||
getAsyncToolCall,
|
||||
getAsyncToolCalls,
|
||||
markAsyncToolDelivered,
|
||||
releaseCompletedAsyncToolClaim,
|
||||
updateRunStatus,
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/copilot/orchestrator/stream/core', async () => {
|
||||
const actual = await vi.importActual<typeof import('./stream/core')>('./stream/core')
|
||||
return {
|
||||
...actual,
|
||||
buildToolCallSummaries: vi.fn(() => []),
|
||||
runStreamLoop,
|
||||
}
|
||||
})
|
||||
|
||||
import { orchestrateCopilotStream } from './index'
|
||||
|
||||
describe('orchestrateCopilotStream async continuation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
prepareExecutionContext.mockResolvedValue({
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
})
|
||||
getEffectiveDecryptedEnv.mockResolvedValue({})
|
||||
claimCompletedAsyncToolCall.mockResolvedValue({ toolCallId: 'tool-1' })
|
||||
getAsyncToolCall.mockResolvedValue({
|
||||
toolCallId: 'tool-1',
|
||||
toolName: 'read',
|
||||
status: 'completed',
|
||||
result: { ok: true },
|
||||
error: null,
|
||||
})
|
||||
getAsyncToolCalls.mockResolvedValue([
|
||||
{
|
||||
toolCallId: 'tool-1',
|
||||
toolName: 'read',
|
||||
status: 'completed',
|
||||
result: { ok: true },
|
||||
error: null,
|
||||
},
|
||||
])
|
||||
markAsyncToolDelivered.mockResolvedValue(null)
|
||||
releaseCompletedAsyncToolClaim.mockResolvedValue(null)
|
||||
updateRunStatus.mockResolvedValue(null)
|
||||
})
|
||||
|
||||
it('builds resume payloads with success=true for claimed completed rows', async () => {
|
||||
runStreamLoop
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1'],
|
||||
}
|
||||
})
|
||||
.mockImplementationOnce(async (url: string, opts: RequestInit) => {
|
||||
expect(url).toContain('/api/tools/resume')
|
||||
const body = JSON.parse(String(opts.body))
|
||||
expect(body).toEqual({
|
||||
checkpointId: 'checkpoint-1',
|
||||
results: [
|
||||
{
|
||||
callId: 'tool-1',
|
||||
name: 'read',
|
||||
data: { ok: true },
|
||||
success: true,
|
||||
},
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(markAsyncToolDelivered).toHaveBeenCalledWith('tool-1')
|
||||
})
|
||||
|
||||
it('marks claimed tool calls delivered even when the resumed stream later records errors', async () => {
|
||||
runStreamLoop
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1'],
|
||||
}
|
||||
})
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.errors.push('resume stream failed after handoff')
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(markAsyncToolDelivered).toHaveBeenCalledWith('tool-1')
|
||||
})
|
||||
|
||||
it('forwards done events while still marking async pauses on the run', async () => {
|
||||
const onEvent = vi.fn()
|
||||
const streamOptions: OrchestratorOptions = { onEvent }
|
||||
runStreamLoop.mockImplementationOnce(
|
||||
async (_url: string, _opts: RequestInit, _context: any, _exec: any, loopOptions: any) => {
|
||||
await loopOptions.onEvent({
|
||||
type: 'done',
|
||||
data: {
|
||||
response: {
|
||||
async_pause: {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
...streamOptions,
|
||||
}
|
||||
)
|
||||
|
||||
expect(onEvent).toHaveBeenCalledWith(expect.objectContaining({ type: 'done' }))
|
||||
expect(updateRunStatus).toHaveBeenCalledWith('run-1', 'paused_waiting_for_tool')
|
||||
})
|
||||
|
||||
it('waits for a local running tool before retrying the claim', async () => {
|
||||
const localPendingPromise = Promise.resolve({
|
||||
status: 'success',
|
||||
data: { ok: true },
|
||||
})
|
||||
|
||||
claimCompletedAsyncToolCall
|
||||
.mockResolvedValueOnce(null)
|
||||
.mockResolvedValueOnce({ toolCallId: 'tool-1' })
|
||||
getAsyncToolCall
|
||||
.mockResolvedValueOnce({
|
||||
toolCallId: 'tool-1',
|
||||
toolName: 'read',
|
||||
status: 'running',
|
||||
result: null,
|
||||
error: null,
|
||||
})
|
||||
.mockResolvedValue({
|
||||
toolCallId: 'tool-1',
|
||||
toolName: 'read',
|
||||
status: 'completed',
|
||||
result: { ok: true },
|
||||
error: null,
|
||||
})
|
||||
|
||||
runStreamLoop
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1'],
|
||||
}
|
||||
context.pendingToolPromises.set('tool-1', localPendingPromise)
|
||||
})
|
||||
.mockImplementationOnce(async (url: string, opts: RequestInit) => {
|
||||
expect(url).toContain('/api/tools/resume')
|
||||
const body = JSON.parse(String(opts.body))
|
||||
expect(body.results[0]).toEqual({
|
||||
callId: 'tool-1',
|
||||
name: 'read',
|
||||
data: { ok: true },
|
||||
success: true,
|
||||
})
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(runStreamLoop).toHaveBeenCalledTimes(2)
|
||||
expect(markAsyncToolDelivered).toHaveBeenCalledWith('tool-1')
|
||||
})
|
||||
|
||||
it('releases claimed rows if the resume stream throws before delivery is marked', async () => {
|
||||
runStreamLoop
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1'],
|
||||
}
|
||||
})
|
||||
.mockImplementationOnce(async () => {
|
||||
throw new Error('resume failed')
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(releaseCompletedAsyncToolClaim).toHaveBeenCalledWith('tool-1', 'run-1')
|
||||
expect(markAsyncToolDelivered).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -14,12 +14,17 @@ import {
|
||||
updateRunStatus,
|
||||
} from '@/lib/copilot/async-runs/repository'
|
||||
import { SIM_AGENT_API_URL, SIM_AGENT_VERSION } from '@/lib/copilot/constants'
|
||||
import { prepareExecutionContext } from '@/lib/copilot/orchestrator/tool-executor'
|
||||
import type {
|
||||
ExecutionContext,
|
||||
OrchestratorOptions,
|
||||
OrchestratorResult,
|
||||
SSEEvent,
|
||||
import {
|
||||
isToolAvailableOnSimSide,
|
||||
prepareExecutionContext,
|
||||
} from '@/lib/copilot/orchestrator/tool-executor'
|
||||
import {
|
||||
type ExecutionContext,
|
||||
isTerminalToolCallStatus,
|
||||
type OrchestratorOptions,
|
||||
type OrchestratorResult,
|
||||
type SSEEvent,
|
||||
type ToolCallState,
|
||||
} from '@/lib/copilot/orchestrator/types'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { getEffectiveDecryptedEnv } from '@/lib/environment/utils'
|
||||
@@ -31,18 +36,9 @@ function didAsyncToolSucceed(input: {
|
||||
durableStatus?: string | null
|
||||
durableResult?: Record<string, unknown>
|
||||
durableError?: string | null
|
||||
completion?: { status: string } | undefined
|
||||
toolStateSuccess?: boolean | undefined
|
||||
toolStateStatus?: string | undefined
|
||||
}) {
|
||||
const {
|
||||
durableStatus,
|
||||
durableResult,
|
||||
durableError,
|
||||
completion,
|
||||
toolStateSuccess,
|
||||
toolStateStatus,
|
||||
} = input
|
||||
const { durableStatus, durableResult, durableError, toolStateStatus } = input
|
||||
|
||||
if (durableStatus === ASYNC_TOOL_STATUS.completed) {
|
||||
return true
|
||||
@@ -61,7 +57,15 @@ function didAsyncToolSucceed(input: {
|
||||
if (toolStateStatus === 'success') return true
|
||||
if (toolStateStatus === 'error' || toolStateStatus === 'cancelled') return false
|
||||
|
||||
return completion?.status === 'success' || toolStateSuccess === true
|
||||
return false
|
||||
}
|
||||
|
||||
interface ReadyContinuationTool {
|
||||
toolCallId: string
|
||||
toolState?: ToolCallState
|
||||
durableRow?: Awaited<ReturnType<typeof getAsyncToolCall>>
|
||||
needsDurableClaim: boolean
|
||||
alreadyClaimedByWorker: boolean
|
||||
}
|
||||
|
||||
export interface OrchestrateStreamOptions extends OrchestratorOptions {
|
||||
@@ -190,32 +194,21 @@ export async function orchestrateCopilotStream(
|
||||
if (!continuation) break
|
||||
|
||||
let resumeReady = false
|
||||
let resumeRetries = 0
|
||||
for (;;) {
|
||||
claimedToolCallIds = []
|
||||
claimedByWorkerId = null
|
||||
const resumeWorkerId = continuation.runId || context.runId || context.messageId
|
||||
claimedByWorkerId = resumeWorkerId
|
||||
const claimableToolCallIds: string[] = []
|
||||
const readyTools: ReadyContinuationTool[] = []
|
||||
const localPendingPromises: Promise<unknown>[] = []
|
||||
const missingToolCallIds: string[] = []
|
||||
|
||||
for (const toolCallId of continuation.pendingToolCallIds) {
|
||||
const claimed = await claimCompletedAsyncToolCall(toolCallId, resumeWorkerId).catch(
|
||||
() => null
|
||||
)
|
||||
if (claimed) {
|
||||
claimableToolCallIds.push(toolCallId)
|
||||
claimedToolCallIds.push(toolCallId)
|
||||
continue
|
||||
}
|
||||
const durableRow = await getAsyncToolCall(toolCallId).catch(() => null)
|
||||
const localPendingPromise = context.pendingToolPromises.get(toolCallId)
|
||||
if (!durableRow && localPendingPromise) {
|
||||
claimableToolCallIds.push(toolCallId)
|
||||
continue
|
||||
}
|
||||
if (
|
||||
durableRow &&
|
||||
durableRow.status === ASYNC_TOOL_STATUS.running &&
|
||||
localPendingPromise
|
||||
) {
|
||||
const toolState = context.toolCalls.get(toolCallId)
|
||||
|
||||
if (localPendingPromise) {
|
||||
localPendingPromises.push(localPendingPromise)
|
||||
logger.info('Waiting for local async tool completion before retrying resume claim', {
|
||||
toolCallId,
|
||||
@@ -223,21 +216,55 @@ export async function orchestrateCopilotStream(
|
||||
})
|
||||
continue
|
||||
}
|
||||
const toolState = context.toolCalls.get(toolCallId)
|
||||
if (!durableRow && !localPendingPromise && toolState) {
|
||||
|
||||
if (durableRow && isTerminalAsyncStatus(durableRow.status)) {
|
||||
if (durableRow.claimedBy && durableRow.claimedBy !== resumeWorkerId) {
|
||||
missingToolCallIds.push(toolCallId)
|
||||
logger.warn('Async tool continuation is waiting on a claim held by another worker', {
|
||||
toolCallId,
|
||||
runId: continuation.runId,
|
||||
claimedBy: durableRow.claimedBy,
|
||||
})
|
||||
continue
|
||||
}
|
||||
readyTools.push({
|
||||
toolCallId,
|
||||
toolState,
|
||||
durableRow,
|
||||
needsDurableClaim: durableRow.claimedBy !== resumeWorkerId,
|
||||
alreadyClaimedByWorker: durableRow.claimedBy === resumeWorkerId,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if (
|
||||
!durableRow &&
|
||||
toolState &&
|
||||
isTerminalToolCallStatus(toolState.status) &&
|
||||
!isToolAvailableOnSimSide(toolState.name)
|
||||
) {
|
||||
logger.info('Including Go-handled tool in resume payload (no Sim-side row)', {
|
||||
toolCallId,
|
||||
toolName: toolState.name,
|
||||
status: toolState.status,
|
||||
runId: continuation.runId,
|
||||
})
|
||||
claimableToolCallIds.push(toolCallId)
|
||||
readyTools.push({
|
||||
toolCallId,
|
||||
toolState,
|
||||
needsDurableClaim: false,
|
||||
alreadyClaimedByWorker: false,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.warn('Skipping already-claimed or missing async tool resume', {
|
||||
toolCallId,
|
||||
runId: continuation.runId,
|
||||
durableStatus: durableRow?.status,
|
||||
toolStateStatus: toolState?.status,
|
||||
})
|
||||
missingToolCallIds.push(toolCallId)
|
||||
}
|
||||
|
||||
if (localPendingPromises.length > 0) {
|
||||
@@ -245,30 +272,104 @@ export async function orchestrateCopilotStream(
|
||||
continue
|
||||
}
|
||||
|
||||
if (claimableToolCallIds.length === 0) {
|
||||
logger.warn('Skipping async resume because no tool calls were claimable', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
})
|
||||
context.awaitingAsyncContinuation = undefined
|
||||
break
|
||||
if (missingToolCallIds.length > 0) {
|
||||
if (resumeRetries < 3) {
|
||||
resumeRetries++
|
||||
logger.info('Retrying async resume after some tool calls were not yet ready', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
retry: resumeRetries,
|
||||
missingToolCallIds,
|
||||
})
|
||||
await new Promise((resolve) => setTimeout(resolve, 250 * resumeRetries))
|
||||
continue
|
||||
}
|
||||
throw new Error(
|
||||
`Failed to resume async tool continuation: pending tool calls were not ready (${missingToolCallIds.join(', ')})`
|
||||
)
|
||||
}
|
||||
|
||||
if (readyTools.length === 0) {
|
||||
if (resumeRetries < 3 && continuation.pendingToolCallIds.length > 0) {
|
||||
resumeRetries++
|
||||
logger.info('Retrying async resume because no tool calls were ready yet', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
retry: resumeRetries,
|
||||
})
|
||||
await new Promise((resolve) => setTimeout(resolve, 250 * resumeRetries))
|
||||
continue
|
||||
}
|
||||
throw new Error('Failed to resume async tool continuation: no tool calls were ready')
|
||||
}
|
||||
|
||||
const claimCandidates = readyTools.filter((tool) => tool.needsDurableClaim)
|
||||
const newlyClaimedToolCallIds: string[] = []
|
||||
const claimFailures: string[] = []
|
||||
|
||||
for (const tool of claimCandidates) {
|
||||
const claimed = await claimCompletedAsyncToolCall(tool.toolCallId, resumeWorkerId).catch(
|
||||
() => null
|
||||
)
|
||||
if (!claimed) {
|
||||
claimFailures.push(tool.toolCallId)
|
||||
continue
|
||||
}
|
||||
newlyClaimedToolCallIds.push(tool.toolCallId)
|
||||
}
|
||||
|
||||
if (claimFailures.length > 0) {
|
||||
if (newlyClaimedToolCallIds.length > 0) {
|
||||
logger.info('Releasing async tool claims after claim contention during resume', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
newlyClaimedToolCallIds,
|
||||
claimFailures,
|
||||
})
|
||||
await Promise.all(
|
||||
newlyClaimedToolCallIds.map((toolCallId) =>
|
||||
releaseCompletedAsyncToolClaim(toolCallId, resumeWorkerId).catch(() => null)
|
||||
)
|
||||
)
|
||||
}
|
||||
if (resumeRetries < 3) {
|
||||
resumeRetries++
|
||||
logger.info('Retrying async resume after claim contention', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
retry: resumeRetries,
|
||||
claimFailures,
|
||||
})
|
||||
await new Promise((resolve) => setTimeout(resolve, 250 * resumeRetries))
|
||||
continue
|
||||
}
|
||||
throw new Error(
|
||||
`Failed to resume async tool continuation: unable to claim tool calls (${claimFailures.join(', ')})`
|
||||
)
|
||||
}
|
||||
|
||||
claimedToolCallIds = [
|
||||
...readyTools
|
||||
.filter((tool) => tool.alreadyClaimedByWorker)
|
||||
.map((tool) => tool.toolCallId),
|
||||
...newlyClaimedToolCallIds,
|
||||
]
|
||||
claimedByWorkerId = claimedToolCallIds.length > 0 ? resumeWorkerId : null
|
||||
|
||||
logger.info('Resuming async tool continuation', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
toolCallIds: claimableToolCallIds,
|
||||
toolCallIds: readyTools.map((tool) => tool.toolCallId),
|
||||
})
|
||||
|
||||
const durableRows = await getAsyncToolCalls(claimableToolCallIds).catch(() => [])
|
||||
const durableRows = await getAsyncToolCalls(
|
||||
readyTools.map((tool) => tool.toolCallId)
|
||||
).catch(() => [])
|
||||
const durableByToolCallId = new Map(durableRows.map((row) => [row.toolCallId, row]))
|
||||
|
||||
const results = await Promise.all(
|
||||
claimableToolCallIds.map(async (toolCallId) => {
|
||||
const completion = await context.pendingToolPromises.get(toolCallId)
|
||||
const toolState = context.toolCalls.get(toolCallId)
|
||||
|
||||
const durable = durableByToolCallId.get(toolCallId)
|
||||
readyTools.map(async (tool) => {
|
||||
const durable = durableByToolCallId.get(tool.toolCallId) || tool.durableRow
|
||||
const durableStatus = durable?.status
|
||||
const durableResult =
|
||||
durable?.result && typeof durable.result === 'object'
|
||||
@@ -278,19 +379,15 @@ export async function orchestrateCopilotStream(
|
||||
durableStatus,
|
||||
durableResult,
|
||||
durableError: durable?.error,
|
||||
completion,
|
||||
toolStateSuccess: toolState?.result?.success,
|
||||
toolStateStatus: toolState?.status,
|
||||
toolStateStatus: tool.toolState?.status,
|
||||
})
|
||||
const data =
|
||||
durableResult ||
|
||||
completion?.data ||
|
||||
(toolState?.result?.output as Record<string, unknown> | undefined) ||
|
||||
(tool.toolState?.result?.output as Record<string, unknown> | undefined) ||
|
||||
(success
|
||||
? { message: completion?.message || 'Tool completed' }
|
||||
? { message: 'Tool completed' }
|
||||
: {
|
||||
error:
|
||||
completion?.message || durable?.error || toolState?.error || 'Tool failed',
|
||||
error: durable?.error || tool.toolState?.error || 'Tool failed',
|
||||
})
|
||||
|
||||
if (
|
||||
@@ -299,14 +396,14 @@ export async function orchestrateCopilotStream(
|
||||
!isDeliveredAsyncStatus(durableStatus)
|
||||
) {
|
||||
logger.warn('Async tool row was claimed for resume without terminal durable state', {
|
||||
toolCallId,
|
||||
toolCallId: tool.toolCallId,
|
||||
status: durableStatus,
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
callId: toolCallId,
|
||||
name: durable?.toolName || toolState?.name || '',
|
||||
callId: tool.toolCallId,
|
||||
name: durable?.toolName || tool.toolState?.name || '',
|
||||
data,
|
||||
success,
|
||||
}
|
||||
|
||||
@@ -209,4 +209,76 @@ describe('sse-handlers tool lifecycle', () => {
|
||||
expect(markToolComplete).toHaveBeenCalledTimes(1)
|
||||
expect(context.toolCalls.get('tool-upsert-fail')?.status).toBe('success')
|
||||
})
|
||||
|
||||
it('does not execute a tool if a terminal tool_result arrives before local execution starts', async () => {
|
||||
let resolveUpsert: ((value: null) => void) | undefined
|
||||
upsertAsyncToolCall.mockImplementationOnce(
|
||||
() =>
|
||||
new Promise((resolve) => {
|
||||
resolveUpsert = resolve
|
||||
})
|
||||
)
|
||||
const onEvent = vi.fn()
|
||||
|
||||
await sseHandlers.tool_call(
|
||||
{
|
||||
type: 'tool_call',
|
||||
data: { id: 'tool-race', name: 'read', arguments: { workflowId: 'workflow-1' } },
|
||||
} as any,
|
||||
context,
|
||||
execContext,
|
||||
{ onEvent, interactive: false, timeout: 1000 }
|
||||
)
|
||||
|
||||
await sseHandlers.tool_result(
|
||||
{
|
||||
type: 'tool_result',
|
||||
toolCallId: 'tool-race',
|
||||
data: { id: 'tool-race', success: true, result: { ok: true } },
|
||||
} as any,
|
||||
context,
|
||||
execContext,
|
||||
{ onEvent, interactive: false, timeout: 1000 }
|
||||
)
|
||||
|
||||
resolveUpsert?.(null)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
|
||||
expect(executeToolServerSide).not.toHaveBeenCalled()
|
||||
expect(markToolComplete).not.toHaveBeenCalled()
|
||||
expect(context.toolCalls.get('tool-race')?.status).toBe('success')
|
||||
expect(context.toolCalls.get('tool-race')?.result?.output).toEqual({ ok: true })
|
||||
})
|
||||
|
||||
it('does not execute a tool if a tool_result arrives before the tool_call event', async () => {
|
||||
const onEvent = vi.fn()
|
||||
|
||||
await sseHandlers.tool_result(
|
||||
{
|
||||
type: 'tool_result',
|
||||
toolCallId: 'tool-early-result',
|
||||
toolName: 'read',
|
||||
data: { id: 'tool-early-result', name: 'read', success: true, result: { ok: true } },
|
||||
} as any,
|
||||
context,
|
||||
execContext,
|
||||
{ onEvent, interactive: false, timeout: 1000 }
|
||||
)
|
||||
|
||||
await sseHandlers.tool_call(
|
||||
{
|
||||
type: 'tool_call',
|
||||
data: { id: 'tool-early-result', name: 'read', arguments: { workflowId: 'workflow-1' } },
|
||||
} as any,
|
||||
context,
|
||||
execContext,
|
||||
{ onEvent, interactive: false, timeout: 1000 }
|
||||
)
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
|
||||
expect(executeToolServerSide).not.toHaveBeenCalled()
|
||||
expect(markToolComplete).not.toHaveBeenCalled()
|
||||
expect(context.toolCalls.get('tool-early-result')?.status).toBe('success')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -213,6 +213,27 @@ function inferToolSuccess(data: Record<string, unknown> | undefined): {
|
||||
return { success, hasResultData, hasError }
|
||||
}
|
||||
|
||||
function ensureTerminalToolCallState(
|
||||
context: StreamingContext,
|
||||
toolCallId: string,
|
||||
toolName: string
|
||||
): ToolCallState {
|
||||
const existing = context.toolCalls.get(toolCallId)
|
||||
if (existing) {
|
||||
return existing
|
||||
}
|
||||
|
||||
const toolCall: ToolCallState = {
|
||||
id: toolCallId,
|
||||
name: toolName || 'unknown_tool',
|
||||
status: 'pending',
|
||||
startTime: Date.now(),
|
||||
}
|
||||
context.toolCalls.set(toolCallId, toolCall)
|
||||
addContentBlock(context, { type: 'tool_call', toolCall })
|
||||
return toolCall
|
||||
}
|
||||
|
||||
export type SSEHandler = (
|
||||
event: SSEEvent,
|
||||
context: StreamingContext,
|
||||
@@ -246,8 +267,12 @@ export const sseHandlers: Record<string, SSEHandler> = {
|
||||
const data = getEventData(event)
|
||||
const toolCallId = event.toolCallId || (data?.id as string | undefined)
|
||||
if (!toolCallId) return
|
||||
const current = context.toolCalls.get(toolCallId)
|
||||
if (!current) return
|
||||
const toolName =
|
||||
event.toolName ||
|
||||
(data?.name as string | undefined) ||
|
||||
context.toolCalls.get(toolCallId)?.name ||
|
||||
''
|
||||
const current = ensureTerminalToolCallState(context, toolCallId, toolName)
|
||||
|
||||
const { success, hasResultData, hasError } = inferToolSuccess(data)
|
||||
|
||||
@@ -263,16 +288,22 @@ export const sseHandlers: Record<string, SSEHandler> = {
|
||||
const resultObj = asRecord(data?.result)
|
||||
current.error = (data?.error || resultObj.error) as string | undefined
|
||||
}
|
||||
markToolResultSeen(toolCallId)
|
||||
},
|
||||
tool_error: (event, context) => {
|
||||
const data = getEventData(event)
|
||||
const toolCallId = event.toolCallId || (data?.id as string | undefined)
|
||||
if (!toolCallId) return
|
||||
const current = context.toolCalls.get(toolCallId)
|
||||
if (!current) return
|
||||
const toolName =
|
||||
event.toolName ||
|
||||
(data?.name as string | undefined) ||
|
||||
context.toolCalls.get(toolCallId)?.name ||
|
||||
''
|
||||
const current = ensureTerminalToolCallState(context, toolCallId, toolName)
|
||||
current.status = 'error'
|
||||
current.error = (data?.error as string | undefined) || 'Tool execution failed'
|
||||
current.endTime = Date.now()
|
||||
markToolResultSeen(toolCallId)
|
||||
},
|
||||
tool_call_delta: () => {
|
||||
// Argument streaming delta — no action needed on orchestrator side
|
||||
@@ -313,6 +344,9 @@ export const sseHandlers: Record<string, SSEHandler> = {
|
||||
existing?.endTime ||
|
||||
(existing && existing.status !== 'pending' && existing.status !== 'executing')
|
||||
) {
|
||||
if (!existing.name && toolName) {
|
||||
existing.name = toolName
|
||||
}
|
||||
if (!existing.params && args) {
|
||||
existing.params = args
|
||||
}
|
||||
@@ -558,6 +592,12 @@ export const subAgentHandlers: Record<string, SSEHandler> = {
|
||||
const existing = context.toolCalls.get(toolCallId)
|
||||
// Ignore late/duplicate tool_call events once we already have a result.
|
||||
if (wasToolResultSeen(toolCallId) || existing?.endTime) {
|
||||
if (existing && !existing.name && toolName) {
|
||||
existing.name = toolName
|
||||
}
|
||||
if (existing && !existing.params && args) {
|
||||
existing.params = args
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -686,13 +726,14 @@ export const subAgentHandlers: Record<string, SSEHandler> = {
|
||||
const data = getEventData(event)
|
||||
const toolCallId = event.toolCallId || (data?.id as string | undefined)
|
||||
if (!toolCallId) return
|
||||
const toolName = event.toolName || (data?.name as string | undefined) || ''
|
||||
|
||||
// Update in subAgentToolCalls.
|
||||
const toolCalls = context.subAgentToolCalls[parentToolCallId] || []
|
||||
const subAgentToolCall = toolCalls.find((tc) => tc.id === toolCallId)
|
||||
|
||||
// Also update in main toolCalls (where we added it for execution).
|
||||
const mainToolCall = context.toolCalls.get(toolCallId)
|
||||
const mainToolCall = ensureTerminalToolCallState(context, toolCallId, toolName)
|
||||
|
||||
const { success, hasResultData, hasError } = inferToolSuccess(data)
|
||||
|
||||
@@ -719,6 +760,9 @@ export const subAgentHandlers: Record<string, SSEHandler> = {
|
||||
mainToolCall.error = (data?.error || resultObj.error) as string | undefined
|
||||
}
|
||||
}
|
||||
if (subAgentToolCall || mainToolCall) {
|
||||
markToolResultSeen(toolCallId)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -4,18 +4,15 @@ import { createLogger } from '@sim/logger'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { completeAsyncToolCall, markAsyncToolRunning } from '@/lib/copilot/async-runs/repository'
|
||||
import { waitForToolConfirmation } from '@/lib/copilot/orchestrator/persistence'
|
||||
import {
|
||||
asRecord,
|
||||
markToolResultSeen,
|
||||
wasToolResultSeen,
|
||||
} from '@/lib/copilot/orchestrator/sse/utils'
|
||||
import { asRecord, markToolResultSeen } from '@/lib/copilot/orchestrator/sse/utils'
|
||||
import { executeToolServerSide, markToolComplete } from '@/lib/copilot/orchestrator/tool-executor'
|
||||
import type {
|
||||
ExecutionContext,
|
||||
OrchestratorOptions,
|
||||
SSEEvent,
|
||||
StreamingContext,
|
||||
ToolCallResult,
|
||||
import {
|
||||
type ExecutionContext,
|
||||
isTerminalToolCallStatus,
|
||||
type OrchestratorOptions,
|
||||
type SSEEvent,
|
||||
type StreamingContext,
|
||||
type ToolCallResult,
|
||||
} from '@/lib/copilot/orchestrator/types'
|
||||
import {
|
||||
extractDeletedResourcesFromToolResult,
|
||||
@@ -117,6 +114,20 @@ const FORMAT_TO_CONTENT_TYPE: Record<OutputFormat, string> = {
|
||||
html: 'text/html',
|
||||
}
|
||||
|
||||
function normalizeOutputWorkspaceFileName(outputPath: string): string {
|
||||
const trimmed = outputPath.trim().replace(/^\/+/, '')
|
||||
const withoutPrefix = trimmed.startsWith('files/') ? trimmed.slice('files/'.length) : trimmed
|
||||
if (!withoutPrefix) {
|
||||
throw new Error('outputPath must include a file name, e.g. "files/result.json"')
|
||||
}
|
||||
if (withoutPrefix.includes('/')) {
|
||||
throw new Error(
|
||||
'outputPath must target a flat workspace file, e.g. "files/result.json". Nested paths like "files/reports/result.json" are not supported.'
|
||||
)
|
||||
}
|
||||
return withoutPrefix
|
||||
}
|
||||
|
||||
function resolveOutputFormat(fileName: string, explicit?: string): OutputFormat {
|
||||
if (explicit && explicit in FORMAT_TO_CONTENT_TYPE) return explicit as OutputFormat
|
||||
const ext = fileName.slice(fileName.lastIndexOf('.')).toLowerCase()
|
||||
@@ -153,10 +164,10 @@ async function maybeWriteOutputToFile(
|
||||
|
||||
const explicitFormat =
|
||||
(params?.outputFormat as string | undefined) ?? (args?.outputFormat as string | undefined)
|
||||
const fileName = outputPath.replace(/^files\//, '')
|
||||
const format = resolveOutputFormat(fileName, explicitFormat)
|
||||
|
||||
try {
|
||||
const fileName = normalizeOutputWorkspaceFileName(outputPath)
|
||||
const format = resolveOutputFormat(fileName, explicitFormat)
|
||||
if (context.abortSignal?.aborted) {
|
||||
throw new Error('Request aborted before tool mutation could be applied')
|
||||
}
|
||||
@@ -193,12 +204,16 @@ async function maybeWriteOutputToFile(
|
||||
},
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err)
|
||||
logger.warn('Failed to write tool output to file', {
|
||||
toolName,
|
||||
outputPath,
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
error: message,
|
||||
})
|
||||
return result
|
||||
return {
|
||||
success: false,
|
||||
error: `Failed to write output file: ${message}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,6 +244,48 @@ function cancelledCompletion(message: string): AsyncToolCompletion {
|
||||
}
|
||||
}
|
||||
|
||||
function terminalCompletionFromToolCall(toolCall: {
|
||||
status: string
|
||||
error?: string
|
||||
result?: { output?: unknown; error?: string }
|
||||
}): AsyncToolCompletion {
|
||||
if (toolCall.status === 'cancelled') {
|
||||
return cancelledCompletion(toolCall.error || 'Tool execution cancelled')
|
||||
}
|
||||
|
||||
if (toolCall.status === 'success') {
|
||||
return {
|
||||
status: 'success',
|
||||
message: 'Tool completed',
|
||||
data:
|
||||
toolCall.result?.output &&
|
||||
typeof toolCall.result.output === 'object' &&
|
||||
!Array.isArray(toolCall.result.output)
|
||||
? (toolCall.result.output as Record<string, unknown>)
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCall.status === 'skipped') {
|
||||
return {
|
||||
status: 'success',
|
||||
message: 'Tool skipped',
|
||||
data:
|
||||
toolCall.result?.output &&
|
||||
typeof toolCall.result.output === 'object' &&
|
||||
!Array.isArray(toolCall.result.output)
|
||||
? (toolCall.result.output as Record<string, unknown>)
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
status: toolCall.status === 'rejected' ? 'rejected' : 'error',
|
||||
message: toolCall.error || toolCall.result?.error || 'Tool failed',
|
||||
data: { error: toolCall.error || toolCall.result?.error || 'Tool failed' },
|
||||
}
|
||||
}
|
||||
|
||||
function reportCancelledTool(
|
||||
toolCall: { id: string; name: string },
|
||||
message: string,
|
||||
@@ -491,8 +548,8 @@ export async function executeToolAndReport(
|
||||
if (toolCall.status === 'executing') {
|
||||
return { status: 'running', message: 'Tool already executing' }
|
||||
}
|
||||
if (wasToolResultSeen(toolCall.id)) {
|
||||
return { status: 'success', message: 'Tool result already processed' }
|
||||
if (toolCall.endTime || isTerminalToolCallStatus(toolCall.status)) {
|
||||
return terminalCompletionFromToolCall(toolCall)
|
||||
}
|
||||
|
||||
if (abortRequested(context, execContext, options)) {
|
||||
@@ -520,6 +577,9 @@ export async function executeToolAndReport(
|
||||
|
||||
try {
|
||||
let result = await executeToolServerSide(toolCall, execContext)
|
||||
if (toolCall.endTime || isTerminalToolCallStatus(toolCall.status)) {
|
||||
return terminalCompletionFromToolCall(toolCall)
|
||||
}
|
||||
if (abortRequested(context, execContext, options)) {
|
||||
toolCall.status = 'cancelled'
|
||||
toolCall.endTime = Date.now()
|
||||
@@ -581,10 +641,17 @@ export async function executeToolAndReport(
|
||||
toolCall.endTime = Date.now()
|
||||
|
||||
if (result.success) {
|
||||
const raw = result.output
|
||||
const preview =
|
||||
typeof raw === 'string'
|
||||
? raw.slice(0, 200)
|
||||
: raw && typeof raw === 'object'
|
||||
? JSON.stringify(raw).slice(0, 200)
|
||||
: undefined
|
||||
logger.info('Tool execution succeeded', {
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
output: result.output,
|
||||
outputPreview: preview,
|
||||
})
|
||||
} else {
|
||||
logger.warn('Tool execution failed', {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import {
|
||||
markToolResultSeen,
|
||||
normalizeSseEvent,
|
||||
shouldSkipToolCallEvent,
|
||||
shouldSkipToolResultEvent,
|
||||
@@ -37,6 +38,7 @@ describe('sse-utils', () => {
|
||||
it.concurrent('dedupes tool_result events', () => {
|
||||
const event = { type: 'tool_result', data: { id: 'tool_result_1', name: 'plan' } }
|
||||
expect(shouldSkipToolResultEvent(event as any)).toBe(false)
|
||||
markToolResultSeen('tool_result_1')
|
||||
expect(shouldSkipToolResultEvent(event as any)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -125,7 +125,5 @@ export function shouldSkipToolResultEvent(event: SSEEvent): boolean {
|
||||
if (event.type !== 'tool_result') return false
|
||||
const toolCallId = getToolCallIdFromEvent(event)
|
||||
if (!toolCallId) return false
|
||||
if (wasToolResultSeen(toolCallId)) return true
|
||||
markToolResultSeen(toolCallId)
|
||||
return false
|
||||
return wasToolResultSeen(toolCallId)
|
||||
}
|
||||
|
||||
@@ -1050,21 +1050,21 @@ const SIM_WORKFLOW_TOOL_HANDLERS: Record<
|
||||
return {
|
||||
success: false,
|
||||
error:
|
||||
'Opening a workspace file requires workspace context. Pass the file UUID from files/<name>/meta.json.',
|
||||
'Opening a workspace file requires workspace context. Pass the canonical file UUID from files/by-id/<fileId>/meta.json.',
|
||||
}
|
||||
}
|
||||
if (!isUuid(params.id)) {
|
||||
return {
|
||||
success: false,
|
||||
error:
|
||||
'open_resource for files requires the canonical UUID from files/<name>/meta.json (the "id" field). Do not pass VFS paths, display names, or file_<name> strings.',
|
||||
'open_resource for files requires the canonical file UUID. Read files/by-id/<fileId>/meta.json or files/<name>/meta.json and pass the "id" field. Do not pass VFS paths or display names.',
|
||||
}
|
||||
}
|
||||
const record = await getWorkspaceFile(c.workspaceId, params.id)
|
||||
if (!record) {
|
||||
return {
|
||||
success: false,
|
||||
error: `No workspace file with id "${params.id}". Confirm the UUID from meta.json.`,
|
||||
error: `No workspace file with id "${params.id}". Confirm the UUID from files/by-id/<fileId>/meta.json.`,
|
||||
}
|
||||
}
|
||||
resourceId = record.id
|
||||
|
||||
@@ -16,6 +16,7 @@ import { getTableById, queryRows } from '@/lib/table/service'
|
||||
import {
|
||||
downloadWorkspaceFile,
|
||||
findWorkspaceFileRecord,
|
||||
getSandboxWorkspaceFilePath,
|
||||
listWorkspaceFiles,
|
||||
} from '@/lib/uploads/contexts/workspace/workspace-file-manager'
|
||||
import { getWorkflowById } from '@/lib/workflows/utils'
|
||||
@@ -179,23 +180,30 @@ export async function executeIntegrationToolDirect(
|
||||
])
|
||||
let totalSize = 0
|
||||
|
||||
const inputFilePaths = executionParams.inputFiles as string[] | undefined
|
||||
if (inputFilePaths?.length) {
|
||||
const inputFileIds = executionParams.inputFiles as string[] | undefined
|
||||
if (inputFileIds?.length) {
|
||||
const allFiles = await listWorkspaceFiles(workspaceId)
|
||||
for (const filePath of inputFilePaths) {
|
||||
const fileName = filePath.replace(/^files\//, '')
|
||||
const ext = fileName.split('.').pop()?.toLowerCase() ?? ''
|
||||
if (!TEXT_EXTENSIONS.has(ext)) {
|
||||
logger.warn('Skipping non-text sandbox input file', { fileName, ext })
|
||||
for (const fileRef of inputFileIds) {
|
||||
const record = findWorkspaceFileRecord(allFiles, fileRef)
|
||||
if (!record) {
|
||||
logger.warn('Sandbox input file not found', { fileRef })
|
||||
continue
|
||||
}
|
||||
const record = findWorkspaceFileRecord(allFiles, filePath)
|
||||
if (!record) {
|
||||
logger.warn('Sandbox input file not found', { fileName })
|
||||
const ext = record.name.split('.').pop()?.toLowerCase() ?? ''
|
||||
if (!TEXT_EXTENSIONS.has(ext)) {
|
||||
logger.warn('Skipping non-text sandbox input file', {
|
||||
fileId: record.id,
|
||||
fileName: record.name,
|
||||
ext,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if (record.size > MAX_FILE_SIZE) {
|
||||
logger.warn('Sandbox input file exceeds size limit', { fileName, size: record.size })
|
||||
logger.warn('Sandbox input file exceeds size limit', {
|
||||
fileId: record.id,
|
||||
fileName: record.name,
|
||||
size: record.size,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if (totalSize + record.size > MAX_TOTAL_SIZE) {
|
||||
@@ -204,7 +212,15 @@ export async function executeIntegrationToolDirect(
|
||||
}
|
||||
const buffer = await downloadWorkspaceFile(record)
|
||||
totalSize += buffer.length
|
||||
sandboxFiles.push({ path: `/home/user/${fileName}`, content: buffer.toString('utf-8') })
|
||||
const textContent = buffer.toString('utf-8')
|
||||
sandboxFiles.push({
|
||||
path: getSandboxWorkspaceFilePath(record),
|
||||
content: textContent,
|
||||
})
|
||||
sandboxFiles.push({
|
||||
path: `/home/user/${record.name}`,
|
||||
content: textContent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,6 +59,18 @@ export type ToolCallStatus =
|
||||
| 'rejected'
|
||||
| 'cancelled'
|
||||
|
||||
const TERMINAL_TOOL_STATUSES: ReadonlySet<ToolCallStatus> = new Set([
|
||||
'success',
|
||||
'error',
|
||||
'cancelled',
|
||||
'skipped',
|
||||
'rejected',
|
||||
])
|
||||
|
||||
export function isTerminalToolCallStatus(status?: string): boolean {
|
||||
return TERMINAL_TOOL_STATUSES.has(status as ToolCallStatus)
|
||||
}
|
||||
|
||||
export interface ToolCallState {
|
||||
id: string
|
||||
name: string
|
||||
|
||||
@@ -35,6 +35,15 @@ function inferContentType(fileName: string, explicitType?: string): string {
|
||||
return EXT_TO_MIME[ext] || 'text/plain'
|
||||
}
|
||||
|
||||
function validateFlatWorkspaceFileName(fileName: string): string | null {
|
||||
const trimmed = fileName.trim()
|
||||
if (!trimmed) return 'File name cannot be empty'
|
||||
if (trimmed.includes('/')) {
|
||||
return 'Workspace files use a flat namespace. Use a plain file name like "report.csv", not a path like "files/reports/report.csv".'
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
export const workspaceFileServerTool: BaseServerTool<WorkspaceFileArgs, WorkspaceFileResult> = {
|
||||
name: 'workspace_file',
|
||||
async execute(
|
||||
@@ -67,6 +76,10 @@ export const workspaceFileServerTool: BaseServerTool<WorkspaceFileArgs, Workspac
|
||||
if (content === undefined || content === null) {
|
||||
return { success: false, message: 'content is required for write operation' }
|
||||
}
|
||||
const fileNameValidationError = validateFlatWorkspaceFileName(fileName)
|
||||
if (fileNameValidationError) {
|
||||
return { success: false, message: fileNameValidationError }
|
||||
}
|
||||
|
||||
const isPptx = fileName.toLowerCase().endsWith('.pptx')
|
||||
let contentType: string
|
||||
@@ -188,6 +201,10 @@ export const workspaceFileServerTool: BaseServerTool<WorkspaceFileArgs, Workspac
|
||||
if (!newName) {
|
||||
return { success: false, message: 'newName is required for rename operation' }
|
||||
}
|
||||
const fileNameValidationError = validateFlatWorkspaceFileName(newName)
|
||||
if (fileNameValidationError) {
|
||||
return { success: false, message: fileNameValidationError }
|
||||
}
|
||||
|
||||
const fileRecord = await getWorkspaceFile(workspaceId, fileId)
|
||||
if (!fileRecord) {
|
||||
|
||||
@@ -27,6 +27,15 @@ const ASPECT_RATIO_TO_SIZE: Record<string, string> = {
|
||||
'3:4': '768x1024',
|
||||
}
|
||||
|
||||
function validateGeneratedWorkspaceFileName(fileName: string): string | null {
|
||||
const trimmed = fileName.trim()
|
||||
if (!trimmed) return 'File name cannot be empty'
|
||||
if (trimmed.includes('/')) {
|
||||
return 'Workspace files use a flat namespace. Use a plain file name like "generated-image.png", not a path like "images/generated-image.png".'
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
interface GenerateImageArgs {
|
||||
prompt: string
|
||||
referenceFileIds?: string[]
|
||||
@@ -151,6 +160,10 @@ export const generateImageServerTool: BaseServerTool<GenerateImageArgs, Generate
|
||||
|
||||
const ext = mimeType.includes('jpeg') || mimeType.includes('jpg') ? '.jpg' : '.png'
|
||||
const fileName = params.fileName || `generated-image${ext}`
|
||||
const fileNameValidationError = validateGeneratedWorkspaceFileName(fileName)
|
||||
if (fileNameValidationError) {
|
||||
return { success: false, message: fileNameValidationError }
|
||||
}
|
||||
const imageBuffer = Buffer.from(imageBase64, 'base64')
|
||||
|
||||
if (params.overwriteFileId) {
|
||||
|
||||
@@ -230,10 +230,12 @@ export const knowledgeBaseServerTool: BaseServerTool<KnowledgeBaseArgs, Knowledg
|
||||
}
|
||||
}
|
||||
|
||||
if (!args.filePath) {
|
||||
const fileReference = args.fileId || args.filePath
|
||||
if (!fileReference) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'filePath is required (e.g. "files/report.pdf")',
|
||||
message:
|
||||
'fileId is required for add_file. Read files/{name}/meta.json or files/by-id/*/meta.json to get the canonical file ID.',
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,12 +248,12 @@ export const knowledgeBaseServerTool: BaseServerTool<KnowledgeBaseArgs, Knowledg
|
||||
}
|
||||
|
||||
const kbWorkspaceId: string = targetKb.workspaceId
|
||||
const fileRecord = await resolveWorkspaceFileReference(kbWorkspaceId, args.filePath)
|
||||
const fileRecord = await resolveWorkspaceFileReference(kbWorkspaceId, fileReference)
|
||||
|
||||
if (!fileRecord) {
|
||||
return {
|
||||
success: false,
|
||||
message: `Workspace file not found: "${args.filePath}"`,
|
||||
message: `Workspace file not found: "${fileReference}"`,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -41,14 +41,53 @@ const SCHEMA_SAMPLE_SIZE = 100
|
||||
|
||||
type ColumnType = 'string' | 'number' | 'boolean' | 'date' | 'json'
|
||||
|
||||
function sanitizeColumnName(raw: string): string {
|
||||
let name = raw
|
||||
.trim()
|
||||
.replace(/[^a-zA-Z0-9_]/g, '_')
|
||||
.replace(/_+/g, '_')
|
||||
.replace(/^_|_$/g, '')
|
||||
if (!name || /^\d/.test(name)) name = `col_${name}`
|
||||
return name
|
||||
}
|
||||
|
||||
function sanitizeHeaders(
|
||||
headers: string[],
|
||||
rows: Record<string, unknown>[]
|
||||
): { headers: string[]; rows: Record<string, unknown>[] } {
|
||||
const renamed = new Map<string, string>()
|
||||
const seen = new Set<string>()
|
||||
|
||||
for (const raw of headers) {
|
||||
let safe = sanitizeColumnName(raw)
|
||||
while (seen.has(safe)) safe = `${safe}_`
|
||||
seen.add(safe)
|
||||
renamed.set(raw, safe)
|
||||
}
|
||||
|
||||
const noChange = headers.every((h) => renamed.get(h) === h)
|
||||
if (noChange) return { headers, rows }
|
||||
|
||||
return {
|
||||
headers: headers.map((h) => renamed.get(h)!),
|
||||
rows: rows.map((row) => {
|
||||
const out: Record<string, unknown> = {}
|
||||
for (const [raw, safe] of renamed) {
|
||||
if (raw in row) out[safe] = row[raw]
|
||||
}
|
||||
return out
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async function resolveWorkspaceFile(
|
||||
filePath: string,
|
||||
fileReference: string,
|
||||
workspaceId: string
|
||||
): Promise<{ buffer: Buffer; name: string; type: string }> {
|
||||
const record = await resolveWorkspaceFileReference(workspaceId, filePath)
|
||||
const record = await resolveWorkspaceFileReference(workspaceId, fileReference)
|
||||
if (!record) {
|
||||
throw new Error(
|
||||
`File not found: "${filePath}". Use glob("files/*/meta.json") to list available files.`
|
||||
`File not found: "${fileReference}". Use glob("files/by-id/*/meta.json") to list canonical file IDs.`
|
||||
)
|
||||
}
|
||||
const buffer = await downloadWorkspaceFile(record)
|
||||
@@ -87,7 +126,7 @@ async function parseJsonRows(
|
||||
}
|
||||
for (const key of Object.keys(row)) headerSet.add(key)
|
||||
}
|
||||
return { headers: [...headerSet], rows: parsed }
|
||||
return sanitizeHeaders([...headerSet], parsed)
|
||||
}
|
||||
|
||||
async function parseCsvRows(
|
||||
@@ -110,7 +149,7 @@ async function parseCsvRows(
|
||||
if (headers.length === 0) {
|
||||
throw new Error('CSV file has no headers')
|
||||
}
|
||||
return { headers, rows: parsed }
|
||||
return sanitizeHeaders(headers, parsed)
|
||||
}
|
||||
|
||||
function inferColumnType(values: unknown[]): ColumnType {
|
||||
@@ -645,15 +684,21 @@ export const userTableServerTool: BaseServerTool<UserTableArgs, UserTableResult>
|
||||
}
|
||||
|
||||
case 'create_from_file': {
|
||||
const fileId = (args as Record<string, unknown>).fileId as string | undefined
|
||||
const filePath = (args as Record<string, unknown>).filePath as string | undefined
|
||||
if (!filePath) {
|
||||
return { success: false, message: 'filePath is required (e.g. "files/data.csv")' }
|
||||
const fileReference = fileId || filePath
|
||||
if (!fileReference) {
|
||||
return {
|
||||
success: false,
|
||||
message:
|
||||
'fileId is required for create_from_file. Read files/{name}/meta.json or files/by-id/*/meta.json to get the canonical file ID.',
|
||||
}
|
||||
}
|
||||
if (!workspaceId) {
|
||||
return { success: false, message: 'Workspace ID is required' }
|
||||
}
|
||||
|
||||
const file = await resolveWorkspaceFile(filePath, workspaceId)
|
||||
const file = await resolveWorkspaceFile(fileReference, workspaceId)
|
||||
const { headers, rows } = await parseFileRows(file.buffer, file.name, file.type)
|
||||
if (rows.length === 0) {
|
||||
return { success: false, message: 'File contains no data rows' }
|
||||
@@ -700,10 +745,16 @@ export const userTableServerTool: BaseServerTool<UserTableArgs, UserTableResult>
|
||||
}
|
||||
|
||||
case 'import_file': {
|
||||
const fileId = (args as Record<string, unknown>).fileId as string | undefined
|
||||
const filePath = (args as Record<string, unknown>).filePath as string | undefined
|
||||
const tableId = (args as Record<string, unknown>).tableId as string | undefined
|
||||
if (!filePath) {
|
||||
return { success: false, message: 'filePath is required (e.g. "files/data.csv")' }
|
||||
const fileReference = fileId || filePath
|
||||
if (!fileReference) {
|
||||
return {
|
||||
success: false,
|
||||
message:
|
||||
'fileId is required for import_file. Read files/{name}/meta.json or files/by-id/*/meta.json to get the canonical file ID.',
|
||||
}
|
||||
}
|
||||
if (!tableId) {
|
||||
return { success: false, message: 'tableId is required for import_file' }
|
||||
@@ -717,7 +768,7 @@ export const userTableServerTool: BaseServerTool<UserTableArgs, UserTableResult>
|
||||
return { success: false, message: `Table not found: ${tableId}` }
|
||||
}
|
||||
|
||||
const file = await resolveWorkspaceFile(filePath, workspaceId)
|
||||
const file = await resolveWorkspaceFile(fileReference, workspaceId)
|
||||
const { headers, rows } = await parseFileRows(file.buffer, file.name, file.type)
|
||||
if (rows.length === 0) {
|
||||
return { success: false, message: 'File contains no data rows' }
|
||||
|
||||
@@ -11,6 +11,7 @@ import { getServePathPrefix } from '@/lib/uploads'
|
||||
import {
|
||||
downloadWorkspaceFile,
|
||||
findWorkspaceFileRecord,
|
||||
getSandboxWorkspaceFilePath,
|
||||
getWorkspaceFile,
|
||||
listWorkspaceFiles,
|
||||
updateWorkspaceFileContent,
|
||||
@@ -49,6 +50,15 @@ const TEXT_EXTENSIONS = new Set(['csv', 'json', 'txt', 'md', 'html', 'xml', 'tsv
|
||||
const MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
const MAX_TOTAL_SIZE = 50 * 1024 * 1024
|
||||
|
||||
function validateGeneratedWorkspaceFileName(fileName: string): string | null {
|
||||
const trimmed = fileName.trim()
|
||||
if (!trimmed) return 'File name cannot be empty'
|
||||
if (trimmed.includes('/')) {
|
||||
return 'Workspace files use a flat namespace. Use a plain file name like "chart.png", not a path like "charts/chart.png".'
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
async function collectSandboxFiles(
|
||||
workspaceId: string,
|
||||
inputFiles?: string[],
|
||||
@@ -59,20 +69,27 @@ async function collectSandboxFiles(
|
||||
|
||||
if (inputFiles?.length) {
|
||||
const allFiles = await listWorkspaceFiles(workspaceId)
|
||||
for (const filePath of inputFiles) {
|
||||
const fileName = filePath.replace(/^files\//, '')
|
||||
const ext = fileName.split('.').pop()?.toLowerCase() ?? ''
|
||||
if (!TEXT_EXTENSIONS.has(ext)) {
|
||||
logger.warn('Skipping non-text sandbox input file', { fileName, ext })
|
||||
for (const fileRef of inputFiles) {
|
||||
const record = findWorkspaceFileRecord(allFiles, fileRef)
|
||||
if (!record) {
|
||||
logger.warn('Sandbox input file not found', { fileRef })
|
||||
continue
|
||||
}
|
||||
const record = findWorkspaceFileRecord(allFiles, filePath)
|
||||
if (!record) {
|
||||
logger.warn('Sandbox input file not found', { fileName })
|
||||
const ext = record.name.split('.').pop()?.toLowerCase() ?? ''
|
||||
if (!TEXT_EXTENSIONS.has(ext)) {
|
||||
logger.warn('Skipping non-text sandbox input file', {
|
||||
fileId: record.id,
|
||||
fileName: record.name,
|
||||
ext,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if (record.size > MAX_FILE_SIZE) {
|
||||
logger.warn('Sandbox input file exceeds size limit', { fileName, size: record.size })
|
||||
logger.warn('Sandbox input file exceeds size limit', {
|
||||
fileId: record.id,
|
||||
fileName: record.name,
|
||||
size: record.size,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if (totalSize + record.size > MAX_TOTAL_SIZE) {
|
||||
@@ -81,7 +98,15 @@ async function collectSandboxFiles(
|
||||
}
|
||||
const buffer = await downloadWorkspaceFile(record)
|
||||
totalSize += buffer.length
|
||||
sandboxFiles.push({ path: `/home/user/${fileName}`, content: buffer.toString('utf-8') })
|
||||
const textContent = buffer.toString('utf-8')
|
||||
sandboxFiles.push({
|
||||
path: getSandboxWorkspaceFilePath(record),
|
||||
content: textContent,
|
||||
})
|
||||
sandboxFiles.push({
|
||||
path: `/home/user/${record.name}`,
|
||||
content: textContent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +210,10 @@ export const generateVisualizationServerTool: BaseServerTool<
|
||||
}
|
||||
|
||||
const fileName = params.fileName || 'chart.png'
|
||||
const fileNameValidationError = validateGeneratedWorkspaceFileName(fileName)
|
||||
if (fileNameValidationError) {
|
||||
return { success: false, message: fileNameValidationError }
|
||||
}
|
||||
const imageBuffer = Buffer.from(imageBase64, 'base64')
|
||||
|
||||
if (params.overwriteFileId) {
|
||||
|
||||
@@ -50,7 +50,9 @@ export const KnowledgeBaseArgsSchema = z.object({
|
||||
workspaceId: z.string().optional(),
|
||||
/** Knowledge base ID (required for get, query, add_file, list_tags, create_tag, get_tag_usage, add_connector) */
|
||||
knowledgeBaseId: z.string().optional(),
|
||||
/** Workspace file path to add as a document (required for add_file). Example: "files/report.pdf" */
|
||||
/** Workspace file ID to add as a document (required for add_file). */
|
||||
fileId: z.string().optional(),
|
||||
/** Legacy workspace file reference for add_file. Prefer fileId. */
|
||||
filePath: z.string().optional(),
|
||||
/** Search query text (required for query) */
|
||||
query: z.string().optional(),
|
||||
@@ -145,6 +147,7 @@ export const UserTableArgsSchema = z.object({
|
||||
sort: z.record(z.enum(['asc', 'desc'])).optional(),
|
||||
limit: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
fileId: z.string().optional(),
|
||||
filePath: z.string().optional(),
|
||||
column: z
|
||||
.object({
|
||||
|
||||
@@ -9,6 +9,16 @@ function vfsFromEntries(entries: [string, string][]): Map<string, string> {
|
||||
}
|
||||
|
||||
describe('glob', () => {
|
||||
it('matches canonical file metadata paths by id', () => {
|
||||
const files = vfsFromEntries([
|
||||
['files/by-id/wf_123/meta.json', '{}'],
|
||||
['files/data.csv/meta.json', '{}'],
|
||||
])
|
||||
const hits = glob(files, 'files/by-id/*/meta.json')
|
||||
expect(hits).toContain('files/by-id/wf_123/meta.json')
|
||||
expect(hits).not.toContain('files/data.csv/meta.json')
|
||||
})
|
||||
|
||||
it('matches one path segment for single star (files listing pattern)', () => {
|
||||
const files = vfsFromEntries([
|
||||
['files/a/meta.json', '{}'],
|
||||
|
||||
@@ -262,6 +262,7 @@ export function serializeConnectorOverview(connectors: SerializableConnectorConf
|
||||
|
||||
/**
|
||||
* Serialize workspace file metadata for VFS files/{name}/meta.json
|
||||
* and files/by-id/{id}/meta.json.
|
||||
*/
|
||||
export function serializeFileMeta(file: {
|
||||
id: string
|
||||
|
||||
@@ -271,6 +271,7 @@ function getStaticComponentFiles(): Map<string, string> {
|
||||
* knowledgebases/{name}/connectors.json
|
||||
* tables/{name}/meta.json
|
||||
* files/{name}/meta.json
|
||||
* files/by-id/{id}/meta.json
|
||||
* jobs/{title}/meta.json
|
||||
* jobs/{title}/history.json
|
||||
* jobs/{title}/executions.json
|
||||
@@ -390,7 +391,7 @@ export class WorkspaceVFS {
|
||||
/**
|
||||
* Attempt to read dynamic workspace file content from storage.
|
||||
* Handles images (base64), parseable documents (PDF, etc.), and text files.
|
||||
* Returns null if the path doesn't match `files/{name}` or the file isn't found.
|
||||
* Returns null if the path doesn't match `files/{name}` / `files/by-id/{id}` or the file isn't found.
|
||||
*/
|
||||
async readFileContent(path: string): Promise<FileReadResult | null> {
|
||||
const match = path.match(/^files\/(.+?)(?:\/content)?$/)
|
||||
@@ -676,6 +677,16 @@ export class WorkspaceVFS {
|
||||
uploadedAt: file.uploadedAt,
|
||||
})
|
||||
)
|
||||
this.files.set(
|
||||
`files/by-id/${file.id}/meta.json`,
|
||||
serializeFileMeta({
|
||||
id: file.id,
|
||||
name: file.name,
|
||||
contentType: file.type,
|
||||
size: file.size,
|
||||
uploadedAt: file.uploadedAt,
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
return files.map((f) => ({ name: f.name, type: f.type, size: f.size }))
|
||||
|
||||
@@ -382,16 +382,21 @@ export async function listWorkspaceFiles(
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize a workspace file reference to its display name.
|
||||
* Supports raw names and VFS-style paths like `files/name`, `files/name/content`,
|
||||
* and `files/name/meta.json`.
|
||||
*
|
||||
* Used by storage resolution (`findWorkspaceFileRecord`), not by `open_resource`, which
|
||||
* requires the canonical database UUID only.
|
||||
* Normalize a workspace file reference to either a display name or canonical file ID.
|
||||
* Supports raw IDs, `files/{name}`, `files/{name}/content`, `files/{name}/meta.json`,
|
||||
* and canonical VFS aliases like `files/by-id/{fileId}/content`.
|
||||
*/
|
||||
export function normalizeWorkspaceFileReference(fileReference: string): string {
|
||||
const trimmed = fileReference.trim().replace(/^\/+/, '')
|
||||
|
||||
if (trimmed.startsWith('files/by-id/')) {
|
||||
const byIdRef = trimmed.slice('files/by-id/'.length)
|
||||
const match = byIdRef.match(/^([^/]+)(?:\/(?:meta\.json|content))?$/)
|
||||
if (match?.[1]) {
|
||||
return match[1]
|
||||
}
|
||||
}
|
||||
|
||||
if (trimmed.startsWith('files/')) {
|
||||
const withoutPrefix = trimmed.slice('files/'.length)
|
||||
if (withoutPrefix.endsWith('/meta.json')) {
|
||||
@@ -406,6 +411,13 @@ export function normalizeWorkspaceFileReference(fileReference: string): string {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
/**
|
||||
* Canonical sandbox mount path for an existing workspace file.
|
||||
*/
|
||||
export function getSandboxWorkspaceFilePath(file: Pick<WorkspaceFileRecord, 'id' | 'name'>): string {
|
||||
return `/home/user/files/${file.id}/${file.name}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a workspace file record in an existing list from either its id or a VFS/name reference.
|
||||
* For copilot `open_resource` and the resource panel, use {@link getWorkspaceFile} with a UUID only.
|
||||
@@ -420,10 +432,13 @@ export function findWorkspaceFileRecord(
|
||||
}
|
||||
|
||||
const normalizedReference = normalizeWorkspaceFileReference(fileReference)
|
||||
const normalizedIdMatch = files.find((file) => file.id === normalizedReference)
|
||||
if (normalizedIdMatch) {
|
||||
return normalizedIdMatch
|
||||
}
|
||||
|
||||
const segmentKey = normalizeVfsSegment(normalizedReference)
|
||||
return (
|
||||
files.find((file) => normalizeVfsSegment(file.name) === segmentKey) ?? null
|
||||
)
|
||||
return files.find((file) => normalizeVfsSegment(file.name) === segmentKey) ?? null
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user