mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-10 15:38:00 -05:00
fix(custom-tools): updates to legacy + copilot generated custom tools (#1960)
* fix(custom-tools): updates to existing tools * don't reorder custom tools in modal based on edit time * restructure custom tools to persist copilot generated tools * fix tests
This commit is contained in:
committed by
GitHub
parent
6513cbb7c1
commit
d50aefcc68
@@ -70,6 +70,7 @@ describe('Custom Tools API Routes', () => {
|
||||
const mockSelect = vi.fn()
|
||||
const mockFrom = vi.fn()
|
||||
const mockWhere = vi.fn()
|
||||
const mockOrderBy = vi.fn()
|
||||
const mockInsert = vi.fn()
|
||||
const mockValues = vi.fn()
|
||||
const mockUpdate = vi.fn()
|
||||
@@ -84,10 +85,23 @@ describe('Custom Tools API Routes', () => {
|
||||
// Reset all mock implementations
|
||||
mockSelect.mockReturnValue({ from: mockFrom })
|
||||
mockFrom.mockReturnValue({ where: mockWhere })
|
||||
// where() can be called with limit() or directly awaited
|
||||
// Create a mock query builder that supports both patterns
|
||||
// where() can be called with orderBy(), limit(), or directly awaited
|
||||
// Create a mock query builder that supports all patterns
|
||||
mockWhere.mockImplementation((condition) => {
|
||||
// Return an object that is both awaitable and has a limit() method
|
||||
// Return an object that is both awaitable and has orderBy() and limit() methods
|
||||
const queryBuilder = {
|
||||
orderBy: mockOrderBy,
|
||||
limit: mockLimit,
|
||||
then: (resolve: (value: typeof sampleTools) => void) => {
|
||||
resolve(sampleTools)
|
||||
return queryBuilder
|
||||
},
|
||||
catch: (reject: (error: Error) => void) => queryBuilder,
|
||||
}
|
||||
return queryBuilder
|
||||
})
|
||||
mockOrderBy.mockImplementation(() => {
|
||||
// orderBy returns an awaitable query builder
|
||||
const queryBuilder = {
|
||||
limit: mockLimit,
|
||||
then: (resolve: (value: typeof sampleTools) => void) => {
|
||||
@@ -120,9 +134,22 @@ describe('Custom Tools API Routes', () => {
|
||||
const txMockUpdate = vi.fn().mockReturnValue({ set: mockSet })
|
||||
const txMockDelete = vi.fn().mockReturnValue({ where: mockWhere })
|
||||
|
||||
// Transaction where() should also support the query builder pattern
|
||||
// Transaction where() should also support the query builder pattern with orderBy
|
||||
const txMockOrderBy = vi.fn().mockImplementation(() => {
|
||||
const queryBuilder = {
|
||||
limit: mockLimit,
|
||||
then: (resolve: (value: typeof sampleTools) => void) => {
|
||||
resolve(sampleTools)
|
||||
return queryBuilder
|
||||
},
|
||||
catch: (reject: (error: Error) => void) => queryBuilder,
|
||||
}
|
||||
return queryBuilder
|
||||
})
|
||||
|
||||
const txMockWhere = vi.fn().mockImplementation((condition) => {
|
||||
const queryBuilder = {
|
||||
orderBy: txMockOrderBy,
|
||||
limit: mockLimit,
|
||||
then: (resolve: (value: typeof sampleTools) => void) => {
|
||||
resolve(sampleTools)
|
||||
@@ -201,6 +228,7 @@ describe('Custom Tools API Routes', () => {
|
||||
or: vi.fn().mockImplementation((...conditions) => ({ operator: 'or', conditions })),
|
||||
isNull: vi.fn().mockImplementation((field) => ({ field, operator: 'isNull' })),
|
||||
ne: vi.fn().mockImplementation((field, value) => ({ field, value, operator: 'ne' })),
|
||||
desc: vi.fn().mockImplementation((field) => ({ field, operator: 'desc' })),
|
||||
}
|
||||
})
|
||||
|
||||
@@ -208,6 +236,11 @@ describe('Custom Tools API Routes', () => {
|
||||
vi.doMock('@/lib/utils', () => ({
|
||||
generateRequestId: vi.fn().mockReturnValue('test-request-id'),
|
||||
}))
|
||||
|
||||
// Mock custom tools operations
|
||||
vi.doMock('@/lib/custom-tools/operations', () => ({
|
||||
upsertCustomTools: vi.fn().mockResolvedValue(sampleTools),
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
@@ -224,8 +257,10 @@ describe('Custom Tools API Routes', () => {
|
||||
'http://localhost:3000/api/tools/custom?workspaceId=workspace-123'
|
||||
)
|
||||
|
||||
// Simulate DB returning tools
|
||||
mockWhere.mockReturnValueOnce(Promise.resolve(sampleTools))
|
||||
// Simulate DB returning tools with orderBy chain
|
||||
mockWhere.mockReturnValueOnce({
|
||||
orderBy: mockOrderBy.mockReturnValueOnce(Promise.resolve(sampleTools)),
|
||||
})
|
||||
|
||||
// Import handler after mocks are set up
|
||||
const { GET } = await import('@/app/api/tools/custom/route')
|
||||
@@ -243,6 +278,7 @@ describe('Custom Tools API Routes', () => {
|
||||
expect(mockSelect).toHaveBeenCalled()
|
||||
expect(mockFrom).toHaveBeenCalled()
|
||||
expect(mockWhere).toHaveBeenCalled()
|
||||
expect(mockOrderBy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle unauthorized access', async () => {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { db } from '@sim/db'
|
||||
import { customTools, workflow } from '@sim/db/schema'
|
||||
import { and, eq, isNull, ne, or } from 'drizzle-orm'
|
||||
import { and, desc, eq, isNull, or } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { checkHybridAuth } from '@/lib/auth/hybrid'
|
||||
import { upsertCustomTools } from '@/lib/custom-tools/operations'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserEntityPermissions } from '@/lib/permissions/utils'
|
||||
import { generateRequestId } from '@/lib/utils'
|
||||
@@ -101,6 +102,7 @@ export async function GET(request: NextRequest) {
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(or(...conditions))
|
||||
.orderBy(desc(customTools.createdAt))
|
||||
|
||||
return NextResponse.json({ data: result }, { status: 200 })
|
||||
} catch (error) {
|
||||
@@ -150,96 +152,15 @@ export async function POST(req: NextRequest) {
|
||||
return NextResponse.json({ error: 'Write permission required' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Use a transaction for multi-step database operations
|
||||
return await db.transaction(async (tx) => {
|
||||
// Process each tool: either update existing or create new
|
||||
for (const tool of tools) {
|
||||
const nowTime = new Date()
|
||||
|
||||
if (tool.id) {
|
||||
// Check if tool exists and belongs to the workspace
|
||||
const existingTool = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(and(eq(customTools.id, tool.id), eq(customTools.workspaceId, workspaceId)))
|
||||
.limit(1)
|
||||
|
||||
if (existingTool.length > 0) {
|
||||
// Tool exists - check if name changed and if new name conflicts
|
||||
if (existingTool[0].title !== tool.title) {
|
||||
// Check for duplicate name in workspace (excluding current tool)
|
||||
const duplicateTool = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(
|
||||
and(
|
||||
eq(customTools.workspaceId, workspaceId),
|
||||
eq(customTools.title, tool.title),
|
||||
ne(customTools.id, tool.id)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (duplicateTool.length > 0) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `A tool with the name "${tool.title}" already exists in this workspace`,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Update existing tool
|
||||
await tx
|
||||
.update(customTools)
|
||||
.set({
|
||||
title: tool.title,
|
||||
schema: tool.schema,
|
||||
code: tool.code,
|
||||
updatedAt: nowTime,
|
||||
})
|
||||
.where(and(eq(customTools.id, tool.id), eq(customTools.workspaceId, workspaceId)))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Creating new tool - check for duplicate names in workspace
|
||||
const duplicateTool = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(and(eq(customTools.workspaceId, workspaceId), eq(customTools.title, tool.title)))
|
||||
.limit(1)
|
||||
|
||||
if (duplicateTool.length > 0) {
|
||||
return NextResponse.json(
|
||||
{ error: `A tool with the name "${tool.title}" already exists in this workspace` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
// Create new tool
|
||||
const newToolId = tool.id || crypto.randomUUID()
|
||||
await tx.insert(customTools).values({
|
||||
id: newToolId,
|
||||
workspaceId,
|
||||
userId,
|
||||
title: tool.title,
|
||||
schema: tool.schema,
|
||||
code: tool.code,
|
||||
createdAt: nowTime,
|
||||
updatedAt: nowTime,
|
||||
})
|
||||
}
|
||||
|
||||
// Fetch and return the created/updated tools
|
||||
const resultTools = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(eq(customTools.workspaceId, workspaceId))
|
||||
|
||||
return NextResponse.json({ success: true, data: resultTools })
|
||||
// Use the extracted upsert function
|
||||
const resultTools = await upsertCustomTools({
|
||||
tools,
|
||||
workspaceId,
|
||||
userId,
|
||||
requestId,
|
||||
})
|
||||
|
||||
return NextResponse.json({ success: true, data: resultTools })
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
logger.warn(`[${requestId}] Invalid custom tools data`, {
|
||||
|
||||
@@ -477,6 +477,7 @@ export function ToolInput({
|
||||
const [draggedIndex, setDraggedIndex] = useState<number | null>(null)
|
||||
const [dragOverIndex, setDragOverIndex] = useState<number | null>(null)
|
||||
const customTools = useCustomToolsStore((state) => state.getAllTools())
|
||||
const fetchCustomTools = useCustomToolsStore((state) => state.fetchTools)
|
||||
const subBlockStore = useSubBlockStore()
|
||||
|
||||
// MCP tools integration
|
||||
@@ -487,6 +488,13 @@ export function ToolInput({
|
||||
refreshTools,
|
||||
} = useMcpTools(workspaceId)
|
||||
|
||||
// Fetch custom tools on mount
|
||||
useEffect(() => {
|
||||
if (workspaceId) {
|
||||
fetchCustomTools(workspaceId)
|
||||
}
|
||||
}, [workspaceId, fetchCustomTools])
|
||||
|
||||
// Get the current model from the 'model' subblock
|
||||
const modelValue = useSubBlockStore.getState().getValue(blockId, 'model')
|
||||
const model = typeof modelValue === 'string' ? modelValue : ''
|
||||
|
||||
138
apps/sim/lib/custom-tools/operations.ts
Normal file
138
apps/sim/lib/custom-tools/operations.ts
Normal file
@@ -0,0 +1,138 @@
|
||||
import { db } from '@sim/db'
|
||||
import { customTools } from '@sim/db/schema'
|
||||
import { and, desc, eq, isNull } from 'drizzle-orm'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { generateRequestId } from '@/lib/utils'
|
||||
|
||||
const logger = createLogger('CustomToolsOperations')
|
||||
|
||||
/**
|
||||
* Internal function to create/update custom tools
|
||||
* Can be called from API routes or internal services
|
||||
*/
|
||||
export async function upsertCustomTools(params: {
|
||||
tools: Array<{
|
||||
id?: string
|
||||
title: string
|
||||
schema: any
|
||||
code: string
|
||||
}>
|
||||
workspaceId: string
|
||||
userId: string
|
||||
requestId?: string
|
||||
}) {
|
||||
const { tools, workspaceId, userId, requestId = generateRequestId() } = params
|
||||
|
||||
// Use a transaction for multi-step database operations
|
||||
return await db.transaction(async (tx) => {
|
||||
// Process each tool: either update existing or create new
|
||||
for (const tool of tools) {
|
||||
const nowTime = new Date()
|
||||
|
||||
if (tool.id) {
|
||||
// First, check if tool exists in the workspace
|
||||
const existingWorkspaceTool = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(and(eq(customTools.id, tool.id), eq(customTools.workspaceId, workspaceId)))
|
||||
.limit(1)
|
||||
|
||||
if (existingWorkspaceTool.length > 0) {
|
||||
// Tool exists in workspace
|
||||
const newFunctionName = tool.schema?.function?.name
|
||||
if (!newFunctionName) {
|
||||
throw new Error('Tool schema must include a function name')
|
||||
}
|
||||
|
||||
// Check if function name has changed
|
||||
if (tool.id !== newFunctionName) {
|
||||
throw new Error(
|
||||
`Cannot change function name from "${tool.id}" to "${newFunctionName}". Please create a new tool instead.`
|
||||
)
|
||||
}
|
||||
|
||||
// Update existing workspace tool
|
||||
await tx
|
||||
.update(customTools)
|
||||
.set({
|
||||
title: tool.title,
|
||||
schema: tool.schema,
|
||||
code: tool.code,
|
||||
updatedAt: nowTime,
|
||||
})
|
||||
.where(and(eq(customTools.id, tool.id), eq(customTools.workspaceId, workspaceId)))
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a legacy tool (no workspaceId, belongs to user)
|
||||
const existingLegacyTool = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(
|
||||
and(
|
||||
eq(customTools.id, tool.id),
|
||||
isNull(customTools.workspaceId),
|
||||
eq(customTools.userId, userId)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (existingLegacyTool.length > 0) {
|
||||
// Legacy tool found - update it without migrating to workspace
|
||||
await tx
|
||||
.update(customTools)
|
||||
.set({
|
||||
title: tool.title,
|
||||
schema: tool.schema,
|
||||
code: tool.code,
|
||||
updatedAt: nowTime,
|
||||
})
|
||||
.where(eq(customTools.id, tool.id))
|
||||
|
||||
logger.info(`[${requestId}] Updated legacy tool ${tool.id}`)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Creating new tool - use function name as ID for consistency
|
||||
const functionName = tool.schema?.function?.name
|
||||
if (!functionName) {
|
||||
throw new Error('Tool schema must include a function name')
|
||||
}
|
||||
|
||||
// Check for duplicate function names in workspace
|
||||
const duplicateFunction = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(and(eq(customTools.workspaceId, workspaceId), eq(customTools.id, functionName)))
|
||||
.limit(1)
|
||||
|
||||
if (duplicateFunction.length > 0) {
|
||||
throw new Error(
|
||||
`A tool with the function name "${functionName}" already exists in this workspace`
|
||||
)
|
||||
}
|
||||
|
||||
// Create new tool using function name as ID
|
||||
await tx.insert(customTools).values({
|
||||
id: functionName,
|
||||
workspaceId,
|
||||
userId,
|
||||
title: tool.title,
|
||||
schema: tool.schema,
|
||||
code: tool.code,
|
||||
createdAt: nowTime,
|
||||
updatedAt: nowTime,
|
||||
})
|
||||
}
|
||||
|
||||
// Fetch and return the created/updated tools
|
||||
const resultTools = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(eq(customTools.workspaceId, workspaceId))
|
||||
.orderBy(desc(customTools.createdAt))
|
||||
|
||||
return resultTools
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
import { db } from '@sim/db'
|
||||
import { customTools } from '@sim/db/schema'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { upsertCustomTools } from '@/lib/custom-tools/operations'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('CustomToolsPersistence')
|
||||
@@ -91,7 +89,7 @@ export function extractCustomToolsFromWorkflowState(workflowState: any): CustomT
|
||||
}
|
||||
|
||||
/**
|
||||
* Persist custom tools to the database
|
||||
* Persist custom tools to the database using the upsert function
|
||||
* Creates new tools or updates existing ones
|
||||
*/
|
||||
export async function persistCustomToolsToDatabase(
|
||||
@@ -113,70 +111,36 @@ export async function persistCustomToolsToDatabase(
|
||||
const errors: string[] = []
|
||||
let saved = 0
|
||||
|
||||
// Filter out tools without function names
|
||||
const validTools = customToolsList.filter((tool) => {
|
||||
if (!tool.schema?.function?.name) {
|
||||
logger.warn(`Skipping custom tool without function name: ${tool.title}`)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if (validTools.length === 0) {
|
||||
return { saved: 0, errors: [] }
|
||||
}
|
||||
|
||||
try {
|
||||
await db.transaction(async (tx) => {
|
||||
for (const tool of customToolsList) {
|
||||
try {
|
||||
// Extract the base identifier (without 'custom_' prefix) for database storage
|
||||
// If toolId exists and has the prefix, strip it; otherwise use title as base
|
||||
let baseId: string
|
||||
if (tool.toolId) {
|
||||
baseId = tool.toolId.startsWith('custom_')
|
||||
? tool.toolId.replace('custom_', '')
|
||||
: tool.toolId
|
||||
} else {
|
||||
// Use title as the base identifier (agent handler will add 'custom_' prefix)
|
||||
baseId = tool.title
|
||||
}
|
||||
|
||||
const nowTime = new Date()
|
||||
|
||||
// Check if tool already exists in this workspace
|
||||
const existingTool = await tx
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(and(eq(customTools.id, baseId), eq(customTools.workspaceId, workspaceId)))
|
||||
.limit(1)
|
||||
|
||||
if (existingTool.length === 0) {
|
||||
// Create new tool
|
||||
await tx.insert(customTools).values({
|
||||
id: baseId,
|
||||
workspaceId,
|
||||
userId,
|
||||
title: tool.title,
|
||||
schema: tool.schema,
|
||||
code: tool.code,
|
||||
createdAt: nowTime,
|
||||
updatedAt: nowTime,
|
||||
})
|
||||
|
||||
logger.info(`Created custom tool: ${tool.title}`, { toolId: baseId, workspaceId })
|
||||
saved++
|
||||
} else {
|
||||
// Update existing tool in workspace (workspace members can update)
|
||||
await tx
|
||||
.update(customTools)
|
||||
.set({
|
||||
title: tool.title,
|
||||
schema: tool.schema,
|
||||
code: tool.code,
|
||||
updatedAt: nowTime,
|
||||
})
|
||||
.where(and(eq(customTools.id, baseId), eq(customTools.workspaceId, workspaceId)))
|
||||
|
||||
logger.info(`Updated custom tool: ${tool.title}`, { toolId: baseId, workspaceId })
|
||||
saved++
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMsg = `Failed to persist tool ${tool.title}: ${error instanceof Error ? error.message : String(error)}`
|
||||
logger.error(errorMsg, { error })
|
||||
errors.push(errorMsg)
|
||||
}
|
||||
}
|
||||
// Call the upsert function from lib
|
||||
await upsertCustomTools({
|
||||
tools: validTools.map((tool) => ({
|
||||
id: tool.schema.function.name, // Use function name as ID for updates
|
||||
title: tool.title,
|
||||
schema: tool.schema,
|
||||
code: tool.code,
|
||||
})),
|
||||
workspaceId,
|
||||
userId,
|
||||
})
|
||||
|
||||
saved = validTools.length
|
||||
logger.info(`Persisted ${saved} custom tool(s)`, { workspaceId })
|
||||
} catch (error) {
|
||||
const errorMsg = `Transaction failed while persisting custom tools: ${error instanceof Error ? error.message : String(error)}`
|
||||
const errorMsg = `Failed to persist custom tools: ${error instanceof Error ? error.message : String(error)}`
|
||||
logger.error(errorMsg, { error })
|
||||
errors.push(errorMsg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user