mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-11 07:58:06 -05:00
Compare commits
24 Commits
fix/chat-t
...
v0.5.28
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f526c36fc0 | ||
|
|
e24f31cbce | ||
|
|
3fbd57caf1 | ||
|
|
b5da61377c | ||
|
|
18b7032494 | ||
|
|
b7bbef8620 | ||
|
|
52edbea659 | ||
|
|
d480057fd3 | ||
|
|
c27c233da0 | ||
|
|
ebef5f3a27 | ||
|
|
12c4c2d44f | ||
|
|
929a352edb | ||
|
|
6cd078b0fe | ||
|
|
31874939ee | ||
|
|
e157ce5fbc | ||
|
|
774e5d585c | ||
|
|
54cc93743f | ||
|
|
8c32ad4c0d | ||
|
|
1d08796853 | ||
|
|
ebcd243942 | ||
|
|
b7e814b721 | ||
|
|
842ef27ed9 | ||
|
|
31c34b2ea3 | ||
|
|
8f0ef58056 |
File diff suppressed because one or more lines are too long
@@ -119,116 +119,116 @@ import {
|
||||
type IconComponent = ComponentType<SVGProps<SVGSVGElement>>
|
||||
|
||||
export const blockTypeToIconMap: Record<string, IconComponent> = {
|
||||
zoom: ZoomIcon,
|
||||
zep: ZepIcon,
|
||||
calendly: CalendlyIcon,
|
||||
mailchimp: MailchimpIcon,
|
||||
postgresql: PostgresIcon,
|
||||
twilio_voice: TwilioIcon,
|
||||
elasticsearch: ElasticsearchIcon,
|
||||
rds: RDSIcon,
|
||||
translate: TranslateIcon,
|
||||
dynamodb: DynamoDBIcon,
|
||||
wordpress: WordpressIcon,
|
||||
tavily: TavilyIcon,
|
||||
zendesk: ZendeskIcon,
|
||||
youtube: YouTubeIcon,
|
||||
x: xIcon,
|
||||
wordpress: WordpressIcon,
|
||||
wikipedia: WikipediaIcon,
|
||||
whatsapp: WhatsAppIcon,
|
||||
webflow: WebflowIcon,
|
||||
wealthbox: WealthboxIcon,
|
||||
vision: EyeIcon,
|
||||
video_generator: VideoIcon,
|
||||
typeform: TypeformIcon,
|
||||
twilio_voice: TwilioIcon,
|
||||
twilio_sms: TwilioIcon,
|
||||
tts: TTSIcon,
|
||||
trello: TrelloIcon,
|
||||
translate: TranslateIcon,
|
||||
thinking: BrainIcon,
|
||||
telegram: TelegramIcon,
|
||||
tavily: TavilyIcon,
|
||||
supabase: SupabaseIcon,
|
||||
stt: STTIcon,
|
||||
stripe: StripeIcon,
|
||||
stagehand: StagehandIcon,
|
||||
ssh: SshIcon,
|
||||
sqs: SQSIcon,
|
||||
spotify: SpotifyIcon,
|
||||
smtp: SmtpIcon,
|
||||
slack: SlackIcon,
|
||||
shopify: ShopifyIcon,
|
||||
sharepoint: MicrosoftSharepointIcon,
|
||||
sftp: SftpIcon,
|
||||
serper: SerperIcon,
|
||||
sentry: SentryIcon,
|
||||
sendgrid: SendgridIcon,
|
||||
search: SearchIcon,
|
||||
salesforce: SalesforceIcon,
|
||||
s3: S3Icon,
|
||||
resend: ResendIcon,
|
||||
reddit: RedditIcon,
|
||||
rds: RDSIcon,
|
||||
qdrant: QdrantIcon,
|
||||
posthog: PosthogIcon,
|
||||
postgresql: PostgresIcon,
|
||||
polymarket: PolymarketIcon,
|
||||
pipedrive: PipedriveIcon,
|
||||
pinecone: PineconeIcon,
|
||||
perplexity: PerplexityIcon,
|
||||
parallel_ai: ParallelIcon,
|
||||
outlook: OutlookIcon,
|
||||
openai: OpenAIIcon,
|
||||
onedrive: MicrosoftOneDriveIcon,
|
||||
notion: NotionIcon,
|
||||
neo4j: Neo4jIcon,
|
||||
mysql: MySQLIcon,
|
||||
mongodb: MongoDBIcon,
|
||||
mistral_parse: MistralIcon,
|
||||
microsoft_teams: MicrosoftTeamsIcon,
|
||||
microsoft_planner: MicrosoftPlannerIcon,
|
||||
microsoft_excel: MicrosoftExcelIcon,
|
||||
memory: BrainIcon,
|
||||
mem0: Mem0Icon,
|
||||
mailgun: MailgunIcon,
|
||||
mailchimp: MailchimpIcon,
|
||||
linkup: LinkupIcon,
|
||||
linkedin: LinkedInIcon,
|
||||
linear: LinearIcon,
|
||||
knowledge: PackageSearchIcon,
|
||||
kalshi: KalshiIcon,
|
||||
jira: JiraIcon,
|
||||
jina: JinaAIIcon,
|
||||
intercom: IntercomIcon,
|
||||
incidentio: IncidentioIcon,
|
||||
image_generator: ImageIcon,
|
||||
hunter: HunterIOIcon,
|
||||
huggingface: HuggingFaceIcon,
|
||||
hubspot: HubspotIcon,
|
||||
grafana: GrafanaIcon,
|
||||
google_vault: GoogleVaultIcon,
|
||||
google_slides: GoogleSlidesIcon,
|
||||
google_sheets: GoogleSheetsIcon,
|
||||
google_groups: GoogleGroupsIcon,
|
||||
google_forms: GoogleFormsIcon,
|
||||
google_drive: GoogleDriveIcon,
|
||||
google_docs: GoogleDocsIcon,
|
||||
google_calendar: GoogleCalendarIcon,
|
||||
google_search: GoogleIcon,
|
||||
gmail: GmailIcon,
|
||||
gitlab: GitLabIcon,
|
||||
github: GithubIcon,
|
||||
firecrawl: FirecrawlIcon,
|
||||
file: DocumentIcon,
|
||||
exa: ExaAIIcon,
|
||||
elevenlabs: ElevenLabsIcon,
|
||||
elasticsearch: ElasticsearchIcon,
|
||||
dynamodb: DynamoDBIcon,
|
||||
duckduckgo: DuckDuckGoIcon,
|
||||
dropbox: DropboxIcon,
|
||||
discord: DiscordIcon,
|
||||
datadog: DatadogIcon,
|
||||
cursor: CursorIcon,
|
||||
vision: EyeIcon,
|
||||
zoom: ZoomIcon,
|
||||
confluence: ConfluenceIcon,
|
||||
clay: ClayIcon,
|
||||
calendly: CalendlyIcon,
|
||||
browser_use: BrowserUseIcon,
|
||||
asana: AsanaIcon,
|
||||
arxiv: ArxivIcon,
|
||||
webflow: WebflowIcon,
|
||||
pinecone: PineconeIcon,
|
||||
apollo: ApolloIcon,
|
||||
whatsapp: WhatsAppIcon,
|
||||
typeform: TypeformIcon,
|
||||
qdrant: QdrantIcon,
|
||||
shopify: ShopifyIcon,
|
||||
asana: AsanaIcon,
|
||||
sqs: SQSIcon,
|
||||
apify: ApifyIcon,
|
||||
memory: BrainIcon,
|
||||
gitlab: GitLabIcon,
|
||||
polymarket: PolymarketIcon,
|
||||
serper: SerperIcon,
|
||||
linear: LinearIcon,
|
||||
exa: ExaAIIcon,
|
||||
telegram: TelegramIcon,
|
||||
salesforce: SalesforceIcon,
|
||||
hubspot: HubspotIcon,
|
||||
hunter: HunterIOIcon,
|
||||
linkup: LinkupIcon,
|
||||
mongodb: MongoDBIcon,
|
||||
airtable: AirtableIcon,
|
||||
discord: DiscordIcon,
|
||||
ahrefs: AhrefsIcon,
|
||||
neo4j: Neo4jIcon,
|
||||
tts: TTSIcon,
|
||||
jina: JinaAIIcon,
|
||||
google_docs: GoogleDocsIcon,
|
||||
perplexity: PerplexityIcon,
|
||||
google_search: GoogleIcon,
|
||||
x: xIcon,
|
||||
kalshi: KalshiIcon,
|
||||
google_calendar: GoogleCalendarIcon,
|
||||
zep: ZepIcon,
|
||||
posthog: PosthogIcon,
|
||||
grafana: GrafanaIcon,
|
||||
google_slides: GoogleSlidesIcon,
|
||||
microsoft_planner: MicrosoftPlannerIcon,
|
||||
thinking: BrainIcon,
|
||||
pipedrive: PipedriveIcon,
|
||||
dropbox: DropboxIcon,
|
||||
stagehand: StagehandIcon,
|
||||
google_forms: GoogleFormsIcon,
|
||||
file: DocumentIcon,
|
||||
mistral_parse: MistralIcon,
|
||||
gmail: GmailIcon,
|
||||
openai: OpenAIIcon,
|
||||
outlook: OutlookIcon,
|
||||
incidentio: IncidentioIcon,
|
||||
onedrive: MicrosoftOneDriveIcon,
|
||||
resend: ResendIcon,
|
||||
google_vault: GoogleVaultIcon,
|
||||
sharepoint: MicrosoftSharepointIcon,
|
||||
huggingface: HuggingFaceIcon,
|
||||
sendgrid: SendgridIcon,
|
||||
video_generator: VideoIcon,
|
||||
smtp: SmtpIcon,
|
||||
google_groups: GoogleGroupsIcon,
|
||||
mailgun: MailgunIcon,
|
||||
clay: ClayIcon,
|
||||
jira: JiraIcon,
|
||||
search: SearchIcon,
|
||||
linkedin: LinkedInIcon,
|
||||
wealthbox: WealthboxIcon,
|
||||
notion: NotionIcon,
|
||||
elevenlabs: ElevenLabsIcon,
|
||||
microsoft_teams: MicrosoftTeamsIcon,
|
||||
github: GithubIcon,
|
||||
sftp: SftpIcon,
|
||||
ssh: SshIcon,
|
||||
google_drive: GoogleDriveIcon,
|
||||
sentry: SentryIcon,
|
||||
reddit: RedditIcon,
|
||||
parallel_ai: ParallelIcon,
|
||||
spotify: SpotifyIcon,
|
||||
stripe: StripeIcon,
|
||||
s3: S3Icon,
|
||||
trello: TrelloIcon,
|
||||
mem0: Mem0Icon,
|
||||
knowledge: PackageSearchIcon,
|
||||
intercom: IntercomIcon,
|
||||
twilio_sms: TwilioIcon,
|
||||
duckduckgo: DuckDuckGoIcon,
|
||||
slack: SlackIcon,
|
||||
datadog: DatadogIcon,
|
||||
microsoft_excel: MicrosoftExcelIcon,
|
||||
image_generator: ImageIcon,
|
||||
google_sheets: GoogleSheetsIcon,
|
||||
wikipedia: WikipediaIcon,
|
||||
cursor: CursorIcon,
|
||||
firecrawl: FirecrawlIcon,
|
||||
mysql: MySQLIcon,
|
||||
browser_use: BrowserUseIcon,
|
||||
stt: STTIcon,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const DEFAULT_STARS = '19.4k'
|
||||
const DEFAULT_STARS = '18.6k'
|
||||
|
||||
const logger = createLogger('GitHubStars')
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ export async function POST(
|
||||
if ((password || email) && !input) {
|
||||
const response = addCorsHeaders(createSuccessResponse({ authenticated: true }), request)
|
||||
|
||||
setChatAuthCookie(response, deployment.id, deployment.authType, deployment.password)
|
||||
setChatAuthCookie(response, deployment.id, deployment.authType)
|
||||
|
||||
return response
|
||||
}
|
||||
@@ -315,7 +315,7 @@ export async function GET(
|
||||
if (
|
||||
deployment.authType !== 'public' &&
|
||||
authCookie &&
|
||||
validateAuthToken(authCookie.value, deployment.id, deployment.password)
|
||||
validateAuthToken(authCookie.value, deployment.id)
|
||||
) {
|
||||
return addCorsHeaders(
|
||||
createSuccessResponse({
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { createHash } from 'crypto'
|
||||
import { db } from '@sim/db'
|
||||
import { chat, workflow } from '@sim/db/schema'
|
||||
import { eq } from 'drizzle-orm'
|
||||
@@ -10,10 +9,6 @@ import { hasAdminPermission } from '@/lib/workspaces/permissions/utils'
|
||||
|
||||
const logger = createLogger('ChatAuthUtils')
|
||||
|
||||
function hashPassword(encryptedPassword: string): string {
|
||||
return createHash('sha256').update(encryptedPassword).digest('hex').substring(0, 8)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if user has permission to create a chat for a specific workflow
|
||||
* Either the user owns the workflow directly OR has admin permission for the workflow's workspace
|
||||
@@ -82,20 +77,14 @@ export async function checkChatAccess(
|
||||
return { hasAccess: false }
|
||||
}
|
||||
|
||||
function encryptAuthToken(chatId: string, type: string, encryptedPassword?: string | null): string {
|
||||
const pwHash = encryptedPassword ? hashPassword(encryptedPassword) : ''
|
||||
return Buffer.from(`${chatId}:${type}:${Date.now()}:${pwHash}`).toString('base64')
|
||||
const encryptAuthToken = (chatId: string, type: string): string => {
|
||||
return Buffer.from(`${chatId}:${type}:${Date.now()}`).toString('base64')
|
||||
}
|
||||
|
||||
export function validateAuthToken(
|
||||
token: string,
|
||||
chatId: string,
|
||||
encryptedPassword?: string | null
|
||||
): boolean {
|
||||
export const validateAuthToken = (token: string, chatId: string): boolean => {
|
||||
try {
|
||||
const decoded = Buffer.from(token, 'base64').toString()
|
||||
const parts = decoded.split(':')
|
||||
const [storedId, _type, timestamp, storedPwHash] = parts
|
||||
const [storedId, _type, timestamp] = decoded.split(':')
|
||||
|
||||
if (storedId !== chatId) {
|
||||
return false
|
||||
@@ -103,32 +92,20 @@ export function validateAuthToken(
|
||||
|
||||
const createdAt = Number.parseInt(timestamp)
|
||||
const now = Date.now()
|
||||
const expireTime = 24 * 60 * 60 * 1000
|
||||
const expireTime = 24 * 60 * 60 * 1000 // 24 hours
|
||||
|
||||
if (now - createdAt > expireTime) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (encryptedPassword) {
|
||||
const currentPwHash = hashPassword(encryptedPassword)
|
||||
if (storedPwHash !== currentPwHash) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (_e) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export function setChatAuthCookie(
|
||||
response: NextResponse,
|
||||
chatId: string,
|
||||
type: string,
|
||||
encryptedPassword?: string | null
|
||||
): void {
|
||||
const token = encryptAuthToken(chatId, type, encryptedPassword)
|
||||
export const setChatAuthCookie = (response: NextResponse, chatId: string, type: string): void => {
|
||||
const token = encryptAuthToken(chatId, type)
|
||||
response.cookies.set({
|
||||
name: `chat_auth_${chatId}`,
|
||||
value: token,
|
||||
@@ -136,7 +113,7 @@ export function setChatAuthCookie(
|
||||
secure: !isDev,
|
||||
sameSite: 'lax',
|
||||
path: '/',
|
||||
maxAge: 60 * 60 * 24,
|
||||
maxAge: 60 * 60 * 24, // 24 hours
|
||||
})
|
||||
}
|
||||
|
||||
@@ -168,7 +145,7 @@ export async function validateChatAuth(
|
||||
const cookieName = `chat_auth_${deployment.id}`
|
||||
const authCookie = request.cookies.get(cookieName)
|
||||
|
||||
if (authCookie && validateAuthToken(authCookie.value, deployment.id, deployment.password)) {
|
||||
if (authCookie && validateAuthToken(authCookie.value, deployment.id)) {
|
||||
return { authorized: true }
|
||||
}
|
||||
|
||||
|
||||
@@ -1,81 +1,26 @@
|
||||
import { db } from '@sim/db'
|
||||
import { chat } from '@sim/db/schema'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import type { NextRequest } from 'next/server'
|
||||
import { checkHybridAuth } from '@/lib/auth/hybrid'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { validateAlphanumericId } from '@/lib/core/security/input-validation'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { validateAuthToken } from '@/app/api/chat/utils'
|
||||
|
||||
const logger = createLogger('ProxyTTSStreamAPI')
|
||||
|
||||
/**
|
||||
* Validates chat-based authentication for deployed chat voice mode
|
||||
* Checks if the user has a valid chat auth cookie for the given chatId
|
||||
*/
|
||||
async function validateChatAuth(request: NextRequest, chatId: string): Promise<boolean> {
|
||||
try {
|
||||
const chatResult = await db
|
||||
.select({
|
||||
id: chat.id,
|
||||
isActive: chat.isActive,
|
||||
authType: chat.authType,
|
||||
password: chat.password,
|
||||
})
|
||||
.from(chat)
|
||||
.where(eq(chat.id, chatId))
|
||||
.limit(1)
|
||||
|
||||
if (chatResult.length === 0 || !chatResult[0].isActive) {
|
||||
logger.warn('Chat not found or inactive for TTS auth:', chatId)
|
||||
return false
|
||||
}
|
||||
|
||||
const chatData = chatResult[0]
|
||||
|
||||
if (chatData.authType === 'public') {
|
||||
return true
|
||||
}
|
||||
|
||||
const cookieName = `chat_auth_${chatId}`
|
||||
const authCookie = request.cookies.get(cookieName)
|
||||
|
||||
if (authCookie && validateAuthToken(authCookie.value, chatId, chatData.password)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.error('Error validating chat auth for TTS:', error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
let body: any
|
||||
try {
|
||||
body = await request.json()
|
||||
} catch {
|
||||
return new Response('Invalid request body', { status: 400 })
|
||||
const authResult = await checkHybridAuth(request, { requireWorkflowId: false })
|
||||
if (!authResult.success) {
|
||||
logger.error('Authentication failed for TTS stream proxy:', authResult.error)
|
||||
return new Response('Unauthorized', { status: 401 })
|
||||
}
|
||||
|
||||
const { text, voiceId, modelId = 'eleven_turbo_v2_5', chatId } = body
|
||||
|
||||
if (!chatId) {
|
||||
return new Response('chatId is required', { status: 400 })
|
||||
}
|
||||
const body = await request.json()
|
||||
const { text, voiceId, modelId = 'eleven_turbo_v2_5' } = body
|
||||
|
||||
if (!text || !voiceId) {
|
||||
return new Response('Missing required parameters', { status: 400 })
|
||||
}
|
||||
|
||||
const isChatAuthed = await validateChatAuth(request, chatId)
|
||||
if (!isChatAuthed) {
|
||||
logger.warn('Chat authentication failed for TTS, chatId:', chatId)
|
||||
return new Response('Unauthorized', { status: 401 })
|
||||
}
|
||||
|
||||
const voiceIdValidation = validateAlphanumericId(voiceId, 'voiceId', 255)
|
||||
if (!voiceIdValidation.isValid) {
|
||||
logger.error(`Invalid voice ID: ${voiceIdValidation.error}`)
|
||||
|
||||
@@ -23,13 +23,13 @@ export async function GET() {
|
||||
|
||||
if (!response.ok) {
|
||||
console.warn('GitHub API request failed:', response.status)
|
||||
return NextResponse.json({ stars: formatStarCount(19400) })
|
||||
return NextResponse.json({ stars: formatStarCount(14500) })
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return NextResponse.json({ stars: formatStarCount(Number(data?.stargazers_count ?? 19400)) })
|
||||
return NextResponse.json({ stars: formatStarCount(Number(data?.stargazers_count ?? 14500)) })
|
||||
} catch (error) {
|
||||
console.warn('Error fetching GitHub stars:', error)
|
||||
return NextResponse.json({ stars: formatStarCount(19400) })
|
||||
return NextResponse.json({ stars: formatStarCount(14500) })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,6 @@ interface ChatConfig {
|
||||
|
||||
interface AudioStreamingOptions {
|
||||
voiceId: string
|
||||
chatId?: string
|
||||
onError: (error: Error) => void
|
||||
}
|
||||
|
||||
@@ -63,19 +62,16 @@ function fileToBase64(file: File): Promise<string> {
|
||||
* Creates an audio stream handler for text-to-speech conversion
|
||||
* @param streamTextToAudio - Function to stream text to audio
|
||||
* @param voiceId - The voice ID to use for TTS
|
||||
* @param chatId - Optional chat ID for deployed chat authentication
|
||||
* @returns Audio stream handler function or undefined
|
||||
*/
|
||||
function createAudioStreamHandler(
|
||||
streamTextToAudio: (text: string, options: AudioStreamingOptions) => Promise<void>,
|
||||
voiceId: string,
|
||||
chatId?: string
|
||||
voiceId: string
|
||||
) {
|
||||
return async (text: string) => {
|
||||
try {
|
||||
await streamTextToAudio(text, {
|
||||
voiceId,
|
||||
chatId,
|
||||
onError: (error: Error) => {
|
||||
logger.error('Audio streaming error:', error)
|
||||
},
|
||||
@@ -117,7 +113,7 @@ export default function ChatClient({ identifier }: { identifier: string }) {
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null)
|
||||
const messagesContainerRef = useRef<HTMLDivElement>(null)
|
||||
const [starCount, setStarCount] = useState('19.4k')
|
||||
const [starCount, setStarCount] = useState('3.4k')
|
||||
const [conversationId, setConversationId] = useState('')
|
||||
|
||||
const [showScrollButton, setShowScrollButton] = useState(false)
|
||||
@@ -395,11 +391,7 @@ export default function ChatClient({ identifier }: { identifier: string }) {
|
||||
// Use the streaming hook with audio support
|
||||
const shouldPlayAudio = isVoiceInput || isVoiceFirstMode
|
||||
const audioHandler = shouldPlayAudio
|
||||
? createAudioStreamHandler(
|
||||
streamTextToAudio,
|
||||
DEFAULT_VOICE_SETTINGS.voiceId,
|
||||
chatConfig?.id
|
||||
)
|
||||
? createAudioStreamHandler(streamTextToAudio, DEFAULT_VOICE_SETTINGS.voiceId)
|
||||
: undefined
|
||||
|
||||
logger.info('Starting to handle streamed response:', { shouldPlayAudio })
|
||||
|
||||
@@ -68,6 +68,7 @@ export function VoiceInterface({
|
||||
messages = [],
|
||||
className,
|
||||
}: VoiceInterfaceProps) {
|
||||
// Simple state machine
|
||||
const [state, setState] = useState<'idle' | 'listening' | 'agent_speaking'>('idle')
|
||||
const [isInitialized, setIsInitialized] = useState(false)
|
||||
const [isMuted, setIsMuted] = useState(false)
|
||||
@@ -75,14 +76,12 @@ export function VoiceInterface({
|
||||
const [permissionStatus, setPermissionStatus] = useState<'prompt' | 'granted' | 'denied'>(
|
||||
'prompt'
|
||||
)
|
||||
|
||||
// Current turn transcript (subtitle)
|
||||
const [currentTranscript, setCurrentTranscript] = useState('')
|
||||
|
||||
// State tracking
|
||||
const currentStateRef = useRef<'idle' | 'listening' | 'agent_speaking'>('idle')
|
||||
const isCallEndedRef = useRef(false)
|
||||
|
||||
useEffect(() => {
|
||||
isCallEndedRef.current = false
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
currentStateRef.current = state
|
||||
@@ -99,10 +98,12 @@ export function VoiceInterface({
|
||||
const isSupported =
|
||||
typeof window !== 'undefined' && !!(window.SpeechRecognition || window.webkitSpeechRecognition)
|
||||
|
||||
// Update muted ref
|
||||
useEffect(() => {
|
||||
isMutedRef.current = isMuted
|
||||
}, [isMuted])
|
||||
|
||||
// Timeout to handle cases where agent doesn't provide audio response
|
||||
const setResponseTimeout = useCallback(() => {
|
||||
if (responseTimeoutRef.current) {
|
||||
clearTimeout(responseTimeoutRef.current)
|
||||
@@ -112,7 +113,7 @@ export function VoiceInterface({
|
||||
if (currentStateRef.current === 'listening') {
|
||||
setState('idle')
|
||||
}
|
||||
}, 5000)
|
||||
}, 5000) // 5 second timeout (increased from 3)
|
||||
}, [])
|
||||
|
||||
const clearResponseTimeout = useCallback(() => {
|
||||
@@ -122,14 +123,14 @@ export function VoiceInterface({
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Sync with external state
|
||||
useEffect(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (isPlayingAudio && state !== 'agent_speaking') {
|
||||
clearResponseTimeout()
|
||||
clearResponseTimeout() // Clear timeout since agent is responding
|
||||
setState('agent_speaking')
|
||||
setCurrentTranscript('')
|
||||
|
||||
// Mute microphone immediately
|
||||
setIsMuted(true)
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getAudioTracks().forEach((track) => {
|
||||
@@ -137,6 +138,7 @@ export function VoiceInterface({
|
||||
})
|
||||
}
|
||||
|
||||
// Stop speech recognition completely
|
||||
if (recognitionRef.current) {
|
||||
try {
|
||||
recognitionRef.current.abort()
|
||||
@@ -145,12 +147,10 @@ export function VoiceInterface({
|
||||
}
|
||||
}
|
||||
} else if (!isPlayingAudio && state === 'agent_speaking') {
|
||||
// Don't unmute/restart if call has ended
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
setState('idle')
|
||||
setCurrentTranscript('')
|
||||
|
||||
// Re-enable microphone
|
||||
setIsMuted(false)
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getAudioTracks().forEach((track) => {
|
||||
@@ -160,6 +160,7 @@ export function VoiceInterface({
|
||||
}
|
||||
}, [isPlayingAudio, state, clearResponseTimeout])
|
||||
|
||||
// Audio setup
|
||||
const setupAudio = useCallback(async () => {
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
@@ -174,6 +175,7 @@ export function VoiceInterface({
|
||||
setPermissionStatus('granted')
|
||||
mediaStreamRef.current = stream
|
||||
|
||||
// Setup audio context for visualization
|
||||
if (!audioContextRef.current) {
|
||||
const AudioContext = window.AudioContext || window.webkitAudioContext
|
||||
audioContextRef.current = new AudioContext()
|
||||
@@ -192,6 +194,7 @@ export function VoiceInterface({
|
||||
source.connect(analyser)
|
||||
analyserRef.current = analyser
|
||||
|
||||
// Start visualization
|
||||
const updateVisualization = () => {
|
||||
if (!analyserRef.current) return
|
||||
|
||||
@@ -220,6 +223,7 @@ export function VoiceInterface({
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Speech recognition setup
|
||||
const setupSpeechRecognition = useCallback(() => {
|
||||
if (!isSupported) return
|
||||
|
||||
@@ -235,8 +239,6 @@ export function VoiceInterface({
|
||||
recognition.onstart = () => {}
|
||||
|
||||
recognition.onresult = (event: SpeechRecognitionEvent) => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
const currentState = currentStateRef.current
|
||||
|
||||
if (isMutedRef.current || currentState !== 'listening') {
|
||||
@@ -257,11 +259,14 @@ export function VoiceInterface({
|
||||
}
|
||||
}
|
||||
|
||||
// Update live transcript
|
||||
setCurrentTranscript(interimTranscript || finalTranscript)
|
||||
|
||||
// Send final transcript (but keep listening state until agent responds)
|
||||
if (finalTranscript.trim()) {
|
||||
setCurrentTranscript('')
|
||||
setCurrentTranscript('') // Clear transcript
|
||||
|
||||
// Stop recognition to avoid interference while waiting for response
|
||||
if (recognitionRef.current) {
|
||||
try {
|
||||
recognitionRef.current.stop()
|
||||
@@ -270,6 +275,7 @@ export function VoiceInterface({
|
||||
}
|
||||
}
|
||||
|
||||
// Start timeout in case agent doesn't provide audio response
|
||||
setResponseTimeout()
|
||||
|
||||
onVoiceTranscript?.(finalTranscript)
|
||||
@@ -277,14 +283,13 @@ export function VoiceInterface({
|
||||
}
|
||||
|
||||
recognition.onend = () => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
const currentState = currentStateRef.current
|
||||
|
||||
// Only restart recognition if we're in listening state and not muted
|
||||
if (currentState === 'listening' && !isMutedRef.current) {
|
||||
// Add a delay to avoid immediate restart after sending transcript
|
||||
setTimeout(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
// Double-check state hasn't changed during delay
|
||||
if (
|
||||
recognitionRef.current &&
|
||||
currentStateRef.current === 'listening' &&
|
||||
@@ -296,12 +301,14 @@ export function VoiceInterface({
|
||||
logger.debug('Error restarting speech recognition:', error)
|
||||
}
|
||||
}
|
||||
}, 1000)
|
||||
}, 1000) // Longer delay to give agent time to respond
|
||||
}
|
||||
}
|
||||
|
||||
recognition.onerror = (event: SpeechRecognitionErrorEvent) => {
|
||||
// Filter out "aborted" errors - these are expected when we intentionally stop recognition
|
||||
if (event.error === 'aborted') {
|
||||
// Ignore
|
||||
return
|
||||
}
|
||||
|
||||
@@ -313,9 +320,8 @@ export function VoiceInterface({
|
||||
recognitionRef.current = recognition
|
||||
}, [isSupported, onVoiceTranscript, setResponseTimeout])
|
||||
|
||||
// Start/stop listening
|
||||
const startListening = useCallback(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (!isInitialized || isMuted || state !== 'idle') {
|
||||
return
|
||||
}
|
||||
@@ -333,9 +339,6 @@ export function VoiceInterface({
|
||||
}, [isInitialized, isMuted, state])
|
||||
|
||||
const stopListening = useCallback(() => {
|
||||
// Don't process if call has ended
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
setState('idle')
|
||||
setCurrentTranscript('')
|
||||
|
||||
@@ -348,22 +351,25 @@ export function VoiceInterface({
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Handle interrupt
|
||||
const handleInterrupt = useCallback(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (state === 'agent_speaking') {
|
||||
// Clear any subtitle timeouts and text
|
||||
// (No longer needed after removing subtitle system)
|
||||
|
||||
onInterrupt?.()
|
||||
setState('listening')
|
||||
setCurrentTranscript('')
|
||||
|
||||
// Unmute microphone for user input
|
||||
setIsMuted(false)
|
||||
isMutedRef.current = false
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getAudioTracks().forEach((track) => {
|
||||
track.enabled = true
|
||||
})
|
||||
}
|
||||
|
||||
// Start listening immediately
|
||||
if (recognitionRef.current) {
|
||||
try {
|
||||
recognitionRef.current.start()
|
||||
@@ -374,24 +380,14 @@ export function VoiceInterface({
|
||||
}
|
||||
}, [state, onInterrupt])
|
||||
|
||||
// Handle call end with proper cleanup
|
||||
const handleCallEnd = useCallback(() => {
|
||||
// Mark call as ended FIRST to prevent any effects from restarting recognition
|
||||
isCallEndedRef.current = true
|
||||
|
||||
// Set muted to true to prevent auto-start effect from triggering
|
||||
setIsMuted(true)
|
||||
isMutedRef.current = true
|
||||
|
||||
// Stop everything immediately
|
||||
setState('idle')
|
||||
setCurrentTranscript('')
|
||||
setIsMuted(false)
|
||||
|
||||
// Immediately disable audio tracks to stop listening
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getAudioTracks().forEach((track) => {
|
||||
track.enabled = false
|
||||
})
|
||||
}
|
||||
|
||||
// Stop speech recognition
|
||||
if (recognitionRef.current) {
|
||||
try {
|
||||
recognitionRef.current.abort()
|
||||
@@ -400,15 +396,19 @@ export function VoiceInterface({
|
||||
}
|
||||
}
|
||||
|
||||
// Clear timeouts
|
||||
clearResponseTimeout()
|
||||
|
||||
// Stop audio playback and streaming immediately
|
||||
onInterrupt?.()
|
||||
|
||||
// Call the original onCallEnd
|
||||
onCallEnd?.()
|
||||
}, [onCallEnd, onInterrupt, clearResponseTimeout])
|
||||
|
||||
// Keyboard handler
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (event.code === 'Space') {
|
||||
event.preventDefault()
|
||||
handleInterrupt()
|
||||
@@ -419,9 +419,8 @@ export function VoiceInterface({
|
||||
return () => document.removeEventListener('keydown', handleKeyDown)
|
||||
}, [handleInterrupt])
|
||||
|
||||
// Mute toggle
|
||||
const toggleMute = useCallback(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (state === 'agent_speaking') {
|
||||
handleInterrupt()
|
||||
return
|
||||
@@ -429,7 +428,6 @@ export function VoiceInterface({
|
||||
|
||||
const newMutedState = !isMuted
|
||||
setIsMuted(newMutedState)
|
||||
isMutedRef.current = newMutedState
|
||||
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getAudioTracks().forEach((track) => {
|
||||
@@ -444,6 +442,7 @@ export function VoiceInterface({
|
||||
}
|
||||
}, [isMuted, state, handleInterrupt, stopListening, startListening])
|
||||
|
||||
// Initialize
|
||||
useEffect(() => {
|
||||
if (isSupported) {
|
||||
setupSpeechRecognition()
|
||||
@@ -451,42 +450,47 @@ export function VoiceInterface({
|
||||
}
|
||||
}, [isSupported, setupSpeechRecognition, setupAudio])
|
||||
|
||||
// Auto-start listening when ready
|
||||
useEffect(() => {
|
||||
if (isCallEndedRef.current) return
|
||||
|
||||
if (isInitialized && !isMuted && state === 'idle') {
|
||||
startListening()
|
||||
}
|
||||
}, [isInitialized, isMuted, state, startListening])
|
||||
|
||||
// Cleanup when call ends or component unmounts
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
isCallEndedRef.current = true
|
||||
|
||||
// Stop speech recognition
|
||||
if (recognitionRef.current) {
|
||||
try {
|
||||
recognitionRef.current.abort()
|
||||
} catch (_e) {
|
||||
} catch (error) {
|
||||
// Ignore
|
||||
}
|
||||
recognitionRef.current = null
|
||||
}
|
||||
|
||||
// Stop media stream
|
||||
if (mediaStreamRef.current) {
|
||||
mediaStreamRef.current.getTracks().forEach((track) => track.stop())
|
||||
mediaStreamRef.current.getTracks().forEach((track) => {
|
||||
track.stop()
|
||||
})
|
||||
mediaStreamRef.current = null
|
||||
}
|
||||
|
||||
// Stop audio context
|
||||
if (audioContextRef.current) {
|
||||
audioContextRef.current.close()
|
||||
audioContextRef.current = null
|
||||
}
|
||||
|
||||
// Cancel animation frame
|
||||
if (animationFrameRef.current) {
|
||||
cancelAnimationFrame(animationFrameRef.current)
|
||||
animationFrameRef.current = null
|
||||
}
|
||||
|
||||
// Clear timeouts
|
||||
if (responseTimeoutRef.current) {
|
||||
clearTimeout(responseTimeoutRef.current)
|
||||
responseTimeoutRef.current = null
|
||||
@@ -494,6 +498,7 @@ export function VoiceInterface({
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Get status text
|
||||
const getStatusText = () => {
|
||||
switch (state) {
|
||||
case 'listening':
|
||||
@@ -505,6 +510,7 @@ export function VoiceInterface({
|
||||
}
|
||||
}
|
||||
|
||||
// Get button content
|
||||
const getButtonContent = () => {
|
||||
if (state === 'agent_speaking') {
|
||||
return (
|
||||
@@ -518,7 +524,9 @@ export function VoiceInterface({
|
||||
|
||||
return (
|
||||
<div className={cn('fixed inset-0 z-[100] flex flex-col bg-white text-gray-900', className)}>
|
||||
{/* Main content */}
|
||||
<div className='flex flex-1 flex-col items-center justify-center px-8'>
|
||||
{/* Voice visualization */}
|
||||
<div className='relative mb-16'>
|
||||
<ParticlesVisualization
|
||||
audioLevels={audioLevels}
|
||||
@@ -530,6 +538,7 @@ export function VoiceInterface({
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Live transcript - subtitle style */}
|
||||
<div className='mb-16 flex h-24 items-center justify-center'>
|
||||
{currentTranscript && (
|
||||
<div className='max-w-2xl px-8'>
|
||||
@@ -540,14 +549,17 @@ export function VoiceInterface({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Status */}
|
||||
<p className='mb-8 text-center text-gray-600 text-lg'>
|
||||
{getStatusText()}
|
||||
{isMuted && <span className='ml-2 text-gray-400 text-sm'>(Muted)</span>}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Controls */}
|
||||
<div className='px-8 pb-12'>
|
||||
<div className='flex items-center justify-center space-x-12'>
|
||||
{/* End call */}
|
||||
<Button
|
||||
onClick={handleCallEnd}
|
||||
variant='outline'
|
||||
@@ -557,6 +569,7 @@ export function VoiceInterface({
|
||||
<Phone className='h-6 w-6 rotate-[135deg]' />
|
||||
</Button>
|
||||
|
||||
{/* Mic/Stop button */}
|
||||
<Button
|
||||
onClick={toggleMute}
|
||||
variant='outline'
|
||||
|
||||
@@ -14,7 +14,6 @@ declare global {
|
||||
interface AudioStreamingOptions {
|
||||
voiceId: string
|
||||
modelId?: string
|
||||
chatId?: string
|
||||
onAudioStart?: () => void
|
||||
onAudioEnd?: () => void
|
||||
onError?: (error: Error) => void
|
||||
@@ -77,14 +76,7 @@ export function useAudioStreaming(sharedAudioContextRef?: RefObject<AudioContext
|
||||
}
|
||||
|
||||
const { text, options } = item
|
||||
const {
|
||||
voiceId,
|
||||
modelId = 'eleven_turbo_v2_5',
|
||||
chatId,
|
||||
onAudioStart,
|
||||
onAudioEnd,
|
||||
onError,
|
||||
} = options
|
||||
const { voiceId, modelId = 'eleven_turbo_v2_5', onAudioStart, onAudioEnd, onError } = options
|
||||
|
||||
try {
|
||||
const audioContext = getAudioContext()
|
||||
@@ -101,7 +93,6 @@ export function useAudioStreaming(sharedAudioContextRef?: RefObject<AudioContext
|
||||
text,
|
||||
voiceId,
|
||||
modelId,
|
||||
chatId,
|
||||
}),
|
||||
signal: abortControllerRef.current?.signal,
|
||||
})
|
||||
|
||||
@@ -262,24 +262,6 @@ const SCOPE_DESCRIPTIONS: Record<string, string> = {
|
||||
'sharing.write': 'Share files and folders with others',
|
||||
// WordPress.com scopes
|
||||
global: 'Full access to manage your WordPress.com sites, posts, pages, media, and settings',
|
||||
// Spotify scopes
|
||||
'user-read-private': 'View your Spotify account details',
|
||||
'user-read-email': 'View your email address on Spotify',
|
||||
'user-library-read': 'View your saved tracks and albums',
|
||||
'user-library-modify': 'Save and remove tracks and albums from your library',
|
||||
'playlist-read-private': 'View your private playlists',
|
||||
'playlist-read-collaborative': 'View collaborative playlists you have access to',
|
||||
'playlist-modify-public': 'Create and manage your public playlists',
|
||||
'playlist-modify-private': 'Create and manage your private playlists',
|
||||
'user-read-playback-state': 'View your current playback state',
|
||||
'user-modify-playback-state': 'Control playback on your Spotify devices',
|
||||
'user-read-currently-playing': 'View your currently playing track',
|
||||
'user-read-recently-played': 'View your recently played tracks',
|
||||
'user-top-read': 'View your top artists and tracks',
|
||||
'user-follow-read': 'View artists and users you follow',
|
||||
'user-follow-modify': 'Follow and unfollow artists and users',
|
||||
'user-read-playback-position': 'View your playback position in podcasts',
|
||||
'ugc-image-upload': 'Upload images to your Spotify playlists',
|
||||
}
|
||||
|
||||
function getScopeDescription(scope: string): string {
|
||||
|
||||
@@ -153,14 +153,6 @@ export const SpotifyBlock: BlockConfig<ToolResponse> = {
|
||||
value: () => 'spotify_search',
|
||||
},
|
||||
|
||||
{
|
||||
id: 'credential',
|
||||
title: 'Spotify Account',
|
||||
type: 'oauth-input',
|
||||
serviceId: 'spotify',
|
||||
required: true,
|
||||
},
|
||||
|
||||
// === SEARCH ===
|
||||
{
|
||||
id: 'query',
|
||||
@@ -655,6 +647,15 @@ export const SpotifyBlock: BlockConfig<ToolResponse> = {
|
||||
],
|
||||
},
|
||||
},
|
||||
|
||||
// === OAUTH CREDENTIAL ===
|
||||
{
|
||||
id: 'credential',
|
||||
title: 'Spotify Account',
|
||||
type: 'oauth-input',
|
||||
serviceId: 'spotify',
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
tools: {
|
||||
access: [
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -987,21 +987,18 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
try {
|
||||
const executionData = JSON.parse(executionDataHeader)
|
||||
|
||||
// If execution data contains content or tool calls, persist to memory
|
||||
if (
|
||||
ctx &&
|
||||
inputs &&
|
||||
(executionData.output?.content || executionData.output?.toolCalls?.list?.length)
|
||||
) {
|
||||
const toolCalls = executionData.output?.toolCalls?.list
|
||||
const messages = this.buildMessagesForMemory(executionData.output.content, toolCalls)
|
||||
|
||||
// Fire and forget - don't await, persist all messages
|
||||
Promise.all(
|
||||
messages.map((message) =>
|
||||
memoryService.persistMemoryMessage(ctx, inputs, message, block.id)
|
||||
// If execution data contains full content, persist to memory
|
||||
if (ctx && inputs && executionData.output?.content) {
|
||||
const assistantMessage: Message = {
|
||||
role: 'assistant',
|
||||
content: executionData.output.content,
|
||||
}
|
||||
// Fire and forget - don't await
|
||||
memoryService
|
||||
.persistMemoryMessage(ctx, inputs, assistantMessage, block.id)
|
||||
.catch((error) =>
|
||||
logger.error('Failed to persist streaming response to memory:', error)
|
||||
)
|
||||
).catch((error) => logger.error('Failed to persist streaming response to memory:', error))
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -1120,28 +1117,25 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract content and tool calls from regular response
|
||||
// Extract content from regular response
|
||||
const blockOutput = result as any
|
||||
const content = blockOutput?.content
|
||||
const toolCalls = blockOutput?.toolCalls?.list
|
||||
|
||||
// Build messages to persist
|
||||
const messages = this.buildMessagesForMemory(content, toolCalls)
|
||||
|
||||
if (messages.length === 0) {
|
||||
if (!content || typeof content !== 'string') {
|
||||
return
|
||||
}
|
||||
|
||||
// Persist all messages
|
||||
for (const message of messages) {
|
||||
await memoryService.persistMemoryMessage(ctx, inputs, message, blockId)
|
||||
const assistantMessage: Message = {
|
||||
role: 'assistant',
|
||||
content,
|
||||
}
|
||||
|
||||
await memoryService.persistMemoryMessage(ctx, inputs, assistantMessage, blockId)
|
||||
|
||||
logger.debug('Persisted assistant response to memory', {
|
||||
workflowId: ctx.workflowId,
|
||||
memoryType: inputs.memoryType,
|
||||
conversationId: inputs.conversationId,
|
||||
messageCount: messages.length,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist response to memory:', error)
|
||||
@@ -1149,69 +1143,6 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds messages for memory storage including tool calls and results
|
||||
* Returns proper OpenAI-compatible message format:
|
||||
* - Assistant message with tool_calls array (if tools were used)
|
||||
* - Tool role messages with results (one per tool call)
|
||||
* - Final assistant message with content (if present)
|
||||
*/
|
||||
private buildMessagesForMemory(
|
||||
content: string | undefined,
|
||||
toolCalls: any[] | undefined
|
||||
): Message[] {
|
||||
const messages: Message[] = []
|
||||
|
||||
if (toolCalls?.length) {
|
||||
// Generate stable IDs for each tool call (only if not provided by provider)
|
||||
// Use index to ensure uniqueness even for same tool name in same millisecond
|
||||
const toolCallsWithIds = toolCalls.map((tc: any, index: number) => ({
|
||||
...tc,
|
||||
_stableId:
|
||||
tc.id ||
|
||||
`call_${tc.name}_${Date.now()}_${index}_${Math.random().toString(36).slice(2, 7)}`,
|
||||
}))
|
||||
|
||||
// Add assistant message with tool_calls
|
||||
const formattedToolCalls = toolCallsWithIds.map((tc: any) => ({
|
||||
id: tc._stableId,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.rawArguments || JSON.stringify(tc.arguments || {}),
|
||||
},
|
||||
}))
|
||||
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: formattedToolCalls,
|
||||
})
|
||||
|
||||
// Add tool result messages using the same stable IDs
|
||||
for (const tc of toolCallsWithIds) {
|
||||
const resultContent =
|
||||
typeof tc.result === 'string' ? tc.result : JSON.stringify(tc.result || {})
|
||||
messages.push({
|
||||
role: 'tool',
|
||||
content: resultContent,
|
||||
tool_call_id: tc._stableId,
|
||||
name: tc.name, // Store tool name for providers that need it (e.g., Google/Gemini)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add final assistant response if present
|
||||
if (content && typeof content === 'string') {
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content,
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
private processProviderResponse(
|
||||
response: any,
|
||||
block: SerializedBlock,
|
||||
|
||||
@@ -32,7 +32,7 @@ describe('Memory', () => {
|
||||
})
|
||||
|
||||
describe('applySlidingWindow (message-based)', () => {
|
||||
it('should keep last N turns (turn = user message + assistant response)', () => {
|
||||
it('should keep last N conversation messages', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'system', content: 'System prompt' },
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
@@ -43,10 +43,9 @@ describe('Memory', () => {
|
||||
{ role: 'assistant', content: 'Response 3' },
|
||||
]
|
||||
|
||||
// Limit to 2 turns: should keep turns 2 and 3
|
||||
const result = (memoryService as any).applySlidingWindow(messages, '2')
|
||||
const result = (memoryService as any).applySlidingWindow(messages, '4')
|
||||
|
||||
expect(result.length).toBe(5) // system + 2 turns (4 messages)
|
||||
expect(result.length).toBe(5)
|
||||
expect(result[0].role).toBe('system')
|
||||
expect(result[0].content).toBe('System prompt')
|
||||
expect(result[1].content).toBe('Message 2')
|
||||
@@ -114,18 +113,19 @@ describe('Memory', () => {
|
||||
it('should preserve first system message and exclude it from token count', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'system', content: 'A' }, // System message - always preserved
|
||||
{ role: 'user', content: 'B' }, // ~1 token (turn 1)
|
||||
{ role: 'assistant', content: 'C' }, // ~1 token (turn 1)
|
||||
{ role: 'user', content: 'D' }, // ~1 token (turn 2)
|
||||
{ role: 'user', content: 'B' }, // ~1 token
|
||||
{ role: 'assistant', content: 'C' }, // ~1 token
|
||||
{ role: 'user', content: 'D' }, // ~1 token
|
||||
]
|
||||
|
||||
// Limit to 2 tokens - fits turn 2 (D=1 token), but turn 1 (B+C=2 tokens) would exceed
|
||||
// Limit to 2 tokens - should fit system message + last 2 conversation messages (D, C)
|
||||
const result = (memoryService as any).applySlidingWindowByTokens(messages, '2', 'gpt-4o')
|
||||
|
||||
// Should have: system message + turn 2 (1 message) = 2 total
|
||||
expect(result.length).toBe(2)
|
||||
// Should have: system message + 2 conversation messages = 3 total
|
||||
expect(result.length).toBe(3)
|
||||
expect(result[0].role).toBe('system') // First system message preserved
|
||||
expect(result[1].content).toBe('D') // Most recent turn
|
||||
expect(result[1].content).toBe('C') // Second most recent conversation message
|
||||
expect(result[2].content).toBe('D') // Most recent conversation message
|
||||
})
|
||||
|
||||
it('should process messages from newest to oldest', () => {
|
||||
@@ -249,29 +249,29 @@ describe('Memory', () => {
|
||||
})
|
||||
|
||||
describe('Token-based vs Message-based comparison', () => {
|
||||
it('should produce different results based on turn limits vs token limits', () => {
|
||||
it('should produce different results for same message count limit', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'user', content: 'A' }, // Short message (~1 token) - turn 1
|
||||
{ role: 'user', content: 'A' }, // Short message (~1 token)
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'This is a much longer response that takes many more tokens',
|
||||
}, // Long message (~15 tokens) - turn 1
|
||||
{ role: 'user', content: 'B' }, // Short message (~1 token) - turn 2
|
||||
}, // Long message (~15 tokens)
|
||||
{ role: 'user', content: 'B' }, // Short message (~1 token)
|
||||
]
|
||||
|
||||
// Turn-based with limit 1: keeps last turn only
|
||||
const messageResult = (memoryService as any).applySlidingWindow(messages, '1')
|
||||
expect(messageResult.length).toBe(1) // Only turn 2 (message B)
|
||||
// Message-based: last 2 messages
|
||||
const messageResult = (memoryService as any).applySlidingWindow(messages, '2')
|
||||
expect(messageResult.length).toBe(2)
|
||||
|
||||
// Token-based: with limit of 10 tokens, fits turn 2 (1 token) but not turn 1 (~16 tokens)
|
||||
// Token-based: with limit of 10 tokens, might fit all 3 messages or just last 2
|
||||
const tokenResult = (memoryService as any).applySlidingWindowByTokens(
|
||||
messages,
|
||||
'10',
|
||||
'gpt-4o'
|
||||
)
|
||||
|
||||
// Both should only fit the last turn due to the long assistant message
|
||||
expect(tokenResult.length).toBe(1)
|
||||
// The long message should affect what fits
|
||||
expect(tokenResult.length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -202,51 +202,13 @@ export class Memory {
|
||||
const systemMessages = messages.filter((msg) => msg.role === 'system')
|
||||
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
|
||||
|
||||
// Group messages into conversation turns
|
||||
// A turn = user message + any tool calls/results + assistant response
|
||||
const turns = this.groupMessagesIntoTurns(conversationMessages)
|
||||
|
||||
// Take the last N turns
|
||||
const recentTurns = turns.slice(-limit)
|
||||
|
||||
// Flatten back to messages
|
||||
const recentMessages = recentTurns.flat()
|
||||
const recentMessages = conversationMessages.slice(-limit)
|
||||
|
||||
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
|
||||
|
||||
return [...firstSystemMessage, ...recentMessages]
|
||||
}
|
||||
|
||||
/**
|
||||
* Groups messages into conversation turns.
|
||||
* A turn starts with a user message and includes all subsequent messages
|
||||
* until the next user message (tool calls, tool results, assistant response).
|
||||
*/
|
||||
private groupMessagesIntoTurns(messages: Message[]): Message[][] {
|
||||
const turns: Message[][] = []
|
||||
let currentTurn: Message[] = []
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === 'user') {
|
||||
// Start a new turn
|
||||
if (currentTurn.length > 0) {
|
||||
turns.push(currentTurn)
|
||||
}
|
||||
currentTurn = [msg]
|
||||
} else {
|
||||
// Add to current turn (assistant, tool, etc.)
|
||||
currentTurn.push(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Don't forget the last turn
|
||||
if (currentTurn.length > 0) {
|
||||
turns.push(currentTurn)
|
||||
}
|
||||
|
||||
return turns
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply token-based sliding window to limit conversation by token count
|
||||
*
|
||||
@@ -254,11 +216,6 @@ export class Memory {
|
||||
* - For consistency with message-based sliding window, the first system message is preserved
|
||||
* - System messages are excluded from the token count
|
||||
* - This ensures system prompts are always available while limiting conversation history
|
||||
*
|
||||
* Turn handling:
|
||||
* - Messages are grouped into turns (user + tool calls/results + assistant response)
|
||||
* - Complete turns are added to stay within token limit
|
||||
* - This prevents breaking tool call/result pairs
|
||||
*/
|
||||
private applySlidingWindowByTokens(
|
||||
messages: Message[],
|
||||
@@ -276,31 +233,25 @@ export class Memory {
|
||||
const systemMessages = messages.filter((msg) => msg.role === 'system')
|
||||
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
|
||||
|
||||
// Group into turns to keep tool call/result pairs together
|
||||
const turns = this.groupMessagesIntoTurns(conversationMessages)
|
||||
|
||||
const result: Message[] = []
|
||||
let currentTokenCount = 0
|
||||
|
||||
// Add turns from most recent backwards
|
||||
for (let i = turns.length - 1; i >= 0; i--) {
|
||||
const turn = turns[i]
|
||||
const turnTokens = turn.reduce(
|
||||
(sum, msg) => sum + getAccurateTokenCount(msg.content || '', model),
|
||||
0
|
||||
)
|
||||
// Add conversation messages from most recent backwards
|
||||
for (let i = conversationMessages.length - 1; i >= 0; i--) {
|
||||
const message = conversationMessages[i]
|
||||
const messageTokens = getAccurateTokenCount(message.content, model)
|
||||
|
||||
if (currentTokenCount + turnTokens <= tokenLimit) {
|
||||
result.unshift(...turn)
|
||||
currentTokenCount += turnTokens
|
||||
if (currentTokenCount + messageTokens <= tokenLimit) {
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
} else if (result.length === 0) {
|
||||
logger.warn('Single turn exceeds token limit, including anyway', {
|
||||
turnTokens,
|
||||
logger.warn('Single message exceeds token limit, including anyway', {
|
||||
messageTokens,
|
||||
tokenLimit,
|
||||
turnMessages: turn.length,
|
||||
messageRole: message.role,
|
||||
})
|
||||
result.unshift(...turn)
|
||||
currentTokenCount += turnTokens
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
break
|
||||
} else {
|
||||
// Token limit reached, stop processing
|
||||
@@ -308,20 +259,17 @@ export class Memory {
|
||||
}
|
||||
}
|
||||
|
||||
// No need to remove orphaned messages - turns are already complete
|
||||
const cleanedResult = result
|
||||
|
||||
logger.debug('Applied token-based sliding window', {
|
||||
totalMessages: messages.length,
|
||||
conversationMessages: conversationMessages.length,
|
||||
includedMessages: cleanedResult.length,
|
||||
includedMessages: result.length,
|
||||
totalTokens: currentTokenCount,
|
||||
tokenLimit,
|
||||
})
|
||||
|
||||
// Preserve first system message and prepend to results (consistent with message-based window)
|
||||
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
|
||||
return [...firstSystemMessage, ...cleanedResult]
|
||||
return [...firstSystemMessage, ...result]
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -376,7 +324,7 @@ export class Memory {
|
||||
// Count tokens used by system messages first
|
||||
let systemTokenCount = 0
|
||||
for (const msg of systemMessages) {
|
||||
systemTokenCount += getAccurateTokenCount(msg.content || '', model)
|
||||
systemTokenCount += getAccurateTokenCount(msg.content, model)
|
||||
}
|
||||
|
||||
// Calculate remaining tokens available for conversation messages
|
||||
@@ -391,36 +339,30 @@ export class Memory {
|
||||
return systemMessages
|
||||
}
|
||||
|
||||
// Group into turns to keep tool call/result pairs together
|
||||
const turns = this.groupMessagesIntoTurns(conversationMessages)
|
||||
|
||||
const result: Message[] = []
|
||||
let currentTokenCount = 0
|
||||
|
||||
for (let i = turns.length - 1; i >= 0; i--) {
|
||||
const turn = turns[i]
|
||||
const turnTokens = turn.reduce(
|
||||
(sum, msg) => sum + getAccurateTokenCount(msg.content || '', model),
|
||||
0
|
||||
)
|
||||
for (let i = conversationMessages.length - 1; i >= 0; i--) {
|
||||
const message = conversationMessages[i]
|
||||
const messageTokens = getAccurateTokenCount(message.content, model)
|
||||
|
||||
if (currentTokenCount + turnTokens <= remainingTokens) {
|
||||
result.unshift(...turn)
|
||||
currentTokenCount += turnTokens
|
||||
if (currentTokenCount + messageTokens <= remainingTokens) {
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
} else if (result.length === 0) {
|
||||
logger.warn('Single turn exceeds remaining context window, including anyway', {
|
||||
turnTokens,
|
||||
logger.warn('Single message exceeds remaining context window, including anyway', {
|
||||
messageTokens,
|
||||
remainingTokens,
|
||||
systemTokenCount,
|
||||
turnMessages: turn.length,
|
||||
messageRole: message.role,
|
||||
})
|
||||
result.unshift(...turn)
|
||||
currentTokenCount += turnTokens
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
break
|
||||
} else {
|
||||
logger.info('Auto-trimmed conversation history to fit context window', {
|
||||
originalTurns: turns.length,
|
||||
trimmedTurns: turns.length - i - 1,
|
||||
originalMessages: conversationMessages.length,
|
||||
trimmedMessages: result.length,
|
||||
conversationTokens: currentTokenCount,
|
||||
systemTokens: systemTokenCount,
|
||||
totalTokens: currentTokenCount + systemTokenCount,
|
||||
@@ -430,7 +372,6 @@ export class Memory {
|
||||
}
|
||||
}
|
||||
|
||||
// No need to remove orphaned messages - turns are already complete
|
||||
return [...systemMessages, ...result]
|
||||
}
|
||||
|
||||
@@ -697,7 +638,7 @@ export class Memory {
|
||||
/**
|
||||
* Validate inputs to prevent malicious data or performance issues
|
||||
*/
|
||||
private validateInputs(conversationId?: string, content?: string | null): void {
|
||||
private validateInputs(conversationId?: string, content?: string): void {
|
||||
if (conversationId) {
|
||||
if (conversationId.length > 255) {
|
||||
throw new Error('Conversation ID too long (max 255 characters)')
|
||||
|
||||
@@ -37,22 +37,10 @@ export interface ToolInput {
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool'
|
||||
content: string | null
|
||||
role: 'system' | 'user' | 'assistant'
|
||||
content: string
|
||||
function_call?: any
|
||||
tool_calls?: ToolCallMessage[]
|
||||
tool_call_id?: string
|
||||
/** Tool name for tool role messages (used by providers like Google/Gemini) */
|
||||
name?: string
|
||||
}
|
||||
|
||||
export interface ToolCallMessage {
|
||||
id: string
|
||||
type: 'function'
|
||||
function: {
|
||||
name: string
|
||||
arguments: string
|
||||
}
|
||||
tool_calls?: any[]
|
||||
}
|
||||
|
||||
export interface StreamingConfig {
|
||||
|
||||
@@ -11,7 +11,7 @@ export { BLOCK_DIMENSIONS, CONTAINER_DIMENSIONS } from '@/lib/workflows/blocks/b
|
||||
/**
|
||||
* Horizontal spacing between layers (columns)
|
||||
*/
|
||||
export const DEFAULT_HORIZONTAL_SPACING = 250
|
||||
export const DEFAULT_HORIZONTAL_SPACING = 350
|
||||
|
||||
/**
|
||||
* Vertical spacing between blocks in the same layer
|
||||
|
||||
@@ -4,12 +4,7 @@ import type { StreamingExecution } from '@/executor/types'
|
||||
import { executeTool } from '@/tools'
|
||||
import { getProviderDefaultModel, getProviderModels } from '../models'
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types'
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '../utils'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils'
|
||||
|
||||
const logger = createLogger('AnthropicProvider')
|
||||
|
||||
@@ -73,12 +68,8 @@ export const anthropicProvider: ProviderConfig = {
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
|
||||
sanitizedMessages.forEach((msg) => {
|
||||
request.messages.forEach((msg) => {
|
||||
if (msg.role === 'function') {
|
||||
// Legacy function role format
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: [
|
||||
@@ -89,41 +80,7 @@ export const anthropicProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
})
|
||||
} else if (msg.role === 'tool') {
|
||||
// Modern tool role format (OpenAI-compatible)
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: (msg as any).tool_call_id,
|
||||
content: msg.content || '',
|
||||
},
|
||||
],
|
||||
})
|
||||
} else if (msg.tool_calls && Array.isArray(msg.tool_calls)) {
|
||||
// Modern tool_calls format (OpenAI-compatible)
|
||||
const toolUseContent = msg.tool_calls.map((tc: any) => ({
|
||||
type: 'tool_use',
|
||||
id: tc.id,
|
||||
name: tc.function?.name || tc.name,
|
||||
input:
|
||||
typeof tc.function?.arguments === 'string'
|
||||
? (() => {
|
||||
try {
|
||||
return JSON.parse(tc.function.arguments)
|
||||
} catch {
|
||||
return {}
|
||||
}
|
||||
})()
|
||||
: tc.function?.arguments || tc.arguments || {},
|
||||
}))
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: toolUseContent,
|
||||
})
|
||||
} else if (msg.function_call) {
|
||||
// Legacy function_call format
|
||||
const toolUseId = `${msg.function_call.name}-${Date.now()}`
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
@@ -533,14 +490,9 @@ ${fieldDescriptions}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the original tool use ID from the API response
|
||||
const toolUseId = toolUse.id || generateToolUseId(toolName)
|
||||
|
||||
toolCalls.push({
|
||||
id: toolUseId,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: JSON.stringify(toolArgs),
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
@@ -549,6 +501,7 @@ ${fieldDescriptions}
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
const toolUseId = generateToolUseId(toolName)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
@@ -887,14 +840,9 @@ ${fieldDescriptions}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the original tool use ID from the API response
|
||||
const toolUseId = toolUse.id || generateToolUseId(toolName)
|
||||
|
||||
toolCalls.push({
|
||||
id: toolUseId,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: JSON.stringify(toolArgs),
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
@@ -903,6 +851,7 @@ ${fieldDescriptions}
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
const toolUseId = generateToolUseId(toolName)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
|
||||
@@ -12,7 +12,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -121,10 +120,9 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to Azure OpenAI format if provided
|
||||
@@ -419,10 +417,8 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,7 +11,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -87,10 +86,9 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to Cerebras format if provided
|
||||
@@ -325,10 +323,8 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,7 +11,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -85,10 +84,9 @@ export const deepseekProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
@@ -325,10 +323,8 @@ export const deepseekProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -10,7 +10,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -553,14 +552,9 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a unique ID for this tool call (Google doesn't provide one)
|
||||
const toolCallId = `call_${toolName}_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: JSON.stringify(toolArgs),
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
@@ -1093,10 +1087,9 @@ function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents.push({ role: 'user', parts: [{ text: request.context }] })
|
||||
}
|
||||
|
||||
// Process messages (sanitized to ensure proper tool call/result pairing)
|
||||
// Process messages
|
||||
if (request.messages && request.messages.length > 0) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
for (const message of sanitizedMessages) {
|
||||
for (const message of request.messages) {
|
||||
if (message.role === 'system') {
|
||||
// Add to system instruction
|
||||
if (!systemInstruction) {
|
||||
@@ -1126,28 +1119,10 @@ function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents.push({ role: 'model', parts: functionCalls })
|
||||
}
|
||||
} else if (message.role === 'tool') {
|
||||
// Convert tool response to Gemini's functionResponse format
|
||||
// Gemini uses 'user' role for function responses
|
||||
const functionName = (message as any).name || 'function'
|
||||
|
||||
let responseData: any
|
||||
try {
|
||||
responseData =
|
||||
typeof message.content === 'string' ? JSON.parse(message.content) : message.content
|
||||
} catch {
|
||||
responseData = { result: message.content }
|
||||
}
|
||||
|
||||
// Convert tool response (Gemini only accepts user/model roles)
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: functionName,
|
||||
response: responseData,
|
||||
},
|
||||
},
|
||||
],
|
||||
parts: [{ text: `Function result: ${message.content}` }],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -76,10 +75,9 @@ export const groqProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to function format if provided
|
||||
@@ -298,10 +296,8 @@ export const groqProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,7 +11,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -101,10 +100,8 @@ export const mistralProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
@@ -358,10 +355,8 @@ export const mistralProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -9,7 +9,7 @@ import type {
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { prepareToolExecution, sanitizeMessagesForProvider } from '@/providers/utils'
|
||||
import { prepareToolExecution } from '@/providers/utils'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
@@ -126,10 +126,9 @@ export const ollamaProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
@@ -408,10 +407,8 @@ export const ollamaProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,7 +11,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -104,10 +103,9 @@ export const openaiProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages (sanitized to ensure proper tool call/result pairing)
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
@@ -400,10 +398,8 @@ export const openaiProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,7 +11,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -94,10 +93,8 @@ export const openRouterProvider: ProviderConfig = {
|
||||
allMessages.push({ role: 'user', content: request.context })
|
||||
}
|
||||
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
@@ -306,10 +303,8 @@ export const openRouterProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -1049,96 +1049,3 @@ export function prepareToolExecution(
|
||||
|
||||
return { toolParams, executionParams }
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes messages array to ensure proper tool call/result pairing
|
||||
* This prevents provider errors like "tool_result without corresponding tool_use"
|
||||
*
|
||||
* Rules enforced:
|
||||
* 1. Every tool message must have a matching tool_calls message before it
|
||||
* 2. Every tool_calls in an assistant message should have corresponding tool results
|
||||
* 3. Messages maintain their original order
|
||||
*/
|
||||
export function sanitizeMessagesForProvider(
|
||||
messages: Array<{
|
||||
role: string
|
||||
content?: string | null
|
||||
tool_calls?: Array<{ id: string; [key: string]: any }>
|
||||
tool_call_id?: string
|
||||
[key: string]: any
|
||||
}>
|
||||
): typeof messages {
|
||||
if (!messages || messages.length === 0) {
|
||||
return messages
|
||||
}
|
||||
|
||||
// Build a map of tool_call IDs to their positions
|
||||
const toolCallIdToIndex = new Map<string, number>()
|
||||
const toolResultIds = new Set<string>()
|
||||
|
||||
// First pass: collect all tool_call IDs and tool result IDs
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i]
|
||||
|
||||
if (msg.tool_calls && Array.isArray(msg.tool_calls)) {
|
||||
for (const tc of msg.tool_calls) {
|
||||
if (tc.id) {
|
||||
toolCallIdToIndex.set(tc.id, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (msg.role === 'tool' && msg.tool_call_id) {
|
||||
toolResultIds.add(msg.tool_call_id)
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: filter messages
|
||||
const result: typeof messages = []
|
||||
|
||||
for (const msg of messages) {
|
||||
// For tool messages: only include if there's a matching tool_calls before it
|
||||
if (msg.role === 'tool') {
|
||||
const toolCallId = msg.tool_call_id
|
||||
if (toolCallId && toolCallIdToIndex.has(toolCallId)) {
|
||||
result.push(msg)
|
||||
} else {
|
||||
logger.debug('Removing orphaned tool message', { toolCallId })
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For assistant messages with tool_calls: only include tool_calls that have results
|
||||
if (msg.role === 'assistant' && msg.tool_calls && Array.isArray(msg.tool_calls)) {
|
||||
const validToolCalls = msg.tool_calls.filter((tc) => tc.id && toolResultIds.has(tc.id))
|
||||
|
||||
if (validToolCalls.length === 0) {
|
||||
// No valid tool calls - if there's content, keep as regular message
|
||||
if (msg.content) {
|
||||
const { tool_calls, ...msgWithoutToolCalls } = msg
|
||||
result.push(msgWithoutToolCalls)
|
||||
} else {
|
||||
logger.debug('Removing assistant message with orphaned tool_calls', {
|
||||
toolCallIds: msg.tool_calls.map((tc) => tc.id),
|
||||
})
|
||||
}
|
||||
} else if (validToolCalls.length === msg.tool_calls.length) {
|
||||
// All tool calls are valid
|
||||
result.push(msg)
|
||||
} else {
|
||||
// Some tool calls are orphaned - keep only valid ones
|
||||
result.push({ ...msg, tool_calls: validToolCalls })
|
||||
logger.debug('Filtered orphaned tool_calls from message', {
|
||||
original: msg.tool_calls.length,
|
||||
kept: validToolCalls.length,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// All other messages pass through
|
||||
result.push(msg)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
@@ -141,10 +140,8 @@ export const vllmProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
@@ -403,10 +400,8 @@ export const vllmProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
@@ -11,7 +11,6 @@ import type {
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
sanitizeMessagesForProvider,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
@@ -84,10 +83,8 @@ export const xAIProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Sanitize messages to ensure proper tool call/result pairing
|
||||
if (request.messages) {
|
||||
const sanitizedMessages = sanitizeMessagesForProvider(request.messages)
|
||||
allMessages.push(...sanitizedMessages)
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Set up tools
|
||||
@@ -367,10 +364,8 @@ export const xAIProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
id: toolCall.id,
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
rawArguments: toolCall.function.arguments,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
|
||||
Reference in New Issue
Block a user