mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-09 23:17:59 -05:00
Compare commits
5 Commits
main
...
fix/chat-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f67caf0798 | ||
|
|
0b853a7d95 | ||
|
|
4f31560a0e | ||
|
|
8b5027f2a6 | ||
|
|
c6c658a6e1 |
@@ -80,6 +80,10 @@ export function VoiceInterface({
|
||||
const currentStateRef = useRef<'idle' | 'listening' | 'agent_speaking'>('idle')
|
||||
const isCallEndedRef = useRef(false)
|
||||
|
||||
useEffect(() => {
|
||||
isCallEndedRef.current = false
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
currentStateRef.current = state
|
||||
}, [state])
|
||||
@@ -119,6 +123,8 @@ export function VoiceInterface({
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (isPlayingAudio && state !== 'agent_speaking') {
|
||||
clearResponseTimeout()
|
||||
setState('agent_speaking')
|
||||
@@ -139,6 +145,9 @@ export function VoiceInterface({
|
||||
}
|
||||
}
|
||||
} else if (!isPlayingAudio && state === 'agent_speaking') {
|
||||
// Don't unmute/restart if call has ended
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
setState('idle')
|
||||
setCurrentTranscript('')
|
||||
|
||||
@@ -226,6 +235,8 @@ export function VoiceInterface({
|
||||
recognition.onstart = () => {}
|
||||
|
||||
recognition.onresult = (event: SpeechRecognitionEvent) => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
const currentState = currentStateRef.current
|
||||
|
||||
if (isMutedRef.current || currentState !== 'listening') {
|
||||
@@ -303,6 +314,8 @@ export function VoiceInterface({
|
||||
}, [isSupported, onVoiceTranscript, setResponseTimeout])
|
||||
|
||||
const startListening = useCallback(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (!isInitialized || isMuted || state !== 'idle') {
|
||||
return
|
||||
}
|
||||
@@ -320,6 +333,9 @@ export function VoiceInterface({
|
||||
}, [isInitialized, isMuted, state])
|
||||
|
||||
const stopListening = useCallback(() => {
|
||||
// Don't process if call has ended
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
setState('idle')
|
||||
setCurrentTranscript('')
|
||||
|
||||
@@ -333,12 +349,15 @@ export function VoiceInterface({
|
||||
}, [])
|
||||
|
||||
const handleInterrupt = useCallback(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (state === 'agent_speaking') {
|
||||
onInterrupt?.()
|
||||
setState('listening')
|
||||
setCurrentTranscript('')
|
||||
|
||||
setIsMuted(false)
|
||||
isMutedRef.current = false
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getAudioTracks().forEach((track) => {
|
||||
track.enabled = true
|
||||
@@ -356,11 +375,22 @@ export function VoiceInterface({
|
||||
}, [state, onInterrupt])
|
||||
|
||||
const handleCallEnd = useCallback(() => {
|
||||
// Mark call as ended FIRST to prevent any effects from restarting recognition
|
||||
isCallEndedRef.current = true
|
||||
|
||||
// Set muted to true to prevent auto-start effect from triggering
|
||||
setIsMuted(true)
|
||||
isMutedRef.current = true
|
||||
|
||||
setState('idle')
|
||||
setCurrentTranscript('')
|
||||
setIsMuted(false)
|
||||
|
||||
// Immediately disable audio tracks to stop listening
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getAudioTracks().forEach((track) => {
|
||||
track.enabled = false
|
||||
})
|
||||
}
|
||||
|
||||
if (recognitionRef.current) {
|
||||
try {
|
||||
@@ -377,6 +407,8 @@ export function VoiceInterface({
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (event.code === 'Space') {
|
||||
event.preventDefault()
|
||||
handleInterrupt()
|
||||
@@ -388,6 +420,8 @@ export function VoiceInterface({
|
||||
}, [handleInterrupt])
|
||||
|
||||
const toggleMute = useCallback(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (state === 'agent_speaking') {
|
||||
handleInterrupt()
|
||||
return
|
||||
@@ -395,6 +429,7 @@ export function VoiceInterface({
|
||||
|
||||
const newMutedState = !isMuted
|
||||
setIsMuted(newMutedState)
|
||||
isMutedRef.current = newMutedState
|
||||
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getAudioTracks().forEach((track) => {
|
||||
@@ -417,6 +452,8 @@ export function VoiceInterface({
|
||||
}, [isSupported, setupSpeechRecognition, setupAudio])
|
||||
|
||||
useEffect(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (isInitialized && !isMuted && state === 'idle') {
|
||||
startListening()
|
||||
}
|
||||
|
||||
@@ -987,18 +987,21 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
try {
|
||||
const executionData = JSON.parse(executionDataHeader)
|
||||
|
||||
// If execution data contains full content, persist to memory
|
||||
if (ctx && inputs && executionData.output?.content) {
|
||||
const assistantMessage: Message = {
|
||||
role: 'assistant',
|
||||
content: executionData.output.content,
|
||||
}
|
||||
// Fire and forget - don't await
|
||||
memoryService
|
||||
.persistMemoryMessage(ctx, inputs, assistantMessage, block.id)
|
||||
.catch((error) =>
|
||||
logger.error('Failed to persist streaming response to memory:', error)
|
||||
// If execution data contains content or tool calls, persist to memory
|
||||
if (
|
||||
ctx &&
|
||||
inputs &&
|
||||
(executionData.output?.content || executionData.output?.toolCalls?.list?.length)
|
||||
) {
|
||||
const toolCalls = executionData.output?.toolCalls?.list
|
||||
const messages = this.buildMessagesForMemory(executionData.output.content, toolCalls)
|
||||
|
||||
// Fire and forget - don't await, persist all messages
|
||||
Promise.all(
|
||||
messages.map((message) =>
|
||||
memoryService.persistMemoryMessage(ctx, inputs, message, block.id)
|
||||
)
|
||||
).catch((error) => logger.error('Failed to persist streaming response to memory:', error))
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -1117,25 +1120,28 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract content from regular response
|
||||
// Extract content and tool calls from regular response
|
||||
const blockOutput = result as any
|
||||
const content = blockOutput?.content
|
||||
const toolCalls = blockOutput?.toolCalls?.list
|
||||
|
||||
if (!content || typeof content !== 'string') {
|
||||
// Build messages to persist
|
||||
const messages = this.buildMessagesForMemory(content, toolCalls)
|
||||
|
||||
if (messages.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const assistantMessage: Message = {
|
||||
role: 'assistant',
|
||||
content,
|
||||
// Persist all messages
|
||||
for (const message of messages) {
|
||||
await memoryService.persistMemoryMessage(ctx, inputs, message, blockId)
|
||||
}
|
||||
|
||||
await memoryService.persistMemoryMessage(ctx, inputs, assistantMessage, blockId)
|
||||
|
||||
logger.debug('Persisted assistant response to memory', {
|
||||
workflowId: ctx.workflowId,
|
||||
memoryType: inputs.memoryType,
|
||||
conversationId: inputs.conversationId,
|
||||
messageCount: messages.length,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist response to memory:', error)
|
||||
@@ -1143,6 +1149,69 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds messages for memory storage including tool calls and results
|
||||
* Returns proper OpenAI-compatible message format:
|
||||
* - Assistant message with tool_calls array (if tools were used)
|
||||
* - Tool role messages with results (one per tool call)
|
||||
* - Final assistant message with content (if present)
|
||||
*/
|
||||
private buildMessagesForMemory(
|
||||
content: string | undefined,
|
||||
toolCalls: any[] | undefined
|
||||
): Message[] {
|
||||
const messages: Message[] = []
|
||||
|
||||
if (toolCalls?.length) {
|
||||
// Generate stable IDs for each tool call (only if not provided by provider)
|
||||
// Use index to ensure uniqueness even for same tool name in same millisecond
|
||||
const toolCallsWithIds = toolCalls.map((tc: any, index: number) => ({
|
||||
...tc,
|
||||
_stableId:
|
||||
tc.id ||
|
||||
`call_${tc.name}_${Date.now()}_${index}_${Math.random().toString(36).slice(2, 7)}`,
|
||||
}))
|
||||
|
||||
// Add assistant message with tool_calls
|
||||
const formattedToolCalls = toolCallsWithIds.map((tc: any) => ({
|
||||
id: tc._stableId,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.rawArguments || JSON.stringify(tc.arguments || {}),
|
||||
},
|
||||
}))
|
||||
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: formattedToolCalls,
|
||||
})
|
||||
|
||||
// Add tool result messages using the same stable IDs
|
||||
for (const tc of toolCallsWithIds) {
|
||||
const resultContent =
|
||||
typeof tc.result === 'string' ? tc.result : JSON.stringify(tc.result || {})
|
||||
messages.push({
|
||||
role: 'tool',
|
||||
content: resultContent,
|
||||
tool_call_id: tc._stableId,
|
||||
name: tc.name, // Store tool name for providers that need it (e.g., Google/Gemini)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add final assistant response if present
|
||||
if (content && typeof content === 'string') {
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content,
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
private processProviderResponse(
|
||||
response: any,
|
||||
block: SerializedBlock,
|
||||
|
||||
@@ -32,7 +32,7 @@ describe('Memory', () => {
|
||||
})
|
||||
|
||||
describe('applySlidingWindow (message-based)', () => {
|
||||
it('should keep last N conversation messages', () => {
|
||||
it('should keep last N turns (turn = user message + assistant response)', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'system', content: 'System prompt' },
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
@@ -43,9 +43,10 @@ describe('Memory', () => {
|
||||
{ role: 'assistant', content: 'Response 3' },
|
||||
]
|
||||
|
||||
const result = (memoryService as any).applySlidingWindow(messages, '4')
|
||||
// Limit to 2 turns: should keep turns 2 and 3
|
||||
const result = (memoryService as any).applySlidingWindow(messages, '2')
|
||||
|
||||
expect(result.length).toBe(5)
|
||||
expect(result.length).toBe(5) // system + 2 turns (4 messages)
|
||||
expect(result[0].role).toBe('system')
|
||||
expect(result[0].content).toBe('System prompt')
|
||||
expect(result[1].content).toBe('Message 2')
|
||||
@@ -113,19 +114,18 @@ describe('Memory', () => {
|
||||
it('should preserve first system message and exclude it from token count', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'system', content: 'A' }, // System message - always preserved
|
||||
{ role: 'user', content: 'B' }, // ~1 token
|
||||
{ role: 'assistant', content: 'C' }, // ~1 token
|
||||
{ role: 'user', content: 'D' }, // ~1 token
|
||||
{ role: 'user', content: 'B' }, // ~1 token (turn 1)
|
||||
{ role: 'assistant', content: 'C' }, // ~1 token (turn 1)
|
||||
{ role: 'user', content: 'D' }, // ~1 token (turn 2)
|
||||
]
|
||||
|
||||
// Limit to 2 tokens - should fit system message + last 2 conversation messages (D, C)
|
||||
// Limit to 2 tokens - fits turn 2 (D=1 token), but turn 1 (B+C=2 tokens) would exceed
|
||||
const result = (memoryService as any).applySlidingWindowByTokens(messages, '2', 'gpt-4o')
|
||||
|
||||
// Should have: system message + 2 conversation messages = 3 total
|
||||
expect(result.length).toBe(3)
|
||||
// Should have: system message + turn 2 (1 message) = 2 total
|
||||
expect(result.length).toBe(2)
|
||||
expect(result[0].role).toBe('system') // First system message preserved
|
||||
expect(result[1].content).toBe('C') // Second most recent conversation message
|
||||
expect(result[2].content).toBe('D') // Most recent conversation message
|
||||
expect(result[1].content).toBe('D') // Most recent turn
|
||||
})
|
||||
|
||||
it('should process messages from newest to oldest', () => {
|
||||
@@ -249,29 +249,29 @@ describe('Memory', () => {
|
||||
})
|
||||
|
||||
describe('Token-based vs Message-based comparison', () => {
|
||||
it('should produce different results for same message count limit', () => {
|
||||
it('should produce different results based on turn limits vs token limits', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'user', content: 'A' }, // Short message (~1 token)
|
||||
{ role: 'user', content: 'A' }, // Short message (~1 token) - turn 1
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'This is a much longer response that takes many more tokens',
|
||||
}, // Long message (~15 tokens)
|
||||
{ role: 'user', content: 'B' }, // Short message (~1 token)
|
||||
}, // Long message (~15 tokens) - turn 1
|
||||
{ role: 'user', content: 'B' }, // Short message (~1 token) - turn 2
|
||||
]
|
||||
|
||||
// Message-based: last 2 messages
|
||||
const messageResult = (memoryService as any).applySlidingWindow(messages, '2')
|
||||
expect(messageResult.length).toBe(2)
|
||||
// Turn-based with limit 1: keeps last turn only
|
||||
const messageResult = (memoryService as any).applySlidingWindow(messages, '1')
|
||||
expect(messageResult.length).toBe(1) // Only turn 2 (message B)
|
||||
|
||||
// Token-based: with limit of 10 tokens, might fit all 3 messages or just last 2
|
||||
// Token-based: with limit of 10 tokens, fits turn 2 (1 token) but not turn 1 (~16 tokens)
|
||||
const tokenResult = (memoryService as any).applySlidingWindowByTokens(
|
||||
messages,
|
||||
'10',
|
||||
'gpt-4o'
|
||||
)
|
||||
|
||||
// The long message should affect what fits
|
||||
expect(tokenResult.length).toBeGreaterThanOrEqual(1)
|
||||
// Both should only fit the last turn due to the long assistant message
|
||||
expect(tokenResult.length).toBe(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -202,13 +202,51 @@ export class Memory {
|
||||
const systemMessages = messages.filter((msg) => msg.role === 'system')
|
||||
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
|
||||
|
||||
const recentMessages = conversationMessages.slice(-limit)
|
||||
// Group messages into conversation turns
|
||||
// A turn = user message + any tool calls/results + assistant response
|
||||
const turns = this.groupMessagesIntoTurns(conversationMessages)
|
||||
|
||||
// Take the last N turns
|
||||
const recentTurns = turns.slice(-limit)
|
||||
|
||||
// Flatten back to messages
|
||||
const recentMessages = recentTurns.flat()
|
||||
|
||||
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
|
||||
|
||||
return [...firstSystemMessage, ...recentMessages]
|
||||
}
|
||||
|
||||
/**
|
||||
* Groups messages into conversation turns.
|
||||
* A turn starts with a user message and includes all subsequent messages
|
||||
* until the next user message (tool calls, tool results, assistant response).
|
||||
*/
|
||||
private groupMessagesIntoTurns(messages: Message[]): Message[][] {
|
||||
const turns: Message[][] = []
|
||||
let currentTurn: Message[] = []
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === 'user') {
|
||||
// Start a new turn
|
||||
if (currentTurn.length > 0) {
|
||||
turns.push(currentTurn)
|
||||
}
|
||||
currentTurn = [msg]
|
||||
} else {
|
||||
// Add to current turn (assistant, tool, etc.)
|
||||
currentTurn.push(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Don't forget the last turn
|
||||
if (currentTurn.length > 0) {
|
||||
turns.push(currentTurn)
|
||||
}
|
||||
|
||||
return turns
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply token-based sliding window to limit conversation by token count
|
||||
*
|
||||
@@ -216,6 +254,11 @@ export class Memory {
|
||||
* - For consistency with message-based sliding window, the first system message is preserved
|
||||
* - System messages are excluded from the token count
|
||||
* - This ensures system prompts are always available while limiting conversation history
|
||||
*
|
||||
* Turn handling:
|
||||
* - Messages are grouped into turns (user + tool calls/results + assistant response)
|
||||
* - Complete turns are added to stay within token limit
|
||||
* - This prevents breaking tool call/result pairs
|
||||
*/
|
||||
private applySlidingWindowByTokens(
|
||||
messages: Message[],
|
||||
@@ -233,25 +276,31 @@ export class Memory {
|
||||
const systemMessages = messages.filter((msg) => msg.role === 'system')
|
||||
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
|
||||
|
||||
// Group into turns to keep tool call/result pairs together
|
||||
const turns = this.groupMessagesIntoTurns(conversationMessages)
|
||||
|
||||
const result: Message[] = []
|
||||
let currentTokenCount = 0
|
||||
|
||||
// Add conversation messages from most recent backwards
|
||||
for (let i = conversationMessages.length - 1; i >= 0; i--) {
|
||||
const message = conversationMessages[i]
|
||||
const messageTokens = getAccurateTokenCount(message.content, model)
|
||||
// Add turns from most recent backwards
|
||||
for (let i = turns.length - 1; i >= 0; i--) {
|
||||
const turn = turns[i]
|
||||
const turnTokens = turn.reduce(
|
||||
(sum, msg) => sum + getAccurateTokenCount(msg.content || '', model),
|
||||
0
|
||||
)
|
||||
|
||||
if (currentTokenCount + messageTokens <= tokenLimit) {
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
if (currentTokenCount + turnTokens <= tokenLimit) {
|
||||
result.unshift(...turn)
|
||||
currentTokenCount += turnTokens
|
||||
} else if (result.length === 0) {
|
||||
logger.warn('Single message exceeds token limit, including anyway', {
|
||||
messageTokens,
|
||||
logger.warn('Single turn exceeds token limit, including anyway', {
|
||||
turnTokens,
|
||||
tokenLimit,
|
||||
messageRole: message.role,
|
||||
turnMessages: turn.length,
|
||||
})
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
result.unshift(...turn)
|
||||
currentTokenCount += turnTokens
|
||||
break
|
||||
} else {
|
||||
// Token limit reached, stop processing
|
||||
@@ -259,17 +308,20 @@ export class Memory {
|
||||
}
|
||||
}
|
||||
|
||||
// No need to remove orphaned messages - turns are already complete
|
||||
const cleanedResult = result
|
||||
|
||||
logger.debug('Applied token-based sliding window', {
|
||||
totalMessages: messages.length,
|
||||
conversationMessages: conversationMessages.length,
|
||||
includedMessages: result.length,
|
||||
includedMessages: cleanedResult.length,
|
||||
totalTokens: currentTokenCount,
|
||||
tokenLimit,
|
||||
})
|
||||
|
||||
// Preserve first system message and prepend to results (consistent with message-based window)
|
||||
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
|
||||
return [...firstSystemMessage, ...result]
|
||||
return [...firstSystemMessage, ...cleanedResult]
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -324,7 +376,7 @@ export class Memory {
|
||||
// Count tokens used by system messages first
|
||||
let systemTokenCount = 0
|
||||
for (const msg of systemMessages) {
|
||||
systemTokenCount += getAccurateTokenCount(msg.content, model)
|
||||
systemTokenCount += getAccurateTokenCount(msg.content || '', model)
|
||||
}
|
||||
|
||||
// Calculate remaining tokens available for conversation messages
|
||||
@@ -339,30 +391,36 @@ export class Memory {
|
||||
return systemMessages
|
||||
}
|
||||
|
||||
// Group into turns to keep tool call/result pairs together
|
||||
const turns = this.groupMessagesIntoTurns(conversationMessages)
|
||||
|
||||
const result: Message[] = []
|
||||
let currentTokenCount = 0
|
||||
|
||||
for (let i = conversationMessages.length - 1; i >= 0; i--) {
|
||||
const message = conversationMessages[i]
|
||||
const messageTokens = getAccurateTokenCount(message.content, model)
|
||||
for (let i = turns.length - 1; i >= 0; i--) {
|
||||
const turn = turns[i]
|
||||
const turnTokens = turn.reduce(
|
||||
(sum, msg) => sum + getAccurateTokenCount(msg.content || '', model),
|
||||
0
|
||||
)
|
||||
|
||||
if (currentTokenCount + messageTokens <= remainingTokens) {
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
if (currentTokenCount + turnTokens <= remainingTokens) {
|
||||
result.unshift(...turn)
|
||||
currentTokenCount += turnTokens
|
||||
} else if (result.length === 0) {
|
||||
logger.warn('Single message exceeds remaining context window, including anyway', {
|
||||
messageTokens,
|
||||
logger.warn('Single turn exceeds remaining context window, including anyway', {
|
||||
turnTokens,
|
||||
remainingTokens,
|
||||
systemTokenCount,
|
||||
messageRole: message.role,
|
||||
turnMessages: turn.length,
|
||||
})
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
result.unshift(...turn)
|
||||
currentTokenCount += turnTokens
|
||||
break
|
||||
} else {
|
||||
logger.info('Auto-trimmed conversation history to fit context window', {
|
||||
originalMessages: conversationMessages.length,
|
||||
trimmedMessages: result.length,
|
||||
originalTurns: turns.length,
|
||||
trimmedTurns: turns.length - i - 1,
|
||||
conversationTokens: currentTokenCount,
|
||||
systemTokens: systemTokenCount,
|
||||
totalTokens: currentTokenCount + systemTokenCount,
|
||||
@@ -372,6 +430,7 @@ export class Memory {
|
||||
}
|
||||
}
|
||||
|
||||
// No need to remove orphaned messages - turns are already complete
|
||||
return [...systemMessages, ...result]
|
||||
}
|
||||
|
||||
@@ -638,7 +697,7 @@ export class Memory {
|
||||
/**
|
||||
* Validate inputs to prevent malicious data or performance issues
|
||||
*/
|
||||
private validateInputs(conversationId?: string, content?: string): void {
|
||||
private validateInputs(conversationId?: string, content?: string | null): void {
|
||||
if (conversationId) {
|
||||
if (conversationId.length > 255) {
|
||||
throw new Error('Conversation ID too long (max 255 characters)')
|
||||
|
||||
@@ -37,10 +37,22 @@ export interface ToolInput {
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
role: 'system' | 'user' | 'assistant'
|
||||
content: string
|
||||
role: 'system' | 'user' | 'assistant' | 'tool'
|
||||
content: string | null
|
||||
function_call?: any
|
||||
tool_calls?: any[]
|
||||
tool_calls?: ToolCallMessage[]
|
||||
tool_call_id?: string
|
||||
/** Tool name for tool role messages (used by providers like Google/Gemini) */
|
||||
name?: string
|
||||
}
|
||||
|
||||
export interface ToolCallMessage {
|
||||
id: string
|
||||
type: 'function'
|
||||
function: {
|
||||
name: string
|
||||
arguments: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface StreamingConfig {
|
||||
|
||||
@@ -4,7 +4,12 @@ import type { StreamingExecution } from '@/executor/types'
|
||||
import { executeTool } from '@/tools'
|
||||
import { getProviderDefaultModel, getProviderModels } from '../models'
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils'
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '../utils'
|
||||
|
||||
const logger = createLogger('AnthropicProvider')
|
||||
|
||||
@@ -68,8 +73,12 @@ export const anthropicProvider: ProviderConfig = {
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
request.messages.forEach((msg) => {
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
|
||||
sanitizedMessages.forEach((msg) => {
|
||||
if (msg.role === 'function') {
|
||||
// Legacy function role format
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: [
|
||||
@@ -80,7 +89,41 @@ export const anthropicProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
})
|
||||
} else if (msg.role === 'tool') {
|
||||
// Modern tool role format (OpenAI-compatible)
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: (msg as any).tool_call_id,
|
||||
content: msg.content || '',
|
||||
},
|
||||
],
|
||||
})
|
||||
} else if (msg.tool_calls && Array.isArray(msg.tool_calls)) {
|
||||
// Modern tool_calls format (OpenAI-compatible)
|
||||
const toolUseContent = msg.tool_calls.map((tc: any) => ({
|
||||
type: 'tool_use',
|
||||
id: tc.id,
|
||||
name: tc.function?.name || tc.name,
|
||||
input:
|
||||
typeof tc.function?.arguments === 'string'
|
||||
? (() => {
|
||||
try {
|
||||
return JSON.parse(tc.function.arguments)
|
||||
} catch {
|
||||
return {}
|
||||
}
|
||||
})()
|
||||
: tc.function?.arguments || tc.arguments || {},
|
||||
}))
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: toolUseContent,
|
||||
})
|
||||
} else if (msg.function_call) {
|
||||
// Legacy function_call format
|
||||
const toolUseId = `${msg.function_call.name}-${Date.now()}`
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
@@ -490,9 +533,14 @@ ${fieldDescriptions}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the original tool use ID from the API response
|
||||
const toolUseId = toolUse.id || generateToolUseId(toolName)
|
||||
|
||||
toolCalls.push({
|
||||
id: toolUseId,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: JSON.stringify(toolArgs),
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
@@ -501,7 +549,6 @@ ${fieldDescriptions}
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
const toolUseId = generateToolUseId(toolName)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
@@ -840,9 +887,14 @@ ${fieldDescriptions}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the original tool use ID from the API response
|
||||
const toolUseId = toolUse.id || generateToolUseId(toolName)
|
||||
|
||||
toolCalls.push({
|
||||
id: toolUseId,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: JSON.stringify(toolArgs),
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
@@ -851,7 +903,6 @@ ${fieldDescriptions}
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
const toolUseId = generateToolUseId(toolName)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -120,9 +121,10 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
// Transform tools to Azure OpenAI format if provided
|
||||
@@ -417,8 +419,10 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -86,9 +87,10 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
// Transform tools to Cerebras format if provided
|
||||
@@ -323,8 +325,10 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -84,9 +85,10 @@ export const deepseekProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
@@ -323,8 +325,10 @@ export const deepseekProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -10,6 +10,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -552,9 +553,14 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a unique ID for this tool call (Google doesn't provide one)
|
||||
const toolCallId = `call_${toolName}_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: JSON.stringify(toolArgs),
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
@@ -1087,9 +1093,10 @@ function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents.push({ role: 'user', parts: [{ text: request.context }] })
|
||||
}
|
||||
|
||||
// Process messages
|
||||
// Process messages (sanitized to ensure proper tool call/result pairing)
|
||||
if (request.messages && request.messages.length > 0) {
|
||||
for (const message of request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
for (const message of sanitizedMessages) {
|
||||
if (message.role === 'system') {
|
||||
// Add to system instruction
|
||||
if (!systemInstruction) {
|
||||
@@ -1119,10 +1126,28 @@ function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents.push({ role: 'model', parts: functionCalls })
|
||||
}
|
||||
} else if (message.role === 'tool') {
|
||||
// Convert tool response (Gemini only accepts user/model roles)
|
||||
// Convert tool response to Gemini's functionResponse format
|
||||
// Gemini uses 'user' role for function responses
|
||||
const functionName = (message as any).name || 'function'
|
||||
|
||||
let responseData: any
|
||||
try {
|
||||
responseData =
|
||||
typeof message.content === 'string' ? JSON.parse(message.content) : message.content
|
||||
} catch {
|
||||
responseData = { result: message.content }
|
||||
}
|
||||
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [{ text: `Function result: ${message.content}` }],
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: functionName,
|
||||
response: responseData,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -75,9 +76,10 @@ export const groqProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
// Transform tools to function format if provided
|
||||
@@ -296,8 +298,10 @@ export const groqProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -100,8 +101,10 @@ export const mistralProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
@@ -355,8 +358,10 @@ export const mistralProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -9,7 +9,7 @@ import type {
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { prepareToolExecution } from '@/providers/utils'
|
||||
import { prepareToolExecution, sanitizeMessagesForProvider } from '@/providers/utils'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
@@ -126,9 +126,10 @@ export const ollamaProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
@@ -407,8 +408,10 @@ export const ollamaProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -103,9 +104,10 @@ export const openaiProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
@@ -398,8 +400,10 @@ export const openaiProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -93,8 +94,10 @@ export const openRouterProvider: ProviderConfig = {
|
||||
allMessages.push({ role: 'user', content: request.context })
|
||||
}
|
||||
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
@@ -303,8 +306,10 @@ export const openRouterProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -1049,3 +1049,96 @@ export function prepareToolExecution(
|
||||
|
||||
return { toolParams, executionParams }
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes messages array to ensure proper tool call/result pairing
|
||||
* This prevents provider errors like "tool_result without corresponding tool_use"
|
||||
*
|
||||
* Rules enforced:
|
||||
* 1. Every tool message must have a matching tool_calls message before it
|
||||
* 2. Every tool_calls in an assistant message should have corresponding tool results
|
||||
* 3. Messages maintain their original order
|
||||
*/
|
||||
export function sanitizeMessagesForProvider(
|
||||
messages: Array<{
|
||||
role: string
|
||||
content?: string | null
|
||||
tool_calls?: Array<{ id: string; [key: string]: any }>
|
||||
tool_call_id?: string
|
||||
[key: string]: any
|
||||
}>
|
||||
): typeof messages {
|
||||
if (!messages || messages.length === 0) {
|
||||
return messages
|
||||
}
|
||||
|
||||
// Build a map of tool_call IDs to their positions
|
||||
const toolCallIdToIndex = new Map<string, number>()
|
||||
const toolResultIds = new Set<string>()
|
||||
|
||||
// First pass: collect all tool_call IDs and tool result IDs
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i]
|
||||
|
||||
if (msg.tool_calls && Array.isArray(msg.tool_calls)) {
|
||||
for (const tc of msg.tool_calls) {
|
||||
if (tc.id) {
|
||||
toolCallIdToIndex.set(tc.id, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (msg.role === 'tool' && msg.tool_call_id) {
|
||||
toolResultIds.add(msg.tool_call_id)
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: filter messages
|
||||
const result: typeof messages = []
|
||||
|
||||
for (const msg of messages) {
|
||||
// For tool messages: only include if there's a matching tool_calls before it
|
||||
if (msg.role === 'tool') {
|
||||
const toolCallId = msg.tool_call_id
|
||||
if (toolCallId && toolCallIdToIndex.has(toolCallId)) {
|
||||
result.push(msg)
|
||||
} else {
|
||||
logger.debug('Removing orphaned tool message', { toolCallId })
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For assistant messages with tool_calls: only include tool_calls that have results
|
||||
if (msg.role === 'assistant' && msg.tool_calls && Array.isArray(msg.tool_calls)) {
|
||||
const validToolCalls = msg.tool_calls.filter((tc) => tc.id && toolResultIds.has(tc.id))
|
||||
|
||||
if (validToolCalls.length === 0) {
|
||||
// No valid tool calls - if there's content, keep as regular message
|
||||
if (msg.content) {
|
||||
const { tool_calls, ...msgWithoutToolCalls } = msg
|
||||
result.push(msgWithoutToolCalls)
|
||||
} else {
|
||||
logger.debug('Removing assistant message with orphaned tool_calls', {
|
||||
toolCallIds: msg.tool_calls.map((tc) => tc.id),
|
||||
})
|
||||
}
|
||||
} else if (validToolCalls.length === msg.tool_calls.length) {
|
||||
// All tool calls are valid
|
||||
result.push(msg)
|
||||
} else {
|
||||
// Some tool calls are orphaned - keep only valid ones
|
||||
result.push({ ...msg, tool_calls: validToolCalls })
|
||||
logger.debug('Filtered orphaned tool_calls from message', {
|
||||
original: msg.tool_calls.length,
|
||||
kept: validToolCalls.length,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// All other messages pass through
|
||||
result.push(msg)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
@@ -140,8 +141,10 @@ export const vllmProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
@@ -400,8 +403,10 @@ export const vllmProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -83,8 +84,10 @@ export const xAIProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
}
|
||||
|
||||
// Set up tools
|
||||
@@ -364,8 +367,10 @@ export const xAIProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
Reference in New Issue
Block a user