This commit is contained in:
Siddharth Ganesan
2026-02-09 11:23:38 -08:00
parent d2c028f7cd
commit ed613c3a4f
6 changed files with 125 additions and 50 deletions

View File

@@ -1230,7 +1230,7 @@ function shouldShowRunSkipButtons(toolCall: CopilotToolCall): boolean {
}
// Never show buttons for tools the user has marked as always-allowed
if (useCopilotStore.getState().autoAllowedTools.includes(toolCall.name)) {
if (useCopilotStore.getState().isToolAutoAllowed(toolCall.name)) {
return false
}
@@ -1438,10 +1438,10 @@ export function ToolCall({
const paramsRef = useRef(params)
// Check if this integration tool is auto-allowed
// Subscribe to autoAllowedTools so we re-render when it changes
const autoAllowedTools = useCopilotStore((s) => s.autoAllowedTools)
const { removeAutoAllowedTool, setToolCallState } = useCopilotStore()
const isAutoAllowed = isIntegrationTool(toolCall.name) && autoAllowedTools.includes(toolCall.name)
const isAutoAllowed = useCopilotStore(
(s) => isIntegrationTool(toolCall.name) && s.isToolAutoAllowed(toolCall.name)
)
// Update edited params when toolCall params change (deep comparison to avoid resetting user edits on ref change)
useEffect(() => {

View File

@@ -499,7 +499,7 @@ export const sseHandlers: Record<string, SSEHandler> = {
const { toolCallsById } = get()
if (!toolCallsById[toolCallId]) {
const isAutoAllowed = get().autoAllowedTools.includes(toolName)
const isAutoAllowed = get().isToolAutoAllowed(toolName)
const initialState = isAutoAllowed
? ClientToolCallState.executing
: ClientToolCallState.pending
@@ -528,10 +528,7 @@ export const sseHandlers: Record<string, SSEHandler> = {
const existing = toolCallsById[id]
const toolName = name || existing?.name || 'unknown_tool'
const autoAllowedTools = get().autoAllowedTools
const isAutoAllowed =
autoAllowedTools.includes(toolName) ||
(existing?.name ? autoAllowedTools.includes(existing.name) : false)
const isAutoAllowed = get().isToolAutoAllowed(toolName)
let initialState = isAutoAllowed ? ClientToolCallState.executing : ClientToolCallState.pending
// Avoid flickering back to pending on partial/duplicate events once a tool is executing.

View File

@@ -194,7 +194,7 @@ export const subAgentSSEHandlers: Record<string, SSEHandler> = {
existingIndex >= 0 ? context.subAgentToolCalls[parentToolCallId][existingIndex] : undefined
// Auto-allowed tools skip pending state to avoid flashing interrupt buttons
const isAutoAllowed = get().autoAllowedTools.includes(name)
const isAutoAllowed = get().isToolAutoAllowed(name)
let initialState = isAutoAllowed ? ClientToolCallState.executing : ClientToolCallState.pending
// Avoid flickering back to pending on partial/duplicate events once a tool is executing.

View File

@@ -90,11 +90,22 @@ export function isTerminalState(state: string): boolean {
)
}
/**
* Resolves the appropriate terminal state for a non-terminal tool call.
* 'executing' → 'success': the server was running it, assume it completed.
* Everything else → 'aborted': never reached execution.
*/
function resolveAbortState(currentState: string): ClientToolCallState {
return currentState === ClientToolCallState.executing
? ClientToolCallState.success
: ClientToolCallState.aborted
}
export function abortAllInProgressTools(set: StoreSet, get: () => CopilotStore) {
try {
const { toolCallsById, messages } = get()
const updatedMap = { ...toolCallsById }
const abortedIds = new Set<string>()
const resolvedIds = new Map<string, ClientToolCallState>()
let hasUpdates = false
for (const [id, tc] of Object.entries(toolCallsById)) {
const st = tc.state
@@ -104,12 +115,13 @@ export function abortAllInProgressTools(set: StoreSet, get: () => CopilotStore)
st === ClientToolCallState.rejected ||
st === ClientToolCallState.aborted
if (!isTerminal || isReviewState(st)) {
abortedIds.add(id)
const resolved = resolveAbortState(st)
resolvedIds.set(id, resolved)
updatedMap[id] = {
...tc,
state: ClientToolCallState.aborted,
state: resolved,
subAgentStreaming: false,
display: resolveToolDisplay(tc.name, ClientToolCallState.aborted, id, tc.params),
display: resolveToolDisplay(tc.name, resolved, id, tc.params),
}
hasUpdates = true
} else if (tc.subAgentStreaming) {
@@ -120,7 +132,7 @@ export function abortAllInProgressTools(set: StoreSet, get: () => CopilotStore)
hasUpdates = true
}
}
if (abortedIds.size > 0 || hasUpdates) {
if (resolvedIds.size > 0 || hasUpdates) {
set({ toolCallsById: updatedMap })
set((s: CopilotStore) => {
const msgs = [...s.messages]
@@ -129,17 +141,18 @@ export function abortAllInProgressTools(set: StoreSet, get: () => CopilotStore)
if (m.role !== 'assistant' || !Array.isArray(m.contentBlocks)) continue
let changed = false
const blocks = m.contentBlocks.map((b: any) => {
if (b?.type === 'tool_call' && b.toolCall?.id && abortedIds.has(b.toolCall.id)) {
if (b?.type === 'tool_call' && b.toolCall?.id && resolvedIds.has(b.toolCall.id)) {
changed = true
const prev = b.toolCall
const resolved = resolvedIds.get(b.toolCall.id)!
return {
...b,
toolCall: {
...prev,
state: ClientToolCallState.aborted,
state: resolved,
display: resolveToolDisplay(
prev?.name,
ClientToolCallState.aborted,
resolved,
prev?.id,
prev?.params
),

View File

@@ -138,6 +138,41 @@ function updateActiveStreamEventId(
writeActiveStreamToStorage(next)
}
const AUTO_ALLOWED_TOOLS_STORAGE_KEY = 'copilot_auto_allowed_tools'
function readAutoAllowedToolsFromStorage(): string[] | null {
if (typeof window === 'undefined') return null
try {
const raw = window.localStorage.getItem(AUTO_ALLOWED_TOOLS_STORAGE_KEY)
if (!raw) return null
const parsed = JSON.parse(raw)
if (!Array.isArray(parsed)) return null
return parsed.filter((item): item is string => typeof item === 'string')
} catch (error) {
logger.warn('[AutoAllowedTools] Failed to read local cache', {
error: error instanceof Error ? error.message : String(error),
})
return null
}
}
function writeAutoAllowedToolsToStorage(tools: string[]): void {
if (typeof window === 'undefined') return
try {
window.localStorage.setItem(AUTO_ALLOWED_TOOLS_STORAGE_KEY, JSON.stringify(tools))
} catch (error) {
logger.warn('[AutoAllowedTools] Failed to write local cache', {
error: error instanceof Error ? error.message : String(error),
})
}
}
function isToolAutoAllowedByList(toolId: string, autoAllowedTools: string[]): boolean {
if (!toolId) return false
const normalizedTarget = toolId.trim()
return autoAllowedTools.some((allowed) => allowed?.trim() === normalizedTarget)
}
/**
* Clear any lingering diff preview from a previous session.
* Called lazily when the store is first activated (setWorkflowId).
@@ -870,6 +905,8 @@ async function resumeFromLiveStream(
return false
}
const cachedAutoAllowedTools = readAutoAllowedToolsFromStorage()
// Initial state (subset required for UI/streaming)
const initialState = {
mode: 'build' as const,
@@ -903,7 +940,8 @@ const initialState = {
streamingPlanContent: '',
toolCallsById: {} as Record<string, CopilotToolCall>,
suppressAutoSelect: false,
autoAllowedTools: [] as string[],
autoAllowedTools: cachedAutoAllowedTools ?? ([] as string[]),
autoAllowedToolsLoaded: cachedAutoAllowedTools !== null,
activeStream: null as CopilotStreamInfo | null,
messageQueue: [] as import('./types').QueuedMessage[],
suppressAbortContinueOption: false,
@@ -940,6 +978,9 @@ export const useCopilotStore = create<CopilotStore>()(
mode: get().mode,
selectedModel: get().selectedModel,
agentPrefetch: get().agentPrefetch,
enabledModels: get().enabledModels,
autoAllowedTools: get().autoAllowedTools,
autoAllowedToolsLoaded: get().autoAllowedToolsLoaded,
})
},
@@ -1245,6 +1286,16 @@ export const useCopilotStore = create<CopilotStore>()(
// Send a message (streaming only)
sendMessage: async (message: string, options = {}) => {
if (!get().autoAllowedToolsLoaded) {
try {
await get().loadAutoAllowedTools()
} catch (error) {
logger.warn('[Copilot] Failed to preload auto-allowed tools before send', {
error: error instanceof Error ? error.message : String(error),
})
}
}
const prepared = prepareSendContext(get, set, message, options as SendMessageOptionsInput)
if (!prepared) return
@@ -1848,6 +1899,8 @@ export const useCopilotStore = create<CopilotStore>()(
context.wasAborted && !context.suppressContinueOption
? appendContinueOption(finalContent)
: finalContentStripped
// Step 1: Update messages in state but keep isSendingMessage: true.
// This prevents loadChats from overwriting with stale DB data during persist.
set((state) => {
const snapshotId = state.currentUserMessageId
const nextSnapshots =
@@ -1868,9 +1921,7 @@ export const useCopilotStore = create<CopilotStore>()(
}
: msg
),
isSendingMessage: false,
isAborting: false,
abortController: null,
currentUserMessageId: null,
messageSnapshots: nextSnapshots,
}
@@ -1887,31 +1938,9 @@ export const useCopilotStore = create<CopilotStore>()(
await get().handleNewChatCreation(context.newChatId)
}
// Process next message in queue if any
const nextInQueue = get().messageQueue[0]
if (nextInQueue) {
// Use originalMessageId if available (from edit/resend), otherwise use queue entry id
const messageIdToUse = nextInQueue.originalMessageId || nextInQueue.id
logger.debug('[Queue] Processing next queued message', {
id: nextInQueue.id,
originalMessageId: nextInQueue.originalMessageId,
messageIdToUse,
queueLength: get().messageQueue.length,
})
// Remove from queue and send
get().removeFromQueue(nextInQueue.id)
// Use setTimeout to avoid blocking the current execution
setTimeout(() => {
get().sendMessage(nextInQueue.content, {
stream: true,
fileAttachments: nextInQueue.fileAttachments,
contexts: nextInQueue.contexts,
messageId: messageIdToUse,
})
}, QUEUE_PROCESS_DELAY_MS)
}
// Persist full message state (including contentBlocks), plan artifact, and config to database
// Step 2: Persist messages to DB BEFORE marking stream as done.
// loadChats checks isSendingMessage — while true it preserves in-memory messages.
// Persisting first ensures the DB is up-to-date before we allow overwrites.
const { currentChat, streamingPlanContent, mode, selectedModel } = get()
if (currentChat) {
try {
@@ -1964,6 +1993,34 @@ export const useCopilotStore = create<CopilotStore>()(
}
}
// Step 3: NOW mark stream as done. DB is up-to-date, so if loadChats
// overwrites messages it will use the persisted (correct) data.
set({ isSendingMessage: false, abortController: null })
// Process next message in queue if any
const nextInQueue = get().messageQueue[0]
if (nextInQueue) {
// Use originalMessageId if available (from edit/resend), otherwise use queue entry id
const messageIdToUse = nextInQueue.originalMessageId || nextInQueue.id
logger.debug('[Queue] Processing next queued message', {
id: nextInQueue.id,
originalMessageId: nextInQueue.originalMessageId,
messageIdToUse,
queueLength: get().messageQueue.length,
})
// Remove from queue and send
get().removeFromQueue(nextInQueue.id)
// Use setTimeout to avoid blocking the current execution
setTimeout(() => {
get().sendMessage(nextInQueue.content, {
stream: true,
fileAttachments: nextInQueue.fileAttachments,
contexts: nextInQueue.contexts,
messageId: messageIdToUse,
})
}, QUEUE_PROCESS_DELAY_MS)
}
// Invalidate subscription queries to update usage
setTimeout(() => {
const queryClient = getQueryClient()
@@ -2142,12 +2199,15 @@ export const useCopilotStore = create<CopilotStore>()(
if (res.ok) {
const data = await res.json()
const tools = data.autoAllowedTools ?? []
set({ autoAllowedTools: tools })
set({ autoAllowedTools: tools, autoAllowedToolsLoaded: true })
writeAutoAllowedToolsToStorage(tools)
logger.debug('[AutoAllowedTools] Loaded successfully', { count: tools.length, tools })
} else {
set({ autoAllowedToolsLoaded: true })
logger.warn('[AutoAllowedTools] Load failed with status', { status: res.status })
}
} catch (err) {
set({ autoAllowedToolsLoaded: true })
logger.error('[AutoAllowedTools] Failed to load', { error: err })
}
},
@@ -2164,7 +2224,9 @@ export const useCopilotStore = create<CopilotStore>()(
if (res.ok) {
const data = await res.json()
logger.debug('[AutoAllowedTools] API returned', { toolId, tools: data.autoAllowedTools })
set({ autoAllowedTools: data.autoAllowedTools ?? [] })
const tools = data.autoAllowedTools ?? []
set({ autoAllowedTools: tools, autoAllowedToolsLoaded: true })
writeAutoAllowedToolsToStorage(tools)
logger.debug('[AutoAllowedTools] Added tool to store', { toolId })
}
} catch (err) {
@@ -2182,7 +2244,9 @@ export const useCopilotStore = create<CopilotStore>()(
)
if (res.ok) {
const data = await res.json()
set({ autoAllowedTools: data.autoAllowedTools ?? [] })
const tools = data.autoAllowedTools ?? []
set({ autoAllowedTools: tools, autoAllowedToolsLoaded: true })
writeAutoAllowedToolsToStorage(tools)
logger.debug('[AutoAllowedTools] Removed tool', { toolId })
}
} catch (err) {
@@ -2192,7 +2256,7 @@ export const useCopilotStore = create<CopilotStore>()(
isToolAutoAllowed: (toolId: string) => {
const { autoAllowedTools } = get()
return autoAllowedTools.includes(toolId)
return isToolAutoAllowedByList(toolId, autoAllowedTools)
},
// Credential masking

View File

@@ -167,6 +167,7 @@ export interface CopilotState {
// Auto-allowed integration tools (tools that can run without confirmation)
autoAllowedTools: string[]
autoAllowedToolsLoaded: boolean
// Active stream metadata for reconnect/replay
activeStream: CopilotStreamInfo | null