Compare commits

...

5 Commits

Author SHA1 Message Date
Siddharth Ganesan
f67caf0798 Fix tests 2025-12-15 10:56:33 -08:00
Siddharth Ganesan
0b853a7d95 Fixes 2025-12-15 10:33:41 -08:00
Vikhyath Mondreti
4f31560a0e use isCallEndRef correctly 2025-12-13 12:50:24 -08:00
Vikhyath Mondreti
8b5027f2a6 Merge branch 'staging' into fix/chat-tools 2025-12-13 12:34:12 -08:00
Siddharth Ganesan
c6c658a6e1 Fix chat tools 2025-12-13 11:56:56 -08:00
18 changed files with 487 additions and 98 deletions

View File

@@ -80,6 +80,10 @@ export function VoiceInterface({
const currentStateRef = useRef<'idle' | 'listening' | 'agent_speaking'>('idle')
const isCallEndedRef = useRef(false)
useEffect(() => {
isCallEndedRef.current = false
}, [])
useEffect(() => {
currentStateRef.current = state
}, [state])
@@ -119,6 +123,8 @@ export function VoiceInterface({
}, [])
useEffect(() => {
if (isCallEndedRef.current) return
if (isPlayingAudio && state !== 'agent_speaking') {
clearResponseTimeout()
setState('agent_speaking')
@@ -139,6 +145,9 @@ export function VoiceInterface({
}
}
} else if (!isPlayingAudio && state === 'agent_speaking') {
// Don't unmute/restart if call has ended
if (isCallEndedRef.current) return
setState('idle')
setCurrentTranscript('')
@@ -226,6 +235,8 @@ export function VoiceInterface({
recognition.onstart = () => {}
recognition.onresult = (event: SpeechRecognitionEvent) => {
if (isCallEndedRef.current) return
const currentState = currentStateRef.current
if (isMutedRef.current || currentState !== 'listening') {
@@ -303,6 +314,8 @@ export function VoiceInterface({
}, [isSupported, onVoiceTranscript, setResponseTimeout])
const startListening = useCallback(() => {
if (isCallEndedRef.current) return
if (!isInitialized || isMuted || state !== 'idle') {
return
}
@@ -320,6 +333,9 @@ export function VoiceInterface({
}, [isInitialized, isMuted, state])
const stopListening = useCallback(() => {
// Don't process if call has ended
if (isCallEndedRef.current) return
setState('idle')
setCurrentTranscript('')
@@ -333,12 +349,15 @@ export function VoiceInterface({
}, [])
const handleInterrupt = useCallback(() => {
if (isCallEndedRef.current) return
if (state === 'agent_speaking') {
onInterrupt?.()
setState('listening')
setCurrentTranscript('')
setIsMuted(false)
isMutedRef.current = false
if (mediaStreamRef.current) {
mediaStreamRef.current.getAudioTracks().forEach((track) => {
track.enabled = true
@@ -356,11 +375,22 @@ export function VoiceInterface({
}, [state, onInterrupt])
const handleCallEnd = useCallback(() => {
// Mark call as ended FIRST to prevent any effects from restarting recognition
isCallEndedRef.current = true
// Set muted to true to prevent auto-start effect from triggering
setIsMuted(true)
isMutedRef.current = true
setState('idle')
setCurrentTranscript('')
setIsMuted(false)
// Immediately disable audio tracks to stop listening
if (mediaStreamRef.current) {
mediaStreamRef.current.getAudioTracks().forEach((track) => {
track.enabled = false
})
}
if (recognitionRef.current) {
try {
@@ -377,6 +407,8 @@ export function VoiceInterface({
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (isCallEndedRef.current) return
if (event.code === 'Space') {
event.preventDefault()
handleInterrupt()
@@ -388,6 +420,8 @@ export function VoiceInterface({
}, [handleInterrupt])
const toggleMute = useCallback(() => {
if (isCallEndedRef.current) return
if (state === 'agent_speaking') {
handleInterrupt()
return
@@ -395,6 +429,7 @@ export function VoiceInterface({
const newMutedState = !isMuted
setIsMuted(newMutedState)
isMutedRef.current = newMutedState
if (mediaStreamRef.current) {
mediaStreamRef.current.getAudioTracks().forEach((track) => {
@@ -417,6 +452,8 @@ export function VoiceInterface({
}, [isSupported, setupSpeechRecognition, setupAudio])
useEffect(() => {
if (isCallEndedRef.current) return
if (isInitialized && !isMuted && state === 'idle') {
startListening()
}

View File

@@ -987,18 +987,21 @@ export class AgentBlockHandler implements BlockHandler {
try {
const executionData = JSON.parse(executionDataHeader)
// If execution data contains full content, persist to memory
if (ctx && inputs && executionData.output?.content) {
const assistantMessage: Message = {
role: 'assistant',
content: executionData.output.content,
}
// Fire and forget - don't await
memoryService
.persistMemoryMessage(ctx, inputs, assistantMessage, block.id)
.catch((error) =>
logger.error('Failed to persist streaming response to memory:', error)
// If execution data contains content or tool calls, persist to memory
if (
ctx &&
inputs &&
(executionData.output?.content || executionData.output?.toolCalls?.list?.length)
) {
const toolCalls = executionData.output?.toolCalls?.list
const messages = this.buildMessagesForMemory(executionData.output.content, toolCalls)
// Fire and forget - don't await, persist all messages
Promise.all(
messages.map((message) =>
memoryService.persistMemoryMessage(ctx, inputs, message, block.id)
)
).catch((error) => logger.error('Failed to persist streaming response to memory:', error))
}
return {
@@ -1117,25 +1120,28 @@ export class AgentBlockHandler implements BlockHandler {
return
}
// Extract content from regular response
// Extract content and tool calls from regular response
const blockOutput = result as any
const content = blockOutput?.content
const toolCalls = blockOutput?.toolCalls?.list
if (!content || typeof content !== 'string') {
// Build messages to persist
const messages = this.buildMessagesForMemory(content, toolCalls)
if (messages.length === 0) {
return
}
const assistantMessage: Message = {
role: 'assistant',
content,
// Persist all messages
for (const message of messages) {
await memoryService.persistMemoryMessage(ctx, inputs, message, blockId)
}
await memoryService.persistMemoryMessage(ctx, inputs, assistantMessage, blockId)
logger.debug('Persisted assistant response to memory', {
workflowId: ctx.workflowId,
memoryType: inputs.memoryType,
conversationId: inputs.conversationId,
messageCount: messages.length,
})
} catch (error) {
logger.error('Failed to persist response to memory:', error)
@@ -1143,6 +1149,69 @@ export class AgentBlockHandler implements BlockHandler {
}
}
/**
* Builds messages for memory storage including tool calls and results
* Returns proper OpenAI-compatible message format:
* - Assistant message with tool_calls array (if tools were used)
* - Tool role messages with results (one per tool call)
* - Final assistant message with content (if present)
*/
private buildMessagesForMemory(
content: string | undefined,
toolCalls: any[] | undefined
): Message[] {
const messages: Message[] = []
if (toolCalls?.length) {
// Generate stable IDs for each tool call (only if not provided by provider)
// Use index to ensure uniqueness even for same tool name in same millisecond
const toolCallsWithIds = toolCalls.map((tc: any, index: number) => ({
...tc,
_stableId:
tc.id ||
`call_${tc.name}_${Date.now()}_${index}_${Math.random().toString(36).slice(2, 7)}`,
}))
// Add assistant message with tool_calls
const formattedToolCalls = toolCallsWithIds.map((tc: any) => ({
id: tc._stableId,
type: 'function' as const,
function: {
name: tc.name,
arguments: tc.rawArguments || JSON.stringify(tc.arguments || {}),
},
}))
messages.push({
role: 'assistant',
content: null,
tool_calls: formattedToolCalls,
})
// Add tool result messages using the same stable IDs
for (const tc of toolCallsWithIds) {
const resultContent =
typeof tc.result === 'string' ? tc.result : JSON.stringify(tc.result || {})
messages.push({
role: 'tool',
content: resultContent,
tool_call_id: tc._stableId,
name: tc.name, // Store tool name for providers that need it (e.g., Google/Gemini)
})
}
}
// Add final assistant response if present
if (content && typeof content === 'string') {
messages.push({
role: 'assistant',
content,
})
}
return messages
}
private processProviderResponse(
response: any,
block: SerializedBlock,

View File

@@ -32,7 +32,7 @@ describe('Memory', () => {
})
describe('applySlidingWindow (message-based)', () => {
it('should keep last N conversation messages', () => {
it('should keep last N turns (turn = user message + assistant response)', () => {
const messages: Message[] = [
{ role: 'system', content: 'System prompt' },
{ role: 'user', content: 'Message 1' },
@@ -43,9 +43,10 @@ describe('Memory', () => {
{ role: 'assistant', content: 'Response 3' },
]
const result = (memoryService as any).applySlidingWindow(messages, '4')
// Limit to 2 turns: should keep turns 2 and 3
const result = (memoryService as any).applySlidingWindow(messages, '2')
expect(result.length).toBe(5)
expect(result.length).toBe(5) // system + 2 turns (4 messages)
expect(result[0].role).toBe('system')
expect(result[0].content).toBe('System prompt')
expect(result[1].content).toBe('Message 2')
@@ -113,19 +114,18 @@ describe('Memory', () => {
it('should preserve first system message and exclude it from token count', () => {
const messages: Message[] = [
{ role: 'system', content: 'A' }, // System message - always preserved
{ role: 'user', content: 'B' }, // ~1 token
{ role: 'assistant', content: 'C' }, // ~1 token
{ role: 'user', content: 'D' }, // ~1 token
{ role: 'user', content: 'B' }, // ~1 token (turn 1)
{ role: 'assistant', content: 'C' }, // ~1 token (turn 1)
{ role: 'user', content: 'D' }, // ~1 token (turn 2)
]
// Limit to 2 tokens - should fit system message + last 2 conversation messages (D, C)
// Limit to 2 tokens - fits turn 2 (D=1 token), but turn 1 (B+C=2 tokens) would exceed
const result = (memoryService as any).applySlidingWindowByTokens(messages, '2', 'gpt-4o')
// Should have: system message + 2 conversation messages = 3 total
expect(result.length).toBe(3)
// Should have: system message + turn 2 (1 message) = 2 total
expect(result.length).toBe(2)
expect(result[0].role).toBe('system') // First system message preserved
expect(result[1].content).toBe('C') // Second most recent conversation message
expect(result[2].content).toBe('D') // Most recent conversation message
expect(result[1].content).toBe('D') // Most recent turn
})
it('should process messages from newest to oldest', () => {
@@ -249,29 +249,29 @@ describe('Memory', () => {
})
describe('Token-based vs Message-based comparison', () => {
it('should produce different results for same message count limit', () => {
it('should produce different results based on turn limits vs token limits', () => {
const messages: Message[] = [
{ role: 'user', content: 'A' }, // Short message (~1 token)
{ role: 'user', content: 'A' }, // Short message (~1 token) - turn 1
{
role: 'assistant',
content: 'This is a much longer response that takes many more tokens',
}, // Long message (~15 tokens)
{ role: 'user', content: 'B' }, // Short message (~1 token)
}, // Long message (~15 tokens) - turn 1
{ role: 'user', content: 'B' }, // Short message (~1 token) - turn 2
]
// Message-based: last 2 messages
const messageResult = (memoryService as any).applySlidingWindow(messages, '2')
expect(messageResult.length).toBe(2)
// Turn-based with limit 1: keeps last turn only
const messageResult = (memoryService as any).applySlidingWindow(messages, '1')
expect(messageResult.length).toBe(1) // Only turn 2 (message B)
// Token-based: with limit of 10 tokens, might fit all 3 messages or just last 2
// Token-based: with limit of 10 tokens, fits turn 2 (1 token) but not turn 1 (~16 tokens)
const tokenResult = (memoryService as any).applySlidingWindowByTokens(
messages,
'10',
'gpt-4o'
)
// The long message should affect what fits
expect(tokenResult.length).toBeGreaterThanOrEqual(1)
// Both should only fit the last turn due to the long assistant message
expect(tokenResult.length).toBe(1)
})
})
})

View File

@@ -202,13 +202,51 @@ export class Memory {
const systemMessages = messages.filter((msg) => msg.role === 'system')
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
const recentMessages = conversationMessages.slice(-limit)
// Group messages into conversation turns
// A turn = user message + any tool calls/results + assistant response
const turns = this.groupMessagesIntoTurns(conversationMessages)
// Take the last N turns
const recentTurns = turns.slice(-limit)
// Flatten back to messages
const recentMessages = recentTurns.flat()
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
return [...firstSystemMessage, ...recentMessages]
}
/**
* Groups messages into conversation turns.
* A turn starts with a user message and includes all subsequent messages
* until the next user message (tool calls, tool results, assistant response).
*/
private groupMessagesIntoTurns(messages: Message[]): Message[][] {
const turns: Message[][] = []
let currentTurn: Message[] = []
for (const msg of messages) {
if (msg.role === 'user') {
// Start a new turn
if (currentTurn.length > 0) {
turns.push(currentTurn)
}
currentTurn = [msg]
} else {
// Add to current turn (assistant, tool, etc.)
currentTurn.push(msg)
}
}
// Don't forget the last turn
if (currentTurn.length > 0) {
turns.push(currentTurn)
}
return turns
}
/**
* Apply token-based sliding window to limit conversation by token count
*
@@ -216,6 +254,11 @@ export class Memory {
* - For consistency with message-based sliding window, the first system message is preserved
* - System messages are excluded from the token count
* - This ensures system prompts are always available while limiting conversation history
*
* Turn handling:
* - Messages are grouped into turns (user + tool calls/results + assistant response)
* - Complete turns are added to stay within token limit
* - This prevents breaking tool call/result pairs
*/
private applySlidingWindowByTokens(
messages: Message[],
@@ -233,25 +276,31 @@ export class Memory {
const systemMessages = messages.filter((msg) => msg.role === 'system')
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
// Group into turns to keep tool call/result pairs together
const turns = this.groupMessagesIntoTurns(conversationMessages)
const result: Message[] = []
let currentTokenCount = 0
// Add conversation messages from most recent backwards
for (let i = conversationMessages.length - 1; i >= 0; i--) {
const message = conversationMessages[i]
const messageTokens = getAccurateTokenCount(message.content, model)
// Add turns from most recent backwards
for (let i = turns.length - 1; i >= 0; i--) {
const turn = turns[i]
const turnTokens = turn.reduce(
(sum, msg) => sum + getAccurateTokenCount(msg.content || '', model),
0
)
if (currentTokenCount + messageTokens <= tokenLimit) {
result.unshift(message)
currentTokenCount += messageTokens
if (currentTokenCount + turnTokens <= tokenLimit) {
result.unshift(...turn)
currentTokenCount += turnTokens
} else if (result.length === 0) {
logger.warn('Single message exceeds token limit, including anyway', {
messageTokens,
logger.warn('Single turn exceeds token limit, including anyway', {
turnTokens,
tokenLimit,
messageRole: message.role,
turnMessages: turn.length,
})
result.unshift(message)
currentTokenCount += messageTokens
result.unshift(...turn)
currentTokenCount += turnTokens
break
} else {
// Token limit reached, stop processing
@@ -259,17 +308,20 @@ export class Memory {
}
}
// No need to remove orphaned messages - turns are already complete
const cleanedResult = result
logger.debug('Applied token-based sliding window', {
totalMessages: messages.length,
conversationMessages: conversationMessages.length,
includedMessages: result.length,
includedMessages: cleanedResult.length,
totalTokens: currentTokenCount,
tokenLimit,
})
// Preserve first system message and prepend to results (consistent with message-based window)
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
return [...firstSystemMessage, ...result]
return [...firstSystemMessage, ...cleanedResult]
}
/**
@@ -324,7 +376,7 @@ export class Memory {
// Count tokens used by system messages first
let systemTokenCount = 0
for (const msg of systemMessages) {
systemTokenCount += getAccurateTokenCount(msg.content, model)
systemTokenCount += getAccurateTokenCount(msg.content || '', model)
}
// Calculate remaining tokens available for conversation messages
@@ -339,30 +391,36 @@ export class Memory {
return systemMessages
}
// Group into turns to keep tool call/result pairs together
const turns = this.groupMessagesIntoTurns(conversationMessages)
const result: Message[] = []
let currentTokenCount = 0
for (let i = conversationMessages.length - 1; i >= 0; i--) {
const message = conversationMessages[i]
const messageTokens = getAccurateTokenCount(message.content, model)
for (let i = turns.length - 1; i >= 0; i--) {
const turn = turns[i]
const turnTokens = turn.reduce(
(sum, msg) => sum + getAccurateTokenCount(msg.content || '', model),
0
)
if (currentTokenCount + messageTokens <= remainingTokens) {
result.unshift(message)
currentTokenCount += messageTokens
if (currentTokenCount + turnTokens <= remainingTokens) {
result.unshift(...turn)
currentTokenCount += turnTokens
} else if (result.length === 0) {
logger.warn('Single message exceeds remaining context window, including anyway', {
messageTokens,
logger.warn('Single turn exceeds remaining context window, including anyway', {
turnTokens,
remainingTokens,
systemTokenCount,
messageRole: message.role,
turnMessages: turn.length,
})
result.unshift(message)
currentTokenCount += messageTokens
result.unshift(...turn)
currentTokenCount += turnTokens
break
} else {
logger.info('Auto-trimmed conversation history to fit context window', {
originalMessages: conversationMessages.length,
trimmedMessages: result.length,
originalTurns: turns.length,
trimmedTurns: turns.length - i - 1,
conversationTokens: currentTokenCount,
systemTokens: systemTokenCount,
totalTokens: currentTokenCount + systemTokenCount,
@@ -372,6 +430,7 @@ export class Memory {
}
}
// No need to remove orphaned messages - turns are already complete
return [...systemMessages, ...result]
}
@@ -638,7 +697,7 @@ export class Memory {
/**
* Validate inputs to prevent malicious data or performance issues
*/
private validateInputs(conversationId?: string, content?: string): void {
private validateInputs(conversationId?: string, content?: string | null): void {
if (conversationId) {
if (conversationId.length > 255) {
throw new Error('Conversation ID too long (max 255 characters)')

View File

@@ -37,10 +37,22 @@ export interface ToolInput {
}
export interface Message {
role: 'system' | 'user' | 'assistant'
content: string
role: 'system' | 'user' | 'assistant' | 'tool'
content: string | null
function_call?: any
tool_calls?: any[]
tool_calls?: ToolCallMessage[]
tool_call_id?: string
/** Tool name for tool role messages (used by providers like Google/Gemini) */
name?: string
}
export interface ToolCallMessage {
id: string
type: 'function'
function: {
name: string
arguments: string
}
}
export interface StreamingConfig {

View File

@@ -4,7 +4,12 @@ import type { StreamingExecution } from '@/executor/types'
import { executeTool } from '@/tools'
import { getProviderDefaultModel, getProviderModels } from '../models'
import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types'
import { prepareToolExecution, prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils'
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '../utils'
const logger = createLogger('AnthropicProvider')
@@ -68,8 +73,12 @@ export const anthropicProvider: ProviderConfig = {
// Add remaining messages
if (request.messages) {
request.messages.forEach((msg) => {
// Sanitize messages to ensure proper tool call/result pairing
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
sanitizedMessages.forEach((msg) => {
if (msg.role === 'function') {
// Legacy function role format
messages.push({
role: 'user',
content: [
@@ -80,7 +89,41 @@ export const anthropicProvider: ProviderConfig = {
},
],
})
} else if (msg.role === 'tool') {
// Modern tool role format (OpenAI-compatible)
messages.push({
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: (msg as any).tool_call_id,
content: msg.content || '',
},
],
})
} else if (msg.tool_calls && Array.isArray(msg.tool_calls)) {
// Modern tool_calls format (OpenAI-compatible)
const toolUseContent = msg.tool_calls.map((tc: any) => ({
type: 'tool_use',
id: tc.id,
name: tc.function?.name || tc.name,
input:
typeof tc.function?.arguments === 'string'
? (() => {
try {
return JSON.parse(tc.function.arguments)
} catch {
return {}
}
})()
: tc.function?.arguments || tc.arguments || {},
}))
messages.push({
role: 'assistant',
content: toolUseContent,
})
} else if (msg.function_call) {
// Legacy function_call format
const toolUseId = `${msg.function_call.name}-${Date.now()}`
messages.push({
role: 'assistant',
@@ -490,9 +533,14 @@ ${fieldDescriptions}
}
}
// Use the original tool use ID from the API response
const toolUseId = toolUse.id || generateToolUseId(toolName)
toolCalls.push({
id: toolUseId,
name: toolName,
arguments: toolParams,
rawArguments: JSON.stringify(toolArgs),
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
@@ -501,7 +549,6 @@ ${fieldDescriptions}
})
// Add the tool call and result to messages (both success and failure)
const toolUseId = generateToolUseId(toolName)
currentMessages.push({
role: 'assistant',
@@ -840,9 +887,14 @@ ${fieldDescriptions}
}
}
// Use the original tool use ID from the API response
const toolUseId = toolUse.id || generateToolUseId(toolName)
toolCalls.push({
id: toolUseId,
name: toolName,
arguments: toolParams,
rawArguments: JSON.stringify(toolArgs),
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
@@ -851,7 +903,6 @@ ${fieldDescriptions}
})
// Add the tool call and result to messages (both success and failure)
const toolUseId = generateToolUseId(toolName)
currentMessages.push({
role: 'assistant',

View File

@@ -12,6 +12,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -120,9 +121,10 @@ export const azureOpenAIProvider: ProviderConfig = {
})
}
// Add remaining messages
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
// Transform tools to Azure OpenAI format if provided
@@ -417,8 +419,10 @@ export const azureOpenAIProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -11,6 +11,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -86,9 +87,10 @@ export const cerebrasProvider: ProviderConfig = {
})
}
// Add remaining messages
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
// Transform tools to Cerebras format if provided
@@ -323,8 +325,10 @@ export const cerebrasProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -11,6 +11,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -84,9 +85,10 @@ export const deepseekProvider: ProviderConfig = {
})
}
// Add remaining messages
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
// Transform tools to OpenAI format if provided
@@ -323,8 +325,10 @@ export const deepseekProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -10,6 +10,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -552,9 +553,14 @@ export const googleProvider: ProviderConfig = {
}
}
// Generate a unique ID for this tool call (Google doesn't provide one)
const toolCallId = `call_${toolName}_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`
toolCalls.push({
id: toolCallId,
name: toolName,
arguments: toolParams,
rawArguments: JSON.stringify(toolArgs),
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
@@ -1087,9 +1093,10 @@ function convertToGeminiFormat(request: ProviderRequest): {
contents.push({ role: 'user', parts: [{ text: request.context }] })
}
// Process messages
// Process messages (sanitized to ensure proper tool call/result pairing)
if (request.messages && request.messages.length > 0) {
for (const message of request.messages) {
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
for (const message of sanitizedMessages) {
if (message.role === 'system') {
// Add to system instruction
if (!systemInstruction) {
@@ -1119,10 +1126,28 @@ function convertToGeminiFormat(request: ProviderRequest): {
contents.push({ role: 'model', parts: functionCalls })
}
} else if (message.role === 'tool') {
// Convert tool response (Gemini only accepts user/model roles)
// Convert tool response to Gemini's functionResponse format
// Gemini uses 'user' role for function responses
const functionName = (message as any).name || 'function'
let responseData: any
try {
responseData =
typeof message.content === 'string' ? JSON.parse(message.content) : message.content
} catch {
responseData = { result: message.content }
}
contents.push({
role: 'user',
parts: [{ text: `Function result: ${message.content}` }],
parts: [
{
functionResponse: {
name: functionName,
response: responseData,
},
},
],
})
}
}

View File

@@ -11,6 +11,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -75,9 +76,10 @@ export const groqProvider: ProviderConfig = {
})
}
// Add remaining messages
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
// Transform tools to function format if provided
@@ -296,8 +298,10 @@ export const groqProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -11,6 +11,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -100,8 +101,10 @@ export const mistralProvider: ProviderConfig = {
})
}
// Sanitize messages to ensure proper tool call/result pairing
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
const tools = request.tools?.length
@@ -355,8 +358,10 @@ export const mistralProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -9,7 +9,7 @@ import type {
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { prepareToolExecution } from '@/providers/utils'
import { prepareToolExecution, sanitizeMessagesForProvider } from '@/providers/utils'
import { useProvidersStore } from '@/stores/providers/store'
import { executeTool } from '@/tools'
@@ -126,9 +126,10 @@ export const ollamaProvider: ProviderConfig = {
})
}
// Add remaining messages
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
// Transform tools to OpenAI format if provided
@@ -407,8 +408,10 @@ export const ollamaProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -11,6 +11,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -103,9 +104,10 @@ export const openaiProvider: ProviderConfig = {
})
}
// Add remaining messages
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
// Transform tools to OpenAI format if provided
@@ -398,8 +400,10 @@ export const openaiProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -11,6 +11,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -93,8 +94,10 @@ export const openRouterProvider: ProviderConfig = {
allMessages.push({ role: 'user', content: request.context })
}
// Sanitize messages to ensure proper tool call/result pairing
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
const tools = request.tools?.length
@@ -303,8 +306,10 @@ export const openRouterProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -1049,3 +1049,96 @@ export function prepareToolExecution(
return { toolParams, executionParams }
}
/**
* Sanitizes messages array to ensure proper tool call/result pairing
* This prevents provider errors like "tool_result without corresponding tool_use"
*
* Rules enforced:
* 1. Every tool message must have a matching tool_calls message before it
* 2. Every tool_calls in an assistant message should have corresponding tool results
* 3. Messages maintain their original order
*/
export function sanitizeMessagesForProvider(
messages: Array<{
role: string
content?: string | null
tool_calls?: Array<{ id: string; [key: string]: any }>
tool_call_id?: string
[key: string]: any
}>
): typeof messages {
if (!messages || messages.length === 0) {
return messages
}
// Build a map of tool_call IDs to their positions
const toolCallIdToIndex = new Map<string, number>()
const toolResultIds = new Set<string>()
// First pass: collect all tool_call IDs and tool result IDs
for (let i = 0; i < messages.length; i++) {
const msg = messages[i]
if (msg.tool_calls && Array.isArray(msg.tool_calls)) {
for (const tc of msg.tool_calls) {
if (tc.id) {
toolCallIdToIndex.set(tc.id, i)
}
}
}
if (msg.role === 'tool' && msg.tool_call_id) {
toolResultIds.add(msg.tool_call_id)
}
}
// Second pass: filter messages
const result: typeof messages = []
for (const msg of messages) {
// For tool messages: only include if there's a matching tool_calls before it
if (msg.role === 'tool') {
const toolCallId = msg.tool_call_id
if (toolCallId && toolCallIdToIndex.has(toolCallId)) {
result.push(msg)
} else {
logger.debug('Removing orphaned tool message', { toolCallId })
}
continue
}
// For assistant messages with tool_calls: only include tool_calls that have results
if (msg.role === 'assistant' && msg.tool_calls && Array.isArray(msg.tool_calls)) {
const validToolCalls = msg.tool_calls.filter((tc) => tc.id && toolResultIds.has(tc.id))
if (validToolCalls.length === 0) {
// No valid tool calls - if there's content, keep as regular message
if (msg.content) {
const { tool_calls, ...msgWithoutToolCalls } = msg
result.push(msgWithoutToolCalls)
} else {
logger.debug('Removing assistant message with orphaned tool_calls', {
toolCallIds: msg.tool_calls.map((tc) => tc.id),
})
}
} else if (validToolCalls.length === msg.tool_calls.length) {
// All tool calls are valid
result.push(msg)
} else {
// Some tool calls are orphaned - keep only valid ones
result.push({ ...msg, tool_calls: validToolCalls })
logger.debug('Filtered orphaned tool_calls from message', {
original: msg.tool_calls.length,
kept: validToolCalls.length,
})
}
continue
}
// All other messages pass through
result.push(msg)
}
return result
}

View File

@@ -12,6 +12,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { useProvidersStore } from '@/stores/providers/store'
@@ -140,8 +141,10 @@ export const vllmProvider: ProviderConfig = {
})
}
// Sanitize messages to ensure proper tool call/result pairing
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
const tools = request.tools?.length
@@ -400,8 +403,10 @@ export const vllmProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,

View File

@@ -11,6 +11,7 @@ import type {
import {
prepareToolExecution,
prepareToolsWithUsageControl,
sanitizeMessagesForProvider,
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
@@ -83,8 +84,10 @@ export const xAIProvider: ProviderConfig = {
})
}
// Sanitize messages to ensure proper tool call/result pairing
if (request.messages) {
allMessages.push(...request.messages)
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
allMessages.push(...sanitizedMessages)
}
// Set up tools
@@ -364,8 +367,10 @@ export const xAIProvider: ProviderConfig = {
}
toolCalls.push({
id: toolCall.id,
name: toolName,
arguments: toolParams,
rawArguments: toolCall.function.arguments,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,