mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
fix(tools) Directly query db for custom tool id (#3875)
* Directly query db for custom tool id * Switch back to inline imports * Fix lint * Fix test * Fix greptile comments * Fix lint * Make userId and workspaceId required * Add back nullable userId and workspaceId fields --------- Co-authored-by: Theodore Li <theo@sim.ai>
This commit is contained in:
@@ -21,6 +21,7 @@ vi.mock('@/lib/core/config/feature-flags', () => ({
|
||||
isEmailVerificationEnabled: false,
|
||||
isBillingEnabled: false,
|
||||
isOrganizationsEnabled: false,
|
||||
isAccessControlEnabled: false,
|
||||
}))
|
||||
|
||||
vi.mock('@/providers/utils', () => ({
|
||||
@@ -110,6 +111,12 @@ vi.mock('@sim/db/schema', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
const mockGetCustomToolById = vi.fn()
|
||||
|
||||
vi.mock('@/lib/workflows/custom-tools/operations', () => ({
|
||||
getCustomToolById: (...args: unknown[]) => mockGetCustomToolById(...args),
|
||||
}))
|
||||
|
||||
setupGlobalFetchMock()
|
||||
|
||||
const mockGetAllBlocks = getAllBlocks as Mock
|
||||
@@ -1957,49 +1964,22 @@ describe('AgentBlockHandler', () => {
|
||||
const staleInlineCode = 'return { title, content };'
|
||||
const dbCode = 'return { title, content, format };'
|
||||
|
||||
function mockFetchForCustomTool(toolId: string) {
|
||||
mockFetch.mockImplementation((url: string) => {
|
||||
if (typeof url === 'string' && url.includes('/api/tools/custom')) {
|
||||
function mockDBForCustomTool(toolId: string) {
|
||||
mockGetCustomToolById.mockImplementation(({ toolId: id }: { toolId: string }) => {
|
||||
if (id === toolId) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
headers: { get: () => null },
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
data: [
|
||||
{
|
||||
id: toolId,
|
||||
title: 'formatReport',
|
||||
schema: dbSchema,
|
||||
code: dbCode,
|
||||
},
|
||||
],
|
||||
}),
|
||||
id: toolId,
|
||||
title: 'formatReport',
|
||||
schema: dbSchema,
|
||||
code: dbCode,
|
||||
})
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
headers: { get: () => null },
|
||||
json: () => Promise.resolve({}),
|
||||
})
|
||||
return Promise.resolve(null)
|
||||
})
|
||||
}
|
||||
|
||||
function mockFetchFailure() {
|
||||
mockFetch.mockImplementation((url: string) => {
|
||||
if (typeof url === 'string' && url.includes('/api/tools/custom')) {
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 500,
|
||||
headers: { get: () => null },
|
||||
json: () => Promise.resolve({}),
|
||||
})
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
headers: { get: () => null },
|
||||
json: () => Promise.resolve({}),
|
||||
})
|
||||
})
|
||||
function mockDBFailure() {
|
||||
mockGetCustomToolById.mockRejectedValue(new Error('DB connection failed'))
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -2008,11 +1988,13 @@ describe('AgentBlockHandler', () => {
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
mockGetCustomToolById.mockReset()
|
||||
mockContext.userId = 'test-user'
|
||||
})
|
||||
|
||||
it('should always fetch latest schema from DB when customToolId is present', async () => {
|
||||
const toolId = 'custom-tool-123'
|
||||
mockFetchForCustomTool(toolId)
|
||||
mockDBForCustomTool(toolId)
|
||||
|
||||
const inputs = {
|
||||
model: 'gpt-4o',
|
||||
@@ -2046,7 +2028,7 @@ describe('AgentBlockHandler', () => {
|
||||
|
||||
it('should fetch from DB when customToolId has no inline schema', async () => {
|
||||
const toolId = 'custom-tool-123'
|
||||
mockFetchForCustomTool(toolId)
|
||||
mockDBForCustomTool(toolId)
|
||||
|
||||
const inputs = {
|
||||
model: 'gpt-4o',
|
||||
@@ -2075,7 +2057,7 @@ describe('AgentBlockHandler', () => {
|
||||
})
|
||||
|
||||
it('should fall back to inline schema when DB fetch fails and inline exists', async () => {
|
||||
mockFetchFailure()
|
||||
mockDBFailure()
|
||||
|
||||
const inputs = {
|
||||
model: 'gpt-4o',
|
||||
@@ -2107,7 +2089,7 @@ describe('AgentBlockHandler', () => {
|
||||
})
|
||||
|
||||
it('should return null when DB fetch fails and no inline schema exists', async () => {
|
||||
mockFetchFailure()
|
||||
mockDBFailure()
|
||||
|
||||
const inputs = {
|
||||
model: 'gpt-4o',
|
||||
@@ -2135,7 +2117,7 @@ describe('AgentBlockHandler', () => {
|
||||
|
||||
it('should use DB schema when customToolId resolves', async () => {
|
||||
const toolId = 'custom-tool-123'
|
||||
mockFetchForCustomTool(toolId)
|
||||
mockDBForCustomTool(toolId)
|
||||
|
||||
const inputs = {
|
||||
model: 'gpt-4o',
|
||||
@@ -2185,10 +2167,7 @@ describe('AgentBlockHandler', () => {
|
||||
|
||||
await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
const customToolFetches = mockFetch.mock.calls.filter(
|
||||
(call: any[]) => typeof call[0] === 'string' && call[0].includes('/api/tools/custom')
|
||||
)
|
||||
expect(customToolFetches.length).toBe(0)
|
||||
expect(mockGetCustomToolById).not.toHaveBeenCalled()
|
||||
|
||||
expect(mockExecuteProviderRequest).toHaveBeenCalled()
|
||||
const providerCall = mockExecuteProviderRequest.mock.calls[0]
|
||||
|
||||
@@ -3,6 +3,7 @@ import { mcpServers } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { and, eq, inArray, isNull } from 'drizzle-orm'
|
||||
import { createMcpToolId } from '@/lib/mcp/utils'
|
||||
import { getCustomToolById } from '@/lib/workflows/custom-tools/operations'
|
||||
import { getAllBlocks } from '@/blocks'
|
||||
import type { BlockOutput } from '@/blocks/types'
|
||||
import {
|
||||
@@ -277,39 +278,18 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
ctx: ExecutionContext,
|
||||
customToolId: string
|
||||
): Promise<{ schema: any; title: string } | null> {
|
||||
if (!ctx.userId) {
|
||||
logger.error('Cannot fetch custom tool without userId:', { customToolId })
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const headers = await buildAuthHeaders(ctx.userId)
|
||||
const params: Record<string, string> = {}
|
||||
|
||||
if (ctx.workspaceId) {
|
||||
params.workspaceId = ctx.workspaceId
|
||||
}
|
||||
if (ctx.workflowId) {
|
||||
params.workflowId = ctx.workflowId
|
||||
}
|
||||
if (ctx.userId) {
|
||||
params.userId = ctx.userId
|
||||
}
|
||||
|
||||
const url = buildAPIUrl('/api/tools/custom', params)
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers,
|
||||
const tool = await getCustomToolById({
|
||||
toolId: customToolId,
|
||||
userId: ctx.userId,
|
||||
workspaceId: ctx.workspaceId,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
await response.text().catch(() => {})
|
||||
logger.error(`Failed to fetch custom tools: ${response.status}`)
|
||||
return null
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
if (!data.data || !Array.isArray(data.data)) {
|
||||
logger.error('Invalid custom tools API response')
|
||||
return null
|
||||
}
|
||||
|
||||
const tool = data.data.find((t: any) => t.id === customToolId)
|
||||
if (!tool) {
|
||||
logger.warn(`Custom tool not found by ID: ${customToolId}`)
|
||||
return null
|
||||
|
||||
@@ -158,6 +158,32 @@ export async function getCustomToolById(params: {
|
||||
return legacyTool[0] || null
|
||||
}
|
||||
|
||||
export async function getCustomToolByIdOrTitle(params: {
|
||||
identifier: string
|
||||
userId: string
|
||||
workspaceId?: string
|
||||
}) {
|
||||
const { identifier, userId, workspaceId } = params
|
||||
|
||||
const conditions = [or(eq(customTools.id, identifier), eq(customTools.title, identifier))]
|
||||
|
||||
if (workspaceId) {
|
||||
const workspaceTool = await db
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(and(eq(customTools.workspaceId, workspaceId), ...conditions))
|
||||
.limit(1)
|
||||
if (workspaceTool[0]) return workspaceTool[0]
|
||||
}
|
||||
|
||||
const legacyTool = await db
|
||||
.select()
|
||||
.from(customTools)
|
||||
.where(and(isNull(customTools.workspaceId), eq(customTools.userId, userId), ...conditions))
|
||||
.limit(1)
|
||||
return legacyTool[0] || null
|
||||
}
|
||||
|
||||
export async function deleteCustomTool(params: {
|
||||
toolId: string
|
||||
userId: string
|
||||
|
||||
@@ -16,19 +16,29 @@ import {
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// Hoisted mock state - these are available to vi.mock factories
|
||||
const { mockIsHosted, mockEnv, mockGetBYOKKey, mockGetToolAsync, mockRateLimiterFns } = vi.hoisted(
|
||||
() => ({
|
||||
mockIsHosted: { value: false },
|
||||
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
|
||||
mockGetBYOKKey: vi.fn(),
|
||||
mockGetToolAsync: vi.fn(),
|
||||
mockRateLimiterFns: {
|
||||
acquireKey: vi.fn(),
|
||||
preConsumeCapacity: vi.fn(),
|
||||
consumeCapacity: vi.fn(),
|
||||
},
|
||||
})
|
||||
)
|
||||
const {
|
||||
mockIsHosted,
|
||||
mockEnv,
|
||||
mockGetBYOKKey,
|
||||
mockGetToolAsync,
|
||||
mockRateLimiterFns,
|
||||
mockGetCustomToolById,
|
||||
mockListCustomTools,
|
||||
mockGetCustomToolByIdOrTitle,
|
||||
} = vi.hoisted(() => ({
|
||||
mockIsHosted: { value: false },
|
||||
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
|
||||
mockGetBYOKKey: vi.fn(),
|
||||
mockGetToolAsync: vi.fn(),
|
||||
mockRateLimiterFns: {
|
||||
acquireKey: vi.fn(),
|
||||
preConsumeCapacity: vi.fn(),
|
||||
consumeCapacity: vi.fn(),
|
||||
},
|
||||
mockGetCustomToolById: vi.fn(),
|
||||
mockListCustomTools: vi.fn(),
|
||||
mockGetCustomToolByIdOrTitle: vi.fn(),
|
||||
}))
|
||||
|
||||
// Mock feature flags
|
||||
vi.mock('@/lib/core/config/feature-flags', () => ({
|
||||
@@ -214,6 +224,12 @@ vi.mock('@/hooks/queries/utils/custom-tool-cache', () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/lib/workflows/custom-tools/operations', () => ({
|
||||
getCustomToolById: mockGetCustomToolById,
|
||||
listCustomTools: mockListCustomTools,
|
||||
getCustomToolByIdOrTitle: mockGetCustomToolByIdOrTitle,
|
||||
}))
|
||||
|
||||
vi.mock('@/tools/utils.server', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/tools/utils.server')>()
|
||||
mockGetToolAsync.mockImplementation(actual.getToolAsync)
|
||||
@@ -307,30 +323,23 @@ describe('Custom Tools', () => {
|
||||
})
|
||||
|
||||
it('resolves custom tools through the async helper', async () => {
|
||||
setupFetchMock({
|
||||
json: {
|
||||
data: [
|
||||
{
|
||||
id: 'remote-tool-123',
|
||||
title: 'Custom Weather Tool',
|
||||
schema: {
|
||||
function: {
|
||||
name: 'weather_tool',
|
||||
description: 'Get weather information',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string', description: 'City name' },
|
||||
},
|
||||
required: ['location'],
|
||||
},
|
||||
},
|
||||
mockGetCustomToolByIdOrTitle.mockResolvedValue({
|
||||
id: 'remote-tool-123',
|
||||
title: 'Custom Weather Tool',
|
||||
schema: {
|
||||
function: {
|
||||
name: 'weather_tool',
|
||||
description: 'Get weather information',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string', description: 'City name' },
|
||||
},
|
||||
required: ['location'],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
code: '',
|
||||
})
|
||||
|
||||
const customTool = await getToolAsync('custom_remote-tool-123', {
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { generateInternalToken } from '@/lib/auth/internal'
|
||||
import {
|
||||
secureFetchWithPinnedIP,
|
||||
validateUrlWithDNS,
|
||||
} from '@/lib/core/security/input-validation.server'
|
||||
import { getInternalApiBaseUrl } from '@/lib/core/utils/urls'
|
||||
import { getCustomToolByIdOrTitle } from '@/lib/workflows/custom-tools/operations'
|
||||
import { isCustomTool } from '@/executor/constants'
|
||||
import type { CustomToolDefinition } from '@/hooks/queries/custom-tools'
|
||||
import { extractErrorMessage } from '@/tools/error-extractors'
|
||||
@@ -97,67 +96,39 @@ export async function getToolAsync(
|
||||
if (builtInTool) return builtInTool
|
||||
|
||||
if (isCustomTool(toolId)) {
|
||||
return fetchCustomToolFromAPI(toolId, context)
|
||||
return fetchCustomToolFromDB(toolId, context)
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
async function fetchCustomToolFromAPI(
|
||||
async function fetchCustomToolFromDB(
|
||||
customToolId: string,
|
||||
context: GetToolAsyncContext
|
||||
): Promise<ToolConfig | undefined> {
|
||||
const { workflowId, userId, workspaceId } = context
|
||||
const identifier = customToolId.replace('custom_', '')
|
||||
|
||||
if (!userId) {
|
||||
throw new Error(`Cannot fetch custom tool without userId: ${identifier}`)
|
||||
}
|
||||
if (!workspaceId) {
|
||||
throw new Error(`Cannot fetch custom tool without workspaceId: ${identifier}`)
|
||||
}
|
||||
|
||||
try {
|
||||
const baseUrl = getInternalApiBaseUrl()
|
||||
const url = new URL('/api/tools/custom', baseUrl)
|
||||
|
||||
if (workflowId) {
|
||||
url.searchParams.append('workflowId', workflowId)
|
||||
}
|
||||
if (userId) {
|
||||
url.searchParams.append('userId', userId)
|
||||
}
|
||||
if (workspaceId) {
|
||||
url.searchParams.append('workspaceId', workspaceId)
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {}
|
||||
|
||||
try {
|
||||
const internalToken = await generateInternalToken(userId)
|
||||
headers.Authorization = `Bearer ${internalToken}`
|
||||
} catch (error) {
|
||||
logger.warn('Failed to generate internal token for custom tools fetch', { error })
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString(), { headers })
|
||||
|
||||
if (!response.ok) {
|
||||
await response.text().catch(() => {})
|
||||
logger.error(`Failed to fetch custom tools: ${response.statusText}`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
const result = await response.json()
|
||||
|
||||
if (!result.data || !Array.isArray(result.data)) {
|
||||
logger.error(`Invalid response when fetching custom tools: ${JSON.stringify(result)}`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
const customTool = result.data.find(
|
||||
(tool: CustomToolDefinition) => tool.id === identifier || tool.title === identifier
|
||||
) as CustomToolDefinition | undefined
|
||||
const customTool = await getCustomToolByIdOrTitle({
|
||||
identifier,
|
||||
userId,
|
||||
workspaceId,
|
||||
})
|
||||
|
||||
if (!customTool) {
|
||||
logger.error(`Custom tool not found: ${identifier}`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
const toolConfig = createToolConfig(customTool, customToolId)
|
||||
const toolConfig = createToolConfig(customTool as unknown as CustomToolDefinition, customToolId)
|
||||
|
||||
return {
|
||||
...toolConfig,
|
||||
@@ -168,7 +139,7 @@ async function fetchCustomToolFromAPI(
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error fetching custom tool ${identifier} from API:`, error)
|
||||
logger.error(`Error fetching custom tool ${identifier} from DB:`, error)
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user