mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-07 22:24:06 -05:00
feat(copilot): superagent (#2201)
* Superagent poc * Checkpoint brokeN * tool call rag * Fix * Fixes * Improvements * Creds stuff * Fix * Fix tools * Fix stream * Prompt * Update sheets descriptions * Better * Copilot components * Delete stuff * Remove db migration * Fix migrations * Fix things * Copilot side superagent * Build workflow from chat * Combine superagent into copilkot * Render tools * Function execution * Max mode indicators * Tool call confirmations * Credential settings * Remove betas * Bump version * Dropdown options in block metadata * Copilot kb tools * Fix lint * Credentials modal * Fix lint * Cleanup * Env var resolution in superagent tools * Get id for workflow vars * Fix insert into subflow * Fix executor for while and do while loops * Fix metadata for parallel * Remove db migration * Rebase * Add migrations back * Clean up code * Fix executor logic issue * Cleanup * Diagram tool * Fix tool naems * Comment out g3p * Remove popup option * Hide o3 * Remove db migration * Fix merge conflicts * Fix lint * Fix tests * Remove webhook change * Remove cb change * Fix lint * Fix * Fix lint * Fix build * comment out gemini * Add gemini back * Remove bad test * Fix * Fix test * Fix * Nuke bad test * Fix lint --------- Co-authored-by: Vikhyath Mondreti <vikhyathvikku@gmail.com> Co-authored-by: Waleed <walif6@gmail.com> Co-authored-by: waleedlatif1 <waleedlatif1@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8ef9a45125
commit
58251e28e6
150
apps/sim/app/api/copilot/auto-allowed-tools/route.ts
Normal file
150
apps/sim/app/api/copilot/auto-allowed-tools/route.ts
Normal file
@@ -0,0 +1,150 @@
|
||||
import { db } from '@sim/db'
|
||||
import { settings } from '@sim/db/schema'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { auth } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('CopilotAutoAllowedToolsAPI')
|
||||
|
||||
/**
|
||||
* GET - Fetch user's auto-allowed integration tools
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const session = await auth.api.getSession({ headers: request.headers })
|
||||
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
|
||||
const [userSettings] = await db
|
||||
.select()
|
||||
.from(settings)
|
||||
.where(eq(settings.userId, userId))
|
||||
.limit(1)
|
||||
|
||||
if (userSettings) {
|
||||
const autoAllowedTools = (userSettings.copilotAutoAllowedTools as string[]) || []
|
||||
return NextResponse.json({ autoAllowedTools })
|
||||
}
|
||||
|
||||
// If no settings record exists, create one with empty array
|
||||
await db.insert(settings).values({
|
||||
id: userId,
|
||||
userId,
|
||||
copilotAutoAllowedTools: [],
|
||||
})
|
||||
|
||||
return NextResponse.json({ autoAllowedTools: [] })
|
||||
} catch (error) {
|
||||
logger.error('Failed to fetch auto-allowed tools', { error })
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* POST - Add a tool to the auto-allowed list
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const session = await auth.api.getSession({ headers: request.headers })
|
||||
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
const body = await request.json()
|
||||
|
||||
if (!body.toolId || typeof body.toolId !== 'string') {
|
||||
return NextResponse.json({ error: 'toolId must be a string' }, { status: 400 })
|
||||
}
|
||||
|
||||
const toolId = body.toolId
|
||||
|
||||
// Get existing settings
|
||||
const [existing] = await db.select().from(settings).where(eq(settings.userId, userId)).limit(1)
|
||||
|
||||
if (existing) {
|
||||
const currentTools = (existing.copilotAutoAllowedTools as string[]) || []
|
||||
|
||||
// Add tool if not already present
|
||||
if (!currentTools.includes(toolId)) {
|
||||
const updatedTools = [...currentTools, toolId]
|
||||
await db
|
||||
.update(settings)
|
||||
.set({
|
||||
copilotAutoAllowedTools: updatedTools,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(eq(settings.userId, userId))
|
||||
|
||||
logger.info('Added tool to auto-allowed list', { userId, toolId })
|
||||
return NextResponse.json({ success: true, autoAllowedTools: updatedTools })
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true, autoAllowedTools: currentTools })
|
||||
}
|
||||
|
||||
// Create new settings record with the tool
|
||||
await db.insert(settings).values({
|
||||
id: userId,
|
||||
userId,
|
||||
copilotAutoAllowedTools: [toolId],
|
||||
})
|
||||
|
||||
logger.info('Created settings and added tool to auto-allowed list', { userId, toolId })
|
||||
return NextResponse.json({ success: true, autoAllowedTools: [toolId] })
|
||||
} catch (error) {
|
||||
logger.error('Failed to add auto-allowed tool', { error })
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* DELETE - Remove a tool from the auto-allowed list
|
||||
*/
|
||||
export async function DELETE(request: NextRequest) {
|
||||
try {
|
||||
const session = await auth.api.getSession({ headers: request.headers })
|
||||
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
const { searchParams } = new URL(request.url)
|
||||
const toolId = searchParams.get('toolId')
|
||||
|
||||
if (!toolId) {
|
||||
return NextResponse.json({ error: 'toolId query parameter is required' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Get existing settings
|
||||
const [existing] = await db.select().from(settings).where(eq(settings.userId, userId)).limit(1)
|
||||
|
||||
if (existing) {
|
||||
const currentTools = (existing.copilotAutoAllowedTools as string[]) || []
|
||||
const updatedTools = currentTools.filter((t) => t !== toolId)
|
||||
|
||||
await db
|
||||
.update(settings)
|
||||
.set({
|
||||
copilotAutoAllowedTools: updatedTools,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(eq(settings.userId, userId))
|
||||
|
||||
logger.info('Removed tool from auto-allowed list', { userId, toolId })
|
||||
return NextResponse.json({ success: true, autoAllowedTools: updatedTools })
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true, autoAllowedTools: [] })
|
||||
} catch (error) {
|
||||
logger.error('Failed to remove auto-allowed tool', { error })
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -1,634 +0,0 @@
|
||||
/**
|
||||
* Tests for copilot chat API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockAuth,
|
||||
mockCryptoUuid,
|
||||
setupCommonApiMocks,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
describe('Copilot Chat API Route', () => {
|
||||
const mockSelect = vi.fn()
|
||||
const mockFrom = vi.fn()
|
||||
const mockWhere = vi.fn()
|
||||
const mockLimit = vi.fn()
|
||||
const mockOrderBy = vi.fn()
|
||||
const mockInsert = vi.fn()
|
||||
const mockValues = vi.fn()
|
||||
const mockReturning = vi.fn()
|
||||
const mockUpdate = vi.fn()
|
||||
const mockSet = vi.fn()
|
||||
|
||||
const mockExecuteProviderRequest = vi.fn()
|
||||
const mockGetCopilotModel = vi.fn()
|
||||
const mockGetRotatingApiKey = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
setupCommonApiMocks()
|
||||
mockCryptoUuid()
|
||||
|
||||
mockSelect.mockReturnValue({ from: mockFrom })
|
||||
mockFrom.mockReturnValue({ where: mockWhere })
|
||||
mockWhere.mockReturnValue({
|
||||
orderBy: mockOrderBy,
|
||||
limit: mockLimit,
|
||||
})
|
||||
mockOrderBy.mockResolvedValue([])
|
||||
mockLimit.mockResolvedValue([])
|
||||
mockInsert.mockReturnValue({ values: mockValues })
|
||||
mockValues.mockReturnValue({ returning: mockReturning })
|
||||
mockUpdate.mockReturnValue({ set: mockSet })
|
||||
mockSet.mockReturnValue({ where: mockWhere })
|
||||
|
||||
vi.doMock('@sim/db', () => ({
|
||||
db: {
|
||||
select: mockSelect,
|
||||
insert: mockInsert,
|
||||
update: mockUpdate,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('@sim/db/schema', () => ({
|
||||
copilotChats: {
|
||||
id: 'id',
|
||||
userId: 'userId',
|
||||
messages: 'messages',
|
||||
title: 'title',
|
||||
model: 'model',
|
||||
workflowId: 'workflowId',
|
||||
createdAt: 'createdAt',
|
||||
updatedAt: 'updatedAt',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('drizzle-orm', () => ({
|
||||
and: vi.fn((...conditions) => ({ conditions, type: 'and' })),
|
||||
eq: vi.fn((field, value) => ({ field, value, type: 'eq' })),
|
||||
desc: vi.fn((field) => ({ field, type: 'desc' })),
|
||||
}))
|
||||
|
||||
mockGetCopilotModel.mockReturnValue({
|
||||
provider: 'anthropic',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
})
|
||||
|
||||
vi.doMock('@/lib/copilot/config', () => ({
|
||||
getCopilotModel: mockGetCopilotModel,
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/copilot/prompts', () => ({
|
||||
TITLE_GENERATION_SYSTEM_PROMPT: 'Generate a title',
|
||||
TITLE_GENERATION_USER_PROMPT: vi.fn((msg) => `Generate title for: ${msg}`),
|
||||
}))
|
||||
|
||||
mockExecuteProviderRequest.mockResolvedValue({
|
||||
content: 'Generated Title',
|
||||
})
|
||||
|
||||
vi.doMock('@/providers', () => ({
|
||||
executeProviderRequest: mockExecuteProviderRequest,
|
||||
}))
|
||||
|
||||
mockGetRotatingApiKey.mockReturnValue('test-api-key')
|
||||
|
||||
vi.doMock('@/lib/core/config/api-keys', () => ({
|
||||
getRotatingApiKey: mockGetRotatingApiKey,
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/core/utils/request', () => ({
|
||||
generateRequestId: vi.fn(() => 'test-request-id'),
|
||||
}))
|
||||
|
||||
const mockEnvValues = {
|
||||
SIM_AGENT_API_URL: 'http://localhost:8000',
|
||||
COPILOT_API_KEY: 'test-sim-agent-key',
|
||||
BETTER_AUTH_URL: 'http://localhost:3000',
|
||||
NEXT_PUBLIC_APP_URL: 'http://localhost:3000',
|
||||
NODE_ENV: 'test',
|
||||
} as const
|
||||
|
||||
vi.doMock('@/lib/core/config/env', () => ({
|
||||
env: mockEnvValues,
|
||||
getEnv: (variable: string) => mockEnvValues[variable as keyof typeof mockEnvValues],
|
||||
isTruthy: (value: string | boolean | number | undefined) =>
|
||||
typeof value === 'string'
|
||||
? value.toLowerCase() === 'true' || value === '1'
|
||||
: Boolean(value),
|
||||
isFalsy: (value: string | boolean | number | undefined) =>
|
||||
typeof value === 'string'
|
||||
? value.toLowerCase() === 'false' || value === '0'
|
||||
: value === false,
|
||||
}))
|
||||
|
||||
global.fetch = vi.fn()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 400 for invalid request body', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
// Missing required fields
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Invalid request data')
|
||||
expect(responseData.details).toBeDefined()
|
||||
})
|
||||
|
||||
it('should handle new chat creation and forward to sim agent', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock successful chat creation
|
||||
const newChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
title: null,
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [],
|
||||
}
|
||||
mockReturning.mockResolvedValue([newChat])
|
||||
|
||||
// Mock successful sim agent response
|
||||
const mockReadableStream = new ReadableStream({
|
||||
start(controller) {
|
||||
const encoder = new TextEncoder()
|
||||
controller.enqueue(
|
||||
encoder.encode('data: {"type": "assistant_message", "content": "Hello response"}\\n\\n')
|
||||
)
|
||||
controller.close()
|
||||
},
|
||||
})
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: mockReadableStream,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
stream: true,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(mockInsert).toHaveBeenCalled()
|
||||
expect(mockValues).toHaveBeenCalledWith({
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
title: null,
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [],
|
||||
})
|
||||
|
||||
// Verify sim agent was called
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:8000/api/chat-completion-streaming',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-sim-agent-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
userId: 'user-123',
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
model: 'claude-4.5-sonnet',
|
||||
mode: 'agent',
|
||||
messageId: 'mock-uuid-1234-5678',
|
||||
version: '1.0.2',
|
||||
chatId: 'chat-123',
|
||||
}),
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should load existing chat and include conversation history', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock existing chat with history
|
||||
const existingChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
title: 'Existing Chat',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Previous message' },
|
||||
{ role: 'assistant', content: 'Previous response' },
|
||||
],
|
||||
}
|
||||
// For POST route, the select query uses limit not orderBy
|
||||
mockLimit.mockResolvedValue([existingChat])
|
||||
|
||||
// Mock sim agent response
|
||||
const mockReadableStream = new ReadableStream({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
},
|
||||
})
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: mockReadableStream,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'New message',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
|
||||
// Verify conversation history was included
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:8000/api/chat-completion-streaming',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
message: 'New message',
|
||||
workflowId: 'workflow-123',
|
||||
userId: 'user-123',
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
model: 'claude-4.5-sonnet',
|
||||
mode: 'agent',
|
||||
messageId: 'mock-uuid-1234-5678',
|
||||
version: '1.0.2',
|
||||
chatId: 'chat-123',
|
||||
}),
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should include implicit feedback in messages', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const newChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
messages: [],
|
||||
}
|
||||
mockReturning.mockResolvedValue([newChat])
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: new ReadableStream({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
implicitFeedback: 'User seems confused about the workflow',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
await POST(req)
|
||||
|
||||
// Verify implicit feedback was included
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:8000/api/chat-completion-streaming',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
userId: 'user-123',
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
model: 'claude-4.5-sonnet',
|
||||
mode: 'agent',
|
||||
messageId: 'mock-uuid-1234-5678',
|
||||
version: '1.0.2',
|
||||
chatId: 'chat-123',
|
||||
}),
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle sim agent API errors', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
text: () => Promise.resolve('Internal server error'),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toContain('Sim agent API error')
|
||||
})
|
||||
|
||||
it('should handle database errors during chat creation', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error
|
||||
mockReturning.mockRejectedValue(new Error('Database connection failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Database connection failed')
|
||||
})
|
||||
|
||||
it('should use ask mode when specified', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: new ReadableStream({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'What is this workflow?',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
mode: 'ask',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
await POST(req)
|
||||
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:8000/api/chat-completion-streaming',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
message: 'What is this workflow?',
|
||||
workflowId: 'workflow-123',
|
||||
userId: 'user-123',
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
model: 'claude-4.5-sonnet',
|
||||
mode: 'ask',
|
||||
messageId: 'mock-uuid-1234-5678',
|
||||
version: '1.0.2',
|
||||
chatId: 'chat-123',
|
||||
}),
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('GET', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat?workflowId=workflow-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 400 when workflowId is missing', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('workflowId is required')
|
||||
})
|
||||
|
||||
it('should return chats for authenticated user and workflow', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database response (what comes from DB)
|
||||
const mockDbChats = [
|
||||
{
|
||||
id: 'chat-1',
|
||||
title: 'First Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
{ role: 'user', content: 'Message 2' },
|
||||
{ role: 'assistant', content: 'Response 2' },
|
||||
],
|
||||
createdAt: new Date('2024-01-01'),
|
||||
updatedAt: new Date('2024-01-02'),
|
||||
},
|
||||
{
|
||||
id: 'chat-2',
|
||||
title: 'Second Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
],
|
||||
createdAt: new Date('2024-01-03'),
|
||||
updatedAt: new Date('2024-01-04'),
|
||||
},
|
||||
]
|
||||
|
||||
// Expected transformed response (what the route returns)
|
||||
const expectedChats = [
|
||||
{
|
||||
id: 'chat-1',
|
||||
title: 'First Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
{ role: 'user', content: 'Message 2' },
|
||||
{ role: 'assistant', content: 'Response 2' },
|
||||
],
|
||||
messageCount: 4,
|
||||
previewYaml: null,
|
||||
createdAt: new Date('2024-01-01'),
|
||||
updatedAt: new Date('2024-01-02'),
|
||||
},
|
||||
{
|
||||
id: 'chat-2',
|
||||
title: 'Second Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
],
|
||||
messageCount: 2,
|
||||
previewYaml: null,
|
||||
createdAt: new Date('2024-01-03'),
|
||||
updatedAt: new Date('2024-01-04'),
|
||||
},
|
||||
]
|
||||
|
||||
mockOrderBy.mockResolvedValue(mockDbChats)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat?workflowId=workflow-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
chats: [
|
||||
{
|
||||
id: 'chat-1',
|
||||
title: 'First Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
{ role: 'user', content: 'Message 2' },
|
||||
{ role: 'assistant', content: 'Response 2' },
|
||||
],
|
||||
messageCount: 4,
|
||||
previewYaml: null,
|
||||
config: null,
|
||||
planArtifact: null,
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-02T00:00:00.000Z',
|
||||
},
|
||||
{
|
||||
id: 'chat-2',
|
||||
title: 'Second Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
],
|
||||
messageCount: 2,
|
||||
previewYaml: null,
|
||||
config: null,
|
||||
planArtifact: null,
|
||||
createdAt: '2024-01-03T00:00:00.000Z',
|
||||
updatedAt: '2024-01-04T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
// Verify database query was made correctly
|
||||
expect(mockSelect).toHaveBeenCalled()
|
||||
expect(mockWhere).toHaveBeenCalled()
|
||||
expect(mockOrderBy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle database errors when fetching chats', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error
|
||||
mockOrderBy.mockRejectedValue(new Error('Database query failed'))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat?workflowId=workflow-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to fetch chats')
|
||||
})
|
||||
|
||||
it('should return empty array when no chats found', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
mockOrderBy.mockResolvedValue([])
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat?workflowId=workflow-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
chats: [],
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -14,11 +14,13 @@ import {
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/request-helpers'
|
||||
import { getCredentialsServerTool } from '@/lib/copilot/tools/server/user/get-credentials'
|
||||
import type { CopilotProviderConfig } from '@/lib/copilot/types'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { CopilotFiles } from '@/lib/uploads'
|
||||
import { createFileContent } from '@/lib/uploads/utils/file-utils'
|
||||
import { tools } from '@/tools/registry'
|
||||
|
||||
const logger = createLogger('CopilotChatAPI')
|
||||
|
||||
@@ -57,9 +59,10 @@ const ChatMessageSchema = z.object({
|
||||
'claude-4.5-sonnet',
|
||||
'claude-4.5-opus',
|
||||
'claude-4.1-opus',
|
||||
'gemini-3-pro',
|
||||
])
|
||||
.optional()
|
||||
.default('claude-4.5-sonnet'),
|
||||
.default('claude-4.5-opus'),
|
||||
mode: z.enum(['ask', 'agent', 'plan']).optional().default('agent'),
|
||||
prefetch: z.boolean().optional(),
|
||||
createNewChat: z.boolean().optional().default(false),
|
||||
@@ -313,6 +316,119 @@ export async function POST(req: NextRequest) {
|
||||
const effectiveConversationId =
|
||||
(currentChat?.conversationId as string | undefined) || conversationId
|
||||
|
||||
// For agent/build mode, fetch credentials and build tool definitions
|
||||
let integrationTools: any[] = []
|
||||
let baseTools: any[] = []
|
||||
let credentials: {
|
||||
oauth: Record<
|
||||
string,
|
||||
{ accessToken: string; accountId: string; name: string; expiresAt?: string }
|
||||
>
|
||||
apiKeys: string[]
|
||||
metadata?: {
|
||||
connectedOAuth: Array<{ provider: string; name: string; scopes?: string[] }>
|
||||
configuredApiKeys: string[]
|
||||
}
|
||||
} | null = null
|
||||
|
||||
if (mode === 'agent') {
|
||||
// Build base tools (executed locally, not deferred)
|
||||
// Include function_execute for code execution capability
|
||||
baseTools = [
|
||||
{
|
||||
name: 'function_execute',
|
||||
description:
|
||||
'Execute JavaScript code to perform calculations, data transformations, API calls, or any programmatic task. Code runs in a secure sandbox with fetch() available. Write plain statements (not wrapped in functions). Example: const res = await fetch(url); const data = await res.json(); return data;',
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
code: {
|
||||
type: 'string',
|
||||
description:
|
||||
'Raw JavaScript statements to execute. Code is auto-wrapped in async context. Use fetch() for HTTP requests. Write like: const res = await fetch(url); return await res.json();',
|
||||
},
|
||||
},
|
||||
required: ['code'],
|
||||
},
|
||||
executeLocally: true,
|
||||
},
|
||||
]
|
||||
// Fetch user credentials (OAuth + API keys)
|
||||
try {
|
||||
const rawCredentials = await getCredentialsServerTool.execute(
|
||||
{},
|
||||
{ userId: authenticatedUserId }
|
||||
)
|
||||
|
||||
// Transform OAuth credentials to map format: { [provider]: { accessToken, accountId, ... } }
|
||||
const oauthMap: Record<
|
||||
string,
|
||||
{ accessToken: string; accountId: string; name: string; expiresAt?: string }
|
||||
> = {}
|
||||
const connectedOAuth: Array<{ provider: string; name: string; scopes?: string[] }> = []
|
||||
for (const cred of rawCredentials?.oauth?.connected?.credentials || []) {
|
||||
if (cred.accessToken) {
|
||||
oauthMap[cred.provider] = {
|
||||
accessToken: cred.accessToken,
|
||||
accountId: cred.id,
|
||||
name: cred.name,
|
||||
}
|
||||
connectedOAuth.push({
|
||||
provider: cred.provider,
|
||||
name: cred.name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
credentials = {
|
||||
oauth: oauthMap,
|
||||
apiKeys: rawCredentials?.environment?.variableNames || [],
|
||||
metadata: {
|
||||
connectedOAuth,
|
||||
configuredApiKeys: rawCredentials?.environment?.variableNames || [],
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(`[${tracker.requestId}] Fetched credentials for build mode`, {
|
||||
oauthProviders: Object.keys(oauthMap),
|
||||
apiKeyCount: credentials.apiKeys.length,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.warn(`[${tracker.requestId}] Failed to fetch credentials`, {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
}
|
||||
|
||||
// Build tool definitions (schemas only)
|
||||
try {
|
||||
const { createUserToolSchema } = await import('@/tools/params')
|
||||
|
||||
integrationTools = Object.entries(tools).map(([toolId, toolConfig]) => {
|
||||
const userSchema = createUserToolSchema(toolConfig)
|
||||
return {
|
||||
name: toolId,
|
||||
description: toolConfig.description || toolConfig.name || toolId,
|
||||
input_schema: userSchema,
|
||||
defer_loading: true, // Anthropic Advanced Tool Use
|
||||
...(toolConfig.oauth?.required && {
|
||||
oauth: {
|
||||
required: true,
|
||||
provider: toolConfig.oauth.provider,
|
||||
},
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(`[${tracker.requestId}] Built tool definitions for build mode`, {
|
||||
integrationToolCount: integrationTools.length,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.warn(`[${tracker.requestId}] Failed to build tool definitions`, {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const requestPayload = {
|
||||
message: message, // Just send the current user message text
|
||||
workflowId,
|
||||
@@ -330,6 +446,10 @@ export async function POST(req: NextRequest) {
|
||||
...(agentContexts.length > 0 && { context: agentContexts }),
|
||||
...(actualChatId ? { chatId: actualChatId } : {}),
|
||||
...(processedFileContents.length > 0 && { fileAttachments: processedFileContents }),
|
||||
// For build/agent mode, include tools and credentials
|
||||
...(integrationTools.length > 0 && { tools: integrationTools }),
|
||||
...(baseTools.length > 0 && { baseTools }),
|
||||
...(credentials && { credentials }),
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -339,6 +459,12 @@ export async function POST(req: NextRequest) {
|
||||
hasConversationId: !!effectiveConversationId,
|
||||
hasFileAttachments: processedFileContents.length > 0,
|
||||
messageLength: message.length,
|
||||
mode,
|
||||
hasTools: integrationTools.length > 0,
|
||||
toolCount: integrationTools.length,
|
||||
hasBaseTools: baseTools.length > 0,
|
||||
baseToolCount: baseTools.length,
|
||||
hasCredentials: !!credentials,
|
||||
})
|
||||
} catch {}
|
||||
|
||||
|
||||
275
apps/sim/app/api/copilot/execute-tool/route.ts
Normal file
275
apps/sim/app/api/copilot/execute-tool/route.ts
Normal file
@@ -0,0 +1,275 @@
|
||||
import { db } from '@sim/db'
|
||||
import { account, workflow } from '@sim/db/schema'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
createBadRequestResponse,
|
||||
createInternalServerErrorResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/request-helpers'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { getEffectiveDecryptedEnv } from '@/lib/environment/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import { getTool } from '@/tools/utils'
|
||||
|
||||
const logger = createLogger('CopilotExecuteToolAPI')
|
||||
|
||||
const ExecuteToolSchema = z.object({
|
||||
toolCallId: z.string(),
|
||||
toolName: z.string(),
|
||||
arguments: z.record(z.any()).optional().default({}),
|
||||
workflowId: z.string().optional(),
|
||||
})
|
||||
|
||||
/**
|
||||
* Resolves all {{ENV_VAR}} references in a value recursively
|
||||
* Works with strings, arrays, and objects
|
||||
*/
|
||||
function resolveEnvVarReferences(value: any, envVars: Record<string, string>): any {
|
||||
if (typeof value === 'string') {
|
||||
// Check for exact match: entire string is "{{VAR_NAME}}"
|
||||
const exactMatch = /^\{\{([^}]+)\}\}$/.exec(value)
|
||||
if (exactMatch) {
|
||||
const envVarName = exactMatch[1].trim()
|
||||
return envVars[envVarName] ?? value
|
||||
}
|
||||
|
||||
// Check for embedded references: "prefix {{VAR}} suffix"
|
||||
return value.replace(/\{\{([^}]+)\}\}/g, (match, varName) => {
|
||||
const trimmedName = varName.trim()
|
||||
return envVars[trimmedName] ?? match
|
||||
})
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
return value.map((item) => resolveEnvVarReferences(item, envVars))
|
||||
}
|
||||
|
||||
if (value !== null && typeof value === 'object') {
|
||||
const resolved: Record<string, any> = {}
|
||||
for (const [key, val] of Object.entries(value)) {
|
||||
resolved[key] = resolveEnvVarReferences(val, envVars)
|
||||
}
|
||||
return resolved
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
const tracker = createRequestTracker()
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
const body = await req.json()
|
||||
|
||||
try {
|
||||
const preview = JSON.stringify(body).slice(0, 300)
|
||||
logger.debug(`[${tracker.requestId}] Incoming execute-tool request`, { preview })
|
||||
} catch {}
|
||||
|
||||
const { toolCallId, toolName, arguments: toolArgs, workflowId } = ExecuteToolSchema.parse(body)
|
||||
|
||||
logger.info(`[${tracker.requestId}] Executing tool`, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
workflowId,
|
||||
hasArgs: Object.keys(toolArgs).length > 0,
|
||||
})
|
||||
|
||||
// Get tool config from registry
|
||||
const toolConfig = getTool(toolName)
|
||||
if (!toolConfig) {
|
||||
// Find similar tool names to help debug
|
||||
const { tools: allTools } = await import('@/tools/registry')
|
||||
const allToolNames = Object.keys(allTools)
|
||||
const prefix = toolName.split('_').slice(0, 2).join('_')
|
||||
const similarTools = allToolNames
|
||||
.filter((name) => name.startsWith(`${prefix.split('_')[0]}_`))
|
||||
.slice(0, 10)
|
||||
|
||||
logger.warn(`[${tracker.requestId}] Tool not found in registry`, {
|
||||
toolName,
|
||||
prefix,
|
||||
similarTools,
|
||||
totalToolsInRegistry: allToolNames.length,
|
||||
})
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: `Tool not found: ${toolName}. Similar tools: ${similarTools.join(', ')}`,
|
||||
toolCallId,
|
||||
},
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
|
||||
// Get the workspaceId from the workflow (env vars are stored at workspace level)
|
||||
let workspaceId: string | undefined
|
||||
if (workflowId) {
|
||||
const workflowResult = await db
|
||||
.select({ workspaceId: workflow.workspaceId })
|
||||
.from(workflow)
|
||||
.where(eq(workflow.id, workflowId))
|
||||
.limit(1)
|
||||
workspaceId = workflowResult[0]?.workspaceId ?? undefined
|
||||
}
|
||||
|
||||
// Get decrypted environment variables early so we can resolve all {{VAR}} references
|
||||
const decryptedEnvVars = await getEffectiveDecryptedEnv(userId, workspaceId)
|
||||
|
||||
logger.info(`[${tracker.requestId}] Fetched environment variables`, {
|
||||
workflowId,
|
||||
workspaceId,
|
||||
envVarCount: Object.keys(decryptedEnvVars).length,
|
||||
envVarKeys: Object.keys(decryptedEnvVars),
|
||||
})
|
||||
|
||||
// Build execution params starting with LLM-provided arguments
|
||||
// Resolve all {{ENV_VAR}} references in the arguments
|
||||
const executionParams: Record<string, any> = resolveEnvVarReferences(toolArgs, decryptedEnvVars)
|
||||
|
||||
logger.info(`[${tracker.requestId}] Resolved env var references in arguments`, {
|
||||
toolName,
|
||||
originalArgKeys: Object.keys(toolArgs),
|
||||
resolvedArgKeys: Object.keys(executionParams),
|
||||
})
|
||||
|
||||
// Resolve OAuth access token if required
|
||||
if (toolConfig.oauth?.required && toolConfig.oauth.provider) {
|
||||
const provider = toolConfig.oauth.provider
|
||||
logger.info(`[${tracker.requestId}] Resolving OAuth token`, { provider })
|
||||
|
||||
try {
|
||||
// Find the account for this provider and user
|
||||
const accounts = await db
|
||||
.select()
|
||||
.from(account)
|
||||
.where(and(eq(account.providerId, provider), eq(account.userId, userId)))
|
||||
.limit(1)
|
||||
|
||||
if (accounts.length > 0) {
|
||||
const acc = accounts[0]
|
||||
const requestId = generateRequestId()
|
||||
const { accessToken } = await refreshTokenIfNeeded(requestId, acc as any, acc.id)
|
||||
|
||||
if (accessToken) {
|
||||
executionParams.accessToken = accessToken
|
||||
logger.info(`[${tracker.requestId}] OAuth token resolved`, { provider })
|
||||
} else {
|
||||
logger.warn(`[${tracker.requestId}] No access token available`, { provider })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: `OAuth token not available for ${provider}. Please reconnect your account.`,
|
||||
toolCallId,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
} else {
|
||||
logger.warn(`[${tracker.requestId}] No account found for provider`, { provider })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: `No ${provider} account connected. Please connect your account first.`,
|
||||
toolCallId,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${tracker.requestId}] Failed to resolve OAuth token`, {
|
||||
provider,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: `Failed to get OAuth token for ${provider}`,
|
||||
toolCallId,
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tool requires an API key that wasn't resolved via {{ENV_VAR}} reference
|
||||
const needsApiKey = toolConfig.params?.apiKey?.required
|
||||
|
||||
if (needsApiKey && !executionParams.apiKey) {
|
||||
logger.warn(`[${tracker.requestId}] No API key found for tool`, { toolName })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: `API key not provided for ${toolName}. Use {{YOUR_API_KEY_ENV_VAR}} to reference your environment variable.`,
|
||||
toolCallId,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
// Add execution context
|
||||
executionParams._context = {
|
||||
workflowId,
|
||||
userId,
|
||||
}
|
||||
|
||||
// Special handling for function_execute - inject environment variables
|
||||
if (toolName === 'function_execute') {
|
||||
executionParams.envVars = decryptedEnvVars
|
||||
executionParams.workflowVariables = {} // No workflow variables in copilot context
|
||||
executionParams.blockData = {} // No block data in copilot context
|
||||
executionParams.blockNameMapping = {} // No block mapping in copilot context
|
||||
executionParams.language = executionParams.language || 'javascript'
|
||||
executionParams.timeout = executionParams.timeout || 30000
|
||||
|
||||
logger.info(`[${tracker.requestId}] Injected env vars for function_execute`, {
|
||||
envVarCount: Object.keys(decryptedEnvVars).length,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
logger.info(`[${tracker.requestId}] Executing tool with resolved credentials`, {
|
||||
toolName,
|
||||
hasAccessToken: !!executionParams.accessToken,
|
||||
hasApiKey: !!executionParams.apiKey,
|
||||
})
|
||||
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
|
||||
logger.info(`[${tracker.requestId}] Tool execution complete`, {
|
||||
toolName,
|
||||
success: result.success,
|
||||
hasOutput: !!result.output,
|
||||
})
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
toolCallId,
|
||||
result: {
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
error: result.error,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.debug(`[${tracker.requestId}] Zod validation error`, { issues: error.issues })
|
||||
return createBadRequestResponse('Invalid request body for execute-tool')
|
||||
}
|
||||
logger.error(`[${tracker.requestId}] Failed to execute tool:`, error)
|
||||
const errorMessage = error instanceof Error ? error.message : 'Failed to execute tool'
|
||||
return createInternalServerErrorResponse(errorMessage)
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ const DEFAULT_ENABLED_MODELS: Record<string, boolean> = {
|
||||
'claude-4.5-sonnet': true,
|
||||
'claude-4.5-opus': true,
|
||||
// 'claude-4.1-opus': true,
|
||||
'gemini-3-pro': true,
|
||||
}
|
||||
|
||||
// GET - Fetch user's enabled models
|
||||
|
||||
@@ -965,7 +965,7 @@ The system will substitute actual values when these placeholders are used, keepi
|
||||
instruction:
|
||||
'Extract the requested information from this page according to the schema',
|
||||
schema: zodSchema,
|
||||
})
|
||||
} as any)
|
||||
|
||||
logger.info('Successfully extracted structured data as fallback', {
|
||||
keys: structuredOutput ? Object.keys(structuredOutput) : [],
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
'use client'
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { AlertCircle } from 'lucide-react'
|
||||
import mermaid from 'mermaid'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('MermaidDiagram')
|
||||
|
||||
interface MermaidDiagramProps {
|
||||
diagramText: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Renders mermaid diagrams with pan/zoom support
|
||||
*/
|
||||
export function MermaidDiagram({ diagramText }: MermaidDiagramProps) {
|
||||
const [dataUrl, setDataUrl] = useState<string | null>(null)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [zoom, setZoom] = useState(0.6)
|
||||
const [pan, setPan] = useState({ x: 0, y: 0 })
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const dragStart = useRef({ x: 0, y: 0, panX: 0, panY: 0 })
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const renderDiagram = useCallback(async () => {
|
||||
if (!diagramText?.trim()) {
|
||||
setError('No diagram text provided')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
setError(null)
|
||||
mermaid.initialize({
|
||||
startOnLoad: false,
|
||||
theme: 'base',
|
||||
securityLevel: 'loose',
|
||||
fontFamily: 'system-ui, -apple-system, sans-serif',
|
||||
fontSize: 14,
|
||||
flowchart: {
|
||||
useMaxWidth: false,
|
||||
htmlLabels: true,
|
||||
padding: 20,
|
||||
nodeSpacing: 50,
|
||||
rankSpacing: 60,
|
||||
},
|
||||
themeVariables: {
|
||||
primaryColor: '#dbeafe',
|
||||
primaryTextColor: '#1e3a5f',
|
||||
primaryBorderColor: '#3b82f6',
|
||||
lineColor: '#64748b',
|
||||
secondaryColor: '#fef3c7',
|
||||
secondaryTextColor: '#92400e',
|
||||
secondaryBorderColor: '#f59e0b',
|
||||
tertiaryColor: '#d1fae5',
|
||||
tertiaryTextColor: '#065f46',
|
||||
tertiaryBorderColor: '#10b981',
|
||||
background: '#ffffff',
|
||||
mainBkg: '#dbeafe',
|
||||
nodeBorder: '#3b82f6',
|
||||
nodeTextColor: '#1e3a5f',
|
||||
clusterBkg: '#f1f5f9',
|
||||
clusterBorder: '#94a3b8',
|
||||
titleColor: '#0f172a',
|
||||
textColor: '#334155',
|
||||
edgeLabelBackground: '#ffffff',
|
||||
},
|
||||
})
|
||||
|
||||
const id = `mermaid-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`
|
||||
// Replace \n with <br> for proper line breaks in labels
|
||||
const processedText = diagramText.trim().replace(/\\n/g, '<br>')
|
||||
const { svg } = await mermaid.render(id, processedText)
|
||||
const encoded = btoa(unescape(encodeURIComponent(svg)))
|
||||
setDataUrl(`data:image/svg+xml;base64,${encoded}`)
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : 'Failed to render diagram'
|
||||
logger.error('Mermaid render error', { error: msg })
|
||||
setError(msg)
|
||||
}
|
||||
}, [diagramText])
|
||||
|
||||
useEffect(() => {
|
||||
renderDiagram()
|
||||
}, [renderDiagram])
|
||||
|
||||
const handleWheel = useCallback((e: React.WheelEvent) => {
|
||||
e.preventDefault()
|
||||
const delta = e.deltaY > 0 ? 0.9 : 1.1
|
||||
setZoom((z) => Math.min(Math.max(z * delta, 0.1), 3))
|
||||
}, [])
|
||||
|
||||
const handleMouseDown = useCallback(
|
||||
(e: React.MouseEvent) => {
|
||||
e.preventDefault()
|
||||
setIsDragging(true)
|
||||
dragStart.current = { x: e.clientX, y: e.clientY, panX: pan.x, panY: pan.y }
|
||||
},
|
||||
[pan]
|
||||
)
|
||||
|
||||
const handleMouseMove = useCallback(
|
||||
(e: React.MouseEvent) => {
|
||||
if (!isDragging) return
|
||||
setPan({
|
||||
x: dragStart.current.panX + (e.clientX - dragStart.current.x),
|
||||
y: dragStart.current.panY + (e.clientY - dragStart.current.y),
|
||||
})
|
||||
},
|
||||
[isDragging]
|
||||
)
|
||||
|
||||
const handleMouseUp = useCallback(() => setIsDragging(false), [])
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className='flex items-center gap-2 rounded-md border border-red-500/30 bg-red-500/10 p-3 text-red-400 text-sm'>
|
||||
<AlertCircle className='h-4 w-4 flex-shrink-0' />
|
||||
<span>{error}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (!dataUrl) {
|
||||
return (
|
||||
<div className='flex h-24 items-center justify-center text-[var(--text-tertiary)] text-sm'>
|
||||
Rendering...
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={containerRef}
|
||||
className='select-none overflow-hidden rounded-md border border-[var(--border-strong)] bg-white'
|
||||
style={{
|
||||
height: 500,
|
||||
minHeight: 150,
|
||||
resize: 'vertical',
|
||||
cursor: isDragging ? 'grabbing' : 'grab',
|
||||
}}
|
||||
onWheel={handleWheel}
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseUp={handleMouseUp}
|
||||
onMouseLeave={handleMouseUp}
|
||||
title='Scroll to zoom, drag to pan, drag edge to resize'
|
||||
>
|
||||
<img
|
||||
src={dataUrl}
|
||||
alt='Mermaid diagram'
|
||||
className='pointer-events-none h-full w-full object-contain'
|
||||
style={{ transform: `translate(${pan.x}px, ${pan.y}px) scale(${zoom})` }}
|
||||
draggable={false}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { Loader2 } from 'lucide-react'
|
||||
import useDrivePicker from 'react-google-drive-picker'
|
||||
import { Button } from '@/components/emcn'
|
||||
import { Button, Code } from '@/components/emcn'
|
||||
import { GoogleDriveIcon } from '@/components/icons'
|
||||
import { ClientToolCallState } from '@/lib/copilot/tools/client/base-tool'
|
||||
import { getClientTool } from '@/lib/copilot/tools/client/manager'
|
||||
@@ -11,6 +11,7 @@ import { getRegisteredTools } from '@/lib/copilot/tools/client/registry'
|
||||
import { getEnv } from '@/lib/core/config/env'
|
||||
import { CLASS_TOOL_METADATA, useCopilotStore } from '@/stores/panel/copilot/store'
|
||||
import type { CopilotToolCall } from '@/stores/panel/copilot/types'
|
||||
import { MermaidDiagram } from '../mermaid-diagram/mermaid-diagram'
|
||||
|
||||
interface ToolCallProps {
|
||||
toolCall?: CopilotToolCall
|
||||
@@ -100,6 +101,10 @@ const ACTION_VERBS = [
|
||||
'Create',
|
||||
'Creating',
|
||||
'Created',
|
||||
'Generating',
|
||||
'Generated',
|
||||
'Rendering',
|
||||
'Rendered',
|
||||
] as const
|
||||
|
||||
/**
|
||||
@@ -295,7 +300,43 @@ function getDisplayName(toolCall: CopilotToolCall): string {
|
||||
const byState = def?.metadata?.displayNames?.[toolCall.state]
|
||||
if (byState?.text) return byState.text
|
||||
} catch {}
|
||||
return toolCall.name
|
||||
|
||||
// For integration tools, format the tool name nicely
|
||||
// e.g., "google_calendar_list_events" -> "Running Google Calendar List Events"
|
||||
const stateVerb = getStateVerb(toolCall.state)
|
||||
const formattedName = formatToolName(toolCall.name)
|
||||
return `${stateVerb} ${formattedName}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Get verb prefix based on tool state
|
||||
*/
|
||||
function getStateVerb(state: string): string {
|
||||
switch (state) {
|
||||
case 'pending':
|
||||
case 'executing':
|
||||
return 'Running'
|
||||
case 'success':
|
||||
return 'Ran'
|
||||
case 'error':
|
||||
return 'Failed'
|
||||
case 'rejected':
|
||||
case 'aborted':
|
||||
return 'Skipped'
|
||||
default:
|
||||
return 'Running'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format tool name for display
|
||||
* e.g., "google_calendar_list_events" -> "Google Calendar List Events"
|
||||
*/
|
||||
function formatToolName(name: string): string {
|
||||
return name
|
||||
.split('_')
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(' ')
|
||||
}
|
||||
|
||||
function RunSkipButtons({
|
||||
@@ -479,12 +520,19 @@ export function ToolCall({ toolCall: toolCallProp, toolCallId, onStateChange }:
|
||||
}
|
||||
}, [params])
|
||||
|
||||
// Skip rendering tools that are not in the registry or are explicitly omitted
|
||||
try {
|
||||
if (toolCall.name === 'checkoff_todo' || toolCall.name === 'mark_todo_in_progress') return null
|
||||
// Allow if tool id exists in CLASS_TOOL_METADATA (client tools)
|
||||
if (!CLASS_TOOL_METADATA[toolCall.name]) return null
|
||||
} catch {
|
||||
// Skip rendering some internal tools
|
||||
if (toolCall.name === 'checkoff_todo' || toolCall.name === 'mark_todo_in_progress') return null
|
||||
|
||||
// Get current mode from store to determine if we should render integration tools
|
||||
const mode = useCopilotStore.getState().mode
|
||||
|
||||
// Allow rendering if:
|
||||
// 1. Tool is in CLASS_TOOL_METADATA (client tools), OR
|
||||
// 2. We're in build mode (integration tools are executed server-side)
|
||||
const isClientTool = !!CLASS_TOOL_METADATA[toolCall.name]
|
||||
const isIntegrationToolInBuildMode = mode === 'build' && !isClientTool
|
||||
|
||||
if (!isClientTool && !isIntegrationToolInBuildMode) {
|
||||
return null
|
||||
}
|
||||
const isExpandableTool =
|
||||
@@ -874,6 +922,63 @@ export function ToolCall({ toolCall: toolCallProp, toolCallId, onStateChange }:
|
||||
)
|
||||
}
|
||||
|
||||
// Special rendering for function_execute - show code block
|
||||
if (toolCall.name === 'function_execute') {
|
||||
const code = params.code || ''
|
||||
|
||||
return (
|
||||
<div className='w-full'>
|
||||
<ShimmerOverlayText
|
||||
text={displayName}
|
||||
active={isLoadingState}
|
||||
isSpecial={false}
|
||||
className='font-[470] font-season text-[#939393] text-sm dark:text-[#939393]'
|
||||
/>
|
||||
{code && (
|
||||
<div className='mt-2'>
|
||||
<Code.Viewer code={code} language='javascript' showGutter />
|
||||
</div>
|
||||
)}
|
||||
{showButtons && (
|
||||
<RunSkipButtons
|
||||
toolCall={toolCall}
|
||||
onStateChange={handleStateChange}
|
||||
editedParams={editedParams}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Special rendering for generate_diagram - show mermaid diagram
|
||||
if (toolCall.name === 'generate_diagram') {
|
||||
const diagramText = params.diagramText || ''
|
||||
const language = params.language || 'mermaid'
|
||||
|
||||
return (
|
||||
<div className='w-full'>
|
||||
<ShimmerOverlayText
|
||||
text={displayName}
|
||||
active={isLoadingState}
|
||||
isSpecial={false}
|
||||
className='font-[470] font-season text-[#939393] text-sm dark:text-[#939393]'
|
||||
/>
|
||||
{diagramText && language === 'mermaid' && (
|
||||
<div className='mt-2'>
|
||||
<MermaidDiagram diagramText={diagramText} />
|
||||
</div>
|
||||
)}
|
||||
{showButtons && (
|
||||
<RunSkipButtons
|
||||
toolCall={toolCall}
|
||||
onStateChange={handleStateChange}
|
||||
editedParams={editedParams}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className='w-full'>
|
||||
<div
|
||||
|
||||
@@ -32,6 +32,13 @@ function getModelIconComponent(modelValue: string) {
|
||||
return <IconComponent className='h-3.5 w-3.5' />
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a model should display the MAX badge
|
||||
*/
|
||||
function isMaxModel(modelValue: string): boolean {
|
||||
return modelValue === 'claude-4.5-sonnet' || modelValue === 'claude-4.5-opus'
|
||||
}
|
||||
|
||||
/**
|
||||
* Model selector dropdown for choosing AI model.
|
||||
* Displays model icon and label.
|
||||
@@ -132,6 +139,11 @@ export function ModelSelector({ selectedModel, isNearTop, onModelSelect }: Model
|
||||
>
|
||||
{getModelIconComponent(option.value)}
|
||||
<span>{option.label}</span>
|
||||
{isMaxModel(option.value) && (
|
||||
<Badge variant='default' className='ml-auto px-[6px] py-[1px] text-[10px]'>
|
||||
MAX
|
||||
</Badge>
|
||||
)}
|
||||
</PopoverItem>
|
||||
))}
|
||||
</PopoverScrollArea>
|
||||
|
||||
@@ -20,23 +20,24 @@ export const MENTION_OPTIONS = [
|
||||
* Model configuration options
|
||||
*/
|
||||
export const MODEL_OPTIONS = [
|
||||
// { value: 'claude-4-sonnet', label: 'Claude 4 Sonnet' },
|
||||
{ value: 'claude-4.5-sonnet', label: 'Claude 4.5 Sonnet' },
|
||||
{ value: 'claude-4.5-haiku', label: 'Claude 4.5 Haiku' },
|
||||
{ value: 'claude-4.5-opus', label: 'Claude 4.5 Opus' },
|
||||
{ value: 'claude-4.5-sonnet', label: 'Claude 4.5 Sonnet' },
|
||||
// { value: 'claude-4-sonnet', label: 'Claude 4 Sonnet' },
|
||||
{ value: 'claude-4.5-haiku', label: 'Claude 4.5 Haiku' },
|
||||
// { value: 'claude-4.1-opus', label: 'Claude 4.1 Opus' },
|
||||
{ value: 'gpt-5.1-codex', label: 'GPT 5.1 Codex' },
|
||||
// { value: 'gpt-5-codex', label: 'GPT 5 Codex' },
|
||||
{ value: 'gpt-5.1-medium', label: 'GPT 5.1 Medium' },
|
||||
// { value: 'gpt-5-fast', label: 'GPT 5 Fast' },
|
||||
// { value: 'gpt-5', label: 'GPT 5' },
|
||||
// { value: 'gpt-5.1-fast', label: 'GPT 5.1 Fast' },
|
||||
// { value: 'gpt-5.1', label: 'GPT 5.1' },
|
||||
{ value: 'gpt-5.1-medium', label: 'GPT 5.1 Medium' },
|
||||
// { value: 'gpt-5.1-high', label: 'GPT 5.1 High' },
|
||||
// { value: 'gpt-5-codex', label: 'GPT 5 Codex' },
|
||||
{ value: 'gpt-5.1-codex', label: 'GPT 5.1 Codex' },
|
||||
// { value: 'gpt-5-high', label: 'GPT 5 High' },
|
||||
// { value: 'gpt-4o', label: 'GPT 4o' },
|
||||
// { value: 'gpt-4.1', label: 'GPT 4.1' },
|
||||
{ value: 'o3', label: 'o3' },
|
||||
// { value: 'o3', label: 'o3' },
|
||||
{ value: 'gemini-3-pro', label: 'Gemini 3 Pro' },
|
||||
] as const
|
||||
|
||||
/**
|
||||
|
||||
@@ -59,6 +59,15 @@ interface UserInputProps {
|
||||
panelWidth?: number
|
||||
clearOnSubmit?: boolean
|
||||
hasPlanArtifact?: boolean
|
||||
/** Override workflowId from store (for use outside copilot context) */
|
||||
workflowIdOverride?: string | null
|
||||
/** Override selectedModel from store (for use outside copilot context) */
|
||||
selectedModelOverride?: string
|
||||
/** Override setSelectedModel from store (for use outside copilot context) */
|
||||
onModelChangeOverride?: (model: string) => void
|
||||
hideModeSelector?: boolean
|
||||
/** Disable @mention functionality */
|
||||
disableMentions?: boolean
|
||||
}
|
||||
|
||||
interface UserInputRef {
|
||||
@@ -90,6 +99,11 @@ const UserInput = forwardRef<UserInputRef, UserInputProps>(
|
||||
panelWidth = 308,
|
||||
clearOnSubmit = true,
|
||||
hasPlanArtifact = false,
|
||||
workflowIdOverride,
|
||||
selectedModelOverride,
|
||||
onModelChangeOverride,
|
||||
hideModeSelector = false,
|
||||
disableMentions = false,
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
@@ -98,8 +112,13 @@ const UserInput = forwardRef<UserInputRef, UserInputProps>(
|
||||
const params = useParams()
|
||||
const workspaceId = params.workspaceId as string
|
||||
|
||||
// Store hooks
|
||||
const { workflowId, selectedModel, setSelectedModel, contextUsage } = useCopilotStore()
|
||||
const copilotStore = useCopilotStore()
|
||||
const workflowId =
|
||||
workflowIdOverride !== undefined ? workflowIdOverride : copilotStore.workflowId
|
||||
const selectedModel =
|
||||
selectedModelOverride !== undefined ? selectedModelOverride : copilotStore.selectedModel
|
||||
const setSelectedModel = onModelChangeOverride || copilotStore.setSelectedModel
|
||||
const contextUsage = copilotStore.contextUsage
|
||||
|
||||
// Internal state
|
||||
const [internalMessage, setInternalMessage] = useState('')
|
||||
@@ -459,6 +478,9 @@ const UserInput = forwardRef<UserInputRef, UserInputProps>(
|
||||
const newValue = e.target.value
|
||||
setMessage(newValue)
|
||||
|
||||
// Skip mention menu logic if mentions are disabled
|
||||
if (disableMentions) return
|
||||
|
||||
const caret = e.target.selectionStart ?? newValue.length
|
||||
const active = mentionMenu.getActiveMentionQueryAtPosition(caret, newValue)
|
||||
|
||||
@@ -477,7 +499,7 @@ const UserInput = forwardRef<UserInputRef, UserInputProps>(
|
||||
mentionMenu.setSubmenuQueryStart(null)
|
||||
}
|
||||
},
|
||||
[setMessage, mentionMenu]
|
||||
[setMessage, mentionMenu, disableMentions]
|
||||
)
|
||||
|
||||
const handleSelectAdjust = useCallback(() => {
|
||||
@@ -608,32 +630,27 @@ const UserInput = forwardRef<UserInputRef, UserInputProps>(
|
||||
{/* Top Row: Context controls + Build Workflow button */}
|
||||
<div className='mb-[6px] flex flex-wrap items-center justify-between gap-[6px]'>
|
||||
<div className='flex flex-wrap items-center gap-[6px]'>
|
||||
<Badge
|
||||
variant='outline'
|
||||
onClick={handleOpenMentionMenuWithAt}
|
||||
title='Insert @'
|
||||
className={cn(
|
||||
'cursor-pointer rounded-[6px] p-[4.5px]',
|
||||
(disabled || isLoading) && 'cursor-not-allowed'
|
||||
)}
|
||||
>
|
||||
<AtSign className='h-3 w-3' strokeWidth={1.75} />
|
||||
</Badge>
|
||||
{!disableMentions && (
|
||||
<>
|
||||
<Badge
|
||||
variant='outline'
|
||||
onClick={handleOpenMentionMenuWithAt}
|
||||
title='Insert @'
|
||||
className={cn(
|
||||
'cursor-pointer rounded-[6px] p-[4.5px]',
|
||||
(disabled || isLoading) && 'cursor-not-allowed'
|
||||
)}
|
||||
>
|
||||
<AtSign className='h-3 w-3' strokeWidth={1.75} />
|
||||
</Badge>
|
||||
|
||||
{/* Context Usage Indicator */}
|
||||
{/* {contextUsage && contextUsage.percentage > 0 && (
|
||||
<ContextUsageIndicator
|
||||
percentage={contextUsage.percentage}
|
||||
size={18}
|
||||
strokeWidth={2.5}
|
||||
/>
|
||||
)} */}
|
||||
|
||||
{/* Selected Context Pills */}
|
||||
<ContextPills
|
||||
contexts={contextManagement.selectedContexts}
|
||||
onRemoveContext={contextManagement.removeContext}
|
||||
/>
|
||||
{/* Selected Context Pills */}
|
||||
<ContextPills
|
||||
contexts={contextManagement.selectedContexts}
|
||||
onRemoveContext={contextManagement.removeContext}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{hasPlanArtifact && (
|
||||
@@ -690,7 +707,8 @@ const UserInput = forwardRef<UserInputRef, UserInputProps>(
|
||||
/>
|
||||
|
||||
{/* Mention Menu Portal */}
|
||||
{mentionMenu.showMentionMenu &&
|
||||
{!disableMentions &&
|
||||
mentionMenu.showMentionMenu &&
|
||||
createPortal(
|
||||
<MentionMenu
|
||||
mentionMenu={mentionMenu}
|
||||
@@ -706,12 +724,14 @@ const UserInput = forwardRef<UserInputRef, UserInputProps>(
|
||||
<div className='flex items-center justify-between gap-2'>
|
||||
{/* Left side: Mode Selector + Model Selector */}
|
||||
<div className='flex min-w-0 flex-1 items-center gap-[8px]'>
|
||||
<ModeSelector
|
||||
mode={mode}
|
||||
onModeChange={onModeChange}
|
||||
isNearTop={isNearTop}
|
||||
disabled={disabled}
|
||||
/>
|
||||
{!hideModeSelector && (
|
||||
<ModeSelector
|
||||
mode={mode}
|
||||
onModeChange={onModeChange}
|
||||
isNearTop={isNearTop}
|
||||
disabled={disabled}
|
||||
/>
|
||||
)}
|
||||
|
||||
<ModelSelector
|
||||
selectedModel={selectedModel}
|
||||
|
||||
@@ -107,6 +107,8 @@ export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref
|
||||
setPlanTodos,
|
||||
clearPlanArtifact,
|
||||
savePlanArtifact,
|
||||
setSelectedModel,
|
||||
loadAutoAllowedTools,
|
||||
} = useCopilotStore()
|
||||
|
||||
// Initialize copilot
|
||||
@@ -117,6 +119,7 @@ export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref
|
||||
setCopilotWorkflowId,
|
||||
loadChats,
|
||||
fetchContextUsage,
|
||||
loadAutoAllowedTools,
|
||||
currentChat,
|
||||
isSendingMessage,
|
||||
})
|
||||
|
||||
@@ -12,6 +12,7 @@ interface UseCopilotInitializationProps {
|
||||
setCopilotWorkflowId: (workflowId: string | null) => Promise<void>
|
||||
loadChats: (forceRefresh?: boolean) => Promise<void>
|
||||
fetchContextUsage: () => Promise<void>
|
||||
loadAutoAllowedTools: () => Promise<void>
|
||||
currentChat: any
|
||||
isSendingMessage: boolean
|
||||
}
|
||||
@@ -30,6 +31,7 @@ export function useCopilotInitialization(props: UseCopilotInitializationProps) {
|
||||
setCopilotWorkflowId,
|
||||
loadChats,
|
||||
fetchContextUsage,
|
||||
loadAutoAllowedTools,
|
||||
currentChat,
|
||||
isSendingMessage,
|
||||
} = props
|
||||
@@ -112,6 +114,19 @@ export function useCopilotInitialization(props: UseCopilotInitializationProps) {
|
||||
}
|
||||
}, [isInitialized, currentChat?.id, activeWorkflowId, fetchContextUsage])
|
||||
|
||||
/**
|
||||
* Load auto-allowed tools once on mount
|
||||
*/
|
||||
const hasLoadedAutoAllowedToolsRef = useRef(false)
|
||||
useEffect(() => {
|
||||
if (hasMountedRef.current && !hasLoadedAutoAllowedToolsRef.current) {
|
||||
hasLoadedAutoAllowedToolsRef.current = true
|
||||
loadAutoAllowedTools().catch((err) => {
|
||||
logger.warn('[Copilot] Failed to load auto-allowed tools', err)
|
||||
})
|
||||
}
|
||||
}, [loadAutoAllowedTools])
|
||||
|
||||
return {
|
||||
isInitialized,
|
||||
}
|
||||
|
||||
@@ -11,7 +11,9 @@ import ReactFlow, {
|
||||
useReactFlow,
|
||||
} from 'reactflow'
|
||||
import 'reactflow/dist/style.css'
|
||||
import type { OAuthConnectEventDetail } from '@/lib/copilot/tools/client/other/oauth-request-access'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { OAuthProvider } from '@/lib/oauth'
|
||||
import { TriggerUtils } from '@/lib/workflows/triggers/triggers'
|
||||
import { useUserPermissionsContext } from '@/app/workspace/[workspaceId]/providers/workspace-permissions-provider'
|
||||
import {
|
||||
@@ -27,6 +29,7 @@ import { Chat } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/ch
|
||||
import { Cursors } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/cursors/cursors'
|
||||
import { ErrorBoundary } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/error/index'
|
||||
import { NoteBlock } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/note-block/note-block'
|
||||
import { OAuthRequiredModal } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/editor/components/sub-block/components/credential-selector/components/oauth-required-modal'
|
||||
import { WorkflowBlock } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-block/workflow-block'
|
||||
import { WorkflowEdge } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-edge/workflow-edge'
|
||||
import {
|
||||
@@ -94,6 +97,13 @@ const WorkflowContent = React.memo(() => {
|
||||
|
||||
// Track whether the active connection drag started from an error handle
|
||||
const [isErrorConnectionDrag, setIsErrorConnectionDrag] = useState(false)
|
||||
const [oauthModal, setOauthModal] = useState<{
|
||||
provider: OAuthProvider
|
||||
serviceId: string
|
||||
providerName: string
|
||||
requiredScopes: string[]
|
||||
newScopes?: string[]
|
||||
} | null>(null)
|
||||
|
||||
// Hooks
|
||||
const params = useParams()
|
||||
@@ -163,6 +173,25 @@ const WorkflowContent = React.memo(() => {
|
||||
return Object.keys(blocks).length === 0
|
||||
}, [blocks])
|
||||
|
||||
// Listen for global OAuth connect events (from Copilot tool)
|
||||
useEffect(() => {
|
||||
const handleOpenOAuthConnect = (event: Event) => {
|
||||
const detail = (event as CustomEvent<OAuthConnectEventDetail>).detail
|
||||
if (!detail) return
|
||||
setOauthModal({
|
||||
provider: detail.providerId as OAuthProvider,
|
||||
serviceId: detail.serviceId,
|
||||
providerName: detail.providerName,
|
||||
requiredScopes: detail.requiredScopes || [],
|
||||
newScopes: detail.newScopes || [],
|
||||
})
|
||||
}
|
||||
|
||||
window.addEventListener('open-oauth-connect', handleOpenOAuthConnect as EventListener)
|
||||
return () =>
|
||||
window.removeEventListener('open-oauth-connect', handleOpenOAuthConnect as EventListener)
|
||||
}, [])
|
||||
|
||||
// Get diff analysis for edge reconstruction
|
||||
const { diffAnalysis, isShowingDiff, isDiffReady, reapplyDiffMarkers, hasActiveDiff } =
|
||||
useWorkflowDiffStore()
|
||||
@@ -2277,6 +2306,18 @@ const WorkflowContent = React.memo(() => {
|
||||
</div>
|
||||
|
||||
<Terminal />
|
||||
|
||||
{oauthModal && (
|
||||
<OAuthRequiredModal
|
||||
isOpen={true}
|
||||
onClose={() => setOauthModal(null)}
|
||||
provider={oauthModal.provider}
|
||||
toolName={oauthModal.providerName}
|
||||
serviceId={oauthModal.serviceId}
|
||||
requiredScopes={oauthModal.requiredScopes}
|
||||
newScopes={oauthModal.newScopes}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
@@ -493,7 +493,7 @@ const Combobox = forwardRef<HTMLDivElement, ComboboxProps>(
|
||||
<Search className='mr-2 h-[14px] w-[14px] shrink-0 text-[var(--text-muted)]' />
|
||||
<input
|
||||
ref={searchInputRef}
|
||||
className='w-full bg-transparent text-sm text-[var(--text-primary)] placeholder:text-[var(--text-muted)] focus:outline-none'
|
||||
className='w-full bg-transparent text-[var(--text-primary)] text-sm placeholder:text-[var(--text-muted)] focus:outline-none'
|
||||
placeholder={searchPlaceholder}
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
|
||||
@@ -163,10 +163,13 @@ export class EdgeConstructor {
|
||||
sourceIsParallelBlock ||
|
||||
targetIsParallelBlock
|
||||
) {
|
||||
if (sourceIsLoopBlock) {
|
||||
const sentinelEndId = buildSentinelEndId(source)
|
||||
let loopSentinelStartId: string | undefined
|
||||
|
||||
if (!dag.nodes.has(sentinelEndId)) {
|
||||
if (sourceIsLoopBlock) {
|
||||
const sentinelEndId = buildSentinelEndId(originalSource)
|
||||
loopSentinelStartId = buildSentinelStartId(originalSource)
|
||||
|
||||
if (!dag.nodes.has(sentinelEndId) || !dag.nodes.has(loopSentinelStartId)) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -184,6 +187,10 @@ export class EdgeConstructor {
|
||||
target = sentinelStartId
|
||||
}
|
||||
|
||||
if (loopSentinelStartId) {
|
||||
this.addEdge(dag, loopSentinelStartId, target, EDGE.LOOP_EXIT, targetHandle)
|
||||
}
|
||||
|
||||
if (sourceIsParallelBlock || targetIsParallelBlock) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -89,6 +89,14 @@ export class EdgeManager {
|
||||
private shouldActivateEdge(edge: DAGEdge, output: NormalizedBlockOutput): boolean {
|
||||
const handle = edge.sourceHandle
|
||||
|
||||
if (output.selectedRoute === EDGE.LOOP_EXIT) {
|
||||
return handle === EDGE.LOOP_EXIT
|
||||
}
|
||||
|
||||
if (output.selectedRoute === EDGE.LOOP_CONTINUE) {
|
||||
return handle === EDGE.LOOP_CONTINUE || handle === EDGE.LOOP_CONTINUE_ALT
|
||||
}
|
||||
|
||||
if (!handle) {
|
||||
return true
|
||||
}
|
||||
@@ -104,13 +112,6 @@ export class EdgeManager {
|
||||
}
|
||||
|
||||
switch (handle) {
|
||||
case EDGE.LOOP_CONTINUE:
|
||||
case EDGE.LOOP_CONTINUE_ALT:
|
||||
return output.selectedRoute === EDGE.LOOP_CONTINUE
|
||||
|
||||
case EDGE.LOOP_EXIT:
|
||||
return output.selectedRoute === EDGE.LOOP_EXIT
|
||||
|
||||
case EDGE.ERROR:
|
||||
return !!output.error
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ export interface LoopScope {
|
||||
item?: any
|
||||
items?: any[]
|
||||
condition?: string
|
||||
loopType?: 'for' | 'forEach' | 'while' | 'doWhile'
|
||||
skipFirstConditionCheck?: boolean
|
||||
}
|
||||
|
||||
|
||||
@@ -48,11 +48,13 @@ export class LoopOrchestrator {
|
||||
|
||||
switch (loopType) {
|
||||
case 'for':
|
||||
scope.loopType = 'for'
|
||||
scope.maxIterations = loopConfig.iterations || DEFAULTS.MAX_LOOP_ITERATIONS
|
||||
scope.condition = buildLoopIndexCondition(scope.maxIterations)
|
||||
break
|
||||
|
||||
case 'forEach': {
|
||||
scope.loopType = 'forEach'
|
||||
const items = this.resolveForEachItems(ctx, loopConfig.forEachItems)
|
||||
scope.items = items
|
||||
scope.maxIterations = items.length
|
||||
@@ -62,17 +64,18 @@ export class LoopOrchestrator {
|
||||
}
|
||||
|
||||
case 'while':
|
||||
scope.loopType = 'while'
|
||||
scope.condition = loopConfig.whileCondition
|
||||
break
|
||||
|
||||
case 'doWhile':
|
||||
scope.loopType = 'doWhile'
|
||||
if (loopConfig.doWhileCondition) {
|
||||
scope.condition = loopConfig.doWhileCondition
|
||||
} else {
|
||||
scope.maxIterations = loopConfig.iterations || DEFAULTS.MAX_LOOP_ITERATIONS
|
||||
scope.condition = buildLoopIndexCondition(scope.maxIterations)
|
||||
}
|
||||
scope.skipFirstConditionCheck = true
|
||||
break
|
||||
|
||||
default:
|
||||
@@ -130,12 +133,8 @@ export class LoopOrchestrator {
|
||||
|
||||
scope.currentIterationOutputs.clear()
|
||||
|
||||
const isFirstIteration = scope.iteration === 0
|
||||
const shouldSkipFirstCheck = scope.skipFirstConditionCheck && isFirstIteration
|
||||
if (!shouldSkipFirstCheck) {
|
||||
if (!this.evaluateCondition(ctx, scope, scope.iteration + 1)) {
|
||||
return this.createExitResult(ctx, loopId, scope)
|
||||
}
|
||||
if (!this.evaluateCondition(ctx, scope, scope.iteration + 1)) {
|
||||
return this.createExitResult(ctx, loopId, scope)
|
||||
}
|
||||
|
||||
scope.iteration++
|
||||
@@ -245,6 +244,43 @@ export class LoopOrchestrator {
|
||||
return ctx.loopExecutions?.get(loopId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the initial condition for while loops at the sentinel start.
|
||||
* For while loops, the condition must be checked BEFORE the first iteration.
|
||||
* If the condition is false, the loop body should be skipped entirely.
|
||||
*
|
||||
* @returns true if the loop should execute, false if it should be skipped
|
||||
*/
|
||||
evaluateInitialCondition(ctx: ExecutionContext, loopId: string): boolean {
|
||||
const scope = ctx.loopExecutions?.get(loopId)
|
||||
if (!scope) {
|
||||
logger.warn('Loop scope not found for initial condition evaluation', { loopId })
|
||||
return true
|
||||
}
|
||||
|
||||
// Only while loops need an initial condition check
|
||||
// - for/forEach: always execute based on iteration count/items
|
||||
// - doWhile: always execute at least once, check condition after
|
||||
// - while: check condition before first iteration
|
||||
if (scope.loopType !== 'while') {
|
||||
return true
|
||||
}
|
||||
|
||||
if (!scope.condition) {
|
||||
logger.warn('No condition defined for while loop', { loopId })
|
||||
return false
|
||||
}
|
||||
|
||||
const result = this.evaluateWhileCondition(ctx, scope.condition, scope)
|
||||
logger.info('While loop initial condition evaluation', {
|
||||
loopId,
|
||||
condition: scope.condition,
|
||||
result,
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
shouldExecuteLoopNode(_ctx: ExecutionContext, _nodeId: string, _loopId: string): boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -78,6 +78,17 @@ export class NodeExecutionOrchestrator {
|
||||
|
||||
switch (sentinelType) {
|
||||
case 'start': {
|
||||
if (loopId) {
|
||||
const shouldExecute = this.loopOrchestrator.evaluateInitialCondition(ctx, loopId)
|
||||
if (!shouldExecute) {
|
||||
logger.info('While loop initial condition false, skipping loop body', { loopId })
|
||||
return {
|
||||
sentinelStart: true,
|
||||
shouldExit: true,
|
||||
selectedRoute: EDGE.LOOP_EXIT,
|
||||
}
|
||||
}
|
||||
}
|
||||
return { sentinelStart: true }
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ export interface SendMessageRequest {
|
||||
| 'claude-4.5-sonnet'
|
||||
| 'claude-4.5-opus'
|
||||
| 'claude-4.1-opus'
|
||||
| 'gemini-3-pro'
|
||||
prefetch?: boolean
|
||||
createNewChat?: boolean
|
||||
stream?: boolean
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
export const SIM_AGENT_API_URL_DEFAULT = 'https://copilot.sim.ai'
|
||||
export const SIM_AGENT_VERSION = '1.0.2'
|
||||
export const SIM_AGENT_VERSION = '1.0.3'
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { z } from 'zod'
|
||||
import { KnowledgeBaseArgsSchema, KnowledgeBaseResultSchema } from './tools/shared/schemas'
|
||||
|
||||
// Tool IDs supported by the Copilot runtime
|
||||
export const ToolIds = z.enum([
|
||||
@@ -32,6 +33,8 @@ export const ToolIds = z.enum([
|
||||
'deploy_workflow',
|
||||
'check_deployment_status',
|
||||
'navigate_ui',
|
||||
'knowledge_base',
|
||||
'generate_diagram',
|
||||
])
|
||||
export type ToolId = z.infer<typeof ToolIds>
|
||||
|
||||
@@ -71,7 +74,9 @@ export const ToolArgSchemas = {
|
||||
),
|
||||
}),
|
||||
// New
|
||||
oauth_request_access: z.object({}),
|
||||
oauth_request_access: z.object({
|
||||
providerName: z.string(),
|
||||
}),
|
||||
|
||||
deploy_workflow: z.object({
|
||||
action: z.enum(['deploy', 'undeploy']).optional().default('deploy'),
|
||||
@@ -195,6 +200,13 @@ export const ToolArgSchemas = {
|
||||
reason: z.object({
|
||||
reasoning: z.string(),
|
||||
}),
|
||||
|
||||
knowledge_base: KnowledgeBaseArgsSchema,
|
||||
|
||||
generate_diagram: z.object({
|
||||
diagramText: z.string().describe('The raw diagram text content (e.g., mermaid syntax)'),
|
||||
language: z.enum(['mermaid']).default('mermaid').describe('The diagram language/format'),
|
||||
}),
|
||||
} as const
|
||||
export type ToolArgSchemaMap = typeof ToolArgSchemas
|
||||
|
||||
@@ -267,6 +279,8 @@ export const ToolSSESchemas = {
|
||||
ToolArgSchemas.check_deployment_status
|
||||
),
|
||||
navigate_ui: toolCallSSEFor('navigate_ui', ToolArgSchemas.navigate_ui),
|
||||
knowledge_base: toolCallSSEFor('knowledge_base', ToolArgSchemas.knowledge_base),
|
||||
generate_diagram: toolCallSSEFor('generate_diagram', ToolArgSchemas.generate_diagram),
|
||||
} as const
|
||||
export type ToolSSESchemaMap = typeof ToolSSESchemas
|
||||
|
||||
@@ -464,6 +478,12 @@ export const ToolResultSchemas = {
|
||||
workflowName: z.string().optional(),
|
||||
navigated: z.boolean(),
|
||||
}),
|
||||
knowledge_base: KnowledgeBaseResultSchema,
|
||||
generate_diagram: z.object({
|
||||
diagramText: z.string(),
|
||||
language: z.enum(['mermaid']),
|
||||
rendered: z.boolean().optional(),
|
||||
}),
|
||||
} as const
|
||||
export type ToolResultSchemaMap = typeof ToolResultSchemas
|
||||
|
||||
|
||||
130
apps/sim/lib/copilot/tools/client/knowledge/knowledge-base.ts
Normal file
130
apps/sim/lib/copilot/tools/client/knowledge/knowledge-base.ts
Normal file
@@ -0,0 +1,130 @@
|
||||
import { Database, Loader2, MinusCircle, PlusCircle, XCircle } from 'lucide-react'
|
||||
import {
|
||||
BaseClientTool,
|
||||
type BaseClientToolMetadata,
|
||||
ClientToolCallState,
|
||||
} from '@/lib/copilot/tools/client/base-tool'
|
||||
import {
|
||||
ExecuteResponseSuccessSchema,
|
||||
type KnowledgeBaseArgs,
|
||||
} from '@/lib/copilot/tools/shared/schemas'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { useCopilotStore } from '@/stores/panel/copilot/store'
|
||||
|
||||
/**
|
||||
* Client tool for knowledge base operations
|
||||
*/
|
||||
export class KnowledgeBaseClientTool extends BaseClientTool {
|
||||
static readonly id = 'knowledge_base'
|
||||
|
||||
constructor(toolCallId: string) {
|
||||
super(toolCallId, KnowledgeBaseClientTool.id, KnowledgeBaseClientTool.metadata)
|
||||
}
|
||||
|
||||
/**
|
||||
* Only show interrupt for create operation
|
||||
*/
|
||||
getInterruptDisplays(): BaseClientToolMetadata['interrupt'] | undefined {
|
||||
const toolCallsById = useCopilotStore.getState().toolCallsById
|
||||
const toolCall = toolCallsById[this.toolCallId]
|
||||
const params = toolCall?.params as KnowledgeBaseArgs | undefined
|
||||
|
||||
// Only require confirmation for create operation
|
||||
if (params?.operation === 'create') {
|
||||
const name = params?.args?.name || 'new knowledge base'
|
||||
return {
|
||||
accept: { text: `Create "${name}"`, icon: PlusCircle },
|
||||
reject: { text: 'Skip', icon: XCircle },
|
||||
}
|
||||
}
|
||||
|
||||
// No interrupt for list, get, query - auto-execute
|
||||
return undefined
|
||||
}
|
||||
|
||||
static readonly metadata: BaseClientToolMetadata = {
|
||||
displayNames: {
|
||||
[ClientToolCallState.generating]: { text: 'Accessing knowledge base', icon: Loader2 },
|
||||
[ClientToolCallState.pending]: { text: 'Accessing knowledge base', icon: Loader2 },
|
||||
[ClientToolCallState.executing]: { text: 'Accessing knowledge base', icon: Loader2 },
|
||||
[ClientToolCallState.success]: { text: 'Accessed knowledge base', icon: Database },
|
||||
[ClientToolCallState.error]: { text: 'Failed to access knowledge base', icon: XCircle },
|
||||
[ClientToolCallState.aborted]: { text: 'Aborted knowledge base access', icon: MinusCircle },
|
||||
[ClientToolCallState.rejected]: { text: 'Skipped knowledge base access', icon: MinusCircle },
|
||||
},
|
||||
getDynamicText: (params: Record<string, any>, state: ClientToolCallState) => {
|
||||
const operation = params?.operation as string | undefined
|
||||
const name = params?.args?.name as string | undefined
|
||||
|
||||
const opVerbs: Record<string, { active: string; past: string; pending?: string }> = {
|
||||
create: {
|
||||
active: 'Creating knowledge base',
|
||||
past: 'Created knowledge base',
|
||||
pending: name ? `Create knowledge base "${name}"?` : 'Create knowledge base?',
|
||||
},
|
||||
list: { active: 'Listing knowledge bases', past: 'Listed knowledge bases' },
|
||||
get: { active: 'Getting knowledge base', past: 'Retrieved knowledge base' },
|
||||
query: { active: 'Querying knowledge base', past: 'Queried knowledge base' },
|
||||
}
|
||||
const defaultVerb: { active: string; past: string; pending?: string } = {
|
||||
active: 'Accessing knowledge base',
|
||||
past: 'Accessed knowledge base',
|
||||
}
|
||||
const verb = operation ? opVerbs[operation] || defaultVerb : defaultVerb
|
||||
|
||||
if (state === ClientToolCallState.success) {
|
||||
return verb.past
|
||||
}
|
||||
if (state === ClientToolCallState.pending && verb.pending) {
|
||||
return verb.pending
|
||||
}
|
||||
if (
|
||||
state === ClientToolCallState.generating ||
|
||||
state === ClientToolCallState.pending ||
|
||||
state === ClientToolCallState.executing
|
||||
) {
|
||||
return verb.active
|
||||
}
|
||||
return undefined
|
||||
},
|
||||
}
|
||||
|
||||
async handleReject(): Promise<void> {
|
||||
await super.handleReject()
|
||||
this.setState(ClientToolCallState.rejected)
|
||||
}
|
||||
|
||||
async handleAccept(args?: KnowledgeBaseArgs): Promise<void> {
|
||||
await this.execute(args)
|
||||
}
|
||||
|
||||
async execute(args?: KnowledgeBaseArgs): Promise<void> {
|
||||
const logger = createLogger('KnowledgeBaseClientTool')
|
||||
try {
|
||||
this.setState(ClientToolCallState.executing)
|
||||
const payload: KnowledgeBaseArgs = { ...(args || { operation: 'list' }) }
|
||||
|
||||
const res = await fetch('/api/copilot/execute-copilot-server-tool', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ toolName: 'knowledge_base', payload }),
|
||||
})
|
||||
|
||||
if (!res.ok) {
|
||||
const txt = await res.text().catch(() => '')
|
||||
throw new Error(txt || `Server error (${res.status})`)
|
||||
}
|
||||
|
||||
const json = await res.json()
|
||||
const parsed = ExecuteResponseSuccessSchema.parse(json)
|
||||
|
||||
this.setState(ClientToolCallState.success)
|
||||
await this.markToolComplete(200, 'Knowledge base operation completed', parsed.result)
|
||||
this.setState(ClientToolCallState.success)
|
||||
} catch (e: any) {
|
||||
logger.error('execute failed', { message: e?.message })
|
||||
this.setState(ClientToolCallState.error)
|
||||
await this.markToolComplete(500, e?.message || 'Failed to access knowledge base')
|
||||
}
|
||||
}
|
||||
}
|
||||
64
apps/sim/lib/copilot/tools/client/other/generate-diagram.ts
Normal file
64
apps/sim/lib/copilot/tools/client/other/generate-diagram.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
import { GitBranch, Loader2, MinusCircle, XCircle } from 'lucide-react'
|
||||
import {
|
||||
BaseClientTool,
|
||||
type BaseClientToolMetadata,
|
||||
ClientToolCallState,
|
||||
} from '@/lib/copilot/tools/client/base-tool'
|
||||
|
||||
interface GenerateDiagramArgs {
|
||||
diagramText: string
|
||||
language?: 'mermaid'
|
||||
}
|
||||
|
||||
/**
|
||||
* Client tool for rendering diagrams in the copilot chat.
|
||||
* This tool renders mermaid diagrams directly in the UI without server execution.
|
||||
*/
|
||||
export class GenerateDiagramClientTool extends BaseClientTool {
|
||||
static readonly id = 'generate_diagram'
|
||||
|
||||
constructor(toolCallId: string) {
|
||||
super(toolCallId, GenerateDiagramClientTool.id, GenerateDiagramClientTool.metadata)
|
||||
}
|
||||
|
||||
static readonly metadata: BaseClientToolMetadata = {
|
||||
displayNames: {
|
||||
[ClientToolCallState.generating]: { text: 'Designing workflow', icon: Loader2 },
|
||||
[ClientToolCallState.pending]: { text: 'Designing workflow', icon: Loader2 },
|
||||
[ClientToolCallState.executing]: { text: 'Designing workflow', icon: Loader2 },
|
||||
[ClientToolCallState.success]: { text: 'Designed workflow', icon: GitBranch },
|
||||
[ClientToolCallState.error]: { text: 'Failed to design workflow', icon: XCircle },
|
||||
[ClientToolCallState.aborted]: { text: 'Aborted designing workflow', icon: MinusCircle },
|
||||
[ClientToolCallState.rejected]: { text: 'Skipped designing workflow', icon: MinusCircle },
|
||||
},
|
||||
interrupt: undefined,
|
||||
}
|
||||
|
||||
async execute(args?: GenerateDiagramArgs): Promise<void> {
|
||||
try {
|
||||
this.setState(ClientToolCallState.executing)
|
||||
|
||||
const diagramText = args?.diagramText
|
||||
const language = args?.language || 'mermaid'
|
||||
|
||||
if (!diagramText?.trim()) {
|
||||
await this.markToolComplete(400, 'No diagram text provided')
|
||||
this.setState(ClientToolCallState.error)
|
||||
return
|
||||
}
|
||||
|
||||
// The actual rendering happens in the UI component (tool-call.tsx)
|
||||
// We just need to mark the tool as complete with the diagram data
|
||||
await this.markToolComplete(200, 'Diagram rendered successfully', {
|
||||
diagramText,
|
||||
language,
|
||||
rendered: true,
|
||||
})
|
||||
this.setState(ClientToolCallState.success)
|
||||
} catch (error: any) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
await this.markToolComplete(500, message)
|
||||
this.setState(ClientToolCallState.error)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,64 @@ import {
|
||||
type BaseClientToolMetadata,
|
||||
ClientToolCallState,
|
||||
} from '@/lib/copilot/tools/client/base-tool'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { OAUTH_PROVIDERS, type OAuthServiceConfig } from '@/lib/oauth/oauth'
|
||||
|
||||
const logger = createLogger('OAuthRequestAccessClientTool')
|
||||
|
||||
interface OAuthRequestAccessArgs {
|
||||
providerName?: string
|
||||
}
|
||||
|
||||
interface ResolvedServiceInfo {
|
||||
serviceId: string
|
||||
providerId: string
|
||||
service: OAuthServiceConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the service configuration from a provider name.
|
||||
* The providerName should match the exact `name` field returned by get_credentials tool's notConnected services.
|
||||
*/
|
||||
function findServiceByName(providerName: string): ResolvedServiceInfo | null {
|
||||
const normalizedName = providerName.toLowerCase().trim()
|
||||
|
||||
// First pass: exact match (case-insensitive)
|
||||
for (const [, providerConfig] of Object.entries(OAUTH_PROVIDERS)) {
|
||||
for (const [serviceId, service] of Object.entries(providerConfig.services)) {
|
||||
if (service.name.toLowerCase() === normalizedName) {
|
||||
return { serviceId, providerId: service.providerId, service }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: partial match as fallback for flexibility
|
||||
for (const [, providerConfig] of Object.entries(OAUTH_PROVIDERS)) {
|
||||
for (const [serviceId, service] of Object.entries(providerConfig.services)) {
|
||||
if (
|
||||
service.name.toLowerCase().includes(normalizedName) ||
|
||||
normalizedName.includes(service.name.toLowerCase())
|
||||
) {
|
||||
return { serviceId, providerId: service.providerId, service }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export interface OAuthConnectEventDetail {
|
||||
providerName: string
|
||||
serviceId: string
|
||||
providerId: string
|
||||
requiredScopes: string[]
|
||||
newScopes?: string[]
|
||||
}
|
||||
|
||||
export class OAuthRequestAccessClientTool extends BaseClientTool {
|
||||
static readonly id = 'oauth_request_access'
|
||||
|
||||
private cleanupListener?: () => void
|
||||
private providerName?: string
|
||||
|
||||
constructor(toolCallId: string) {
|
||||
super(toolCallId, OAuthRequestAccessClientTool.id, OAuthRequestAccessClientTool.metadata)
|
||||
@@ -18,7 +71,7 @@ export class OAuthRequestAccessClientTool extends BaseClientTool {
|
||||
displayNames: {
|
||||
[ClientToolCallState.generating]: { text: 'Requesting integration access', icon: Loader2 },
|
||||
[ClientToolCallState.pending]: { text: 'Requesting integration access', icon: Loader2 },
|
||||
[ClientToolCallState.executing]: { text: 'Requesting integration access', icon: Loader2 },
|
||||
[ClientToolCallState.executing]: { text: 'Connecting integration', icon: Loader2 },
|
||||
[ClientToolCallState.rejected]: { text: 'Skipped integration access', icon: MinusCircle },
|
||||
[ClientToolCallState.success]: { text: 'Integration connected', icon: CheckCircle },
|
||||
[ClientToolCallState.error]: { text: 'Failed to request integration access', icon: X },
|
||||
@@ -28,63 +81,92 @@ export class OAuthRequestAccessClientTool extends BaseClientTool {
|
||||
accept: { text: 'Connect', icon: PlugZap },
|
||||
reject: { text: 'Skip', icon: MinusCircle },
|
||||
},
|
||||
getDynamicText: (params, state) => {
|
||||
if (params.providerName) {
|
||||
const name = params.providerName
|
||||
switch (state) {
|
||||
case ClientToolCallState.generating:
|
||||
case ClientToolCallState.pending:
|
||||
return `Requesting ${name} access`
|
||||
case ClientToolCallState.executing:
|
||||
return `Connecting to ${name}`
|
||||
case ClientToolCallState.rejected:
|
||||
return `Skipped ${name} access`
|
||||
case ClientToolCallState.success:
|
||||
return `${name} connected`
|
||||
case ClientToolCallState.error:
|
||||
return `Failed to connect ${name}`
|
||||
case ClientToolCallState.aborted:
|
||||
return `Aborted ${name} connection`
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
},
|
||||
}
|
||||
|
||||
async handleAccept(): Promise<void> {
|
||||
async handleAccept(args?: OAuthRequestAccessArgs): Promise<void> {
|
||||
try {
|
||||
// Move to executing (we're waiting for the user to connect an integration)
|
||||
if (args?.providerName) {
|
||||
this.providerName = args.providerName
|
||||
}
|
||||
|
||||
if (!this.providerName) {
|
||||
logger.error('No provider name provided')
|
||||
this.setState(ClientToolCallState.error)
|
||||
await this.markToolComplete(400, 'No provider name specified')
|
||||
return
|
||||
}
|
||||
|
||||
// Find the service by name
|
||||
const serviceInfo = findServiceByName(this.providerName)
|
||||
if (!serviceInfo) {
|
||||
logger.error('Could not find OAuth service for provider', {
|
||||
providerName: this.providerName,
|
||||
})
|
||||
this.setState(ClientToolCallState.error)
|
||||
await this.markToolComplete(400, `Unknown provider: ${this.providerName}`)
|
||||
return
|
||||
}
|
||||
|
||||
const { serviceId, providerId, service } = serviceInfo
|
||||
|
||||
logger.info('Opening OAuth connect modal', {
|
||||
providerName: this.providerName,
|
||||
serviceId,
|
||||
providerId,
|
||||
})
|
||||
|
||||
// Move to executing state
|
||||
this.setState(ClientToolCallState.executing)
|
||||
|
||||
if (typeof window !== 'undefined') {
|
||||
// Listen for modal close; complete success on connection, otherwise mark skipped/rejected
|
||||
const onClosed = async (evt: Event) => {
|
||||
try {
|
||||
const detail = (evt as CustomEvent).detail as { success?: boolean }
|
||||
if (detail?.success) {
|
||||
await this.markToolComplete(200, { granted: true })
|
||||
this.setState(ClientToolCallState.success)
|
||||
} else {
|
||||
await this.markToolComplete(200, 'Tool execution was skipped by the user')
|
||||
this.setState(ClientToolCallState.rejected)
|
||||
}
|
||||
} finally {
|
||||
if (this.cleanupListener) this.cleanupListener()
|
||||
this.cleanupListener = undefined
|
||||
}
|
||||
}
|
||||
window.addEventListener(
|
||||
'oauth-integration-closed',
|
||||
onClosed as EventListener,
|
||||
{
|
||||
once: true,
|
||||
} as any
|
||||
)
|
||||
this.cleanupListener = () =>
|
||||
window.removeEventListener('oauth-integration-closed', onClosed as EventListener)
|
||||
// Dispatch event to open the OAuth modal (same pattern as open-settings)
|
||||
window.dispatchEvent(
|
||||
new CustomEvent<OAuthConnectEventDetail>('open-oauth-connect', {
|
||||
detail: {
|
||||
providerName: this.providerName,
|
||||
serviceId,
|
||||
providerId,
|
||||
requiredScopes: service.scopes || [],
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
window.dispatchEvent(new CustomEvent('open-settings', { detail: { tab: 'credentials' } }))
|
||||
}
|
||||
// Mark as success - the modal will handle the actual OAuth flow
|
||||
this.setState(ClientToolCallState.success)
|
||||
await this.markToolComplete(200, `Opened ${this.providerName} connection dialog`)
|
||||
} catch (e) {
|
||||
logger.error('Failed to open OAuth connect modal', { error: e })
|
||||
this.setState(ClientToolCallState.error)
|
||||
await this.markToolComplete(500, 'Failed to open integrations settings')
|
||||
await this.markToolComplete(500, 'Failed to open OAuth connection dialog')
|
||||
}
|
||||
}
|
||||
|
||||
async handleReject(): Promise<void> {
|
||||
await super.handleReject()
|
||||
this.setState(ClientToolCallState.rejected)
|
||||
if (this.cleanupListener) this.cleanupListener()
|
||||
this.cleanupListener = undefined
|
||||
}
|
||||
|
||||
async completeAfterConnection(): Promise<void> {
|
||||
await this.markToolComplete(200, { granted: true })
|
||||
this.setState(ClientToolCallState.success)
|
||||
if (this.cleanupListener) this.cleanupListener()
|
||||
this.cleanupListener = undefined
|
||||
}
|
||||
|
||||
async execute(): Promise<void> {
|
||||
await this.handleAccept()
|
||||
async execute(args?: OAuthRequestAccessArgs): Promise<void> {
|
||||
await this.handleAccept(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,8 +51,9 @@ export class GetGlobalWorkflowVariablesClientTool extends BaseClientTool {
|
||||
}
|
||||
const json = await res.json()
|
||||
const varsRecord = (json?.data as Record<string, any>) || {}
|
||||
// Convert to name/value pairs for clarity
|
||||
// Convert to id/name/value for clarity
|
||||
const variables = Object.values(varsRecord).map((v: any) => ({
|
||||
id: String(v?.id || ''),
|
||||
name: String(v?.name || ''),
|
||||
value: (v as any)?.value,
|
||||
}))
|
||||
|
||||
@@ -9,6 +9,7 @@ import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { registry as blockRegistry } from '@/blocks/registry'
|
||||
import type { BlockConfig } from '@/blocks/types'
|
||||
import { AuthMode } from '@/blocks/types'
|
||||
import { PROVIDER_DEFINITIONS } from '@/providers/models'
|
||||
import { tools as toolsRegistry } from '@/tools/registry'
|
||||
import { getTrigger, isTriggerValid } from '@/triggers'
|
||||
import { SYSTEM_SUBBLOCK_IDS } from '@/triggers/consts'
|
||||
@@ -381,14 +382,6 @@ function extractInputs(metadata: CopilotBlockMetadata): {
|
||||
continue
|
||||
}
|
||||
|
||||
if (
|
||||
schema.type === 'oauth-credential' ||
|
||||
schema.type === 'credential-input' ||
|
||||
schema.type === 'oauth-input'
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (schema.id === 'triggerConfig' || schema.type === 'trigger-config') {
|
||||
continue
|
||||
}
|
||||
@@ -469,15 +462,6 @@ function extractOperationInputs(
|
||||
continue
|
||||
}
|
||||
|
||||
const lowerKey = key.toLowerCase()
|
||||
if (
|
||||
lowerKey.includes('token') ||
|
||||
lowerKey.includes('credential') ||
|
||||
lowerKey.includes('apikey')
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
const input: any = {
|
||||
name: key,
|
||||
type: (inputDef as any)?.type || 'string',
|
||||
@@ -556,6 +540,7 @@ function mapSchemaTypeToSimpleType(schemaType: string, schema: CopilotSubblockMe
|
||||
'multi-select': 'array',
|
||||
'credential-input': 'credential',
|
||||
'oauth-credential': 'credential',
|
||||
'oauth-input': 'credential',
|
||||
}
|
||||
|
||||
const mappedType = typeMap[schemaType] || schemaType
|
||||
@@ -681,40 +666,131 @@ function resolveAuthType(
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all available models from PROVIDER_DEFINITIONS as static options.
|
||||
* This provides fallback data when store state is not available server-side.
|
||||
* Excludes dynamic providers (ollama, vllm, openrouter) which require runtime fetching.
|
||||
*/
|
||||
function getStaticModelOptions(): { id: string; label?: string }[] {
|
||||
const models: { id: string; label?: string }[] = []
|
||||
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
// Skip providers with dynamic/fetched models
|
||||
if (provider.id === 'ollama' || provider.id === 'vllm' || provider.id === 'openrouter') {
|
||||
continue
|
||||
}
|
||||
if (provider?.models) {
|
||||
for (const model of provider.models) {
|
||||
models.push({ id: model.id, label: model.id })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to call a dynamic options function with fallback data injected.
|
||||
* When the function accesses store state that's unavailable server-side,
|
||||
* this provides static fallback data from known sources.
|
||||
*
|
||||
* @param optionsFn - The options function to call
|
||||
* @returns Options array or undefined if options cannot be resolved
|
||||
*/
|
||||
function callOptionsWithFallback(
|
||||
optionsFn: () => any[]
|
||||
): { id: string; label?: string; hasIcon?: boolean }[] | undefined {
|
||||
// Get static model data to use as fallback
|
||||
const staticModels = getStaticModelOptions()
|
||||
|
||||
// Create a mock providers state with static data
|
||||
const mockProvidersState = {
|
||||
providers: {
|
||||
base: { models: staticModels.map((m) => m.id) },
|
||||
ollama: { models: [] },
|
||||
vllm: { models: [] },
|
||||
openrouter: { models: [] },
|
||||
},
|
||||
}
|
||||
|
||||
// Store original getState if it exists
|
||||
let originalGetState: (() => any) | undefined
|
||||
let store: any
|
||||
|
||||
try {
|
||||
// Try to get the providers store module
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
store = require('@/stores/providers/store')
|
||||
if (store?.useProvidersStore?.getState) {
|
||||
originalGetState = store.useProvidersStore.getState
|
||||
// Temporarily replace getState with our mock
|
||||
store.useProvidersStore.getState = () => mockProvidersState
|
||||
}
|
||||
} catch {
|
||||
// Store module not available, continue with mock
|
||||
}
|
||||
|
||||
try {
|
||||
const result = optionsFn()
|
||||
return result
|
||||
} finally {
|
||||
// Restore original getState
|
||||
if (store?.useProvidersStore && originalGetState) {
|
||||
store.useProvidersStore.getState = originalGetState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function resolveSubblockOptions(
|
||||
sb: any
|
||||
): { id: string; label?: string; hasIcon?: boolean }[] | undefined {
|
||||
try {
|
||||
const rawOptions = typeof sb.options === 'function' ? sb.options() : sb.options
|
||||
if (!Array.isArray(rawOptions)) return undefined
|
||||
|
||||
const normalized = rawOptions
|
||||
.map((opt: any) => {
|
||||
if (!opt) return undefined
|
||||
|
||||
const id = typeof opt === 'object' ? opt.id : opt
|
||||
if (id === undefined || id === null) return undefined
|
||||
|
||||
const result: { id: string; label?: string; hasIcon?: boolean } = {
|
||||
id: String(id),
|
||||
}
|
||||
|
||||
if (typeof opt === 'object' && typeof opt.label === 'string') {
|
||||
result.label = opt.label
|
||||
}
|
||||
|
||||
if (typeof opt === 'object' && opt.icon) {
|
||||
result.hasIcon = true
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
.filter((o): o is { id: string; label?: string; hasIcon?: boolean } => o !== undefined)
|
||||
|
||||
return normalized.length > 0 ? normalized : undefined
|
||||
} catch {
|
||||
// Skip if subblock uses fetchOptions (async network calls)
|
||||
if (sb.fetchOptions) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
let rawOptions: any[] | undefined
|
||||
|
||||
try {
|
||||
if (typeof sb.options === 'function') {
|
||||
// Try calling with fallback data injection for store-dependent options
|
||||
rawOptions = callOptionsWithFallback(sb.options)
|
||||
} else {
|
||||
rawOptions = sb.options
|
||||
}
|
||||
} catch {
|
||||
// Options function failed even with fallback, skip
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (!Array.isArray(rawOptions) || rawOptions.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const normalized = rawOptions
|
||||
.map((opt: any) => {
|
||||
if (!opt) return undefined
|
||||
|
||||
const id = typeof opt === 'object' ? opt.id : opt
|
||||
if (id === undefined || id === null) return undefined
|
||||
|
||||
const result: { id: string; label?: string; hasIcon?: boolean } = {
|
||||
id: String(id),
|
||||
}
|
||||
|
||||
if (typeof opt === 'object' && typeof opt.label === 'string') {
|
||||
result.label = opt.label
|
||||
}
|
||||
|
||||
if (typeof opt === 'object' && opt.icon) {
|
||||
result.hasIcon = true
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
.filter((o): o is { id: string; label?: string; hasIcon?: boolean } => o !== undefined)
|
||||
|
||||
return normalized.length > 0 ? normalized : undefined
|
||||
}
|
||||
|
||||
function removeNullish(obj: any): any {
|
||||
@@ -883,6 +959,9 @@ const SPECIAL_BLOCKS_METADATA: Record<string, any> = {
|
||||
- Use forEach for collection processing, for loops for fixed iterations.
|
||||
- Cannot have loops/parallels inside a loop block.
|
||||
- For yaml it needs to connect blocks inside to the start field of the block.
|
||||
- IMPORTANT for while/doWhile: The condition is evaluated BEFORE each iteration starts, so blocks INSIDE the loop cannot be referenced in the condition (their outputs don't exist yet when the condition runs).
|
||||
- For while/doWhile conditions, use: <loop.index> for iteration count, workflow variables (set by blocks OUTSIDE the loop), or references to blocks OUTSIDE the loop.
|
||||
- To break a while/doWhile loop based on internal block results, use a variables block OUTSIDE the loop and update it from inside, then reference that variable in the condition.
|
||||
`,
|
||||
inputs: {
|
||||
loopType: {
|
||||
@@ -909,7 +988,8 @@ const SPECIAL_BLOCKS_METADATA: Record<string, any> = {
|
||||
condition: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description: "Condition to evaluate (for 'while' and 'doWhile' loopType)",
|
||||
description:
|
||||
"Condition to evaluate (for 'while' and 'doWhile' loopType). IMPORTANT: Cannot reference blocks INSIDE the loop - use <loop.index>, workflow variables, or blocks OUTSIDE the loop instead.",
|
||||
example: '<loop.index> < 10',
|
||||
},
|
||||
maxConcurrency: {
|
||||
@@ -962,7 +1042,9 @@ const SPECIAL_BLOCKS_METADATA: Record<string, any> = {
|
||||
title: 'Condition',
|
||||
type: 'code',
|
||||
language: 'javascript',
|
||||
placeholder: '<counter.value> < 10',
|
||||
placeholder: '<loop.index> < 10 or <variable.variablename>',
|
||||
description:
|
||||
'Cannot reference blocks inside the loop. Use <loop.index>, workflow variables, or blocks outside the loop.',
|
||||
condition: { field: 'loopType', value: ['while', 'doWhile'] },
|
||||
},
|
||||
{
|
||||
@@ -1020,12 +1102,12 @@ const SPECIAL_BLOCKS_METADATA: Record<string, any> = {
|
||||
},
|
||||
outputs: {
|
||||
results: { type: 'array', description: 'Array of results from all parallel branches' },
|
||||
branchId: { type: 'number', description: 'Current branch ID (0-based)' },
|
||||
branchItem: {
|
||||
index: { type: 'number', description: 'Current branch index (0-based)' },
|
||||
currentItem: {
|
||||
type: 'any',
|
||||
description: 'Current item for this branch (for collection type)',
|
||||
},
|
||||
totalBranches: { type: 'number', description: 'Total number of parallel branches' },
|
||||
items: { type: 'array', description: 'All distribution items' },
|
||||
},
|
||||
subBlocks: [
|
||||
{
|
||||
|
||||
238
apps/sim/lib/copilot/tools/server/knowledge/knowledge-base.ts
Normal file
238
apps/sim/lib/copilot/tools/server/knowledge/knowledge-base.ts
Normal file
@@ -0,0 +1,238 @@
|
||||
import type { BaseServerTool } from '@/lib/copilot/tools/server/base-tool'
|
||||
import {
|
||||
type KnowledgeBaseArgs,
|
||||
KnowledgeBaseArgsSchema,
|
||||
type KnowledgeBaseResult,
|
||||
} from '@/lib/copilot/tools/shared/schemas'
|
||||
import { generateSearchEmbedding } from '@/lib/knowledge/embeddings'
|
||||
import {
|
||||
createKnowledgeBase,
|
||||
getKnowledgeBaseById,
|
||||
getKnowledgeBases,
|
||||
} from '@/lib/knowledge/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getQueryStrategy, handleVectorOnlySearch } from '@/app/api/knowledge/search/utils'
|
||||
|
||||
const logger = createLogger('KnowledgeBaseServerTool')
|
||||
|
||||
// Re-export for backwards compatibility
|
||||
export const KnowledgeBaseInput = KnowledgeBaseArgsSchema
|
||||
export type KnowledgeBaseInputType = KnowledgeBaseArgs
|
||||
export type KnowledgeBaseResultType = KnowledgeBaseResult
|
||||
|
||||
/**
|
||||
* Knowledge base tool for copilot to create, list, and get knowledge bases
|
||||
*/
|
||||
export const knowledgeBaseServerTool: BaseServerTool<KnowledgeBaseArgs, KnowledgeBaseResult> = {
|
||||
name: 'knowledge_base',
|
||||
async execute(
|
||||
params: KnowledgeBaseArgs,
|
||||
context?: { userId: string }
|
||||
): Promise<KnowledgeBaseResult> {
|
||||
if (!context?.userId) {
|
||||
logger.error('Unauthorized attempt to access knowledge base - no authenticated user context')
|
||||
throw new Error('Authentication required')
|
||||
}
|
||||
|
||||
const { operation, args = {} } = params
|
||||
|
||||
try {
|
||||
switch (operation) {
|
||||
case 'create': {
|
||||
if (!args.name) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'Name is required for creating a knowledge base',
|
||||
}
|
||||
}
|
||||
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const newKnowledgeBase = await createKnowledgeBase(
|
||||
{
|
||||
name: args.name,
|
||||
description: args.description,
|
||||
workspaceId: args.workspaceId,
|
||||
userId: context.userId,
|
||||
embeddingModel: 'text-embedding-3-small',
|
||||
embeddingDimension: 1536,
|
||||
chunkingConfig: args.chunkingConfig || {
|
||||
maxSize: 1024,
|
||||
minSize: 1,
|
||||
overlap: 200,
|
||||
},
|
||||
},
|
||||
requestId
|
||||
)
|
||||
|
||||
logger.info('Knowledge base created via copilot', {
|
||||
knowledgeBaseId: newKnowledgeBase.id,
|
||||
name: newKnowledgeBase.name,
|
||||
userId: context.userId,
|
||||
})
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: `Knowledge base "${newKnowledgeBase.name}" created successfully`,
|
||||
data: {
|
||||
id: newKnowledgeBase.id,
|
||||
name: newKnowledgeBase.name,
|
||||
description: newKnowledgeBase.description,
|
||||
workspaceId: newKnowledgeBase.workspaceId,
|
||||
docCount: newKnowledgeBase.docCount,
|
||||
createdAt: newKnowledgeBase.createdAt,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
case 'list': {
|
||||
const knowledgeBases = await getKnowledgeBases(context.userId, args.workspaceId)
|
||||
|
||||
logger.info('Knowledge bases listed via copilot', {
|
||||
count: knowledgeBases.length,
|
||||
userId: context.userId,
|
||||
workspaceId: args.workspaceId,
|
||||
})
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: `Found ${knowledgeBases.length} knowledge base(s)`,
|
||||
data: knowledgeBases.map((kb) => ({
|
||||
id: kb.id,
|
||||
name: kb.name,
|
||||
description: kb.description,
|
||||
workspaceId: kb.workspaceId,
|
||||
docCount: kb.docCount,
|
||||
tokenCount: kb.tokenCount,
|
||||
createdAt: kb.createdAt,
|
||||
updatedAt: kb.updatedAt,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
case 'get': {
|
||||
if (!args.knowledgeBaseId) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'Knowledge base ID is required for get operation',
|
||||
}
|
||||
}
|
||||
|
||||
const knowledgeBase = await getKnowledgeBaseById(args.knowledgeBaseId)
|
||||
if (!knowledgeBase) {
|
||||
return {
|
||||
success: false,
|
||||
message: `Knowledge base with ID "${args.knowledgeBaseId}" not found`,
|
||||
}
|
||||
}
|
||||
|
||||
logger.info('Knowledge base metadata retrieved via copilot', {
|
||||
knowledgeBaseId: knowledgeBase.id,
|
||||
userId: context.userId,
|
||||
})
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: `Retrieved knowledge base "${knowledgeBase.name}"`,
|
||||
data: {
|
||||
id: knowledgeBase.id,
|
||||
name: knowledgeBase.name,
|
||||
description: knowledgeBase.description,
|
||||
workspaceId: knowledgeBase.workspaceId,
|
||||
docCount: knowledgeBase.docCount,
|
||||
tokenCount: knowledgeBase.tokenCount,
|
||||
embeddingModel: knowledgeBase.embeddingModel,
|
||||
chunkingConfig: knowledgeBase.chunkingConfig,
|
||||
createdAt: knowledgeBase.createdAt,
|
||||
updatedAt: knowledgeBase.updatedAt,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
case 'query': {
|
||||
if (!args.knowledgeBaseId) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'Knowledge base ID is required for query operation',
|
||||
}
|
||||
}
|
||||
|
||||
if (!args.query) {
|
||||
return {
|
||||
success: false,
|
||||
message: 'Query text is required for query operation',
|
||||
}
|
||||
}
|
||||
|
||||
// Verify knowledge base exists
|
||||
const kb = await getKnowledgeBaseById(args.knowledgeBaseId)
|
||||
if (!kb) {
|
||||
return {
|
||||
success: false,
|
||||
message: `Knowledge base with ID "${args.knowledgeBaseId}" not found`,
|
||||
}
|
||||
}
|
||||
|
||||
const topK = args.topK || 5
|
||||
|
||||
// Generate embedding for the query
|
||||
const queryEmbedding = await generateSearchEmbedding(args.query)
|
||||
const queryVector = JSON.stringify(queryEmbedding)
|
||||
|
||||
// Get search strategy
|
||||
const strategy = getQueryStrategy(1, topK)
|
||||
|
||||
// Perform vector search
|
||||
const results = await handleVectorOnlySearch({
|
||||
knowledgeBaseIds: [args.knowledgeBaseId],
|
||||
topK,
|
||||
queryVector,
|
||||
distanceThreshold: strategy.distanceThreshold,
|
||||
})
|
||||
|
||||
logger.info('Knowledge base queried via copilot', {
|
||||
knowledgeBaseId: args.knowledgeBaseId,
|
||||
query: args.query.substring(0, 100),
|
||||
resultCount: results.length,
|
||||
userId: context.userId,
|
||||
})
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: `Found ${results.length} result(s) for query "${args.query.substring(0, 50)}${args.query.length > 50 ? '...' : ''}"`,
|
||||
data: {
|
||||
knowledgeBaseId: args.knowledgeBaseId,
|
||||
knowledgeBaseName: kb.name,
|
||||
query: args.query,
|
||||
topK,
|
||||
totalResults: results.length,
|
||||
results: results.map((result) => ({
|
||||
documentId: result.documentId,
|
||||
content: result.content,
|
||||
chunkIndex: result.chunkIndex,
|
||||
similarity: 1 - result.distance,
|
||||
})),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return {
|
||||
success: false,
|
||||
message: `Unknown operation: ${operation}. Supported operations: create, list, get, query`,
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
|
||||
logger.error('Error in knowledge_base tool', {
|
||||
operation,
|
||||
error: errorMessage,
|
||||
userId: context.userId,
|
||||
})
|
||||
|
||||
return {
|
||||
success: false,
|
||||
message: `Failed to ${operation} knowledge base: ${errorMessage}`,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -5,6 +5,10 @@ import { getTriggerBlocksServerTool } from '@/lib/copilot/tools/server/blocks/ge
|
||||
import { searchDocumentationServerTool } from '@/lib/copilot/tools/server/docs/search-documentation'
|
||||
import { listGDriveFilesServerTool } from '@/lib/copilot/tools/server/gdrive/list-files'
|
||||
import { readGDriveFileServerTool } from '@/lib/copilot/tools/server/gdrive/read-file'
|
||||
import {
|
||||
KnowledgeBaseInput,
|
||||
knowledgeBaseServerTool,
|
||||
} from '@/lib/copilot/tools/server/knowledge/knowledge-base'
|
||||
import { makeApiRequestServerTool } from '@/lib/copilot/tools/server/other/make-api-request'
|
||||
import { searchOnlineServerTool } from '@/lib/copilot/tools/server/other/search-online'
|
||||
import { getCredentialsServerTool } from '@/lib/copilot/tools/server/user/get-credentials'
|
||||
@@ -43,6 +47,7 @@ serverToolRegistry[listGDriveFilesServerTool.name] = listGDriveFilesServerTool
|
||||
serverToolRegistry[readGDriveFileServerTool.name] = readGDriveFileServerTool
|
||||
serverToolRegistry[getCredentialsServerTool.name] = getCredentialsServerTool
|
||||
serverToolRegistry[makeApiRequestServerTool.name] = makeApiRequestServerTool
|
||||
serverToolRegistry[knowledgeBaseServerTool.name] = knowledgeBaseServerTool
|
||||
|
||||
export async function routeExecution(
|
||||
toolName: string,
|
||||
@@ -74,6 +79,9 @@ export async function routeExecution(
|
||||
if (toolName === 'get_trigger_blocks') {
|
||||
args = GetTriggerBlocksInput.parse(args)
|
||||
}
|
||||
if (toolName === 'knowledge_base') {
|
||||
args = KnowledgeBaseInput.parse(args)
|
||||
}
|
||||
|
||||
const result = await tool.execute(args, context)
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ import type { BaseServerTool } from '@/lib/copilot/tools/server/base-tool'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { getEnvironmentVariableKeys } from '@/lib/environment/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getAllOAuthServices } from '@/lib/oauth/oauth'
|
||||
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
|
||||
|
||||
interface GetCredentialsParams {
|
||||
userId?: string
|
||||
workflowId?: string
|
||||
}
|
||||
|
||||
@@ -55,17 +55,27 @@ export const getCredentialsServerTool: BaseServerTool<GetCredentialsParams, any>
|
||||
.limit(1)
|
||||
const userEmail = userRecord.length > 0 ? userRecord[0]?.email : null
|
||||
|
||||
const oauthCredentials: Array<{
|
||||
// Get all available OAuth services
|
||||
const allOAuthServices = getAllOAuthServices()
|
||||
|
||||
// Track connected provider IDs
|
||||
const connectedProviderIds = new Set<string>()
|
||||
|
||||
const connectedCredentials: Array<{
|
||||
id: string
|
||||
name: string
|
||||
provider: string
|
||||
serviceName: string
|
||||
lastUsed: string
|
||||
isDefault: boolean
|
||||
accessToken: string | null
|
||||
}> = []
|
||||
const requestId = generateRequestId()
|
||||
|
||||
for (const acc of accounts) {
|
||||
const providerId = acc.providerId
|
||||
connectedProviderIds.add(providerId)
|
||||
|
||||
const [baseProvider, featureType = 'default'] = providerId.split('-')
|
||||
let displayName = ''
|
||||
if (acc.idToken) {
|
||||
@@ -77,6 +87,11 @@ export const getCredentialsServerTool: BaseServerTool<GetCredentialsParams, any>
|
||||
if (!displayName && baseProvider === 'github') displayName = `${acc.accountId} (GitHub)`
|
||||
if (!displayName && userEmail) displayName = userEmail
|
||||
if (!displayName) displayName = `${acc.accountId} (${baseProvider})`
|
||||
|
||||
// Find the service name for this provider ID
|
||||
const service = allOAuthServices.find((s) => s.providerId === providerId)
|
||||
const serviceName = service?.name ?? providerId
|
||||
|
||||
let accessToken: string | null = acc.accessToken ?? null
|
||||
try {
|
||||
const { accessToken: refreshedToken } = await refreshTokenIfNeeded(
|
||||
@@ -86,29 +101,47 @@ export const getCredentialsServerTool: BaseServerTool<GetCredentialsParams, any>
|
||||
)
|
||||
accessToken = refreshedToken || accessToken
|
||||
} catch {}
|
||||
oauthCredentials.push({
|
||||
connectedCredentials.push({
|
||||
id: acc.id,
|
||||
name: displayName,
|
||||
provider: providerId,
|
||||
serviceName,
|
||||
lastUsed: acc.updatedAt.toISOString(),
|
||||
isDefault: featureType === 'default',
|
||||
accessToken,
|
||||
})
|
||||
}
|
||||
|
||||
// Build list of not connected services
|
||||
const notConnectedServices = allOAuthServices
|
||||
.filter((service) => !connectedProviderIds.has(service.providerId))
|
||||
.map((service) => ({
|
||||
providerId: service.providerId,
|
||||
name: service.name,
|
||||
description: service.description,
|
||||
baseProvider: service.baseProvider,
|
||||
}))
|
||||
|
||||
// Fetch environment variables
|
||||
const envResult = await getEnvironmentVariableKeys(userId)
|
||||
|
||||
logger.info('Fetched credentials', {
|
||||
userId,
|
||||
oauthCount: oauthCredentials.length,
|
||||
connectedCount: connectedCredentials.length,
|
||||
notConnectedCount: notConnectedServices.length,
|
||||
envVarCount: envResult.count,
|
||||
})
|
||||
|
||||
return {
|
||||
oauth: {
|
||||
credentials: oauthCredentials,
|
||||
total: oauthCredentials.length,
|
||||
connected: {
|
||||
credentials: connectedCredentials,
|
||||
total: connectedCredentials.length,
|
||||
},
|
||||
notConnected: {
|
||||
services: notConnectedServices,
|
||||
total: notConnectedServices.length,
|
||||
},
|
||||
},
|
||||
environment: {
|
||||
variableNames: envResult.variableNames,
|
||||
|
||||
@@ -8,10 +8,372 @@ import { getBlockOutputs } from '@/lib/workflows/blocks/block-outputs'
|
||||
import { extractAndPersistCustomTools } from '@/lib/workflows/persistence/custom-tools-persistence'
|
||||
import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/persistence/utils'
|
||||
import { validateWorkflowState } from '@/lib/workflows/sanitization/validation'
|
||||
import { getAllBlocks } from '@/blocks/registry'
|
||||
import { getAllBlocks, getBlock } from '@/blocks/registry'
|
||||
import type { SubBlockConfig } from '@/blocks/types'
|
||||
import { generateLoopBlocks, generateParallelBlocks } from '@/stores/workflows/workflow/utils'
|
||||
import { TRIGGER_RUNTIME_SUBBLOCK_IDS } from '@/triggers/consts'
|
||||
|
||||
const validationLogger = createLogger('EditWorkflowValidation')
|
||||
|
||||
/**
|
||||
* Validation error for a specific field
|
||||
*/
|
||||
interface ValidationError {
|
||||
blockId: string
|
||||
blockType: string
|
||||
field: string
|
||||
value: any
|
||||
error: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of input validation
|
||||
*/
|
||||
interface ValidationResult {
|
||||
validInputs: Record<string, any>
|
||||
errors: ValidationError[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates and filters inputs against a block's subBlock configuration
|
||||
* Returns valid inputs and any validation errors encountered
|
||||
*/
|
||||
function validateInputsForBlock(
|
||||
blockType: string,
|
||||
inputs: Record<string, any>,
|
||||
blockId: string,
|
||||
existingInputs?: Record<string, any>
|
||||
): ValidationResult {
|
||||
const errors: ValidationError[] = []
|
||||
const blockConfig = getBlock(blockType)
|
||||
|
||||
if (!blockConfig) {
|
||||
// Unknown block type - return inputs as-is (let it fail later if invalid)
|
||||
validationLogger.warn(`Unknown block type: ${blockType}, skipping validation`)
|
||||
return { validInputs: inputs, errors: [] }
|
||||
}
|
||||
|
||||
const validatedInputs: Record<string, any> = {}
|
||||
const subBlockMap = new Map<string, SubBlockConfig>()
|
||||
|
||||
// Build map of subBlock id -> config
|
||||
for (const subBlock of blockConfig.subBlocks) {
|
||||
subBlockMap.set(subBlock.id, subBlock)
|
||||
}
|
||||
|
||||
// Merge existing inputs with new inputs to evaluate conditions properly
|
||||
const mergedInputs = { ...existingInputs, ...inputs }
|
||||
|
||||
for (const [key, value] of Object.entries(inputs)) {
|
||||
// Skip runtime subblock IDs
|
||||
if (TRIGGER_RUNTIME_SUBBLOCK_IDS.includes(key)) {
|
||||
continue
|
||||
}
|
||||
|
||||
const subBlockConfig = subBlockMap.get(key)
|
||||
|
||||
// If subBlock doesn't exist in config, skip it (unless it's a known dynamic field)
|
||||
if (!subBlockConfig) {
|
||||
// Some fields are valid but not in subBlocks (like loop/parallel config)
|
||||
// Allow these through for special block types
|
||||
if (blockType === 'loop' || blockType === 'parallel') {
|
||||
validatedInputs[key] = value
|
||||
} else {
|
||||
errors.push({
|
||||
blockId,
|
||||
blockType,
|
||||
field: key,
|
||||
value,
|
||||
error: `Unknown input field "${key}" for block type "${blockType}"`,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the field's condition is met
|
||||
if (subBlockConfig.condition && !evaluateCondition(subBlockConfig.condition, mergedInputs)) {
|
||||
errors.push({
|
||||
blockId,
|
||||
blockType,
|
||||
field: key,
|
||||
value,
|
||||
error: `Field "${key}" condition not met - this field is not applicable for the current configuration`,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate value based on subBlock type
|
||||
const validationResult = validateValueForSubBlockType(
|
||||
subBlockConfig,
|
||||
value,
|
||||
key,
|
||||
blockType,
|
||||
blockId
|
||||
)
|
||||
if (validationResult.valid) {
|
||||
validatedInputs[key] = validationResult.value
|
||||
} else if (validationResult.error) {
|
||||
errors.push(validationResult.error)
|
||||
}
|
||||
}
|
||||
|
||||
return { validInputs: validatedInputs, errors }
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates a condition object against current inputs
|
||||
*/
|
||||
function evaluateCondition(
|
||||
condition: SubBlockConfig['condition'],
|
||||
inputs: Record<string, any>
|
||||
): boolean {
|
||||
if (!condition) return true
|
||||
|
||||
// Handle function conditions
|
||||
const resolvedCondition = typeof condition === 'function' ? condition() : condition
|
||||
|
||||
const fieldValue = inputs[resolvedCondition.field]
|
||||
const expectedValues = Array.isArray(resolvedCondition.value)
|
||||
? resolvedCondition.value
|
||||
: [resolvedCondition.value]
|
||||
|
||||
let matches = expectedValues.includes(fieldValue)
|
||||
if (resolvedCondition.not) {
|
||||
matches = !matches
|
||||
}
|
||||
|
||||
// Handle AND condition
|
||||
if (matches && resolvedCondition.and) {
|
||||
const andFieldValue = inputs[resolvedCondition.and.field]
|
||||
const andExpectedValues = Array.isArray(resolvedCondition.and.value)
|
||||
? resolvedCondition.and.value
|
||||
: [resolvedCondition.and.value]
|
||||
let andMatches = andExpectedValues.includes(andFieldValue)
|
||||
if (resolvedCondition.and.not) {
|
||||
andMatches = !andMatches
|
||||
}
|
||||
matches = matches && andMatches
|
||||
}
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of validating a single value
|
||||
*/
|
||||
interface ValueValidationResult {
|
||||
valid: boolean
|
||||
value?: any
|
||||
error?: ValidationError
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates a value against its expected subBlock type
|
||||
* Returns validation result with the value or an error
|
||||
*/
|
||||
function validateValueForSubBlockType(
|
||||
subBlockConfig: SubBlockConfig,
|
||||
value: any,
|
||||
fieldName: string,
|
||||
blockType: string,
|
||||
blockId: string
|
||||
): ValueValidationResult {
|
||||
const { type } = subBlockConfig
|
||||
|
||||
// Handle null/undefined - allow clearing fields
|
||||
if (value === null || value === undefined) {
|
||||
return { valid: true, value }
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case 'dropdown': {
|
||||
// Validate against allowed options
|
||||
const options =
|
||||
typeof subBlockConfig.options === 'function'
|
||||
? subBlockConfig.options()
|
||||
: subBlockConfig.options
|
||||
if (options && Array.isArray(options)) {
|
||||
const validIds = options.map((opt) => opt.id)
|
||||
if (!validIds.includes(value)) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
blockId,
|
||||
blockType,
|
||||
field: fieldName,
|
||||
value,
|
||||
error: `Invalid dropdown value "${value}" for field "${fieldName}". Valid options: ${validIds.join(', ')}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
return { valid: true, value }
|
||||
}
|
||||
|
||||
case 'slider': {
|
||||
// Validate numeric range
|
||||
const numValue = typeof value === 'number' ? value : Number(value)
|
||||
if (Number.isNaN(numValue)) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
blockId,
|
||||
blockType,
|
||||
field: fieldName,
|
||||
value,
|
||||
error: `Invalid slider value "${value}" for field "${fieldName}" - must be a number`,
|
||||
},
|
||||
}
|
||||
}
|
||||
// Clamp to range (allow but warn)
|
||||
let clampedValue = numValue
|
||||
if (subBlockConfig.min !== undefined && numValue < subBlockConfig.min) {
|
||||
clampedValue = subBlockConfig.min
|
||||
}
|
||||
if (subBlockConfig.max !== undefined && numValue > subBlockConfig.max) {
|
||||
clampedValue = subBlockConfig.max
|
||||
}
|
||||
return {
|
||||
valid: true,
|
||||
value: subBlockConfig.integer ? Math.round(clampedValue) : clampedValue,
|
||||
}
|
||||
}
|
||||
|
||||
case 'switch': {
|
||||
// Must be boolean
|
||||
if (typeof value !== 'boolean') {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
blockId,
|
||||
blockType,
|
||||
field: fieldName,
|
||||
value,
|
||||
error: `Invalid switch value "${value}" for field "${fieldName}" - must be true or false`,
|
||||
},
|
||||
}
|
||||
}
|
||||
return { valid: true, value }
|
||||
}
|
||||
|
||||
case 'file-upload': {
|
||||
// File upload should be an object with specific properties or null
|
||||
if (value === null) return { valid: true, value: null }
|
||||
if (typeof value !== 'object') {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
blockId,
|
||||
blockType,
|
||||
field: fieldName,
|
||||
value,
|
||||
error: `Invalid file-upload value for field "${fieldName}" - expected object with name and path properties, or null`,
|
||||
},
|
||||
}
|
||||
}
|
||||
// Validate file object has required properties
|
||||
if (value && (!value.name || !value.path)) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
blockId,
|
||||
blockType,
|
||||
field: fieldName,
|
||||
value,
|
||||
error: `Invalid file-upload object for field "${fieldName}" - must have "name" and "path" properties`,
|
||||
},
|
||||
}
|
||||
}
|
||||
return { valid: true, value }
|
||||
}
|
||||
|
||||
case 'input-format':
|
||||
case 'table': {
|
||||
// Should be an array
|
||||
if (!Array.isArray(value)) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
blockId,
|
||||
blockType,
|
||||
field: fieldName,
|
||||
value,
|
||||
error: `Invalid ${type} value for field "${fieldName}" - expected an array`,
|
||||
},
|
||||
}
|
||||
}
|
||||
return { valid: true, value }
|
||||
}
|
||||
|
||||
case 'tool-input': {
|
||||
// Should be an array of tool objects
|
||||
if (!Array.isArray(value)) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
blockId,
|
||||
blockType,
|
||||
field: fieldName,
|
||||
value,
|
||||
error: `Invalid tool-input value for field "${fieldName}" - expected an array of tool objects`,
|
||||
},
|
||||
}
|
||||
}
|
||||
return { valid: true, value }
|
||||
}
|
||||
|
||||
case 'response-format':
|
||||
case 'code': {
|
||||
// Can be string or object
|
||||
return { valid: true, value }
|
||||
}
|
||||
|
||||
case 'short-input':
|
||||
case 'long-input':
|
||||
case 'combobox': {
|
||||
// Should be string (combobox allows custom values)
|
||||
if (typeof value !== 'string' && typeof value !== 'number') {
|
||||
// Convert to string but don't error
|
||||
return { valid: true, value: String(value) }
|
||||
}
|
||||
return { valid: true, value }
|
||||
}
|
||||
|
||||
// Selector types - allow strings (IDs) or arrays of strings
|
||||
case 'oauth-input':
|
||||
case 'knowledge-base-selector':
|
||||
case 'document-selector':
|
||||
case 'file-selector':
|
||||
case 'project-selector':
|
||||
case 'channel-selector':
|
||||
case 'folder-selector':
|
||||
case 'mcp-server-selector':
|
||||
case 'mcp-tool-selector':
|
||||
case 'workflow-selector': {
|
||||
if (subBlockConfig.multiSelect && Array.isArray(value)) {
|
||||
return { valid: true, value }
|
||||
}
|
||||
if (typeof value === 'string') {
|
||||
return { valid: true, value }
|
||||
}
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
blockId,
|
||||
blockType,
|
||||
field: fieldName,
|
||||
value,
|
||||
error: `Invalid selector value for field "${fieldName}" - expected a string${subBlockConfig.multiSelect ? ' or array of strings' : ''}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// For unknown types, pass through
|
||||
return { valid: true, value }
|
||||
}
|
||||
}
|
||||
|
||||
interface EditWorkflowOperation {
|
||||
operation_type: 'add' | 'edit' | 'delete' | 'insert_into_subflow' | 'extract_from_subflow'
|
||||
block_id: string
|
||||
@@ -118,9 +480,24 @@ function topologicalSortInserts(
|
||||
/**
|
||||
* Helper to create a block state from operation params
|
||||
*/
|
||||
function createBlockFromParams(blockId: string, params: any, parentId?: string): any {
|
||||
function createBlockFromParams(
|
||||
blockId: string,
|
||||
params: any,
|
||||
parentId?: string,
|
||||
errorsCollector?: ValidationError[]
|
||||
): any {
|
||||
const blockConfig = getAllBlocks().find((b) => b.type === params.type)
|
||||
|
||||
// Validate inputs against block configuration
|
||||
let validatedInputs: Record<string, any> | undefined
|
||||
if (params.inputs) {
|
||||
const result = validateInputsForBlock(params.type, params.inputs, blockId)
|
||||
validatedInputs = result.validInputs
|
||||
if (errorsCollector && result.errors.length > 0) {
|
||||
errorsCollector.push(...result.errors)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine outputs based on trigger mode
|
||||
const triggerMode = params.triggerMode || false
|
||||
let outputs: Record<string, any>
|
||||
@@ -129,8 +506,8 @@ function createBlockFromParams(blockId: string, params: any, parentId?: string):
|
||||
outputs = params.outputs
|
||||
} else if (blockConfig) {
|
||||
const subBlocks: Record<string, any> = {}
|
||||
if (params.inputs) {
|
||||
Object.entries(params.inputs).forEach(([key, value]) => {
|
||||
if (validatedInputs) {
|
||||
Object.entries(validatedInputs).forEach(([key, value]) => {
|
||||
// Skip runtime subblock IDs when computing outputs
|
||||
if (TRIGGER_RUNTIME_SUBBLOCK_IDS.includes(key)) {
|
||||
return
|
||||
@@ -158,9 +535,9 @@ function createBlockFromParams(blockId: string, params: any, parentId?: string):
|
||||
data: parentId ? { parentId, extent: 'parent' as const } : {},
|
||||
}
|
||||
|
||||
// Add inputs as subBlocks
|
||||
if (params.inputs) {
|
||||
Object.entries(params.inputs).forEach(([key, value]) => {
|
||||
// Add validated inputs as subBlocks
|
||||
if (validatedInputs) {
|
||||
Object.entries(validatedInputs).forEach(([key, value]) => {
|
||||
if (TRIGGER_RUNTIME_SUBBLOCK_IDS.includes(key)) {
|
||||
return
|
||||
}
|
||||
@@ -299,11 +676,24 @@ function normalizeResponseFormat(value: any): string {
|
||||
function addConnectionsAsEdges(
|
||||
modifiedState: any,
|
||||
blockId: string,
|
||||
connections: Record<string, any>
|
||||
connections: Record<string, any>,
|
||||
logger: ReturnType<typeof createLogger>
|
||||
): void {
|
||||
Object.entries(connections).forEach(([sourceHandle, targets]) => {
|
||||
const targetArray = Array.isArray(targets) ? targets : [targets]
|
||||
targetArray.forEach((targetId: string) => {
|
||||
// Validate target block exists (should always be true due to operation ordering)
|
||||
if (!modifiedState.blocks[targetId]) {
|
||||
logger.warn(
|
||||
`Target block "${targetId}" not found when creating connection from "${blockId}". ` +
|
||||
`This may indicate operations were processed in wrong order.`,
|
||||
{
|
||||
sourceBlockId: blockId,
|
||||
targetBlockId: targetId,
|
||||
existingBlocks: Object.keys(modifiedState.blocks),
|
||||
}
|
||||
)
|
||||
}
|
||||
modifiedState.edges.push({
|
||||
id: crypto.randomUUID(),
|
||||
source: blockId,
|
||||
@@ -348,16 +738,27 @@ function applyTriggerConfigToBlockSubblocks(block: any, triggerConfig: Record<st
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of applying operations to workflow state
|
||||
*/
|
||||
interface ApplyOperationsResult {
|
||||
state: any
|
||||
validationErrors: ValidationError[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply operations directly to the workflow JSON state
|
||||
*/
|
||||
function applyOperationsToWorkflowState(
|
||||
workflowState: any,
|
||||
operations: EditWorkflowOperation[]
|
||||
): any {
|
||||
): ApplyOperationsResult {
|
||||
// Deep clone the workflow state to avoid mutations
|
||||
const modifiedState = JSON.parse(JSON.stringify(workflowState))
|
||||
|
||||
// Collect validation errors across all operations
|
||||
const validationErrors: ValidationError[] = []
|
||||
|
||||
// Log initial state
|
||||
const logger = createLogger('EditWorkflowServerTool')
|
||||
logger.info('Applying operations to workflow:', {
|
||||
@@ -369,7 +770,18 @@ function applyOperationsToWorkflowState(
|
||||
initialBlockCount: Object.keys(modifiedState.blocks || {}).length,
|
||||
})
|
||||
|
||||
// Reorder operations: delete -> extract -> add -> insert -> edit
|
||||
/**
|
||||
* Reorder operations to ensure correct execution sequence:
|
||||
* 1. delete - Remove blocks first to free up IDs and clean state
|
||||
* 2. extract_from_subflow - Extract blocks from subflows before modifications
|
||||
* 3. add - Create new blocks so they exist before being referenced
|
||||
* 4. insert_into_subflow - Insert blocks into subflows (sorted by parent dependency)
|
||||
* 5. edit - Edit existing blocks last, so connections to newly added blocks work
|
||||
*
|
||||
* This ordering is CRITICAL: edit operations may reference blocks being added
|
||||
* in the same batch (e.g., connecting block A to newly added block B).
|
||||
* Without proper ordering, the target block wouldn't exist yet.
|
||||
*/
|
||||
const deletes = operations.filter((op) => op.operation_type === 'delete')
|
||||
const extracts = operations.filter((op) => op.operation_type === 'extract_from_subflow')
|
||||
const adds = operations.filter((op) => op.operation_type === 'add')
|
||||
@@ -445,7 +857,33 @@ function applyOperationsToWorkflowState(
|
||||
// Update inputs (convert to subBlocks format)
|
||||
if (params?.inputs) {
|
||||
if (!block.subBlocks) block.subBlocks = {}
|
||||
Object.entries(params.inputs).forEach(([key, value]) => {
|
||||
|
||||
// Get existing input values for condition evaluation
|
||||
const existingInputs: Record<string, any> = {}
|
||||
Object.entries(block.subBlocks).forEach(([key, subBlock]: [string, any]) => {
|
||||
existingInputs[key] = subBlock?.value
|
||||
})
|
||||
|
||||
// Validate inputs against block configuration
|
||||
const validationResult = validateInputsForBlock(
|
||||
block.type,
|
||||
params.inputs,
|
||||
block_id,
|
||||
existingInputs
|
||||
)
|
||||
validationErrors.push(...validationResult.errors)
|
||||
|
||||
Object.entries(validationResult.validInputs).forEach(([inputKey, value]) => {
|
||||
// Normalize common field name variations (LLM may use plural/singular inconsistently)
|
||||
let key = inputKey
|
||||
if (
|
||||
key === 'credentials' &&
|
||||
!block.subBlocks.credentials &&
|
||||
block.subBlocks.credential
|
||||
) {
|
||||
key = 'credential'
|
||||
}
|
||||
|
||||
if (TRIGGER_RUNTIME_SUBBLOCK_IDS.includes(key)) {
|
||||
return
|
||||
}
|
||||
@@ -496,23 +934,58 @@ function applyOperationsToWorkflowState(
|
||||
applyTriggerConfigToBlockSubblocks(block, block.subBlocks.triggerConfig.value)
|
||||
}
|
||||
|
||||
// Update loop/parallel configuration in block.data
|
||||
// Update loop/parallel configuration in block.data (strict validation)
|
||||
if (block.type === 'loop') {
|
||||
block.data = block.data || {}
|
||||
if (params.inputs.loopType !== undefined) block.data.loopType = params.inputs.loopType
|
||||
if (params.inputs.iterations !== undefined)
|
||||
// loopType is always valid
|
||||
if (params.inputs.loopType !== undefined) {
|
||||
const validLoopTypes = ['for', 'forEach', 'while', 'doWhile']
|
||||
if (validLoopTypes.includes(params.inputs.loopType)) {
|
||||
block.data.loopType = params.inputs.loopType
|
||||
}
|
||||
}
|
||||
const effectiveLoopType = params.inputs.loopType ?? block.data.loopType ?? 'for'
|
||||
// iterations only valid for 'for' loopType
|
||||
if (params.inputs.iterations !== undefined && effectiveLoopType === 'for') {
|
||||
block.data.count = params.inputs.iterations
|
||||
if (params.inputs.collection !== undefined)
|
||||
}
|
||||
// collection only valid for 'forEach' loopType
|
||||
if (params.inputs.collection !== undefined && effectiveLoopType === 'forEach') {
|
||||
block.data.collection = params.inputs.collection
|
||||
if (params.inputs.condition !== undefined)
|
||||
block.data.whileCondition = params.inputs.condition
|
||||
}
|
||||
// condition only valid for 'while' or 'doWhile' loopType
|
||||
if (
|
||||
params.inputs.condition !== undefined &&
|
||||
(effectiveLoopType === 'while' || effectiveLoopType === 'doWhile')
|
||||
) {
|
||||
if (effectiveLoopType === 'doWhile') {
|
||||
block.data.doWhileCondition = params.inputs.condition
|
||||
} else {
|
||||
block.data.whileCondition = params.inputs.condition
|
||||
}
|
||||
}
|
||||
} else if (block.type === 'parallel') {
|
||||
block.data = block.data || {}
|
||||
if (params.inputs.parallelType !== undefined)
|
||||
block.data.parallelType = params.inputs.parallelType
|
||||
if (params.inputs.count !== undefined) block.data.count = params.inputs.count
|
||||
if (params.inputs.collection !== undefined)
|
||||
// parallelType is always valid
|
||||
if (params.inputs.parallelType !== undefined) {
|
||||
const validParallelTypes = ['count', 'collection']
|
||||
if (validParallelTypes.includes(params.inputs.parallelType)) {
|
||||
block.data.parallelType = params.inputs.parallelType
|
||||
}
|
||||
}
|
||||
const effectiveParallelType =
|
||||
params.inputs.parallelType ?? block.data.parallelType ?? 'count'
|
||||
// count only valid for 'count' parallelType
|
||||
if (params.inputs.count !== undefined && effectiveParallelType === 'count') {
|
||||
block.data.count = params.inputs.count
|
||||
}
|
||||
// collection only valid for 'collection' parallelType
|
||||
if (
|
||||
params.inputs.collection !== undefined &&
|
||||
effectiveParallelType === 'collection'
|
||||
) {
|
||||
block.data.collection = params.inputs.collection
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -553,27 +1026,69 @@ function applyOperationsToWorkflowState(
|
||||
|
||||
// Add new nested blocks
|
||||
Object.entries(params.nestedNodes).forEach(([childId, childBlock]: [string, any]) => {
|
||||
const childBlockState = createBlockFromParams(childId, childBlock, block_id)
|
||||
const childBlockState = createBlockFromParams(
|
||||
childId,
|
||||
childBlock,
|
||||
block_id,
|
||||
validationErrors
|
||||
)
|
||||
modifiedState.blocks[childId] = childBlockState
|
||||
|
||||
// Add connections for child block
|
||||
if (childBlock.connections) {
|
||||
addConnectionsAsEdges(modifiedState, childId, childBlock.connections)
|
||||
addConnectionsAsEdges(modifiedState, childId, childBlock.connections, logger)
|
||||
}
|
||||
})
|
||||
|
||||
// Update loop/parallel configuration based on type
|
||||
// Update loop/parallel configuration based on type (strict validation)
|
||||
if (block.type === 'loop') {
|
||||
block.data = block.data || {}
|
||||
if (params.inputs?.loopType) block.data.loopType = params.inputs.loopType
|
||||
if (params.inputs?.iterations) block.data.count = params.inputs.iterations
|
||||
if (params.inputs?.collection) block.data.collection = params.inputs.collection
|
||||
if (params.inputs?.condition) block.data.whileCondition = params.inputs.condition
|
||||
// loopType is always valid
|
||||
if (params.inputs?.loopType) {
|
||||
const validLoopTypes = ['for', 'forEach', 'while', 'doWhile']
|
||||
if (validLoopTypes.includes(params.inputs.loopType)) {
|
||||
block.data.loopType = params.inputs.loopType
|
||||
}
|
||||
}
|
||||
const effectiveLoopType = params.inputs?.loopType ?? block.data.loopType ?? 'for'
|
||||
// iterations only valid for 'for' loopType
|
||||
if (params.inputs?.iterations && effectiveLoopType === 'for') {
|
||||
block.data.count = params.inputs.iterations
|
||||
}
|
||||
// collection only valid for 'forEach' loopType
|
||||
if (params.inputs?.collection && effectiveLoopType === 'forEach') {
|
||||
block.data.collection = params.inputs.collection
|
||||
}
|
||||
// condition only valid for 'while' or 'doWhile' loopType
|
||||
if (
|
||||
params.inputs?.condition &&
|
||||
(effectiveLoopType === 'while' || effectiveLoopType === 'doWhile')
|
||||
) {
|
||||
if (effectiveLoopType === 'doWhile') {
|
||||
block.data.doWhileCondition = params.inputs.condition
|
||||
} else {
|
||||
block.data.whileCondition = params.inputs.condition
|
||||
}
|
||||
}
|
||||
} else if (block.type === 'parallel') {
|
||||
block.data = block.data || {}
|
||||
if (params.inputs?.parallelType) block.data.parallelType = params.inputs.parallelType
|
||||
if (params.inputs?.count) block.data.count = params.inputs.count
|
||||
if (params.inputs?.collection) block.data.collection = params.inputs.collection
|
||||
// parallelType is always valid
|
||||
if (params.inputs?.parallelType) {
|
||||
const validParallelTypes = ['count', 'collection']
|
||||
if (validParallelTypes.includes(params.inputs.parallelType)) {
|
||||
block.data.parallelType = params.inputs.parallelType
|
||||
}
|
||||
}
|
||||
const effectiveParallelType =
|
||||
params.inputs?.parallelType ?? block.data.parallelType ?? 'count'
|
||||
// count only valid for 'count' parallelType
|
||||
if (params.inputs?.count && effectiveParallelType === 'count') {
|
||||
block.data.count = params.inputs.count
|
||||
}
|
||||
// collection only valid for 'collection' parallelType
|
||||
if (params.inputs?.collection && effectiveParallelType === 'collection') {
|
||||
block.data.collection = params.inputs.collection
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -600,6 +1115,18 @@ function applyOperationsToWorkflowState(
|
||||
const actualSourceHandle = mapConnectionTypeToHandle(connectionType)
|
||||
|
||||
const addEdge = (targetBlock: string, targetHandle?: string) => {
|
||||
// Validate target block exists (should always be true due to operation ordering)
|
||||
if (!modifiedState.blocks[targetBlock]) {
|
||||
logger.warn(
|
||||
`Target block "${targetBlock}" not found when creating connection from "${block_id}". ` +
|
||||
`This may indicate operations were processed in wrong order.`,
|
||||
{
|
||||
sourceBlockId: block_id,
|
||||
targetBlockId: targetBlock,
|
||||
existingBlocks: Object.keys(modifiedState.blocks),
|
||||
}
|
||||
)
|
||||
}
|
||||
modifiedState.edges.push({
|
||||
id: crypto.randomUUID(),
|
||||
source: block_id,
|
||||
@@ -646,23 +1173,44 @@ function applyOperationsToWorkflowState(
|
||||
case 'add': {
|
||||
if (params?.type && params?.name) {
|
||||
// Create new block with proper structure
|
||||
const newBlock = createBlockFromParams(block_id, params)
|
||||
const newBlock = createBlockFromParams(block_id, params, undefined, validationErrors)
|
||||
|
||||
// Set loop/parallel data on parent block BEFORE adding to blocks
|
||||
// Set loop/parallel data on parent block BEFORE adding to blocks (strict validation)
|
||||
if (params.nestedNodes) {
|
||||
if (params.type === 'loop') {
|
||||
const validLoopTypes = ['for', 'forEach', 'while', 'doWhile']
|
||||
const loopType =
|
||||
params.inputs?.loopType && validLoopTypes.includes(params.inputs.loopType)
|
||||
? params.inputs.loopType
|
||||
: 'for'
|
||||
newBlock.data = {
|
||||
...newBlock.data,
|
||||
loopType: params.inputs?.loopType || 'for',
|
||||
...(params.inputs?.collection && { collection: params.inputs.collection }),
|
||||
...(params.inputs?.iterations && { count: params.inputs.iterations }),
|
||||
loopType,
|
||||
// Only include type-appropriate fields
|
||||
...(loopType === 'forEach' &&
|
||||
params.inputs?.collection && { collection: params.inputs.collection }),
|
||||
...(loopType === 'for' &&
|
||||
params.inputs?.iterations && { count: params.inputs.iterations }),
|
||||
...(loopType === 'while' &&
|
||||
params.inputs?.condition && { whileCondition: params.inputs.condition }),
|
||||
...(loopType === 'doWhile' &&
|
||||
params.inputs?.condition && { doWhileCondition: params.inputs.condition }),
|
||||
}
|
||||
} else if (params.type === 'parallel') {
|
||||
const validParallelTypes = ['count', 'collection']
|
||||
const parallelType =
|
||||
params.inputs?.parallelType &&
|
||||
validParallelTypes.includes(params.inputs.parallelType)
|
||||
? params.inputs.parallelType
|
||||
: 'count'
|
||||
newBlock.data = {
|
||||
...newBlock.data,
|
||||
parallelType: params.inputs?.parallelType || 'count',
|
||||
...(params.inputs?.collection && { collection: params.inputs.collection }),
|
||||
...(params.inputs?.count && { count: params.inputs.count }),
|
||||
parallelType,
|
||||
// Only include type-appropriate fields
|
||||
...(parallelType === 'collection' &&
|
||||
params.inputs?.collection && { collection: params.inputs.collection }),
|
||||
...(parallelType === 'count' &&
|
||||
params.inputs?.count && { count: params.inputs.count }),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -674,18 +1222,23 @@ function applyOperationsToWorkflowState(
|
||||
// Handle nested nodes (for loops/parallels created from scratch)
|
||||
if (params.nestedNodes) {
|
||||
Object.entries(params.nestedNodes).forEach(([childId, childBlock]: [string, any]) => {
|
||||
const childBlockState = createBlockFromParams(childId, childBlock, block_id)
|
||||
const childBlockState = createBlockFromParams(
|
||||
childId,
|
||||
childBlock,
|
||||
block_id,
|
||||
validationErrors
|
||||
)
|
||||
modifiedState.blocks[childId] = childBlockState
|
||||
|
||||
if (childBlock.connections) {
|
||||
addConnectionsAsEdges(modifiedState, childId, childBlock.connections)
|
||||
addConnectionsAsEdges(modifiedState, childId, childBlock.connections, logger)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Add connections as edges
|
||||
if (params.connections) {
|
||||
addConnectionsAsEdges(modifiedState, block_id, params.connections)
|
||||
addConnectionsAsEdges(modifiedState, block_id, params.connections, logger)
|
||||
}
|
||||
}
|
||||
break
|
||||
@@ -734,9 +1287,26 @@ function applyOperationsToWorkflowState(
|
||||
extent: 'parent' as const,
|
||||
}
|
||||
|
||||
// Update inputs if provided
|
||||
// Update inputs if provided (with validation)
|
||||
if (params.inputs) {
|
||||
Object.entries(params.inputs).forEach(([key, value]) => {
|
||||
// Get existing input values for condition evaluation
|
||||
const existingInputs: Record<string, any> = {}
|
||||
Object.entries(existingBlock.subBlocks || {}).forEach(
|
||||
([key, subBlock]: [string, any]) => {
|
||||
existingInputs[key] = subBlock?.value
|
||||
}
|
||||
)
|
||||
|
||||
// Validate inputs against block configuration
|
||||
const validationResult = validateInputsForBlock(
|
||||
existingBlock.type,
|
||||
params.inputs,
|
||||
block_id,
|
||||
existingInputs
|
||||
)
|
||||
validationErrors.push(...validationResult.errors)
|
||||
|
||||
Object.entries(validationResult.validInputs).forEach(([key, value]) => {
|
||||
// Skip runtime subblock IDs (webhookId, triggerPath, testUrl, testUrlExpiresAt, scheduleId)
|
||||
if (TRIGGER_RUNTIME_SUBBLOCK_IDS.includes(key)) {
|
||||
return
|
||||
@@ -773,7 +1343,7 @@ function applyOperationsToWorkflowState(
|
||||
}
|
||||
} else {
|
||||
// Create new block as child of subflow
|
||||
const newBlock = createBlockFromParams(block_id, params, subflowId)
|
||||
const newBlock = createBlockFromParams(block_id, params, subflowId, validationErrors)
|
||||
modifiedState.blocks[block_id] = newBlock
|
||||
}
|
||||
|
||||
@@ -783,7 +1353,7 @@ function applyOperationsToWorkflowState(
|
||||
modifiedState.edges = modifiedState.edges.filter((edge: any) => edge.source !== block_id)
|
||||
|
||||
// Add new connections
|
||||
addConnectionsAsEdges(modifiedState, block_id, params.connections)
|
||||
addConnectionsAsEdges(modifiedState, block_id, params.connections, logger)
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -854,7 +1424,7 @@ function applyOperationsToWorkflowState(
|
||||
)
|
||||
}
|
||||
|
||||
return modifiedState
|
||||
return { state: modifiedState, validationErrors }
|
||||
}
|
||||
|
||||
async function getCurrentWorkflowStateFromDb(
|
||||
@@ -937,7 +1507,10 @@ export const editWorkflowServerTool: BaseServerTool<EditWorkflowParams, any> = {
|
||||
}
|
||||
|
||||
// Apply operations directly to the workflow state
|
||||
const modifiedWorkflowState = applyOperationsToWorkflowState(workflowState, operations)
|
||||
const { state: modifiedWorkflowState, validationErrors } = applyOperationsToWorkflowState(
|
||||
workflowState,
|
||||
operations
|
||||
)
|
||||
|
||||
// Validate the workflow state
|
||||
const validation = validateWorkflowState(modifiedWorkflowState, { sanitize: true })
|
||||
@@ -997,14 +1570,26 @@ export const editWorkflowServerTool: BaseServerTool<EditWorkflowParams, any> = {
|
||||
operationCount: operations.length,
|
||||
blocksCount: Object.keys(modifiedWorkflowState.blocks).length,
|
||||
edgesCount: modifiedWorkflowState.edges.length,
|
||||
validationErrors: validation.errors.length,
|
||||
inputValidationErrors: validationErrors.length,
|
||||
schemaValidationErrors: validation.errors.length,
|
||||
validationWarnings: validation.warnings.length,
|
||||
})
|
||||
|
||||
// Format validation errors for LLM feedback
|
||||
const inputErrors =
|
||||
validationErrors.length > 0
|
||||
? validationErrors.map((e) => `Block "${e.blockId}" (${e.blockType}): ${e.error}`)
|
||||
: undefined
|
||||
|
||||
// Return the modified workflow state for the client to convert to YAML if needed
|
||||
return {
|
||||
success: true,
|
||||
workflowState: validation.sanitizedState || modifiedWorkflowState,
|
||||
// Include input validation errors so the LLM can see what was rejected
|
||||
...(inputErrors && {
|
||||
inputValidationErrors: inputErrors,
|
||||
inputValidationMessage: `${inputErrors.length} input(s) were rejected due to validation errors. The workflow was still updated with valid inputs only. Errors: ${inputErrors.join('; ')}`,
|
||||
}),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -34,3 +34,40 @@ export const GetTriggerBlocksResult = z.object({
|
||||
triggerBlockIds: z.array(z.string()),
|
||||
})
|
||||
export type GetTriggerBlocksResultType = z.infer<typeof GetTriggerBlocksResult>
|
||||
|
||||
// knowledge_base - shared schema used by client tool, server tool, and registry
|
||||
export const KnowledgeBaseArgsSchema = z.object({
|
||||
operation: z.enum(['create', 'list', 'get', 'query']),
|
||||
args: z
|
||||
.object({
|
||||
/** Name of the knowledge base (required for create) */
|
||||
name: z.string().optional(),
|
||||
/** Description of the knowledge base (optional for create) */
|
||||
description: z.string().optional(),
|
||||
/** Workspace ID to associate with (optional for create/list) */
|
||||
workspaceId: z.string().optional(),
|
||||
/** Knowledge base ID (required for get, query) */
|
||||
knowledgeBaseId: z.string().optional(),
|
||||
/** Search query text (required for query) */
|
||||
query: z.string().optional(),
|
||||
/** Number of results to return (optional for query, defaults to 5) */
|
||||
topK: z.number().min(1).max(50).optional(),
|
||||
/** Chunking configuration (optional for create) */
|
||||
chunkingConfig: z
|
||||
.object({
|
||||
maxSize: z.number().min(100).max(4000).default(1024),
|
||||
minSize: z.number().min(1).max(2000).default(1),
|
||||
overlap: z.number().min(0).max(500).default(200),
|
||||
})
|
||||
.optional(),
|
||||
})
|
||||
.optional(),
|
||||
})
|
||||
export type KnowledgeBaseArgs = z.infer<typeof KnowledgeBaseArgsSchema>
|
||||
|
||||
export const KnowledgeBaseResultSchema = z.object({
|
||||
success: z.boolean(),
|
||||
message: z.string(),
|
||||
data: z.any().optional(),
|
||||
})
|
||||
export type KnowledgeBaseResult = z.infer<typeof KnowledgeBaseResultSchema>
|
||||
|
||||
@@ -878,6 +878,37 @@ export const OAUTH_PROVIDERS: Record<string, OAuthProviderConfig> = {
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* Service metadata without React components - safe for server-side use
|
||||
*/
|
||||
export interface OAuthServiceMetadata {
|
||||
providerId: string
|
||||
name: string
|
||||
description: string
|
||||
baseProvider: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a flat list of all available OAuth services with metadata.
|
||||
* This is safe to use on the server as it doesn't include React components.
|
||||
*/
|
||||
export function getAllOAuthServices(): OAuthServiceMetadata[] {
|
||||
const services: OAuthServiceMetadata[] = []
|
||||
|
||||
for (const [baseProviderId, provider] of Object.entries(OAUTH_PROVIDERS)) {
|
||||
for (const service of Object.values(provider.services)) {
|
||||
services.push({
|
||||
providerId: service.providerId,
|
||||
name: service.name,
|
||||
description: service.description,
|
||||
baseProvider: baseProviderId,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return services
|
||||
}
|
||||
|
||||
export function getServiceByProviderAndId(
|
||||
provider: OAuthProvider,
|
||||
serviceId?: string
|
||||
|
||||
@@ -610,24 +610,21 @@ export class WorkflowDiffEngine {
|
||||
const finalEdges: Edge[] = Array.from(edgeMap.values())
|
||||
|
||||
// Build final proposed state
|
||||
// Always regenerate loops and parallels from finalBlocks because the block IDs may have
|
||||
// been remapped (via idMap) and the server's loops/parallels would have stale references.
|
||||
// This ensures the nodes arrays in loops/parallels contain the correct (remapped) block IDs,
|
||||
// which is critical for variable resolution in the tag dropdown.
|
||||
const { generateLoopBlocks, generateParallelBlocks } = await import(
|
||||
'@/stores/workflows/workflow/utils'
|
||||
)
|
||||
const finalProposedState: WorkflowState = {
|
||||
blocks: finalBlocks,
|
||||
edges: finalEdges,
|
||||
loops: proposedState.loops || {},
|
||||
parallels: proposedState.parallels || {},
|
||||
loops: generateLoopBlocks(finalBlocks),
|
||||
parallels: generateParallelBlocks(finalBlocks),
|
||||
lastSaved: Date.now(),
|
||||
}
|
||||
|
||||
// Ensure loops and parallels are generated
|
||||
if (Object.keys(finalProposedState.loops).length === 0) {
|
||||
const { generateLoopBlocks } = await import('@/stores/workflows/workflow/utils')
|
||||
finalProposedState.loops = generateLoopBlocks(finalProposedState.blocks)
|
||||
}
|
||||
if (Object.keys(finalProposedState.parallels).length === 0) {
|
||||
const { generateParallelBlocks } = await import('@/stores/workflows/workflow/utils')
|
||||
finalProposedState.parallels = generateParallelBlocks(finalProposedState.blocks)
|
||||
}
|
||||
|
||||
// Transfer block heights from baseline workflow for better measurements in diff view
|
||||
// If editing on top of diff, this transfers from the diff (which already has good heights)
|
||||
// Otherwise transfers from original workflow
|
||||
|
||||
@@ -321,13 +321,37 @@ export function sanitizeForCopilot(state: WorkflowState): CopilotWorkflowState {
|
||||
let inputs: Record<string, string | number | string[][] | object>
|
||||
|
||||
if (block.type === 'loop' || block.type === 'parallel') {
|
||||
// Extract configuration from block.data
|
||||
// Extract configuration from block.data (only include type-appropriate fields)
|
||||
const loopInputs: Record<string, string | number | string[][] | object> = {}
|
||||
if (block.data?.loopType) loopInputs.loopType = block.data.loopType
|
||||
if (block.data?.count !== undefined) loopInputs.iterations = block.data.count
|
||||
if (block.data?.collection !== undefined) loopInputs.collection = block.data.collection
|
||||
if (block.data?.whileCondition !== undefined) loopInputs.condition = block.data.whileCondition
|
||||
if (block.data?.parallelType) loopInputs.parallelType = block.data.parallelType
|
||||
|
||||
if (block.type === 'loop') {
|
||||
const loopType = block.data?.loopType || 'for'
|
||||
loopInputs.loopType = loopType
|
||||
// Only export fields relevant to the current loopType
|
||||
if (loopType === 'for' && block.data?.count !== undefined) {
|
||||
loopInputs.iterations = block.data.count
|
||||
}
|
||||
if (loopType === 'forEach' && block.data?.collection !== undefined) {
|
||||
loopInputs.collection = block.data.collection
|
||||
}
|
||||
if (loopType === 'while' && block.data?.whileCondition !== undefined) {
|
||||
loopInputs.condition = block.data.whileCondition
|
||||
}
|
||||
if (loopType === 'doWhile' && block.data?.doWhileCondition !== undefined) {
|
||||
loopInputs.condition = block.data.doWhileCondition
|
||||
}
|
||||
} else if (block.type === 'parallel') {
|
||||
const parallelType = block.data?.parallelType || 'count'
|
||||
loopInputs.parallelType = parallelType
|
||||
// Only export fields relevant to the current parallelType
|
||||
if (parallelType === 'count' && block.data?.count !== undefined) {
|
||||
loopInputs.iterations = block.data.count
|
||||
}
|
||||
if (parallelType === 'collection' && block.data?.collection !== undefined) {
|
||||
loopInputs.collection = block.data.collection
|
||||
}
|
||||
}
|
||||
|
||||
inputs = loopInputs
|
||||
} else {
|
||||
// For regular blocks, sanitize subBlocks
|
||||
|
||||
@@ -96,6 +96,7 @@
|
||||
"lodash": "4.17.21",
|
||||
"lucide-react": "^0.479.0",
|
||||
"mammoth": "^1.9.0",
|
||||
"mermaid": "^11.4.1",
|
||||
"mysql2": "3.14.3",
|
||||
"nanoid": "^3.3.7",
|
||||
"next": "16.0.7",
|
||||
|
||||
@@ -18,6 +18,7 @@ import { SummarizeClientTool } from '@/lib/copilot/tools/client/examples/summari
|
||||
import { ListGDriveFilesClientTool } from '@/lib/copilot/tools/client/gdrive/list-files'
|
||||
import { ReadGDriveFileClientTool } from '@/lib/copilot/tools/client/gdrive/read-file'
|
||||
import { GDriveRequestAccessClientTool } from '@/lib/copilot/tools/client/google/gdrive-request-access'
|
||||
import { KnowledgeBaseClientTool } from '@/lib/copilot/tools/client/knowledge/knowledge-base'
|
||||
import {
|
||||
getClientTool,
|
||||
registerClientTool,
|
||||
@@ -25,6 +26,7 @@ import {
|
||||
} from '@/lib/copilot/tools/client/manager'
|
||||
import { NavigateUIClientTool } from '@/lib/copilot/tools/client/navigation/navigate-ui'
|
||||
import { CheckoffTodoClientTool } from '@/lib/copilot/tools/client/other/checkoff-todo'
|
||||
import { GenerateDiagramClientTool } from '@/lib/copilot/tools/client/other/generate-diagram'
|
||||
import { MakeApiRequestClientTool } from '@/lib/copilot/tools/client/other/make-api-request'
|
||||
import { MarkTodoInProgressClientTool } from '@/lib/copilot/tools/client/other/mark-todo-in-progress'
|
||||
import { OAuthRequestAccessClientTool } from '@/lib/copilot/tools/client/other/oauth-request-access'
|
||||
@@ -85,6 +87,7 @@ const CLIENT_TOOL_INSTANTIATORS: Record<string, (id: string) => any> = {
|
||||
list_gdrive_files: (id) => new ListGDriveFilesClientTool(id),
|
||||
read_gdrive_file: (id) => new ReadGDriveFileClientTool(id),
|
||||
get_credentials: (id) => new GetCredentialsClientTool(id),
|
||||
knowledge_base: (id) => new KnowledgeBaseClientTool(id),
|
||||
make_api_request: (id) => new MakeApiRequestClientTool(id),
|
||||
plan: (id) => new PlanClientTool(id),
|
||||
checkoff_todo: (id) => new CheckoffTodoClientTool(id),
|
||||
@@ -104,6 +107,7 @@ const CLIENT_TOOL_INSTANTIATORS: Record<string, (id: string) => any> = {
|
||||
deploy_workflow: (id) => new DeployWorkflowClientTool(id),
|
||||
check_deployment_status: (id) => new CheckDeploymentStatusClientTool(id),
|
||||
navigate_ui: (id) => new NavigateUIClientTool(id),
|
||||
generate_diagram: (id) => new GenerateDiagramClientTool(id),
|
||||
}
|
||||
|
||||
// Read-only static metadata for class-based tools (no instances)
|
||||
@@ -122,6 +126,7 @@ export const CLASS_TOOL_METADATA: Record<string, BaseClientToolMetadata | undefi
|
||||
list_gdrive_files: (ListGDriveFilesClientTool as any)?.metadata,
|
||||
read_gdrive_file: (ReadGDriveFileClientTool as any)?.metadata,
|
||||
get_credentials: (GetCredentialsClientTool as any)?.metadata,
|
||||
knowledge_base: (KnowledgeBaseClientTool as any)?.metadata,
|
||||
make_api_request: (MakeApiRequestClientTool as any)?.metadata,
|
||||
plan: (PlanClientTool as any)?.metadata,
|
||||
checkoff_todo: (CheckoffTodoClientTool as any)?.metadata,
|
||||
@@ -141,6 +146,7 @@ export const CLASS_TOOL_METADATA: Record<string, BaseClientToolMetadata | undefi
|
||||
deploy_workflow: (DeployWorkflowClientTool as any)?.metadata,
|
||||
check_deployment_status: (CheckDeploymentStatusClientTool as any)?.metadata,
|
||||
navigate_ui: (NavigateUIClientTool as any)?.metadata,
|
||||
generate_diagram: (GenerateDiagramClientTool as any)?.metadata,
|
||||
}
|
||||
|
||||
function ensureClientToolInstance(toolName: string | undefined, toolCallId: string | undefined) {
|
||||
@@ -1019,8 +1025,36 @@ const sseHandlers: Record<string, SSEHandler> = {
|
||||
set({ toolCallsById: errorMap })
|
||||
})
|
||||
}, 0)
|
||||
return
|
||||
}
|
||||
} catch {}
|
||||
|
||||
// Integration tools: Check if auto-allowed, otherwise wait for user confirmation
|
||||
// This handles tools like google_calendar_*, exa_*, etc. that aren't in the client registry
|
||||
// Only relevant if mode is 'build' (agent)
|
||||
const { mode, workflowId, autoAllowedTools } = get()
|
||||
if (mode === 'build' && workflowId) {
|
||||
// Check if tool was NOT found in client registry (def is undefined from above)
|
||||
const def = name ? getTool(name) : undefined
|
||||
const inst = getClientTool(id) as any
|
||||
if (!def && !inst && name) {
|
||||
// Check if this tool is auto-allowed
|
||||
if (autoAllowedTools.includes(name)) {
|
||||
logger.info('[build mode] Integration tool auto-allowed, executing', { id, name })
|
||||
|
||||
// Auto-execute the tool
|
||||
setTimeout(() => {
|
||||
get().executeIntegrationTool(id)
|
||||
}, 0)
|
||||
} else {
|
||||
// Integration tools stay in pending state until user confirms
|
||||
logger.info('[build mode] Integration tool awaiting user confirmation', {
|
||||
id,
|
||||
name,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
reasoning: (data, context, _get, set) => {
|
||||
const phase = (data && (data.phase || data?.data?.phase)) as string | undefined
|
||||
@@ -1504,7 +1538,7 @@ async function* parseSSEStream(
|
||||
// Initial state (subset required for UI/streaming)
|
||||
const initialState = {
|
||||
mode: 'build' as const,
|
||||
selectedModel: 'claude-4.5-sonnet' as CopilotStore['selectedModel'],
|
||||
selectedModel: 'claude-4.5-opus' as CopilotStore['selectedModel'],
|
||||
agentPrefetch: false,
|
||||
enabledModels: null as string[] | null, // Null means not loaded yet, empty array means all disabled
|
||||
isCollapsed: false,
|
||||
@@ -1535,6 +1569,7 @@ const initialState = {
|
||||
toolCallsById: {} as Record<string, CopilotToolCall>,
|
||||
suppressAutoSelect: false,
|
||||
contextUsage: null,
|
||||
autoAllowedTools: [] as string[],
|
||||
}
|
||||
|
||||
export const useCopilotStore = create<CopilotStore>()(
|
||||
@@ -1766,6 +1801,7 @@ export const useCopilotStore = create<CopilotStore>()(
|
||||
|
||||
loadChats: async (_forceRefresh = false) => {
|
||||
const { workflowId } = get()
|
||||
|
||||
if (!workflowId) {
|
||||
set({ chats: [], isLoadingChats: false })
|
||||
return
|
||||
@@ -1774,7 +1810,8 @@ export const useCopilotStore = create<CopilotStore>()(
|
||||
// For now always fetch fresh
|
||||
set({ isLoadingChats: true })
|
||||
try {
|
||||
const response = await fetch(`/api/copilot/chat?workflowId=${workflowId}`)
|
||||
const url = `/api/copilot/chat?workflowId=${workflowId}`
|
||||
const response = await fetch(url)
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch chats: ${response.status}`)
|
||||
}
|
||||
@@ -1902,6 +1939,7 @@ export const useCopilotStore = create<CopilotStore>()(
|
||||
contexts?: ChatContext[]
|
||||
messageId?: string
|
||||
}
|
||||
|
||||
if (!workflowId) return
|
||||
|
||||
const abortController = new AbortController()
|
||||
@@ -1972,13 +2010,14 @@ export const useCopilotStore = create<CopilotStore>()(
|
||||
})
|
||||
}
|
||||
|
||||
// Call copilot API
|
||||
const apiMode: 'ask' | 'agent' | 'plan' =
|
||||
mode === 'ask' ? 'ask' : mode === 'plan' ? 'plan' : 'agent'
|
||||
const result = await sendStreamingMessage({
|
||||
message: messageToSend,
|
||||
userMessageId: userMessage.id,
|
||||
chatId: currentChat?.id,
|
||||
workflowId,
|
||||
workflowId: workflowId || undefined,
|
||||
mode: apiMode,
|
||||
model: get().selectedModel,
|
||||
prefetch: get().agentPrefetch,
|
||||
@@ -2812,6 +2851,190 @@ export const useCopilotStore = create<CopilotStore>()(
|
||||
logger.error('[Context Usage] Error fetching:', err)
|
||||
}
|
||||
},
|
||||
|
||||
executeIntegrationTool: async (toolCallId: string) => {
|
||||
const { toolCallsById, workflowId } = get()
|
||||
const toolCall = toolCallsById[toolCallId]
|
||||
if (!toolCall || !workflowId) return
|
||||
|
||||
const { id, name, params } = toolCall
|
||||
|
||||
// Set to executing state
|
||||
const executingMap = { ...get().toolCallsById }
|
||||
executingMap[id] = {
|
||||
...executingMap[id],
|
||||
state: ClientToolCallState.executing,
|
||||
display: resolveToolDisplay(name, ClientToolCallState.executing, id, params),
|
||||
}
|
||||
set({ toolCallsById: executingMap })
|
||||
logger.info('[toolCallsById] pending → executing (integration tool)', { id, name })
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/copilot/execute-tool', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
toolCallId: id,
|
||||
toolName: name,
|
||||
arguments: params || {},
|
||||
workflowId,
|
||||
}),
|
||||
})
|
||||
|
||||
const result = await res.json()
|
||||
const success = result.success && result.result?.success
|
||||
const completeMap = { ...get().toolCallsById }
|
||||
|
||||
// Do not override terminal review/rejected
|
||||
if (
|
||||
isRejectedState(completeMap[id]?.state) ||
|
||||
isReviewState(completeMap[id]?.state) ||
|
||||
isBackgroundState(completeMap[id]?.state)
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
completeMap[id] = {
|
||||
...completeMap[id],
|
||||
state: success ? ClientToolCallState.success : ClientToolCallState.error,
|
||||
display: resolveToolDisplay(
|
||||
name,
|
||||
success ? ClientToolCallState.success : ClientToolCallState.error,
|
||||
id,
|
||||
params
|
||||
),
|
||||
}
|
||||
set({ toolCallsById: completeMap })
|
||||
logger.info(`[toolCallsById] executing → ${success ? 'success' : 'error'} (integration)`, {
|
||||
id,
|
||||
name,
|
||||
result,
|
||||
})
|
||||
|
||||
// Notify backend tool mark-complete endpoint
|
||||
try {
|
||||
await fetch('/api/copilot/tools/mark-complete', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
id,
|
||||
name: name || 'unknown_tool',
|
||||
status: success ? 200 : 500,
|
||||
message: success
|
||||
? result.result?.output?.content
|
||||
: result.result?.error || result.error || 'Tool execution failed',
|
||||
data: success
|
||||
? result.result?.output
|
||||
: {
|
||||
error: result.result?.error || result.error,
|
||||
output: result.result?.output,
|
||||
},
|
||||
}),
|
||||
})
|
||||
} catch {}
|
||||
} catch (e) {
|
||||
const errorMap = { ...get().toolCallsById }
|
||||
// Do not override terminal review/rejected
|
||||
if (
|
||||
isRejectedState(errorMap[id]?.state) ||
|
||||
isReviewState(errorMap[id]?.state) ||
|
||||
isBackgroundState(errorMap[id]?.state)
|
||||
) {
|
||||
return
|
||||
}
|
||||
errorMap[id] = {
|
||||
...errorMap[id],
|
||||
state: ClientToolCallState.error,
|
||||
display: resolveToolDisplay(name, ClientToolCallState.error, id, params),
|
||||
}
|
||||
set({ toolCallsById: errorMap })
|
||||
logger.error('Integration tool execution failed', { id, name, error: e })
|
||||
}
|
||||
},
|
||||
|
||||
skipIntegrationTool: (toolCallId: string) => {
|
||||
const { toolCallsById } = get()
|
||||
const toolCall = toolCallsById[toolCallId]
|
||||
if (!toolCall) return
|
||||
|
||||
const { id, name, params } = toolCall
|
||||
|
||||
// Set to rejected state
|
||||
const rejectedMap = { ...get().toolCallsById }
|
||||
rejectedMap[id] = {
|
||||
...rejectedMap[id],
|
||||
state: ClientToolCallState.rejected,
|
||||
display: resolveToolDisplay(name, ClientToolCallState.rejected, id, params),
|
||||
}
|
||||
set({ toolCallsById: rejectedMap })
|
||||
logger.info('[toolCallsById] pending → rejected (integration tool skipped)', { id, name })
|
||||
|
||||
// Notify backend tool mark-complete endpoint with skip status
|
||||
fetch('/api/copilot/tools/mark-complete', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
id,
|
||||
name: name || 'unknown_tool',
|
||||
status: 200,
|
||||
message: 'Tool execution skipped by user',
|
||||
data: { skipped: true },
|
||||
}),
|
||||
}).catch(() => {})
|
||||
},
|
||||
|
||||
loadAutoAllowedTools: async () => {
|
||||
try {
|
||||
const res = await fetch('/api/copilot/auto-allowed-tools')
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
set({ autoAllowedTools: data.autoAllowedTools || [] })
|
||||
logger.info('[AutoAllowedTools] Loaded', { tools: data.autoAllowedTools })
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[AutoAllowedTools] Failed to load', { error: err })
|
||||
}
|
||||
},
|
||||
|
||||
addAutoAllowedTool: async (toolId: string) => {
|
||||
try {
|
||||
const res = await fetch('/api/copilot/auto-allowed-tools', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ toolId }),
|
||||
})
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
set({ autoAllowedTools: data.autoAllowedTools || [] })
|
||||
logger.info('[AutoAllowedTools] Added tool', { toolId })
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[AutoAllowedTools] Failed to add tool', { toolId, error: err })
|
||||
}
|
||||
},
|
||||
|
||||
removeAutoAllowedTool: async (toolId: string) => {
|
||||
try {
|
||||
const res = await fetch(
|
||||
`/api/copilot/auto-allowed-tools?toolId=${encodeURIComponent(toolId)}`,
|
||||
{
|
||||
method: 'DELETE',
|
||||
}
|
||||
)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
set({ autoAllowedTools: data.autoAllowedTools || [] })
|
||||
logger.info('[AutoAllowedTools] Removed tool', { toolId })
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('[AutoAllowedTools] Failed to remove tool', { toolId, error: err })
|
||||
}
|
||||
},
|
||||
|
||||
isToolAutoAllowed: (toolId: string) => {
|
||||
const { autoAllowedTools } = get()
|
||||
return autoAllowedTools.includes(toolId)
|
||||
},
|
||||
}))
|
||||
)
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ export interface CopilotState {
|
||||
| 'claude-4.5-sonnet'
|
||||
| 'claude-4.5-opus'
|
||||
| 'claude-4.1-opus'
|
||||
| 'gemini-3-pro'
|
||||
agentPrefetch: boolean
|
||||
enabledModels: string[] | null // Null means not loaded yet, array of model IDs when loaded
|
||||
isCollapsed: boolean
|
||||
@@ -138,6 +139,9 @@ export interface CopilotState {
|
||||
when: 'start' | 'end'
|
||||
estimatedTokens?: number
|
||||
} | null
|
||||
|
||||
// Auto-allowed integration tools (tools that can run without confirmation)
|
||||
autoAllowedTools: string[]
|
||||
}
|
||||
|
||||
export interface CopilotActions {
|
||||
@@ -213,6 +217,12 @@ export interface CopilotActions {
|
||||
handleNewChatCreation: (newChatId: string) => Promise<void>
|
||||
updateDiffStore: (yamlContent: string, toolName?: string) => Promise<void>
|
||||
updateDiffStoreWithWorkflowState: (workflowState: any, toolName?: string) => Promise<void>
|
||||
executeIntegrationTool: (toolCallId: string) => Promise<void>
|
||||
skipIntegrationTool: (toolCallId: string) => void
|
||||
loadAutoAllowedTools: () => Promise<void>
|
||||
addAutoAllowedTool: (toolId: string) => Promise<void>
|
||||
removeAutoAllowedTool: (toolId: string) => Promise<void>
|
||||
isToolAutoAllowed: (toolId: string) => boolean
|
||||
}
|
||||
|
||||
export type CopilotStore = CopilotState & CopilotActions
|
||||
|
||||
@@ -53,19 +53,23 @@ export const createTool: ToolConfig<GoogleCalendarCreateParams, GoogleCalendarCr
|
||||
type: 'string',
|
||||
required: true,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'Start date and time (RFC3339 format, e.g., 2025-06-03T10:00:00-08:00)',
|
||||
description:
|
||||
'Start date and time. MUST include timezone offset (e.g., 2025-06-03T10:00:00-08:00) OR provide timeZone parameter',
|
||||
},
|
||||
endDateTime: {
|
||||
type: 'string',
|
||||
required: true,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'End date and time (RFC3339 format, e.g., 2025-06-03T11:00:00-08:00)',
|
||||
description:
|
||||
'End date and time. MUST include timezone offset (e.g., 2025-06-03T11:00:00-08:00) OR provide timeZone parameter',
|
||||
},
|
||||
timeZone: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'Time zone (e.g., America/Los_Angeles)',
|
||||
description:
|
||||
'Time zone (e.g., America/Los_Angeles). Required if datetime does not include offset. Defaults to America/Los_Angeles if not provided.',
|
||||
default: 'America/Los_Angeles',
|
||||
},
|
||||
attendees: {
|
||||
type: 'array',
|
||||
@@ -101,13 +105,20 @@ export const createTool: ToolConfig<GoogleCalendarCreateParams, GoogleCalendarCr
|
||||
'Content-Type': 'application/json',
|
||||
}),
|
||||
body: (params: GoogleCalendarCreateParams): GoogleCalendarEventRequestBody => {
|
||||
// Default timezone if not provided and datetime doesn't include offset
|
||||
const timeZone = params.timeZone || 'America/Los_Angeles'
|
||||
const needsTimezone =
|
||||
!params.startDateTime.includes('+') && !params.startDateTime.includes('-', 10)
|
||||
|
||||
const eventData: GoogleCalendarEventRequestBody = {
|
||||
summary: params.summary,
|
||||
start: {
|
||||
dateTime: params.startDateTime,
|
||||
...(needsTimezone ? { timeZone } : {}),
|
||||
},
|
||||
end: {
|
||||
dateTime: params.endDateTime,
|
||||
...(needsTimezone ? { timeZone } : {}),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -119,6 +130,7 @@ export const createTool: ToolConfig<GoogleCalendarCreateParams, GoogleCalendarCr
|
||||
eventData.location = params.location
|
||||
}
|
||||
|
||||
// Always set timezone if explicitly provided
|
||||
if (params.timeZone) {
|
||||
eventData.start.timeZone = params.timeZone
|
||||
eventData.end.timeZone = params.timeZone
|
||||
|
||||
@@ -35,13 +35,14 @@ export const listTool: ToolConfig<GoogleDriveToolParams, GoogleDriveListResponse
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'A query to filter the files',
|
||||
description:
|
||||
'Search term to filter files by name (e.g. "budget" finds files with "budget" in the name). Do NOT use Google Drive query syntax here - just provide a plain search term.',
|
||||
},
|
||||
pageSize: {
|
||||
type: 'number',
|
||||
required: false,
|
||||
visibility: 'user-only',
|
||||
description: 'The number of files to return',
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The maximum number of files to return (default: 100)',
|
||||
},
|
||||
pageToken: {
|
||||
type: 'string',
|
||||
|
||||
@@ -25,20 +25,21 @@ export const appendTool: ToolConfig<GoogleSheetsToolParams, GoogleSheetsAppendRe
|
||||
spreadsheetId: {
|
||||
type: 'string',
|
||||
required: true,
|
||||
visibility: 'user-only',
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The ID of the spreadsheet to append to',
|
||||
},
|
||||
range: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The range of cells to append after',
|
||||
description: 'The A1 notation range to append after (e.g. "Sheet1", "Sheet1!A:D")',
|
||||
},
|
||||
values: {
|
||||
type: 'array',
|
||||
required: true,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The data to append to the spreadsheet',
|
||||
description:
|
||||
'The data to append as a 2D array (e.g. [["Alice", 30], ["Bob", 25]]) or array of objects.',
|
||||
},
|
||||
valueInputOption: {
|
||||
type: 'string',
|
||||
|
||||
@@ -22,14 +22,16 @@ export const readTool: ToolConfig<GoogleSheetsToolParams, GoogleSheetsReadRespon
|
||||
spreadsheetId: {
|
||||
type: 'string',
|
||||
required: true,
|
||||
visibility: 'user-only',
|
||||
description: 'The ID of the spreadsheet to read from',
|
||||
visibility: 'user-or-llm',
|
||||
description:
|
||||
'The ID of the spreadsheet (found in the URL: docs.google.com/spreadsheets/d/{SPREADSHEET_ID}/edit).',
|
||||
},
|
||||
range: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The range of cells to read from',
|
||||
description:
|
||||
'The A1 notation range to read (e.g. "Sheet1!A1:D10", "A1:B5"). Defaults to first sheet A1:Z1000 if not specified.',
|
||||
},
|
||||
},
|
||||
|
||||
|
||||
@@ -25,20 +25,21 @@ export const updateTool: ToolConfig<GoogleSheetsToolParams, GoogleSheetsUpdateRe
|
||||
spreadsheetId: {
|
||||
type: 'string',
|
||||
required: true,
|
||||
visibility: 'user-only',
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The ID of the spreadsheet to update',
|
||||
},
|
||||
range: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The range of cells to update',
|
||||
description: 'The A1 notation range to update (e.g. "Sheet1!A1:D10", "A1:B5")',
|
||||
},
|
||||
values: {
|
||||
type: 'array',
|
||||
required: true,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The data to update in the spreadsheet',
|
||||
description:
|
||||
'The data to update as a 2D array (e.g. [["Name", "Age"], ["Alice", 30]]) or array of objects.',
|
||||
},
|
||||
valueInputOption: {
|
||||
type: 'string',
|
||||
|
||||
@@ -22,20 +22,21 @@ export const writeTool: ToolConfig<GoogleSheetsToolParams, GoogleSheetsWriteResp
|
||||
spreadsheetId: {
|
||||
type: 'string',
|
||||
required: true,
|
||||
visibility: 'user-only',
|
||||
description: 'The ID of the spreadsheet to write to',
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The ID of the spreadsheet',
|
||||
},
|
||||
range: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The range of cells to write to',
|
||||
description: 'The A1 notation range to write to (e.g. "Sheet1!A1:D10", "A1:B5")',
|
||||
},
|
||||
values: {
|
||||
type: 'array',
|
||||
required: true,
|
||||
visibility: 'user-or-llm',
|
||||
description: 'The data to write to the spreadsheet',
|
||||
description:
|
||||
'The data to write as a 2D array (e.g. [["Name", "Age"], ["Alice", 30], ["Bob", 25]]) or array of objects.',
|
||||
},
|
||||
valueInputOption: {
|
||||
type: 'string',
|
||||
|
||||
@@ -223,6 +223,41 @@ export async function executeTool(
|
||||
}
|
||||
}
|
||||
|
||||
// Check for direct execution (no HTTP request needed)
|
||||
if (tool.directExecution) {
|
||||
logger.info(`[${requestId}] Using directExecution for ${toolId}`)
|
||||
const result = await tool.directExecution(contextParams)
|
||||
|
||||
// Apply post-processing if available and not skipped
|
||||
let finalResult = result
|
||||
if (tool.postProcess && result.success && !skipPostProcess) {
|
||||
try {
|
||||
finalResult = await tool.postProcess(result, contextParams, executeTool)
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Post-processing error for ${toolId}:`, {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
finalResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// Process file outputs if execution context is available
|
||||
finalResult = await processFileOutputs(finalResult, tool, executionContext)
|
||||
|
||||
// Add timing data to the result
|
||||
const endTime = new Date()
|
||||
const endTimeISO = endTime.toISOString()
|
||||
const duration = endTime.getTime() - startTime.getTime()
|
||||
return {
|
||||
...finalResult,
|
||||
timing: {
|
||||
startTime: startTimeISO,
|
||||
endTime: endTimeISO,
|
||||
duration,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// For internal routes or when skipProxy is true, call the API directly
|
||||
// Internal routes are automatically detected by checking if URL starts with /api/
|
||||
const endpointUrl =
|
||||
|
||||
@@ -2,6 +2,7 @@ import { describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createExecutionToolSchema,
|
||||
createLLMToolSchema,
|
||||
createUserToolSchema,
|
||||
filterSchemaForLLM,
|
||||
formatParameterLabel,
|
||||
getToolParametersConfig,
|
||||
@@ -110,6 +111,38 @@ describe('Tool Parameters Utils', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('createUserToolSchema', () => {
|
||||
it.concurrent('should include user-only parameters and omit hidden ones', () => {
|
||||
const toolWithHiddenParam = {
|
||||
...mockToolConfig,
|
||||
id: 'user_schema_tool',
|
||||
params: {
|
||||
...mockToolConfig.params,
|
||||
spreadsheetId: {
|
||||
type: 'string',
|
||||
required: true,
|
||||
visibility: 'user-only' as ParameterVisibility,
|
||||
description: 'Spreadsheet ID to operate on',
|
||||
},
|
||||
accessToken: {
|
||||
type: 'string',
|
||||
required: true,
|
||||
visibility: 'hidden' as ParameterVisibility,
|
||||
description: 'OAuth access token',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const schema = createUserToolSchema(toolWithHiddenParam)
|
||||
|
||||
expect(schema.properties).toHaveProperty('spreadsheetId')
|
||||
expect(schema.required).toContain('spreadsheetId')
|
||||
expect(schema.properties).not.toHaveProperty('accessToken')
|
||||
expect(schema.required).not.toContain('accessToken')
|
||||
expect(schema.properties).toHaveProperty('message')
|
||||
})
|
||||
})
|
||||
|
||||
describe('createExecutionToolSchema', () => {
|
||||
it.concurrent('should create complete schema with all parameters', () => {
|
||||
const schema = createExecutionToolSchema(mockToolConfig)
|
||||
|
||||
@@ -3,6 +3,7 @@ import type { ParameterVisibility, ToolConfig } from '@/tools/types'
|
||||
import { getTool } from '@/tools/utils'
|
||||
|
||||
const logger = createLogger('ToolsParams')
|
||||
type ToolParamDefinition = ToolConfig['params'][string]
|
||||
|
||||
export interface Option {
|
||||
label: string
|
||||
@@ -73,6 +74,9 @@ export interface BlockConfig {
|
||||
export interface SchemaProperty {
|
||||
type: string
|
||||
description: string
|
||||
items?: Record<string, any>
|
||||
properties?: Record<string, SchemaProperty>
|
||||
required?: string[]
|
||||
}
|
||||
|
||||
export interface ToolSchema {
|
||||
@@ -326,6 +330,59 @@ export function getToolParametersConfig(
|
||||
/**
|
||||
* Creates a tool schema for LLM with user-provided parameters excluded
|
||||
*/
|
||||
function buildParameterSchema(
|
||||
toolId: string,
|
||||
paramId: string,
|
||||
param: ToolParamDefinition
|
||||
): SchemaProperty {
|
||||
let schemaType = param.type
|
||||
if (schemaType === 'json' || schemaType === 'any') {
|
||||
schemaType = 'object'
|
||||
}
|
||||
|
||||
const propertySchema: SchemaProperty = {
|
||||
type: schemaType,
|
||||
description: param.description || '',
|
||||
}
|
||||
|
||||
if (param.type === 'array' && param.items) {
|
||||
propertySchema.items = {
|
||||
...param.items,
|
||||
...(param.items.properties && {
|
||||
properties: { ...param.items.properties },
|
||||
}),
|
||||
}
|
||||
} else if (param.items) {
|
||||
logger.warn(`items property ignored for non-array param "${paramId}" in tool "${toolId}"`)
|
||||
}
|
||||
|
||||
return propertySchema
|
||||
}
|
||||
|
||||
export function createUserToolSchema(toolConfig: ToolConfig): ToolSchema {
|
||||
const schema: ToolSchema = {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
required: [],
|
||||
}
|
||||
|
||||
for (const [paramId, param] of Object.entries(toolConfig.params)) {
|
||||
const visibility = param.visibility ?? 'user-or-llm'
|
||||
if (visibility === 'hidden') {
|
||||
continue
|
||||
}
|
||||
|
||||
const propertySchema = buildParameterSchema(toolConfig.id, paramId, param)
|
||||
schema.properties[paramId] = propertySchema
|
||||
|
||||
if (param.required) {
|
||||
schema.required.push(paramId)
|
||||
}
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
export async function createLLMToolSchema(
|
||||
toolConfig: ToolConfig,
|
||||
userProvidedParams: Record<string, unknown>
|
||||
@@ -359,29 +416,7 @@ export async function createLLMToolSchema(
|
||||
}
|
||||
|
||||
// Add parameter to LLM schema
|
||||
let schemaType = param.type
|
||||
if (param.type === 'json' || param.type === 'any') {
|
||||
schemaType = 'object'
|
||||
}
|
||||
|
||||
const propertySchema: any = {
|
||||
type: schemaType,
|
||||
description: param.description || '',
|
||||
}
|
||||
|
||||
// Include items property for arrays
|
||||
if (param.type === 'array' && param.items) {
|
||||
propertySchema.items = {
|
||||
...param.items,
|
||||
...(param.items.properties && {
|
||||
properties: { ...param.items.properties },
|
||||
}),
|
||||
}
|
||||
} else if (param.items) {
|
||||
logger.warn(
|
||||
`items property ignored for non-array param "${paramId}" in tool "${toolConfig.id}"`
|
||||
)
|
||||
}
|
||||
const propertySchema = buildParameterSchema(toolConfig.id, paramId, param)
|
||||
|
||||
// Special handling for workflow_executor's inputMapping parameter
|
||||
if (toolConfig.id === 'workflow_executor' && paramId === 'inputMapping') {
|
||||
|
||||
@@ -105,6 +105,12 @@ export interface ToolConfig<P = any, R = any> {
|
||||
|
||||
// Response handling
|
||||
transformResponse?: (response: Response, params?: P) => Promise<R>
|
||||
|
||||
/**
|
||||
* Direct execution function for tools that don't need HTTP requests.
|
||||
* If provided, this will be called instead of making an HTTP request.
|
||||
*/
|
||||
directExecution?: (params: P) => Promise<ToolResponse>
|
||||
}
|
||||
|
||||
export interface TableRow {
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"allowImportingTsExtensions": true,
|
||||
"jsx": "preserve",
|
||||
"jsx": "react-jsx",
|
||||
"plugins": [
|
||||
{
|
||||
"name": "next"
|
||||
|
||||
@@ -44,6 +44,7 @@
|
||||
"drizzle-orm": "^0.44.5",
|
||||
"ffmpeg-static": "5.3.0",
|
||||
"fluent-ffmpeg": "2.1.3",
|
||||
"mermaid": "11.12.2",
|
||||
"mongodb": "6.19.0",
|
||||
"neo4j-driver": "6.0.1",
|
||||
"nodemailer": "7.0.11",
|
||||
|
||||
1
packages/db/migrations/0117_silly_purifiers.sql
Normal file
1
packages/db/migrations/0117_silly_purifiers.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE "settings" ADD COLUMN "copilot_auto_allowed_tools" jsonb DEFAULT '[]' NOT NULL;
|
||||
7762
packages/db/migrations/meta/0117_snapshot.json
Normal file
7762
packages/db/migrations/meta/0117_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -813,6 +813,13 @@
|
||||
"when": 1764820826997,
|
||||
"tag": "0116_flimsy_shape",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 117,
|
||||
"version": "7",
|
||||
"when": 1764909191102,
|
||||
"tag": "0117_silly_purifiers",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -430,6 +430,9 @@ export const settings = pgTable('settings', {
|
||||
// Copilot preferences - maps model_id to enabled/disabled boolean
|
||||
copilotEnabledModels: jsonb('copilot_enabled_models').notNull().default('{}'),
|
||||
|
||||
// Copilot auto-allowed integration tools - array of tool IDs that can run without confirmation
|
||||
copilotAutoAllowedTools: jsonb('copilot_auto_allowed_tools').notNull().default('[]'),
|
||||
|
||||
updatedAt: timestamp('updated_at').notNull().defaultNow(),
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user