mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
fix(mothership): async resume and tool result ordering (#3735)
* fix(mothership): async resume and tool result ordering * ensure tool call terminal state * address comments
This commit is contained in:
committed by
GitHub
parent
41a7d247ea
commit
0c80438ede
@@ -1,318 +0,0 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import type { OrchestratorOptions } from './types'
|
||||
|
||||
const {
|
||||
prepareExecutionContext,
|
||||
getEffectiveDecryptedEnv,
|
||||
runStreamLoop,
|
||||
claimCompletedAsyncToolCall,
|
||||
getAsyncToolCall,
|
||||
getAsyncToolCalls,
|
||||
markAsyncToolDelivered,
|
||||
releaseCompletedAsyncToolClaim,
|
||||
updateRunStatus,
|
||||
} = vi.hoisted(() => ({
|
||||
prepareExecutionContext: vi.fn(),
|
||||
getEffectiveDecryptedEnv: vi.fn(),
|
||||
runStreamLoop: vi.fn(),
|
||||
claimCompletedAsyncToolCall: vi.fn(),
|
||||
getAsyncToolCall: vi.fn(),
|
||||
getAsyncToolCalls: vi.fn(),
|
||||
markAsyncToolDelivered: vi.fn(),
|
||||
releaseCompletedAsyncToolClaim: vi.fn(),
|
||||
updateRunStatus: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/copilot/orchestrator/tool-executor', () => ({
|
||||
prepareExecutionContext,
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/environment/utils', () => ({
|
||||
getEffectiveDecryptedEnv,
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/copilot/async-runs/repository', () => ({
|
||||
claimCompletedAsyncToolCall,
|
||||
getAsyncToolCall,
|
||||
getAsyncToolCalls,
|
||||
markAsyncToolDelivered,
|
||||
releaseCompletedAsyncToolClaim,
|
||||
updateRunStatus,
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/copilot/orchestrator/stream/core', async () => {
|
||||
const actual = await vi.importActual<typeof import('./stream/core')>('./stream/core')
|
||||
return {
|
||||
...actual,
|
||||
buildToolCallSummaries: vi.fn(() => []),
|
||||
runStreamLoop,
|
||||
}
|
||||
})
|
||||
|
||||
import { orchestrateCopilotStream } from './index'
|
||||
|
||||
describe('orchestrateCopilotStream async continuation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
prepareExecutionContext.mockResolvedValue({
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
})
|
||||
getEffectiveDecryptedEnv.mockResolvedValue({})
|
||||
claimCompletedAsyncToolCall.mockResolvedValue({ toolCallId: 'tool-1' })
|
||||
getAsyncToolCall.mockResolvedValue({
|
||||
toolCallId: 'tool-1',
|
||||
toolName: 'read',
|
||||
status: 'completed',
|
||||
result: { ok: true },
|
||||
error: null,
|
||||
})
|
||||
getAsyncToolCalls.mockResolvedValue([
|
||||
{
|
||||
toolCallId: 'tool-1',
|
||||
toolName: 'read',
|
||||
status: 'completed',
|
||||
result: { ok: true },
|
||||
error: null,
|
||||
},
|
||||
])
|
||||
markAsyncToolDelivered.mockResolvedValue(null)
|
||||
releaseCompletedAsyncToolClaim.mockResolvedValue(null)
|
||||
updateRunStatus.mockResolvedValue(null)
|
||||
})
|
||||
|
||||
it('builds resume payloads with success=true for claimed completed rows', async () => {
|
||||
runStreamLoop
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1'],
|
||||
}
|
||||
})
|
||||
.mockImplementationOnce(async (url: string, opts: RequestInit) => {
|
||||
expect(url).toContain('/api/tools/resume')
|
||||
const body = JSON.parse(String(opts.body))
|
||||
expect(body).toEqual({
|
||||
checkpointId: 'checkpoint-1',
|
||||
results: [
|
||||
{
|
||||
callId: 'tool-1',
|
||||
name: 'read',
|
||||
data: { ok: true },
|
||||
success: true,
|
||||
},
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(markAsyncToolDelivered).toHaveBeenCalledWith('tool-1')
|
||||
})
|
||||
|
||||
it('marks claimed tool calls delivered even when the resumed stream later records errors', async () => {
|
||||
runStreamLoop
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1'],
|
||||
}
|
||||
})
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.errors.push('resume stream failed after handoff')
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(markAsyncToolDelivered).toHaveBeenCalledWith('tool-1')
|
||||
})
|
||||
|
||||
it('forwards done events while still marking async pauses on the run', async () => {
|
||||
const onEvent = vi.fn()
|
||||
const streamOptions: OrchestratorOptions = { onEvent }
|
||||
runStreamLoop.mockImplementationOnce(
|
||||
async (_url: string, _opts: RequestInit, _context: any, _exec: any, loopOptions: any) => {
|
||||
await loopOptions.onEvent({
|
||||
type: 'done',
|
||||
data: {
|
||||
response: {
|
||||
async_pause: {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
...streamOptions,
|
||||
}
|
||||
)
|
||||
|
||||
expect(onEvent).toHaveBeenCalledWith(expect.objectContaining({ type: 'done' }))
|
||||
expect(updateRunStatus).toHaveBeenCalledWith('run-1', 'paused_waiting_for_tool')
|
||||
})
|
||||
|
||||
it('waits for a local running tool before retrying the claim', async () => {
|
||||
const localPendingPromise = Promise.resolve({
|
||||
status: 'success',
|
||||
data: { ok: true },
|
||||
})
|
||||
|
||||
claimCompletedAsyncToolCall
|
||||
.mockResolvedValueOnce(null)
|
||||
.mockResolvedValueOnce({ toolCallId: 'tool-1' })
|
||||
getAsyncToolCall
|
||||
.mockResolvedValueOnce({
|
||||
toolCallId: 'tool-1',
|
||||
toolName: 'read',
|
||||
status: 'running',
|
||||
result: null,
|
||||
error: null,
|
||||
})
|
||||
.mockResolvedValue({
|
||||
toolCallId: 'tool-1',
|
||||
toolName: 'read',
|
||||
status: 'completed',
|
||||
result: { ok: true },
|
||||
error: null,
|
||||
})
|
||||
|
||||
runStreamLoop
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1'],
|
||||
}
|
||||
context.pendingToolPromises.set('tool-1', localPendingPromise)
|
||||
})
|
||||
.mockImplementationOnce(async (url: string, opts: RequestInit) => {
|
||||
expect(url).toContain('/api/tools/resume')
|
||||
const body = JSON.parse(String(opts.body))
|
||||
expect(body.results[0]).toEqual({
|
||||
callId: 'tool-1',
|
||||
name: 'read',
|
||||
data: { ok: true },
|
||||
success: true,
|
||||
})
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(runStreamLoop).toHaveBeenCalledTimes(2)
|
||||
expect(markAsyncToolDelivered).toHaveBeenCalledWith('tool-1')
|
||||
})
|
||||
|
||||
it('releases claimed rows if the resume stream throws before delivery is marked', async () => {
|
||||
runStreamLoop
|
||||
.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1'],
|
||||
}
|
||||
})
|
||||
.mockImplementationOnce(async () => {
|
||||
throw new Error('resume failed')
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
expect(releaseCompletedAsyncToolClaim).toHaveBeenCalledWith('tool-1', 'run-1')
|
||||
expect(markAsyncToolDelivered).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not send a partial resume payload when only some pending tool calls are claimable', async () => {
|
||||
claimCompletedAsyncToolCall
|
||||
.mockResolvedValueOnce({ toolCallId: 'tool-1' })
|
||||
.mockResolvedValueOnce(null)
|
||||
.mockResolvedValueOnce({ toolCallId: 'tool-1' })
|
||||
.mockResolvedValueOnce(null)
|
||||
.mockResolvedValueOnce({ toolCallId: 'tool-1' })
|
||||
.mockResolvedValueOnce(null)
|
||||
.mockResolvedValueOnce({ toolCallId: 'tool-1' })
|
||||
.mockResolvedValueOnce(null)
|
||||
getAsyncToolCall.mockResolvedValue(null)
|
||||
|
||||
runStreamLoop.mockImplementationOnce(async (_url: string, _opts: RequestInit, context: any) => {
|
||||
context.awaitingAsyncContinuation = {
|
||||
checkpointId: 'checkpoint-1',
|
||||
runId: 'run-1',
|
||||
pendingToolCallIds: ['tool-1', 'tool-2'],
|
||||
}
|
||||
})
|
||||
|
||||
const result = await orchestrateCopilotStream(
|
||||
{ message: 'hello' },
|
||||
{
|
||||
userId: 'user-1',
|
||||
workflowId: 'workflow-1',
|
||||
chatId: 'chat-1',
|
||||
executionId: 'exec-1',
|
||||
runId: 'run-1',
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
expect(runStreamLoop).toHaveBeenCalledTimes(1)
|
||||
expect(releaseCompletedAsyncToolClaim).toHaveBeenCalledWith('tool-1', 'run-1')
|
||||
expect(markAsyncToolDelivered).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -14,12 +14,17 @@ import {
|
||||
updateRunStatus,
|
||||
} from '@/lib/copilot/async-runs/repository'
|
||||
import { SIM_AGENT_API_URL, SIM_AGENT_VERSION } from '@/lib/copilot/constants'
|
||||
import { prepareExecutionContext } from '@/lib/copilot/orchestrator/tool-executor'
|
||||
import type {
|
||||
ExecutionContext,
|
||||
OrchestratorOptions,
|
||||
OrchestratorResult,
|
||||
SSEEvent,
|
||||
import {
|
||||
isToolAvailableOnSimSide,
|
||||
prepareExecutionContext,
|
||||
} from '@/lib/copilot/orchestrator/tool-executor'
|
||||
import {
|
||||
type ExecutionContext,
|
||||
isTerminalToolCallStatus,
|
||||
type OrchestratorOptions,
|
||||
type OrchestratorResult,
|
||||
type SSEEvent,
|
||||
type ToolCallState,
|
||||
} from '@/lib/copilot/orchestrator/types'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { getEffectiveDecryptedEnv } from '@/lib/environment/utils'
|
||||
@@ -31,18 +36,9 @@ function didAsyncToolSucceed(input: {
|
||||
durableStatus?: string | null
|
||||
durableResult?: Record<string, unknown>
|
||||
durableError?: string | null
|
||||
completion?: { status: string } | undefined
|
||||
toolStateSuccess?: boolean | undefined
|
||||
toolStateStatus?: string | undefined
|
||||
}) {
|
||||
const {
|
||||
durableStatus,
|
||||
durableResult,
|
||||
durableError,
|
||||
completion,
|
||||
toolStateSuccess,
|
||||
toolStateStatus,
|
||||
} = input
|
||||
const { durableStatus, durableResult, durableError, toolStateStatus } = input
|
||||
|
||||
if (durableStatus === ASYNC_TOOL_STATUS.completed) {
|
||||
return true
|
||||
@@ -61,7 +57,15 @@ function didAsyncToolSucceed(input: {
|
||||
if (toolStateStatus === 'success') return true
|
||||
if (toolStateStatus === 'error' || toolStateStatus === 'cancelled') return false
|
||||
|
||||
return completion?.status === 'success' || toolStateSuccess === true
|
||||
return false
|
||||
}
|
||||
|
||||
interface ReadyContinuationTool {
|
||||
toolCallId: string
|
||||
toolState?: ToolCallState
|
||||
durableRow?: Awaited<ReturnType<typeof getAsyncToolCall>>
|
||||
needsDurableClaim: boolean
|
||||
alreadyClaimedByWorker: boolean
|
||||
}
|
||||
|
||||
export interface OrchestrateStreamOptions extends OrchestratorOptions {
|
||||
@@ -190,33 +194,21 @@ export async function orchestrateCopilotStream(
|
||||
if (!continuation) break
|
||||
|
||||
let resumeReady = false
|
||||
let emptyClaimRetries = 0
|
||||
let resumeRetries = 0
|
||||
for (;;) {
|
||||
claimedToolCallIds = []
|
||||
claimedByWorkerId = null
|
||||
const resumeWorkerId = continuation.runId || context.runId || context.messageId
|
||||
claimedByWorkerId = resumeWorkerId
|
||||
const claimableToolCallIds: string[] = []
|
||||
const readyTools: ReadyContinuationTool[] = []
|
||||
const localPendingPromises: Promise<unknown>[] = []
|
||||
const missingToolCallIds: string[] = []
|
||||
|
||||
for (const toolCallId of continuation.pendingToolCallIds) {
|
||||
const claimed = await claimCompletedAsyncToolCall(toolCallId, resumeWorkerId).catch(
|
||||
() => null
|
||||
)
|
||||
if (claimed) {
|
||||
claimableToolCallIds.push(toolCallId)
|
||||
claimedToolCallIds.push(toolCallId)
|
||||
continue
|
||||
}
|
||||
const durableRow = await getAsyncToolCall(toolCallId).catch(() => null)
|
||||
const localPendingPromise = context.pendingToolPromises.get(toolCallId)
|
||||
if (!durableRow && localPendingPromise) {
|
||||
claimableToolCallIds.push(toolCallId)
|
||||
continue
|
||||
}
|
||||
if (
|
||||
durableRow &&
|
||||
durableRow.status === ASYNC_TOOL_STATUS.running &&
|
||||
localPendingPromise
|
||||
) {
|
||||
const toolState = context.toolCalls.get(toolCallId)
|
||||
|
||||
if (localPendingPromise) {
|
||||
localPendingPromises.push(localPendingPromise)
|
||||
logger.info('Waiting for local async tool completion before retrying resume claim', {
|
||||
toolCallId,
|
||||
@@ -224,21 +216,55 @@ export async function orchestrateCopilotStream(
|
||||
})
|
||||
continue
|
||||
}
|
||||
const toolState = context.toolCalls.get(toolCallId)
|
||||
if (!durableRow && !localPendingPromise && toolState) {
|
||||
|
||||
if (durableRow && isTerminalAsyncStatus(durableRow.status)) {
|
||||
if (durableRow.claimedBy && durableRow.claimedBy !== resumeWorkerId) {
|
||||
missingToolCallIds.push(toolCallId)
|
||||
logger.warn('Async tool continuation is waiting on a claim held by another worker', {
|
||||
toolCallId,
|
||||
runId: continuation.runId,
|
||||
claimedBy: durableRow.claimedBy,
|
||||
})
|
||||
continue
|
||||
}
|
||||
readyTools.push({
|
||||
toolCallId,
|
||||
toolState,
|
||||
durableRow,
|
||||
needsDurableClaim: durableRow.claimedBy !== resumeWorkerId,
|
||||
alreadyClaimedByWorker: durableRow.claimedBy === resumeWorkerId,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if (
|
||||
!durableRow &&
|
||||
toolState &&
|
||||
isTerminalToolCallStatus(toolState.status) &&
|
||||
!isToolAvailableOnSimSide(toolState.name)
|
||||
) {
|
||||
logger.info('Including Go-handled tool in resume payload (no Sim-side row)', {
|
||||
toolCallId,
|
||||
toolName: toolState.name,
|
||||
status: toolState.status,
|
||||
runId: continuation.runId,
|
||||
})
|
||||
claimableToolCallIds.push(toolCallId)
|
||||
readyTools.push({
|
||||
toolCallId,
|
||||
toolState,
|
||||
needsDurableClaim: false,
|
||||
alreadyClaimedByWorker: false,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.warn('Skipping already-claimed or missing async tool resume', {
|
||||
toolCallId,
|
||||
runId: continuation.runId,
|
||||
durableStatus: durableRow?.status,
|
||||
toolStateStatus: toolState?.status,
|
||||
})
|
||||
missingToolCallIds.push(toolCallId)
|
||||
}
|
||||
|
||||
if (localPendingPromises.length > 0) {
|
||||
@@ -246,83 +272,104 @@ export async function orchestrateCopilotStream(
|
||||
continue
|
||||
}
|
||||
|
||||
const missingToolCallIds = continuation.pendingToolCallIds.filter(
|
||||
(toolCallId) => !claimableToolCallIds.includes(toolCallId)
|
||||
)
|
||||
if (missingToolCallIds.length > 0) {
|
||||
if (claimedToolCallIds.length > 0 && claimedByWorkerId) {
|
||||
logger.info('Releasing partial async tool claims before retrying resume', {
|
||||
if (resumeRetries < 3) {
|
||||
resumeRetries++
|
||||
logger.info('Retrying async resume after some tool calls were not yet ready', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
claimedToolCallIds,
|
||||
retry: resumeRetries,
|
||||
missingToolCallIds,
|
||||
})
|
||||
await Promise.all(
|
||||
claimedToolCallIds.map((toolCallId) =>
|
||||
releaseCompletedAsyncToolClaim(toolCallId, claimedByWorkerId!).catch(() => null)
|
||||
)
|
||||
)
|
||||
claimedToolCallIds = []
|
||||
claimedByWorkerId = null
|
||||
}
|
||||
if (emptyClaimRetries < 3) {
|
||||
emptyClaimRetries++
|
||||
logger.info(
|
||||
'Retrying async resume claim after only a subset of tool calls were claimable',
|
||||
{
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
retry: emptyClaimRetries,
|
||||
missingToolCallIds,
|
||||
}
|
||||
)
|
||||
await new Promise((resolve) => setTimeout(resolve, 250 * emptyClaimRetries))
|
||||
await new Promise((resolve) => setTimeout(resolve, 250 * resumeRetries))
|
||||
continue
|
||||
}
|
||||
logger.warn('Skipping async resume because not all tool calls were claimable', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
claimableToolCallIds,
|
||||
missingToolCallIds,
|
||||
})
|
||||
context.awaitingAsyncContinuation = undefined
|
||||
break
|
||||
throw new Error(
|
||||
`Failed to resume async tool continuation: pending tool calls were not ready (${missingToolCallIds.join(', ')})`
|
||||
)
|
||||
}
|
||||
|
||||
if (claimableToolCallIds.length === 0) {
|
||||
if (emptyClaimRetries < 3 && continuation.pendingToolCallIds.length > 0) {
|
||||
emptyClaimRetries++
|
||||
logger.info('Retrying async resume claim after no tool calls were claimable', {
|
||||
if (readyTools.length === 0) {
|
||||
if (resumeRetries < 3 && continuation.pendingToolCallIds.length > 0) {
|
||||
resumeRetries++
|
||||
logger.info('Retrying async resume because no tool calls were ready yet', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
retry: emptyClaimRetries,
|
||||
retry: resumeRetries,
|
||||
})
|
||||
await new Promise((resolve) => setTimeout(resolve, 250 * emptyClaimRetries))
|
||||
await new Promise((resolve) => setTimeout(resolve, 250 * resumeRetries))
|
||||
continue
|
||||
}
|
||||
logger.warn('Skipping async resume because no tool calls were claimable', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
})
|
||||
context.awaitingAsyncContinuation = undefined
|
||||
break
|
||||
throw new Error('Failed to resume async tool continuation: no tool calls were ready')
|
||||
}
|
||||
|
||||
const claimCandidates = readyTools.filter((tool) => tool.needsDurableClaim)
|
||||
const newlyClaimedToolCallIds: string[] = []
|
||||
const claimFailures: string[] = []
|
||||
|
||||
for (const tool of claimCandidates) {
|
||||
const claimed = await claimCompletedAsyncToolCall(tool.toolCallId, resumeWorkerId).catch(
|
||||
() => null
|
||||
)
|
||||
if (!claimed) {
|
||||
claimFailures.push(tool.toolCallId)
|
||||
continue
|
||||
}
|
||||
newlyClaimedToolCallIds.push(tool.toolCallId)
|
||||
}
|
||||
|
||||
if (claimFailures.length > 0) {
|
||||
if (newlyClaimedToolCallIds.length > 0) {
|
||||
logger.info('Releasing async tool claims after claim contention during resume', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
newlyClaimedToolCallIds,
|
||||
claimFailures,
|
||||
})
|
||||
await Promise.all(
|
||||
newlyClaimedToolCallIds.map((toolCallId) =>
|
||||
releaseCompletedAsyncToolClaim(toolCallId, resumeWorkerId).catch(() => null)
|
||||
)
|
||||
)
|
||||
}
|
||||
if (resumeRetries < 3) {
|
||||
resumeRetries++
|
||||
logger.info('Retrying async resume after claim contention', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
retry: resumeRetries,
|
||||
claimFailures,
|
||||
})
|
||||
await new Promise((resolve) => setTimeout(resolve, 250 * resumeRetries))
|
||||
continue
|
||||
}
|
||||
throw new Error(
|
||||
`Failed to resume async tool continuation: unable to claim tool calls (${claimFailures.join(', ')})`
|
||||
)
|
||||
}
|
||||
|
||||
claimedToolCallIds = [
|
||||
...readyTools
|
||||
.filter((tool) => tool.alreadyClaimedByWorker)
|
||||
.map((tool) => tool.toolCallId),
|
||||
...newlyClaimedToolCallIds,
|
||||
]
|
||||
claimedByWorkerId = claimedToolCallIds.length > 0 ? resumeWorkerId : null
|
||||
|
||||
logger.info('Resuming async tool continuation', {
|
||||
checkpointId: continuation.checkpointId,
|
||||
runId: continuation.runId,
|
||||
toolCallIds: claimableToolCallIds,
|
||||
toolCallIds: readyTools.map((tool) => tool.toolCallId),
|
||||
})
|
||||
|
||||
const durableRows = await getAsyncToolCalls(claimableToolCallIds).catch(() => [])
|
||||
const durableRows = await getAsyncToolCalls(
|
||||
readyTools.map((tool) => tool.toolCallId)
|
||||
).catch(() => [])
|
||||
const durableByToolCallId = new Map(durableRows.map((row) => [row.toolCallId, row]))
|
||||
|
||||
const results = await Promise.all(
|
||||
claimableToolCallIds.map(async (toolCallId) => {
|
||||
const completion = await context.pendingToolPromises.get(toolCallId)
|
||||
const toolState = context.toolCalls.get(toolCallId)
|
||||
|
||||
const durable = durableByToolCallId.get(toolCallId)
|
||||
readyTools.map(async (tool) => {
|
||||
const durable = durableByToolCallId.get(tool.toolCallId) || tool.durableRow
|
||||
const durableStatus = durable?.status
|
||||
const durableResult =
|
||||
durable?.result && typeof durable.result === 'object'
|
||||
@@ -332,19 +379,15 @@ export async function orchestrateCopilotStream(
|
||||
durableStatus,
|
||||
durableResult,
|
||||
durableError: durable?.error,
|
||||
completion,
|
||||
toolStateSuccess: toolState?.result?.success,
|
||||
toolStateStatus: toolState?.status,
|
||||
toolStateStatus: tool.toolState?.status,
|
||||
})
|
||||
const data =
|
||||
durableResult ||
|
||||
completion?.data ||
|
||||
(toolState?.result?.output as Record<string, unknown> | undefined) ||
|
||||
(tool.toolState?.result?.output as Record<string, unknown> | undefined) ||
|
||||
(success
|
||||
? { message: completion?.message || 'Tool completed' }
|
||||
? { message: 'Tool completed' }
|
||||
: {
|
||||
error:
|
||||
completion?.message || durable?.error || toolState?.error || 'Tool failed',
|
||||
error: durable?.error || tool.toolState?.error || 'Tool failed',
|
||||
})
|
||||
|
||||
if (
|
||||
@@ -353,14 +396,14 @@ export async function orchestrateCopilotStream(
|
||||
!isDeliveredAsyncStatus(durableStatus)
|
||||
) {
|
||||
logger.warn('Async tool row was claimed for resume without terminal durable state', {
|
||||
toolCallId,
|
||||
toolCallId: tool.toolCallId,
|
||||
status: durableStatus,
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
callId: toolCallId,
|
||||
name: durable?.toolName || toolState?.name || '',
|
||||
callId: tool.toolCallId,
|
||||
name: durable?.toolName || tool.toolState?.name || '',
|
||||
data,
|
||||
success,
|
||||
}
|
||||
|
||||
@@ -209,4 +209,76 @@ describe('sse-handlers tool lifecycle', () => {
|
||||
expect(markToolComplete).toHaveBeenCalledTimes(1)
|
||||
expect(context.toolCalls.get('tool-upsert-fail')?.status).toBe('success')
|
||||
})
|
||||
|
||||
it('does not execute a tool if a terminal tool_result arrives before local execution starts', async () => {
|
||||
let resolveUpsert: ((value: null) => void) | undefined
|
||||
upsertAsyncToolCall.mockImplementationOnce(
|
||||
() =>
|
||||
new Promise((resolve) => {
|
||||
resolveUpsert = resolve
|
||||
})
|
||||
)
|
||||
const onEvent = vi.fn()
|
||||
|
||||
await sseHandlers.tool_call(
|
||||
{
|
||||
type: 'tool_call',
|
||||
data: { id: 'tool-race', name: 'read', arguments: { workflowId: 'workflow-1' } },
|
||||
} as any,
|
||||
context,
|
||||
execContext,
|
||||
{ onEvent, interactive: false, timeout: 1000 }
|
||||
)
|
||||
|
||||
await sseHandlers.tool_result(
|
||||
{
|
||||
type: 'tool_result',
|
||||
toolCallId: 'tool-race',
|
||||
data: { id: 'tool-race', success: true, result: { ok: true } },
|
||||
} as any,
|
||||
context,
|
||||
execContext,
|
||||
{ onEvent, interactive: false, timeout: 1000 }
|
||||
)
|
||||
|
||||
resolveUpsert?.(null)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
|
||||
expect(executeToolServerSide).not.toHaveBeenCalled()
|
||||
expect(markToolComplete).not.toHaveBeenCalled()
|
||||
expect(context.toolCalls.get('tool-race')?.status).toBe('success')
|
||||
expect(context.toolCalls.get('tool-race')?.result?.output).toEqual({ ok: true })
|
||||
})
|
||||
|
||||
it('does not execute a tool if a tool_result arrives before the tool_call event', async () => {
|
||||
const onEvent = vi.fn()
|
||||
|
||||
await sseHandlers.tool_result(
|
||||
{
|
||||
type: 'tool_result',
|
||||
toolCallId: 'tool-early-result',
|
||||
toolName: 'read',
|
||||
data: { id: 'tool-early-result', name: 'read', success: true, result: { ok: true } },
|
||||
} as any,
|
||||
context,
|
||||
execContext,
|
||||
{ onEvent, interactive: false, timeout: 1000 }
|
||||
)
|
||||
|
||||
await sseHandlers.tool_call(
|
||||
{
|
||||
type: 'tool_call',
|
||||
data: { id: 'tool-early-result', name: 'read', arguments: { workflowId: 'workflow-1' } },
|
||||
} as any,
|
||||
context,
|
||||
execContext,
|
||||
{ onEvent, interactive: false, timeout: 1000 }
|
||||
)
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
|
||||
expect(executeToolServerSide).not.toHaveBeenCalled()
|
||||
expect(markToolComplete).not.toHaveBeenCalled()
|
||||
expect(context.toolCalls.get('tool-early-result')?.status).toBe('success')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -213,6 +213,27 @@ function inferToolSuccess(data: Record<string, unknown> | undefined): {
|
||||
return { success, hasResultData, hasError }
|
||||
}
|
||||
|
||||
function ensureTerminalToolCallState(
|
||||
context: StreamingContext,
|
||||
toolCallId: string,
|
||||
toolName: string
|
||||
): ToolCallState {
|
||||
const existing = context.toolCalls.get(toolCallId)
|
||||
if (existing) {
|
||||
return existing
|
||||
}
|
||||
|
||||
const toolCall: ToolCallState = {
|
||||
id: toolCallId,
|
||||
name: toolName || 'unknown_tool',
|
||||
status: 'pending',
|
||||
startTime: Date.now(),
|
||||
}
|
||||
context.toolCalls.set(toolCallId, toolCall)
|
||||
addContentBlock(context, { type: 'tool_call', toolCall })
|
||||
return toolCall
|
||||
}
|
||||
|
||||
export type SSEHandler = (
|
||||
event: SSEEvent,
|
||||
context: StreamingContext,
|
||||
@@ -246,8 +267,12 @@ export const sseHandlers: Record<string, SSEHandler> = {
|
||||
const data = getEventData(event)
|
||||
const toolCallId = event.toolCallId || (data?.id as string | undefined)
|
||||
if (!toolCallId) return
|
||||
const current = context.toolCalls.get(toolCallId)
|
||||
if (!current) return
|
||||
const toolName =
|
||||
event.toolName ||
|
||||
(data?.name as string | undefined) ||
|
||||
context.toolCalls.get(toolCallId)?.name ||
|
||||
''
|
||||
const current = ensureTerminalToolCallState(context, toolCallId, toolName)
|
||||
|
||||
const { success, hasResultData, hasError } = inferToolSuccess(data)
|
||||
|
||||
@@ -263,16 +288,22 @@ export const sseHandlers: Record<string, SSEHandler> = {
|
||||
const resultObj = asRecord(data?.result)
|
||||
current.error = (data?.error || resultObj.error) as string | undefined
|
||||
}
|
||||
markToolResultSeen(toolCallId)
|
||||
},
|
||||
tool_error: (event, context) => {
|
||||
const data = getEventData(event)
|
||||
const toolCallId = event.toolCallId || (data?.id as string | undefined)
|
||||
if (!toolCallId) return
|
||||
const current = context.toolCalls.get(toolCallId)
|
||||
if (!current) return
|
||||
const toolName =
|
||||
event.toolName ||
|
||||
(data?.name as string | undefined) ||
|
||||
context.toolCalls.get(toolCallId)?.name ||
|
||||
''
|
||||
const current = ensureTerminalToolCallState(context, toolCallId, toolName)
|
||||
current.status = 'error'
|
||||
current.error = (data?.error as string | undefined) || 'Tool execution failed'
|
||||
current.endTime = Date.now()
|
||||
markToolResultSeen(toolCallId)
|
||||
},
|
||||
tool_call_delta: () => {
|
||||
// Argument streaming delta — no action needed on orchestrator side
|
||||
@@ -313,6 +344,9 @@ export const sseHandlers: Record<string, SSEHandler> = {
|
||||
existing?.endTime ||
|
||||
(existing && existing.status !== 'pending' && existing.status !== 'executing')
|
||||
) {
|
||||
if (!existing.name && toolName) {
|
||||
existing.name = toolName
|
||||
}
|
||||
if (!existing.params && args) {
|
||||
existing.params = args
|
||||
}
|
||||
@@ -558,6 +592,12 @@ export const subAgentHandlers: Record<string, SSEHandler> = {
|
||||
const existing = context.toolCalls.get(toolCallId)
|
||||
// Ignore late/duplicate tool_call events once we already have a result.
|
||||
if (wasToolResultSeen(toolCallId) || existing?.endTime) {
|
||||
if (existing && !existing.name && toolName) {
|
||||
existing.name = toolName
|
||||
}
|
||||
if (existing && !existing.params && args) {
|
||||
existing.params = args
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -686,13 +726,14 @@ export const subAgentHandlers: Record<string, SSEHandler> = {
|
||||
const data = getEventData(event)
|
||||
const toolCallId = event.toolCallId || (data?.id as string | undefined)
|
||||
if (!toolCallId) return
|
||||
const toolName = event.toolName || (data?.name as string | undefined) || ''
|
||||
|
||||
// Update in subAgentToolCalls.
|
||||
const toolCalls = context.subAgentToolCalls[parentToolCallId] || []
|
||||
const subAgentToolCall = toolCalls.find((tc) => tc.id === toolCallId)
|
||||
|
||||
// Also update in main toolCalls (where we added it for execution).
|
||||
const mainToolCall = context.toolCalls.get(toolCallId)
|
||||
const mainToolCall = ensureTerminalToolCallState(context, toolCallId, toolName)
|
||||
|
||||
const { success, hasResultData, hasError } = inferToolSuccess(data)
|
||||
|
||||
@@ -719,6 +760,9 @@ export const subAgentHandlers: Record<string, SSEHandler> = {
|
||||
mainToolCall.error = (data?.error || resultObj.error) as string | undefined
|
||||
}
|
||||
}
|
||||
if (subAgentToolCall || mainToolCall) {
|
||||
markToolResultSeen(toolCallId)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -4,18 +4,15 @@ import { createLogger } from '@sim/logger'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { completeAsyncToolCall, markAsyncToolRunning } from '@/lib/copilot/async-runs/repository'
|
||||
import { waitForToolConfirmation } from '@/lib/copilot/orchestrator/persistence'
|
||||
import {
|
||||
asRecord,
|
||||
markToolResultSeen,
|
||||
wasToolResultSeen,
|
||||
} from '@/lib/copilot/orchestrator/sse/utils'
|
||||
import { asRecord, markToolResultSeen } from '@/lib/copilot/orchestrator/sse/utils'
|
||||
import { executeToolServerSide, markToolComplete } from '@/lib/copilot/orchestrator/tool-executor'
|
||||
import type {
|
||||
ExecutionContext,
|
||||
OrchestratorOptions,
|
||||
SSEEvent,
|
||||
StreamingContext,
|
||||
ToolCallResult,
|
||||
import {
|
||||
type ExecutionContext,
|
||||
isTerminalToolCallStatus,
|
||||
type OrchestratorOptions,
|
||||
type SSEEvent,
|
||||
type StreamingContext,
|
||||
type ToolCallResult,
|
||||
} from '@/lib/copilot/orchestrator/types'
|
||||
import {
|
||||
extractDeletedResourcesFromToolResult,
|
||||
@@ -247,6 +244,48 @@ function cancelledCompletion(message: string): AsyncToolCompletion {
|
||||
}
|
||||
}
|
||||
|
||||
function terminalCompletionFromToolCall(toolCall: {
|
||||
status: string
|
||||
error?: string
|
||||
result?: { output?: unknown; error?: string }
|
||||
}): AsyncToolCompletion {
|
||||
if (toolCall.status === 'cancelled') {
|
||||
return cancelledCompletion(toolCall.error || 'Tool execution cancelled')
|
||||
}
|
||||
|
||||
if (toolCall.status === 'success') {
|
||||
return {
|
||||
status: 'success',
|
||||
message: 'Tool completed',
|
||||
data:
|
||||
toolCall.result?.output &&
|
||||
typeof toolCall.result.output === 'object' &&
|
||||
!Array.isArray(toolCall.result.output)
|
||||
? (toolCall.result.output as Record<string, unknown>)
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCall.status === 'skipped') {
|
||||
return {
|
||||
status: 'success',
|
||||
message: 'Tool skipped',
|
||||
data:
|
||||
toolCall.result?.output &&
|
||||
typeof toolCall.result.output === 'object' &&
|
||||
!Array.isArray(toolCall.result.output)
|
||||
? (toolCall.result.output as Record<string, unknown>)
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
status: toolCall.status === 'rejected' ? 'rejected' : 'error',
|
||||
message: toolCall.error || toolCall.result?.error || 'Tool failed',
|
||||
data: { error: toolCall.error || toolCall.result?.error || 'Tool failed' },
|
||||
}
|
||||
}
|
||||
|
||||
function reportCancelledTool(
|
||||
toolCall: { id: string; name: string },
|
||||
message: string,
|
||||
@@ -509,8 +548,8 @@ export async function executeToolAndReport(
|
||||
if (toolCall.status === 'executing') {
|
||||
return { status: 'running', message: 'Tool already executing' }
|
||||
}
|
||||
if (wasToolResultSeen(toolCall.id)) {
|
||||
return { status: 'success', message: 'Tool result already processed' }
|
||||
if (toolCall.endTime || isTerminalToolCallStatus(toolCall.status)) {
|
||||
return terminalCompletionFromToolCall(toolCall)
|
||||
}
|
||||
|
||||
if (abortRequested(context, execContext, options)) {
|
||||
@@ -538,6 +577,9 @@ export async function executeToolAndReport(
|
||||
|
||||
try {
|
||||
let result = await executeToolServerSide(toolCall, execContext)
|
||||
if (toolCall.endTime || isTerminalToolCallStatus(toolCall.status)) {
|
||||
return terminalCompletionFromToolCall(toolCall)
|
||||
}
|
||||
if (abortRequested(context, execContext, options)) {
|
||||
toolCall.status = 'cancelled'
|
||||
toolCall.endTime = Date.now()
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import {
|
||||
markToolResultSeen,
|
||||
normalizeSseEvent,
|
||||
shouldSkipToolCallEvent,
|
||||
shouldSkipToolResultEvent,
|
||||
@@ -37,6 +38,7 @@ describe('sse-utils', () => {
|
||||
it.concurrent('dedupes tool_result events', () => {
|
||||
const event = { type: 'tool_result', data: { id: 'tool_result_1', name: 'plan' } }
|
||||
expect(shouldSkipToolResultEvent(event as any)).toBe(false)
|
||||
markToolResultSeen('tool_result_1')
|
||||
expect(shouldSkipToolResultEvent(event as any)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -125,7 +125,5 @@ export function shouldSkipToolResultEvent(event: SSEEvent): boolean {
|
||||
if (event.type !== 'tool_result') return false
|
||||
const toolCallId = getToolCallIdFromEvent(event)
|
||||
if (!toolCallId) return false
|
||||
if (wasToolResultSeen(toolCallId)) return true
|
||||
markToolResultSeen(toolCallId)
|
||||
return false
|
||||
return wasToolResultSeen(toolCallId)
|
||||
}
|
||||
|
||||
@@ -59,6 +59,18 @@ export type ToolCallStatus =
|
||||
| 'rejected'
|
||||
| 'cancelled'
|
||||
|
||||
const TERMINAL_TOOL_STATUSES: ReadonlySet<ToolCallStatus> = new Set([
|
||||
'success',
|
||||
'error',
|
||||
'cancelled',
|
||||
'skipped',
|
||||
'rejected',
|
||||
])
|
||||
|
||||
export function isTerminalToolCallStatus(status?: string): boolean {
|
||||
return TERMINAL_TOOL_STATUSES.has(status as ToolCallStatus)
|
||||
}
|
||||
|
||||
export interface ToolCallState {
|
||||
id: string
|
||||
name: string
|
||||
|
||||
Reference in New Issue
Block a user