fix(tools) Directly query db for custom tool id (#3875)

* Directly query db for custom tool id

* Switch back to inline imports

* Fix lint

* Fix test

* Fix greptile comments

* Fix lint

* Make userId and workspaceId required

* Add back nullable userId and workspaceId fields

---------

Co-authored-by: Theodore Li <theo@sim.ai>
This commit is contained in:
Theodore Li
2026-04-02 19:13:37 -07:00
committed by GitHub
parent b0c0ee29a8
commit 6866da590c
5 changed files with 121 additions and 156 deletions

View File

@@ -21,6 +21,7 @@ vi.mock('@/lib/core/config/feature-flags', () => ({
isEmailVerificationEnabled: false,
isBillingEnabled: false,
isOrganizationsEnabled: false,
isAccessControlEnabled: false,
}))
vi.mock('@/providers/utils', () => ({
@@ -110,6 +111,12 @@ vi.mock('@sim/db/schema', () => ({
},
}))
const mockGetCustomToolById = vi.fn()
vi.mock('@/lib/workflows/custom-tools/operations', () => ({
getCustomToolById: (...args: unknown[]) => mockGetCustomToolById(...args),
}))
setupGlobalFetchMock()
const mockGetAllBlocks = getAllBlocks as Mock
@@ -1957,49 +1964,22 @@ describe('AgentBlockHandler', () => {
const staleInlineCode = 'return { title, content };'
const dbCode = 'return { title, content, format };'
function mockFetchForCustomTool(toolId: string) {
mockFetch.mockImplementation((url: string) => {
if (typeof url === 'string' && url.includes('/api/tools/custom')) {
function mockDBForCustomTool(toolId: string) {
mockGetCustomToolById.mockImplementation(({ toolId: id }: { toolId: string }) => {
if (id === toolId) {
return Promise.resolve({
ok: true,
headers: { get: () => null },
json: () =>
Promise.resolve({
data: [
{
id: toolId,
title: 'formatReport',
schema: dbSchema,
code: dbCode,
},
],
}),
id: toolId,
title: 'formatReport',
schema: dbSchema,
code: dbCode,
})
}
return Promise.resolve({
ok: true,
headers: { get: () => null },
json: () => Promise.resolve({}),
})
return Promise.resolve(null)
})
}
function mockFetchFailure() {
mockFetch.mockImplementation((url: string) => {
if (typeof url === 'string' && url.includes('/api/tools/custom')) {
return Promise.resolve({
ok: false,
status: 500,
headers: { get: () => null },
json: () => Promise.resolve({}),
})
}
return Promise.resolve({
ok: true,
headers: { get: () => null },
json: () => Promise.resolve({}),
})
})
function mockDBFailure() {
mockGetCustomToolById.mockRejectedValue(new Error('DB connection failed'))
}
beforeEach(() => {
@@ -2008,11 +1988,13 @@ describe('AgentBlockHandler', () => {
writable: true,
configurable: true,
})
mockGetCustomToolById.mockReset()
mockContext.userId = 'test-user'
})
it('should always fetch latest schema from DB when customToolId is present', async () => {
const toolId = 'custom-tool-123'
mockFetchForCustomTool(toolId)
mockDBForCustomTool(toolId)
const inputs = {
model: 'gpt-4o',
@@ -2046,7 +2028,7 @@ describe('AgentBlockHandler', () => {
it('should fetch from DB when customToolId has no inline schema', async () => {
const toolId = 'custom-tool-123'
mockFetchForCustomTool(toolId)
mockDBForCustomTool(toolId)
const inputs = {
model: 'gpt-4o',
@@ -2075,7 +2057,7 @@ describe('AgentBlockHandler', () => {
})
it('should fall back to inline schema when DB fetch fails and inline exists', async () => {
mockFetchFailure()
mockDBFailure()
const inputs = {
model: 'gpt-4o',
@@ -2107,7 +2089,7 @@ describe('AgentBlockHandler', () => {
})
it('should return null when DB fetch fails and no inline schema exists', async () => {
mockFetchFailure()
mockDBFailure()
const inputs = {
model: 'gpt-4o',
@@ -2135,7 +2117,7 @@ describe('AgentBlockHandler', () => {
it('should use DB schema when customToolId resolves', async () => {
const toolId = 'custom-tool-123'
mockFetchForCustomTool(toolId)
mockDBForCustomTool(toolId)
const inputs = {
model: 'gpt-4o',
@@ -2185,10 +2167,7 @@ describe('AgentBlockHandler', () => {
await handler.execute(mockContext, mockBlock, inputs)
const customToolFetches = mockFetch.mock.calls.filter(
(call: any[]) => typeof call[0] === 'string' && call[0].includes('/api/tools/custom')
)
expect(customToolFetches.length).toBe(0)
expect(mockGetCustomToolById).not.toHaveBeenCalled()
expect(mockExecuteProviderRequest).toHaveBeenCalled()
const providerCall = mockExecuteProviderRequest.mock.calls[0]

View File

@@ -3,6 +3,7 @@ import { mcpServers } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, inArray, isNull } from 'drizzle-orm'
import { createMcpToolId } from '@/lib/mcp/utils'
import { getCustomToolById } from '@/lib/workflows/custom-tools/operations'
import { getAllBlocks } from '@/blocks'
import type { BlockOutput } from '@/blocks/types'
import {
@@ -277,39 +278,18 @@ export class AgentBlockHandler implements BlockHandler {
ctx: ExecutionContext,
customToolId: string
): Promise<{ schema: any; title: string } | null> {
if (!ctx.userId) {
logger.error('Cannot fetch custom tool without userId:', { customToolId })
return null
}
try {
const headers = await buildAuthHeaders(ctx.userId)
const params: Record<string, string> = {}
if (ctx.workspaceId) {
params.workspaceId = ctx.workspaceId
}
if (ctx.workflowId) {
params.workflowId = ctx.workflowId
}
if (ctx.userId) {
params.userId = ctx.userId
}
const url = buildAPIUrl('/api/tools/custom', params)
const response = await fetch(url.toString(), {
method: 'GET',
headers,
const tool = await getCustomToolById({
toolId: customToolId,
userId: ctx.userId,
workspaceId: ctx.workspaceId,
})
if (!response.ok) {
await response.text().catch(() => {})
logger.error(`Failed to fetch custom tools: ${response.status}`)
return null
}
const data = await response.json()
if (!data.data || !Array.isArray(data.data)) {
logger.error('Invalid custom tools API response')
return null
}
const tool = data.data.find((t: any) => t.id === customToolId)
if (!tool) {
logger.warn(`Custom tool not found by ID: ${customToolId}`)
return null

View File

@@ -158,6 +158,32 @@ export async function getCustomToolById(params: {
return legacyTool[0] || null
}
export async function getCustomToolByIdOrTitle(params: {
identifier: string
userId: string
workspaceId?: string
}) {
const { identifier, userId, workspaceId } = params
const conditions = [or(eq(customTools.id, identifier), eq(customTools.title, identifier))]
if (workspaceId) {
const workspaceTool = await db
.select()
.from(customTools)
.where(and(eq(customTools.workspaceId, workspaceId), ...conditions))
.limit(1)
if (workspaceTool[0]) return workspaceTool[0]
}
const legacyTool = await db
.select()
.from(customTools)
.where(and(isNull(customTools.workspaceId), eq(customTools.userId, userId), ...conditions))
.limit(1)
return legacyTool[0] || null
}
export async function deleteCustomTool(params: {
toolId: string
userId: string

View File

@@ -16,19 +16,29 @@ import {
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
// Hoisted mock state - these are available to vi.mock factories
const { mockIsHosted, mockEnv, mockGetBYOKKey, mockGetToolAsync, mockRateLimiterFns } = vi.hoisted(
() => ({
mockIsHosted: { value: false },
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
mockGetBYOKKey: vi.fn(),
mockGetToolAsync: vi.fn(),
mockRateLimiterFns: {
acquireKey: vi.fn(),
preConsumeCapacity: vi.fn(),
consumeCapacity: vi.fn(),
},
})
)
const {
mockIsHosted,
mockEnv,
mockGetBYOKKey,
mockGetToolAsync,
mockRateLimiterFns,
mockGetCustomToolById,
mockListCustomTools,
mockGetCustomToolByIdOrTitle,
} = vi.hoisted(() => ({
mockIsHosted: { value: false },
mockEnv: { NEXT_PUBLIC_APP_URL: 'http://localhost:3000' } as Record<string, string | undefined>,
mockGetBYOKKey: vi.fn(),
mockGetToolAsync: vi.fn(),
mockRateLimiterFns: {
acquireKey: vi.fn(),
preConsumeCapacity: vi.fn(),
consumeCapacity: vi.fn(),
},
mockGetCustomToolById: vi.fn(),
mockListCustomTools: vi.fn(),
mockGetCustomToolByIdOrTitle: vi.fn(),
}))
// Mock feature flags
vi.mock('@/lib/core/config/feature-flags', () => ({
@@ -214,6 +224,12 @@ vi.mock('@/hooks/queries/utils/custom-tool-cache', () => {
}
})
vi.mock('@/lib/workflows/custom-tools/operations', () => ({
getCustomToolById: mockGetCustomToolById,
listCustomTools: mockListCustomTools,
getCustomToolByIdOrTitle: mockGetCustomToolByIdOrTitle,
}))
vi.mock('@/tools/utils.server', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/tools/utils.server')>()
mockGetToolAsync.mockImplementation(actual.getToolAsync)
@@ -307,30 +323,23 @@ describe('Custom Tools', () => {
})
it('resolves custom tools through the async helper', async () => {
setupFetchMock({
json: {
data: [
{
id: 'remote-tool-123',
title: 'Custom Weather Tool',
schema: {
function: {
name: 'weather_tool',
description: 'Get weather information',
parameters: {
type: 'object',
properties: {
location: { type: 'string', description: 'City name' },
},
required: ['location'],
},
},
mockGetCustomToolByIdOrTitle.mockResolvedValue({
id: 'remote-tool-123',
title: 'Custom Weather Tool',
schema: {
function: {
name: 'weather_tool',
description: 'Get weather information',
parameters: {
type: 'object',
properties: {
location: { type: 'string', description: 'City name' },
},
required: ['location'],
},
],
},
},
status: 200,
headers: { 'content-type': 'application/json' },
code: '',
})
const customTool = await getToolAsync('custom_remote-tool-123', {

View File

@@ -1,10 +1,9 @@
import { createLogger } from '@sim/logger'
import { generateInternalToken } from '@/lib/auth/internal'
import {
secureFetchWithPinnedIP,
validateUrlWithDNS,
} from '@/lib/core/security/input-validation.server'
import { getInternalApiBaseUrl } from '@/lib/core/utils/urls'
import { getCustomToolByIdOrTitle } from '@/lib/workflows/custom-tools/operations'
import { isCustomTool } from '@/executor/constants'
import type { CustomToolDefinition } from '@/hooks/queries/custom-tools'
import { extractErrorMessage } from '@/tools/error-extractors'
@@ -97,67 +96,39 @@ export async function getToolAsync(
if (builtInTool) return builtInTool
if (isCustomTool(toolId)) {
return fetchCustomToolFromAPI(toolId, context)
return fetchCustomToolFromDB(toolId, context)
}
return undefined
}
async function fetchCustomToolFromAPI(
async function fetchCustomToolFromDB(
customToolId: string,
context: GetToolAsyncContext
): Promise<ToolConfig | undefined> {
const { workflowId, userId, workspaceId } = context
const identifier = customToolId.replace('custom_', '')
if (!userId) {
throw new Error(`Cannot fetch custom tool without userId: ${identifier}`)
}
if (!workspaceId) {
throw new Error(`Cannot fetch custom tool without workspaceId: ${identifier}`)
}
try {
const baseUrl = getInternalApiBaseUrl()
const url = new URL('/api/tools/custom', baseUrl)
if (workflowId) {
url.searchParams.append('workflowId', workflowId)
}
if (userId) {
url.searchParams.append('userId', userId)
}
if (workspaceId) {
url.searchParams.append('workspaceId', workspaceId)
}
const headers: Record<string, string> = {}
try {
const internalToken = await generateInternalToken(userId)
headers.Authorization = `Bearer ${internalToken}`
} catch (error) {
logger.warn('Failed to generate internal token for custom tools fetch', { error })
}
const response = await fetch(url.toString(), { headers })
if (!response.ok) {
await response.text().catch(() => {})
logger.error(`Failed to fetch custom tools: ${response.statusText}`)
return undefined
}
const result = await response.json()
if (!result.data || !Array.isArray(result.data)) {
logger.error(`Invalid response when fetching custom tools: ${JSON.stringify(result)}`)
return undefined
}
const customTool = result.data.find(
(tool: CustomToolDefinition) => tool.id === identifier || tool.title === identifier
) as CustomToolDefinition | undefined
const customTool = await getCustomToolByIdOrTitle({
identifier,
userId,
workspaceId,
})
if (!customTool) {
logger.error(`Custom tool not found: ${identifier}`)
return undefined
}
const toolConfig = createToolConfig(customTool, customToolId)
const toolConfig = createToolConfig(customTool as unknown as CustomToolDefinition, customToolId)
return {
...toolConfig,
@@ -168,7 +139,7 @@ async function fetchCustomToolFromAPI(
},
}
} catch (error) {
logger.error(`Error fetching custom tool ${identifier} from API:`, error)
logger.error(`Error fetching custom tool ${identifier} from DB:`, error)
return undefined
}
}