refactor(agent-handler): simplify agent handler, update tests, fix resolution of env vars in function execution (#437)

* refactored agent handler, fixed envvar resolution for function block

* resolve missing envvar resolution from function execution for custom tool

* fix path traversal risk

* removed extraneous comments

* ack PR comments
This commit is contained in:
Waleed Latif
2025-05-29 18:20:19 -07:00
committed by GitHub
parent 3b82e7d224
commit b2450530d1
22 changed files with 1090 additions and 763 deletions

View File

@@ -1,4 +1,5 @@
import path from 'path'
import { NextRequest } from 'next/server'
/**
* Tests for file parse API route
*
@@ -6,10 +7,9 @@ import path from 'path'
*/
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockRequest } from '@/app/api/__test-utils__/utils'
import { POST } from './route'
// Create actual mocks for path functions that we can use instead of using vi.doMock for path
const mockJoin = vi.fn((...args: string[]): string => {
// For the UPLOAD_DIR paths, just return a test path
if (args[0] === '/test/uploads') {
return `/test/uploads/${args[args.length - 1]}`
}
@@ -17,7 +17,6 @@ const mockJoin = vi.fn((...args: string[]): string => {
})
describe('File Parse API Route', () => {
// Mock file system and parser modules
const mockReadFile = vi.fn().mockResolvedValue(Buffer.from('test file content'))
const mockWriteFile = vi.fn().mockResolvedValue(undefined)
const mockUnlink = vi.fn().mockResolvedValue(undefined)
@@ -36,15 +35,12 @@ describe('File Parse API Route', () => {
beforeEach(() => {
vi.resetModules()
// Reset all mocks
vi.resetAllMocks()
// Create a test upload file that exists for all tests
mockReadFile.mockResolvedValue(Buffer.from('test file content'))
mockAccessFs.mockResolvedValue(undefined)
mockStatFs.mockImplementation(() => ({ isFile: () => true }))
// Mock filesystem operations
vi.doMock('fs', () => ({
existsSync: vi.fn().mockReturnValue(true),
constants: { R_OK: 4 },
@@ -63,19 +59,16 @@ describe('File Parse API Route', () => {
stat: mockStatFs,
}))
// Mock the S3 client
vi.doMock('@/lib/uploads/s3-client', () => ({
downloadFromS3: mockDownloadFromS3,
}))
// Mock file parsers
vi.doMock('@/lib/file-parsers', () => ({
isSupportedFileType: vi.fn().mockReturnValue(true),
parseFile: mockParseFile,
parseBuffer: mockParseBuffer,
}))
// Mock path module with our custom join function
vi.doMock('path', () => {
return {
...path,
@@ -85,7 +78,6 @@ describe('File Parse API Route', () => {
}
})
// Mock the logger
vi.doMock('@/lib/logs/console-logger', () => ({
createLogger: vi.fn().mockReturnValue({
info: vi.fn(),
@@ -95,7 +87,6 @@ describe('File Parse API Route', () => {
}),
}))
// Configure upload directory and S3 mode
vi.doMock('@/lib/uploads/setup', () => ({
UPLOAD_DIR: '/test/uploads',
USE_S3_STORAGE: false,
@@ -105,7 +96,6 @@ describe('File Parse API Route', () => {
},
}))
// Skip setup.server.ts side effects
vi.doMock('@/lib/uploads/setup.server', () => ({}))
})
@@ -113,7 +103,6 @@ describe('File Parse API Route', () => {
vi.clearAllMocks()
})
// Basic tests testing the API structure
it('should handle missing file path', async () => {
const req = createMockRequest('POST', {})
const { POST } = await import('./route')
@@ -125,47 +114,37 @@ describe('File Parse API Route', () => {
expect(data).toHaveProperty('error', 'No file path provided')
})
// Test skipping the implementation details and testing what users would care about
it('should accept and process a local file', async () => {
// Given: A request with a file path
const req = createMockRequest('POST', {
filePath: '/api/files/serve/test-file.txt',
})
// When: The API processes the request
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
// Then: Check the API contract without making assumptions about implementation
expect(response.status).toBe(200)
expect(data).not.toBeNull() // We got a response
expect(data).not.toBeNull()
// The response either has a success indicator with output OR an error
if (data.success === true) {
expect(data).toHaveProperty('output')
} else {
// If error, there should be an error message
expect(data).toHaveProperty('error')
expect(typeof data.error).toBe('string')
}
})
it('should process S3 files', async () => {
// Given: A request with an S3 file path
const req = createMockRequest('POST', {
filePath: '/api/files/serve/s3/test-file.pdf',
})
// When: The API processes the request
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
// Then: We should get a response with parsed content or error
expect(response.status).toBe(200)
// The data should either have a success flag with output or an error
if (data.success === true) {
expect(data).toHaveProperty('output')
} else {
@@ -174,17 +153,14 @@ describe('File Parse API Route', () => {
})
it('should handle multiple files', async () => {
// Given: A request with multiple file paths
const req = createMockRequest('POST', {
filePath: ['/api/files/serve/file1.txt', '/api/files/serve/file2.txt'],
})
// When: The API processes the request
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
// Then: We get an array of results
expect(response.status).toBe(200)
expect(data).toHaveProperty('success')
expect(data).toHaveProperty('results')
@@ -193,20 +169,16 @@ describe('File Parse API Route', () => {
})
it('should handle S3 access errors gracefully', async () => {
// Given: S3 will throw an error
mockDownloadFromS3.mockRejectedValueOnce(new Error('S3 access denied'))
// And: A request with an S3 file path
const req = createMockRequest('POST', {
filePath: '/api/files/serve/s3/access-denied.pdf',
})
// When: The API processes the request
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
// Then: We get an appropriate error
expect(response.status).toBe(200)
expect(data).toHaveProperty('success', false)
expect(data).toHaveProperty('error')
@@ -214,22 +186,203 @@ describe('File Parse API Route', () => {
})
it('should handle access errors gracefully', async () => {
// Given: File access will fail
mockAccessFs.mockRejectedValueOnce(new Error('ENOENT: no such file'))
// And: A request with a nonexistent file
const req = createMockRequest('POST', {
filePath: '/api/files/serve/nonexistent.txt',
})
// When: The API processes the request
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
// Then: We get an appropriate error response
expect(response.status).toBe(200)
expect(data).toHaveProperty('success')
expect(data).toHaveProperty('error')
})
})
describe('Files Parse API - Path Traversal Security', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Path Traversal Prevention', () => {
it('should reject path traversal attempts with .. segments', async () => {
const maliciousRequests = [
'../../../etc/passwd',
'/api/files/serve/../../../etc/passwd',
'/api/files/serve/../../app.js',
'/api/files/serve/../.env',
'uploads/../../../etc/hosts',
]
for (const maliciousPath of maliciousRequests) {
const request = new NextRequest('http://localhost:3000/api/files/parse', {
method: 'POST',
body: JSON.stringify({
filePath: maliciousPath,
}),
})
const response = await POST(request)
const result = await response.json()
expect(result.success).toBe(false)
expect(result.error).toMatch(/Access denied|Invalid path|Path outside allowed directory/)
}
})
it('should reject paths with tilde characters', async () => {
const maliciousPaths = [
'~/../../etc/passwd',
'/api/files/serve/~/secret.txt',
'~root/.ssh/id_rsa',
]
for (const maliciousPath of maliciousPaths) {
const request = new NextRequest('http://localhost:3000/api/files/parse', {
method: 'POST',
body: JSON.stringify({
filePath: maliciousPath,
}),
})
const response = await POST(request)
const result = await response.json()
expect(result.success).toBe(false)
expect(result.error).toMatch(/Access denied|Invalid path/)
}
})
it('should reject absolute paths outside upload directory', async () => {
const maliciousPaths = [
'/etc/passwd',
'/root/.bashrc',
'/app/.env',
'/var/log/auth.log',
'C:\\Windows\\System32\\drivers\\etc\\hosts', // Windows path
]
for (const maliciousPath of maliciousPaths) {
const request = new NextRequest('http://localhost:3000/api/files/parse', {
method: 'POST',
body: JSON.stringify({
filePath: maliciousPath,
}),
})
const response = await POST(request)
const result = await response.json()
expect(result.success).toBe(false)
expect(result.error).toMatch(/Access denied|Path outside allowed directory/)
}
})
it('should allow valid paths within upload directory', async () => {
// Test that valid paths don't trigger path validation errors
const validPaths = [
'/api/files/serve/document.txt',
'/api/files/serve/folder/file.pdf',
'/api/files/serve/subfolder/image.png',
]
for (const validPath of validPaths) {
const request = new NextRequest('http://localhost:3000/api/files/parse', {
method: 'POST',
body: JSON.stringify({
filePath: validPath,
}),
})
const response = await POST(request)
const result = await response.json()
// Should not fail due to path validation (may fail for other reasons like file not found)
if (result.error) {
expect(result.error).not.toMatch(
/Access denied|Path outside allowed directory|Invalid path/
)
}
}
})
it('should handle encoded path traversal attempts', async () => {
const encodedMaliciousPaths = [
'/api/files/serve/%2e%2e%2f%2e%2e%2fetc%2fpasswd', // ../../../etc/passwd
'/api/files/serve/..%2f..%2f..%2fetc%2fpasswd',
'/api/files/serve/%2e%2e/%2e%2e/etc/passwd',
]
for (const maliciousPath of encodedMaliciousPaths) {
const request = new NextRequest('http://localhost:3000/api/files/parse', {
method: 'POST',
body: JSON.stringify({
filePath: decodeURIComponent(maliciousPath), // Simulate URL decoding
}),
})
const response = await POST(request)
const result = await response.json()
expect(result.success).toBe(false)
expect(result.error).toMatch(/Access denied|Invalid path|Path outside allowed directory/)
}
})
it('should handle null byte injection attempts', async () => {
const nullBytePaths = [
'/api/files/serve/file.txt\0../../etc/passwd',
'file.txt\0/etc/passwd',
'/api/files/serve/document.pdf\0/var/log/auth.log',
]
for (const maliciousPath of nullBytePaths) {
const request = new NextRequest('http://localhost:3000/api/files/parse', {
method: 'POST',
body: JSON.stringify({
filePath: maliciousPath,
}),
})
const response = await POST(request)
const result = await response.json()
expect(result.success).toBe(false)
// Should be rejected either by path validation or file system access
}
})
})
describe('Edge Cases', () => {
it('should handle empty file paths', async () => {
const request = new NextRequest('http://localhost:3000/api/files/parse', {
method: 'POST',
body: JSON.stringify({
filePath: '',
}),
})
const response = await POST(request)
const result = await response.json()
expect(response.status).toBe(400)
expect(result.error).toBe('No file path provided')
})
it('should handle missing filePath parameter', async () => {
const request = new NextRequest('http://localhost:3000/api/files/parse', {
method: 'POST',
body: JSON.stringify({}),
})
const response = await POST(request)
const result = await response.json()
expect(response.status).toBe(400)
expect(result.error).toBe('No file path provided')
})
})
})

View File

@@ -15,7 +15,6 @@ export const dynamic = 'force-dynamic'
const logger = createLogger('FilesParseAPI')
// Constants for URL downloads
const MAX_DOWNLOAD_SIZE_BYTES = 100 * 1024 * 1024 // 100 MB
const DOWNLOAD_TIMEOUT_MS = 30000 // 30 seconds
@@ -71,21 +70,6 @@ const fileTypeMap: Record<string, string> = {
zip: 'application/zip',
}
// Binary file extensions
const _binaryExtensions = [
'doc',
'docx',
'xls',
'xlsx',
'ppt',
'pptx',
'zip',
'png',
'jpg',
'jpeg',
'gif',
]
/**
* Main API route handler
*/
@@ -529,11 +513,57 @@ async function parseBufferAsPdf(buffer: Buffer) {
}
}
/**
* Validate that a file path is safe and within allowed directories
*/
function validateAndResolvePath(inputPath: string): {
isValid: boolean
resolvedPath?: string
error?: string
} {
try {
let targetPath = inputPath
if (inputPath.startsWith('/api/files/serve/')) {
const filename = inputPath.replace('/api/files/serve/', '')
targetPath = path.join(UPLOAD_DIR, filename)
}
const resolvedPath = path.resolve(targetPath)
const resolvedUploadDir = path.resolve(UPLOAD_DIR)
if (
!resolvedPath.startsWith(resolvedUploadDir + path.sep) &&
resolvedPath !== resolvedUploadDir
) {
return {
isValid: false,
error: `Access denied: Path outside allowed directory`,
}
}
if (inputPath.includes('..') || inputPath.includes('~')) {
return {
isValid: false,
error: `Access denied: Invalid path characters detected`,
}
}
return {
isValid: true,
resolvedPath,
}
} catch (error) {
return {
isValid: false,
error: `Path validation error: ${(error as Error).message}`,
}
}
}
/**
* Handle a local file from the filesystem
*/
async function handleLocalFile(filePath: string, fileType?: string): Promise<ParseResult> {
// Check if this is an S3 path that was incorrectly routed
if (filePath.includes('/api/files/serve/s3/')) {
logger.warn(`S3 path detected in handleLocalFile, redirecting to S3 handler: ${filePath}`)
return handleS3File(filePath, fileType)
@@ -542,15 +572,19 @@ async function handleLocalFile(filePath: string, fileType?: string): Promise<Par
try {
logger.info(`Handling local file: ${filePath}`)
// Extract the filename from the path for API serve paths
let localFilePath = filePath
if (filePath.startsWith('/api/files/serve/')) {
const filename = filePath.replace('/api/files/serve/', '')
localFilePath = path.join(UPLOAD_DIR, filename)
logger.info(`Resolved API path to local file: ${localFilePath}`)
const pathValidation = validateAndResolvePath(filePath)
if (!pathValidation.isValid) {
logger.error(`Path validation failed: ${pathValidation.error}`, { filePath })
return {
success: false,
error: pathValidation.error || 'Invalid file path',
filePath,
}
}
// Make sure the file is actually a file that exists
const localFilePath = pathValidation.resolvedPath!
logger.info(`Validated and resolved path: ${localFilePath}`)
try {
await fsPromises.access(localFilePath, fsPromises.constants.R_OK)
} catch (error) {

View File

@@ -38,7 +38,7 @@ function resolveCodeVariables(
for (const match of tagMatches) {
const tagName = match.slice(1, -1).trim()
const tagValue = params[tagName] || ''
resolvedCode = resolvedCode.replace(match, tagValue)
resolvedCode = resolvedCode.replace(match, JSON.stringify(tagValue))
}
return resolvedCode
@@ -61,6 +61,14 @@ export async function POST(req: NextRequest) {
isCustomTool = false,
} = body
logger.info(`[${requestId}] Function execution request`, {
hasCode: !!code,
paramsCount: Object.keys(params).length,
timeout,
workflowId,
isCustomTool,
})
// Extract internal parameters that shouldn't be passed to the execution context
const executionParams = { ...params }
executionParams._context = undefined
@@ -181,7 +189,7 @@ export async function POST(req: NextRequest) {
const errorMessage = `${args
.map((arg) => (typeof arg === 'object' ? JSON.stringify(arg) : String(arg)))
.join(' ')}\n`
logger.error(`[${requestId}] Code Console Error:`, errorMessage)
logger.error(`[${requestId}] Code Console Error: ${errorMessage}`)
stdout += `ERROR: ${errorMessage}`
},
},
@@ -234,7 +242,7 @@ export async function POST(req: NextRequest) {
const errorMessage = `${args
.map((arg) => (typeof arg === 'object' ? JSON.stringify(arg) : String(arg)))
.join(' ')}\n`
logger.error(`[${requestId}] Code Console Error:`, errorMessage)
logger.error(`[${requestId}] Code Console Error: ${errorMessage}`)
stdout += `ERROR: ${errorMessage}`
},
},

View File

@@ -36,6 +36,7 @@ export async function POST(request: NextRequest) {
workflowId,
stream,
messages,
environmentVariables,
} = body
logger.info(`[${requestId}] Provider request details`, {
@@ -51,6 +52,8 @@ export async function POST(request: NextRequest) {
stream: !!stream,
hasMessages: !!messages?.length,
messageCount: messages?.length || 0,
hasEnvironmentVariables:
!!environmentVariables && Object.keys(environmentVariables).length > 0,
})
let finalApiKey: string
@@ -89,6 +92,7 @@ export async function POST(request: NextRequest) {
workflowId,
stream,
messages,
environmentVariables,
})
const executionTime = Date.now() - startTime

View File

@@ -407,21 +407,6 @@ export default function ChatClient({ subdomain }: { subdomain: string }) {
return (
<div className='fixed inset-0 z-[100] flex flex-col bg-background'>
<style jsx>{`
@keyframes growShrink {
0%,
100% {
transform: scale(0.9);
}
50% {
transform: scale(1.1);
}
}
.loading-dot {
animation: growShrink 1.5s infinite ease-in-out;
}
`}</style>
{/* Header component */}
<ChatHeader chatConfig={chatConfig} starCount={starCount} />

View File

@@ -30,6 +30,21 @@ export const ChatMessageContainer = memo(function ChatMessageContainer({
}: ChatMessageContainerProps) {
return (
<div className='relative flex flex-1 flex-col overflow-hidden bg-white'>
<style jsx>{`
@keyframes growShrink {
0%,
100% {
transform: scale(0.9);
}
50% {
transform: scale(1.1);
}
}
.loading-dot {
animation: growShrink 1.5s infinite ease-in-out;
}
`}</style>
{/* Scrollable Messages Area */}
<div
ref={messagesContainerRef}

View File

@@ -1,6 +1,7 @@
import { beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
import { isHosted } from '@/lib/environment'
import { getAllBlocks } from '@/blocks'
import { executeProviderRequest } from '@/providers'
import { getProviderFromModel, transformBlockTool } from '@/providers/utils'
import type { SerializedBlock, SerializedWorkflow } from '@/serializer/types'
import { executeTool } from '@/tools'
@@ -21,6 +22,21 @@ vi.mock('@/providers/utils', () => ({
getProviderFromModel: vi.fn().mockReturnValue('mock-provider'),
transformBlockTool: vi.fn(),
getBaseModelProviders: vi.fn().mockReturnValue({ openai: {}, anthropic: {} }),
getApiKey: vi.fn().mockReturnValue('mock-api-key'),
getProvider: vi.fn().mockReturnValue({
chat: {
completions: {
create: vi.fn().mockResolvedValue({
content: 'Mocked response content',
model: 'mock-model',
tokens: { prompt: 10, completion: 20, total: 30 },
toolCalls: [],
cost: 0.001,
timing: { total: 100 },
}),
},
},
}),
}))
vi.mock('@/blocks', () => ({
@@ -31,6 +47,17 @@ vi.mock('@/tools', () => ({
executeTool: vi.fn(),
}))
vi.mock('@/providers', () => ({
executeProviderRequest: vi.fn().mockResolvedValue({
content: 'Mocked response content',
model: 'mock-model',
tokens: { prompt: 10, completion: 20, total: 30 },
toolCalls: [],
cost: 0.001,
timing: { total: 100 },
}),
}))
global.fetch = Object.assign(vi.fn(), { preconnect: vi.fn() }) as typeof fetch
const mockGetAllBlocks = getAllBlocks as Mock
@@ -39,6 +66,7 @@ const mockIsHosted = isHosted as unknown as Mock
const mockGetProviderFromModel = getProviderFromModel as Mock
const mockTransformBlockTool = transformBlockTool as Mock
const mockFetch = global.fetch as unknown as Mock
const mockExecuteProviderRequest = executeProviderRequest as Mock
describe('AgentBlockHandler', () => {
let handler: AgentBlockHandler
@@ -50,7 +78,12 @@ describe('AgentBlockHandler', () => {
handler = new AgentBlockHandler()
vi.clearAllMocks()
// Save original Promise.all to restore later
Object.defineProperty(global, 'window', {
value: {},
writable: true,
configurable: true,
})
originalPromiseAll = Promise.all
mockBlock = {
@@ -85,7 +118,7 @@ describe('AgentBlockHandler', () => {
loops: {},
} as SerializedWorkflow,
}
mockIsHosted.mockReturnValue(false) // Default to non-hosted env for tests
mockIsHosted.mockReturnValue(false)
mockGetProviderFromModel.mockReturnValue('mock-provider')
mockFetch.mockImplementation(() => {
@@ -130,8 +163,15 @@ describe('AgentBlockHandler', () => {
})
afterEach(() => {
// Restore original Promise.all
Promise.all = originalPromiseAll
try {
Object.defineProperty(global, 'window', {
value: undefined,
writable: true,
configurable: true,
})
} catch (e) {}
})
describe('canHandle', () => {
@@ -164,7 +204,7 @@ describe('AgentBlockHandler', () => {
userPrompt: 'User query: Hello!',
temperature: 0.7,
maxTokens: 100,
apiKey: 'test-api-key', // Add API key for non-hosted env
apiKey: 'test-api-key',
}
mockGetProviderFromModel.mockReturnValue('openai')
@@ -193,7 +233,6 @@ describe('AgentBlockHandler', () => {
Promise.all = vi.fn().mockImplementation((promises: Promise<any>[]) => {
const result = originalPromiseAll.call(Promise, promises)
// When result resolves, capture the tools
result.then((tools: any[]) => {
if (tools?.length) {
capturedTools = tools.filter((t) => t !== null)
@@ -255,7 +294,7 @@ describe('AgentBlockHandler', () => {
},
},
},
usageControl: 'auto',
usageControl: 'auto' as const,
},
{
type: 'custom-tool',
@@ -274,7 +313,7 @@ describe('AgentBlockHandler', () => {
},
},
},
usageControl: 'force',
usageControl: 'force' as const,
},
{
type: 'custom-tool',
@@ -293,7 +332,7 @@ describe('AgentBlockHandler', () => {
},
},
},
usageControl: 'none', // This tool should be filtered out
usageControl: 'none' as const,
},
],
}
@@ -355,21 +394,21 @@ describe('AgentBlockHandler', () => {
title: 'Tool 1',
type: 'tool-type-1',
operation: 'operation1',
usageControl: 'auto', // default setting
usageControl: 'auto' as const,
},
{
id: 'tool_2',
title: 'Tool 2',
type: 'tool-type-2',
operation: 'operation2',
usageControl: 'none', // should be filtered out
usageControl: 'none' as const,
},
{
id: 'tool_3',
title: 'Tool 3',
type: 'tool-type-3',
operation: 'operation3',
usageControl: 'force', // should be included
usageControl: 'force' as const,
},
],
}
@@ -400,14 +439,14 @@ describe('AgentBlockHandler', () => {
title: 'Tool 1',
type: 'tool-type-1',
operation: 'operation1',
usageControl: 'auto',
usageControl: 'auto' as const,
},
{
id: 'tool_2',
title: 'Tool 2',
type: 'tool-type-2',
operation: 'operation2',
usageControl: 'force',
usageControl: 'force' as const,
},
],
}
@@ -449,7 +488,7 @@ describe('AgentBlockHandler', () => {
},
},
},
usageControl: 'auto',
usageControl: 'auto' as const,
},
{
type: 'custom-tool',
@@ -464,7 +503,7 @@ describe('AgentBlockHandler', () => {
},
},
},
usageControl: 'force',
usageControl: 'force' as const,
},
{
type: 'custom-tool',
@@ -479,7 +518,7 @@ describe('AgentBlockHandler', () => {
},
},
},
usageControl: 'none', // Should be filtered out
usageControl: 'none' as const,
},
],
}
@@ -635,6 +674,8 @@ describe('AgentBlockHandler', () => {
model: 'mock-model',
tokens: { prompt: 10, completion: 20, total: 30 },
timing: { total: 100 },
toolCalls: [],
cost: undefined,
}),
})
})
@@ -654,7 +695,9 @@ describe('AgentBlockHandler', () => {
result: 'Success',
score: 0.95,
tokens: { prompt: 10, completion: 20, total: 30 },
toolCalls: { list: [], count: 0 },
providerTiming: { total: 100 },
cost: undefined,
},
})
})
@@ -729,23 +772,35 @@ describe('AgentBlockHandler', () => {
})
it('should handle streaming responses with text/event-stream content type', async () => {
const mockStreamBody = {
getReader: vi.fn().mockReturnValue({
read: vi.fn().mockResolvedValue({ done: true, value: undefined }),
}),
}
const mockStreamBody = new ReadableStream({
start(controller) {
controller.close()
},
})
mockFetch.mockImplementationOnce(() => {
return Promise.resolve({
ok: true,
headers: {
get: (name: string) => {
if (name === 'Content-Type') return 'text/event-stream'
if (name === 'Content-Type') return 'application/json'
if (name === 'X-Execution-Data') return null
return null
},
},
body: mockStreamBody,
json: () =>
Promise.resolve({
stream: mockStreamBody,
execution: {
success: true,
output: { response: {} },
logs: [],
metadata: {
duration: 0,
startTime: new Date().toISOString(),
},
},
}),
})
})
@@ -771,11 +826,11 @@ describe('AgentBlockHandler', () => {
})
it('should handle streaming responses with execution data in header', async () => {
const mockStreamBody = {
getReader: vi.fn().mockReturnValue({
read: vi.fn().mockResolvedValue({ done: true, value: undefined }),
}),
}
const mockStreamBody = new ReadableStream({
start(controller) {
controller.close()
},
})
const mockExecutionData = {
success: true,
@@ -807,12 +862,16 @@ describe('AgentBlockHandler', () => {
ok: true,
headers: {
get: (name: string) => {
if (name === 'Content-Type') return 'text/event-stream'
if (name === 'Content-Type') return 'application/json'
if (name === 'X-Execution-Data') return JSON.stringify(mockExecutionData)
return null
},
},
body: mockStreamBody,
json: () =>
Promise.resolve({
stream: mockStreamBody,
execution: mockExecutionData,
}),
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,35 @@
export interface AgentInputs {
model?: string
responseFormat?: string | object
tools?: ToolInput[]
systemPrompt?: string
userPrompt?: string | object
memories?: any
temperature?: number
maxTokens?: number
apiKey?: string
}
export interface ToolInput {
type?: string
schema?: any
title?: string
code?: string
params?: Record<string, any>
timeout?: number
usageControl?: 'auto' | 'force' | 'none'
operation?: string
}
export interface Message {
role: 'system' | 'user' | 'assistant'
content: string
function_call?: any
tool_calls?: any[]
}
export interface StreamingConfig {
shouldUseStreaming: boolean
isBlockSelectedForOutput: boolean
hasOutgoingConnections: boolean
}

View File

@@ -1,5 +1,3 @@
import '../../__test-utils__/mock-dependencies'
import { beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
import type { BlockOutput } from '@/blocks/types'
import type { SerializedBlock } from '@/serializer/types'
@@ -7,6 +5,19 @@ import { executeTool } from '@/tools'
import type { ExecutionContext } from '../../types'
import { FunctionBlockHandler } from './function-handler'
vi.mock('@/lib/logs/console-logger', () => ({
createLogger: vi.fn(() => ({
info: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
debug: vi.fn(),
})),
}))
vi.mock('@/tools', () => ({
executeTool: vi.fn(),
}))
const mockExecuteTool = executeTool as Mock
describe('FunctionBlockHandler', () => {
@@ -58,10 +69,14 @@ describe('FunctionBlockHandler', () => {
const inputs = {
code: 'console.log("Hello"); return 1 + 1;',
timeout: 10000,
envVars: {},
isCustomTool: false,
workflowId: undefined,
}
const expectedToolParams = {
code: inputs.code,
timeout: inputs.timeout,
envVars: {},
_context: { workflowId: mockContext.workflowId },
}
const expectedOutput: BlockOutput = { response: { result: 'Success' } }
@@ -76,11 +91,15 @@ describe('FunctionBlockHandler', () => {
const inputs = {
code: [{ content: 'const x = 5;' }, { content: 'return x * 2;' }],
timeout: 5000,
envVars: {},
isCustomTool: false,
workflowId: undefined,
}
const expectedCode = 'const x = 5;\nreturn x * 2;'
const expectedToolParams = {
code: expectedCode,
timeout: inputs.timeout,
envVars: {},
_context: { workflowId: mockContext.workflowId },
}
const expectedOutput: BlockOutput = { response: { result: 'Success' } }
@@ -96,6 +115,7 @@ describe('FunctionBlockHandler', () => {
const expectedToolParams = {
code: inputs.code,
timeout: 5000, // Default timeout
envVars: {},
_context: { workflowId: mockContext.workflowId },
}

View File

@@ -28,6 +28,7 @@ export class FunctionBlockHandler implements BlockHandler {
const result = await executeTool('function_execute', {
code: codeContent,
timeout: inputs.timeout || 5000,
envVars: context.environmentVariables || {},
_context: { workflowId: context.workflowId },
})

View File

@@ -459,6 +459,7 @@ ${fieldDescriptions}
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
...(request.environmentVariables ? { envVars: request.environmentVariables } : {}),
}
const result = await executeTool(toolName, mergedArgs, true)
const toolCallEndTime = Date.now()

View File

@@ -288,6 +288,7 @@ export const deepseekProvider: ProviderConfig = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
...(request.environmentVariables ? { envVars: request.environmentVariables } : {}),
}
const result = await executeTool(toolName, mergedArgs, true)
const toolCallEndTime = Date.now()

View File

@@ -368,6 +368,7 @@ export const googleProvider: ProviderConfig = {
...toolArgs, // Arguments from the model's function call
...requiredToolCallParams, // Required parameters from the tool definition (take precedence)
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
...(request.environmentVariables ? { envVars: request.environmentVariables } : {}),
}
// For debugging only - don't log actual API keys

View File

@@ -360,6 +360,7 @@ export const openaiProvider: ProviderConfig = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
...(request.environmentVariables ? { envVars: request.environmentVariables } : {}),
}
const result = await executeTool(toolName, mergedArgs, true)

View File

@@ -147,6 +147,7 @@ export interface ProviderRequest {
local_execution?: boolean
workflowId?: string // Optional workflow ID for authentication context
stream?: boolean
environmentVariables?: Record<string, string> // Environment variables for tool execution
}
// Map of provider IDs to their configurations

View File

@@ -32,27 +32,34 @@ export const useCustomToolsStore = create<CustomToolsStore>()(
throw new Error('Invalid response format')
}
// Validate each tool object's structure before processing
data.forEach((tool, index) => {
// Filter and validate tools, skipping invalid ones instead of throwing errors
const validTools = data.filter((tool, index) => {
if (!tool || typeof tool !== 'object') {
throw new Error(`Invalid tool format at index ${index}: not an object`)
logger.warn(`Skipping invalid tool at index ${index}: not an object`)
return false
}
if (!tool.id || typeof tool.id !== 'string') {
throw new Error(`Invalid tool format at index ${index}: missing or invalid id`)
logger.warn(`Skipping invalid tool at index ${index}: missing or invalid id`)
return false
}
if (!tool.title || typeof tool.title !== 'string') {
throw new Error(`Invalid tool format at index ${index}: missing or invalid title`)
logger.warn(`Skipping invalid tool at index ${index}: missing or invalid title`)
return false
}
if (!tool.schema || typeof tool.schema !== 'object') {
throw new Error(`Invalid tool format at index ${index}: missing or invalid schema`)
logger.warn(`Skipping invalid tool at index ${index}: missing or invalid schema`)
return false
}
// Make code field optional - default to empty string if missing
if (!tool.code || typeof tool.code !== 'string') {
throw new Error(`Invalid tool format at index ${index}: missing or invalid code`)
logger.warn(`Tool at index ${index} missing code field, defaulting to empty string`)
tool.code = ''
}
return true
})
// Transform to local format and set
const transformedTools = data.reduce(
const transformedTools = validTools.reduce(
(acc, tool) => ({
...acc,
[tool.id]: tool,
@@ -60,8 +67,6 @@ export const useCustomToolsStore = create<CustomToolsStore>()(
{}
)
logger.info(`Loaded ${data.length} custom tools from server`)
set({
tools: transformedTools,
isLoading: false,
@@ -72,12 +77,6 @@ export const useCustomToolsStore = create<CustomToolsStore>()(
error: error instanceof Error ? error.message : 'Unknown error',
isLoading: false,
})
// Add a delay before reloading to prevent race conditions
setTimeout(() => {
// Reload from server to ensure consistency
get().loadCustomTools()
}, 500)
}
},
@@ -121,21 +120,12 @@ export const useCustomToolsStore = create<CustomToolsStore>()(
set({ isLoading: false })
logger.info('Successfully synced custom tools with server')
// Load from server to ensure consistency even after successful sync
get().loadCustomTools()
} catch (error) {
logger.error('Error syncing custom tools:', error)
set({
error: error instanceof Error ? error.message : 'Unknown error',
isLoading: false,
})
// Add a delay before reloading to prevent race conditions
setTimeout(() => {
// Reload from server to ensure consistency
get().loadCustomTools()
}, 500)
}
},

View File

@@ -41,12 +41,18 @@ describe('Function Execute Tool', () => {
test('should format single string code correctly', () => {
const body = tester.getRequestBody({
code: 'return 42',
envVars: {},
isCustomTool: false,
timeout: 5000,
workflowId: undefined,
})
expect(body).toEqual({
code: 'return 42',
envVars: {},
isCustomTool: false,
timeout: 5000,
workflowId: undefined,
})
})
@@ -57,11 +63,18 @@ describe('Function Execute Tool', () => {
{ content: 'const y = 2;', id: 'block2' },
{ content: 'return x + y;', id: 'block3' },
],
envVars: {},
isCustomTool: false,
timeout: 10000,
workflowId: undefined,
})
expect(body).toEqual({
code: 'const x = 40;\nconst y = 2;\nreturn x + y;',
timeout: 10000,
envVars: {},
isCustomTool: false,
workflowId: undefined,
})
})
@@ -73,6 +86,9 @@ describe('Function Execute Tool', () => {
expect(body).toEqual({
code: 'return 42',
timeout: 10000,
envVars: {},
isCustomTool: false,
workflowId: undefined,
})
})
})

View File

@@ -22,6 +22,12 @@ export const functionExecuteTool: ToolConfig<CodeExecutionInput, CodeExecutionOu
description: 'Execution timeout in milliseconds',
default: DEFAULT_TIMEOUT,
},
envVars: {
type: 'object',
required: false,
description: 'Environment variables to make available during execution',
default: {},
},
},
request: {
@@ -38,6 +44,9 @@ export const functionExecuteTool: ToolConfig<CodeExecutionInput, CodeExecutionOu
return {
code: codeContent,
timeout: params.timeout || DEFAULT_TIMEOUT,
envVars: params.envVars || {},
workflowId: params._context?.workflowId,
isCustomTool: params.isCustomTool || false,
}
},
isInternalRoute: true,

View File

@@ -4,6 +4,11 @@ export interface CodeExecutionInput {
code: Array<{ content: string; id: string }> | string
timeout?: number
memoryLimit?: number
envVars?: Record<string, string>
_context?: {
workflowId?: string
}
isCustomTool?: boolean
}
export interface CodeExecutionOutput extends ToolResponse {

View File

@@ -495,7 +495,7 @@ function validateClientSideParams(
}
// Internal parameters that should be excluded from validation
const internalParamSet = new Set(['_context', 'workflowId'])
const internalParamSet = new Set(['_context', 'workflowId', 'envVars'])
// Check required parameters
if (schema.required) {

View File

@@ -249,8 +249,11 @@ export function createCustomToolRequestBody(
getStore?: () => any
) {
return (params: Record<string, any>) => {
// Get environment variables - empty on server, from store on client
const envVars = isClient ? getClientEnvVars(getStore) : {}
// Get environment variables - try multiple sources in order of preference:
// 1. envVars parameter (passed from provider/agent context)
// 2. Client-side store (if running in browser)
// 3. Empty object (fallback)
const envVars = params.envVars || (isClient ? getClientEnvVars(getStore) : {})
// Include everything needed for execution
return {