mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
fix(router): update router to handle azure creds the same way the agent block does (#2572)
* fix(router): update router to handle azure creds the same way the agent block does * cleanup
This commit is contained in:
@@ -187,12 +187,16 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
|
||||
type: 'combobox',
|
||||
placeholder: 'Type or select a model...',
|
||||
required: true,
|
||||
defaultValue: 'claude-sonnet-4-5',
|
||||
options: () => {
|
||||
const providersState = useProvidersStore.getState()
|
||||
const baseModels = providersState.providers.base.models
|
||||
const ollamaModels = providersState.providers.ollama.models
|
||||
const vllmModels = providersState.providers.vllm.models
|
||||
const openrouterModels = providersState.providers.openrouter.models
|
||||
const allModels = Array.from(new Set([...baseModels, ...ollamaModels, ...openrouterModels]))
|
||||
const allModels = Array.from(
|
||||
new Set([...baseModels, ...ollamaModels, ...vllmModels, ...openrouterModels])
|
||||
)
|
||||
|
||||
return allModels.map((model) => {
|
||||
const icon = getProviderIcon(model)
|
||||
|
||||
@@ -135,12 +135,16 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
|
||||
type: 'combobox',
|
||||
placeholder: 'Type or select a model...',
|
||||
required: true,
|
||||
defaultValue: 'claude-sonnet-4-5',
|
||||
options: () => {
|
||||
const providersState = useProvidersStore.getState()
|
||||
const baseModels = providersState.providers.base.models
|
||||
const ollamaModels = providersState.providers.ollama.models
|
||||
const vllmModels = providersState.providers.vllm.models
|
||||
const openrouterModels = providersState.providers.openrouter.models
|
||||
const allModels = Array.from(new Set([...baseModels, ...ollamaModels, ...openrouterModels]))
|
||||
const allModels = Array.from(
|
||||
new Set([...baseModels, ...ollamaModels, ...vllmModels, ...openrouterModels])
|
||||
)
|
||||
|
||||
return allModels.map((model) => {
|
||||
const icon = getProviderIcon(model)
|
||||
|
||||
@@ -178,13 +178,13 @@ export const MEMORY = {
|
||||
} as const
|
||||
|
||||
export const ROUTER = {
|
||||
DEFAULT_MODEL: 'gpt-4o',
|
||||
DEFAULT_MODEL: 'claude-sonnet-4-5',
|
||||
DEFAULT_TEMPERATURE: 0,
|
||||
INFERENCE_TEMPERATURE: 0.1,
|
||||
} as const
|
||||
|
||||
export const EVALUATOR = {
|
||||
DEFAULT_MODEL: 'gpt-4o',
|
||||
DEFAULT_MODEL: 'claude-sonnet-4-5',
|
||||
DEFAULT_TEMPERATURE: 0.1,
|
||||
RESPONSE_SCHEMA_NAME: 'evaluation_response',
|
||||
JSON_INDENT: 2,
|
||||
|
||||
@@ -82,6 +82,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
{ name: 'score2', description: 'Second score', range: { min: 0, max: 10 } },
|
||||
],
|
||||
model: 'gpt-4o',
|
||||
apiKey: 'test-api-key',
|
||||
temperature: 0.1,
|
||||
}
|
||||
|
||||
@@ -97,7 +98,6 @@ describe('EvaluatorBlockHandler', () => {
|
||||
})
|
||||
)
|
||||
|
||||
// Verify the request body contains the expected data
|
||||
const fetchCallArgs = mockFetch.mock.calls[0]
|
||||
const requestBody = JSON.parse(fetchCallArgs[1].body)
|
||||
expect(requestBody).toMatchObject({
|
||||
@@ -137,6 +137,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
const inputs = {
|
||||
content: JSON.stringify(contentObj),
|
||||
metrics: [{ name: 'clarity', description: 'Clarity score', range: { min: 1, max: 5 } }],
|
||||
apiKey: 'test-api-key',
|
||||
}
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -169,6 +170,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
metrics: [
|
||||
{ name: 'completeness', description: 'Data completeness', range: { min: 0, max: 1 } },
|
||||
],
|
||||
apiKey: 'test-api-key',
|
||||
}
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -198,6 +200,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
const inputs = {
|
||||
content: 'Test content',
|
||||
metrics: [{ name: 'quality', description: 'Quality score', range: { min: 1, max: 10 } }],
|
||||
apiKey: 'test-api-key',
|
||||
}
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -223,6 +226,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
const inputs = {
|
||||
content: 'Test content',
|
||||
metrics: [{ name: 'score', description: 'Score', range: { min: 0, max: 5 } }],
|
||||
apiKey: 'test-api-key',
|
||||
}
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -251,6 +255,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
{ name: 'accuracy', description: 'Acc', range: { min: 0, max: 1 } },
|
||||
{ name: 'fluency', description: 'Flu', range: { min: 0, max: 1 } },
|
||||
],
|
||||
apiKey: 'test-api-key',
|
||||
}
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -276,6 +281,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
const inputs = {
|
||||
content: 'Test',
|
||||
metrics: [{ name: 'CamelCaseScore', description: 'Desc', range: { min: 0, max: 10 } }],
|
||||
apiKey: 'test-api-key',
|
||||
}
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -304,6 +310,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
{ name: 'presentScore', description: 'Desc1', range: { min: 0, max: 5 } },
|
||||
{ name: 'missingScore', description: 'Desc2', range: { min: 0, max: 5 } },
|
||||
],
|
||||
apiKey: 'test-api-key',
|
||||
}
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -327,7 +334,7 @@ describe('EvaluatorBlockHandler', () => {
|
||||
})
|
||||
|
||||
it('should handle server error responses', async () => {
|
||||
const inputs = { content: 'Test error handling.' }
|
||||
const inputs = { content: 'Test error handling.', apiKey: 'test-api-key' }
|
||||
|
||||
// Override fetch mock to return an error
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -340,4 +347,139 @@ describe('EvaluatorBlockHandler', () => {
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow('Server error')
|
||||
})
|
||||
|
||||
it('should handle Azure OpenAI models with endpoint and API version', async () => {
|
||||
const inputs = {
|
||||
content: 'Test content to evaluate',
|
||||
metrics: [{ name: 'quality', description: 'Quality score', range: { min: 1, max: 10 } }],
|
||||
model: 'gpt-4o',
|
||||
apiKey: 'test-azure-key',
|
||||
azureEndpoint: 'https://test.openai.azure.com',
|
||||
azureApiVersion: '2024-07-01-preview',
|
||||
}
|
||||
|
||||
mockGetProviderFromModel.mockReturnValue('azure-openai')
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
content: JSON.stringify({ quality: 8 }),
|
||||
model: 'gpt-4o',
|
||||
tokens: {},
|
||||
cost: 0,
|
||||
timing: {},
|
||||
}),
|
||||
})
|
||||
})
|
||||
|
||||
await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
const fetchCallArgs = mockFetch.mock.calls[0]
|
||||
const requestBody = JSON.parse(fetchCallArgs[1].body)
|
||||
|
||||
expect(requestBody).toMatchObject({
|
||||
provider: 'azure-openai',
|
||||
model: 'gpt-4o',
|
||||
apiKey: 'test-azure-key',
|
||||
azureEndpoint: 'https://test.openai.azure.com',
|
||||
azureApiVersion: '2024-07-01-preview',
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw error when API key is missing for non-hosted models', async () => {
|
||||
const inputs = {
|
||||
content: 'Test content',
|
||||
metrics: [{ name: 'score', description: 'Score', range: { min: 0, max: 10 } }],
|
||||
model: 'gpt-4o',
|
||||
// No apiKey provided
|
||||
}
|
||||
|
||||
mockGetProviderFromModel.mockReturnValue('openai')
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
/API key is required/
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle Vertex AI models with OAuth credential', async () => {
|
||||
const inputs = {
|
||||
content: 'Test content to evaluate',
|
||||
metrics: [{ name: 'quality', description: 'Quality score', range: { min: 1, max: 10 } }],
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
vertexCredential: 'test-vertex-credential-id',
|
||||
vertexProject: 'test-gcp-project',
|
||||
vertexLocation: 'us-central1',
|
||||
}
|
||||
|
||||
mockGetProviderFromModel.mockReturnValue('vertex')
|
||||
|
||||
// Mock the database query for Vertex credential
|
||||
const mockDb = await import('@sim/db')
|
||||
const mockAccount = {
|
||||
id: 'test-vertex-credential-id',
|
||||
accessToken: 'mock-access-token',
|
||||
refreshToken: 'mock-refresh-token',
|
||||
expiresAt: new Date(Date.now() + 3600000), // 1 hour from now
|
||||
}
|
||||
vi.spyOn(mockDb.db.query.account, 'findFirst').mockResolvedValue(mockAccount as any)
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
content: JSON.stringify({ quality: 9 }),
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
tokens: {},
|
||||
cost: 0,
|
||||
timing: {},
|
||||
}),
|
||||
})
|
||||
})
|
||||
|
||||
await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
const fetchCallArgs = mockFetch.mock.calls[0]
|
||||
const requestBody = JSON.parse(fetchCallArgs[1].body)
|
||||
|
||||
expect(requestBody).toMatchObject({
|
||||
provider: 'vertex',
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
vertexProject: 'test-gcp-project',
|
||||
vertexLocation: 'us-central1',
|
||||
})
|
||||
expect(requestBody.apiKey).toBe('mock-access-token')
|
||||
})
|
||||
|
||||
it('should use default model when not provided', async () => {
|
||||
const inputs = {
|
||||
content: 'Test content',
|
||||
metrics: [{ name: 'score', description: 'Score', range: { min: 0, max: 10 } }],
|
||||
apiKey: 'test-api-key',
|
||||
// No model provided - should use default
|
||||
}
|
||||
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
content: JSON.stringify({ score: 7 }),
|
||||
model: 'claude-sonnet-4-5',
|
||||
tokens: {},
|
||||
cost: 0,
|
||||
timing: {},
|
||||
}),
|
||||
})
|
||||
})
|
||||
|
||||
await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
const fetchCallArgs = mockFetch.mock.calls[0]
|
||||
const requestBody = JSON.parse(fetchCallArgs[1].body)
|
||||
|
||||
expect(requestBody.model).toBe('claude-sonnet-4-5')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,7 +8,7 @@ import { BlockType, DEFAULTS, EVALUATOR, HTTP } from '@/executor/constants'
|
||||
import type { BlockHandler, ExecutionContext } from '@/executor/types'
|
||||
import { buildAPIUrl, extractAPIErrorMessage } from '@/executor/utils/http'
|
||||
import { isJSONString, parseJSON, stringifyJSON } from '@/executor/utils/json'
|
||||
import { calculateCost, getProviderFromModel } from '@/providers/utils'
|
||||
import { calculateCost, getApiKey, getProviderFromModel } from '@/providers/utils'
|
||||
import type { SerializedBlock } from '@/serializer/types'
|
||||
|
||||
const logger = createLogger('EvaluatorBlockHandler')
|
||||
@@ -35,9 +35,11 @@ export class EvaluatorBlockHandler implements BlockHandler {
|
||||
}
|
||||
const providerId = getProviderFromModel(evaluatorConfig.model)
|
||||
|
||||
let finalApiKey = evaluatorConfig.apiKey
|
||||
let finalApiKey: string
|
||||
if (providerId === 'vertex' && evaluatorConfig.vertexCredential) {
|
||||
finalApiKey = await this.resolveVertexCredential(evaluatorConfig.vertexCredential)
|
||||
} else {
|
||||
finalApiKey = this.getApiKey(providerId, evaluatorConfig.model, evaluatorConfig.apiKey)
|
||||
}
|
||||
|
||||
const processedContent = this.processContent(inputs.content)
|
||||
@@ -122,6 +124,11 @@ export class EvaluatorBlockHandler implements BlockHandler {
|
||||
providerRequest.vertexLocation = evaluatorConfig.vertexLocation
|
||||
}
|
||||
|
||||
if (providerId === 'azure-openai') {
|
||||
providerRequest.azureEndpoint = inputs.azureEndpoint
|
||||
providerRequest.azureApiVersion = inputs.azureApiVersion
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -268,6 +275,20 @@ export class EvaluatorBlockHandler implements BlockHandler {
|
||||
return DEFAULTS.EXECUTION_TIME
|
||||
}
|
||||
|
||||
private getApiKey(providerId: string, model: string, inputApiKey: string): string {
|
||||
try {
|
||||
return getApiKey(providerId, model, inputApiKey)
|
||||
} catch (error) {
|
||||
logger.error('Failed to get API key:', {
|
||||
provider: providerId,
|
||||
model,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
hasProvidedApiKey: !!inputApiKey,
|
||||
})
|
||||
throw new Error(error instanceof Error ? error.message : 'API key error')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a Vertex AI OAuth credential to an access token
|
||||
*/
|
||||
|
||||
@@ -105,6 +105,7 @@ describe('RouterBlockHandler', () => {
|
||||
const inputs = {
|
||||
prompt: 'Choose the best option.',
|
||||
model: 'gpt-4o',
|
||||
apiKey: 'test-api-key',
|
||||
temperature: 0.1,
|
||||
}
|
||||
|
||||
@@ -187,7 +188,7 @@ describe('RouterBlockHandler', () => {
|
||||
})
|
||||
|
||||
it('should throw error if LLM response is not a valid target block ID', async () => {
|
||||
const inputs = { prompt: 'Test' }
|
||||
const inputs = { prompt: 'Test', apiKey: 'test-api-key' }
|
||||
|
||||
// Override fetch mock to return an invalid block ID
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -210,22 +211,22 @@ describe('RouterBlockHandler', () => {
|
||||
})
|
||||
|
||||
it('should use default model and temperature if not provided', async () => {
|
||||
const inputs = { prompt: 'Choose.' }
|
||||
const inputs = { prompt: 'Choose.', apiKey: 'test-api-key' }
|
||||
|
||||
await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
expect(mockGetProviderFromModel).toHaveBeenCalledWith('gpt-4o')
|
||||
expect(mockGetProviderFromModel).toHaveBeenCalledWith('claude-sonnet-4-5')
|
||||
|
||||
const fetchCallArgs = mockFetch.mock.calls[0]
|
||||
const requestBody = JSON.parse(fetchCallArgs[1].body)
|
||||
expect(requestBody).toMatchObject({
|
||||
model: 'gpt-4o',
|
||||
model: 'claude-sonnet-4-5',
|
||||
temperature: 0.1,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle server error responses', async () => {
|
||||
const inputs = { prompt: 'Test error handling.' }
|
||||
const inputs = { prompt: 'Test error handling.', apiKey: 'test-api-key' }
|
||||
|
||||
// Override fetch mock to return an error
|
||||
mockFetch.mockImplementationOnce(() => {
|
||||
@@ -238,4 +239,78 @@ describe('RouterBlockHandler', () => {
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow('Server error')
|
||||
})
|
||||
|
||||
it('should handle Azure OpenAI models with endpoint and API version', async () => {
|
||||
const inputs = {
|
||||
prompt: 'Choose the best option.',
|
||||
model: 'gpt-4o',
|
||||
apiKey: 'test-azure-key',
|
||||
azureEndpoint: 'https://test.openai.azure.com',
|
||||
azureApiVersion: '2024-07-01-preview',
|
||||
}
|
||||
|
||||
mockGetProviderFromModel.mockReturnValue('azure-openai')
|
||||
|
||||
await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
const fetchCallArgs = mockFetch.mock.calls[0]
|
||||
const requestBody = JSON.parse(fetchCallArgs[1].body)
|
||||
|
||||
expect(requestBody).toMatchObject({
|
||||
provider: 'azure-openai',
|
||||
model: 'gpt-4o',
|
||||
apiKey: 'test-azure-key',
|
||||
azureEndpoint: 'https://test.openai.azure.com',
|
||||
azureApiVersion: '2024-07-01-preview',
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw error when API key is missing for non-hosted models', async () => {
|
||||
const inputs = {
|
||||
prompt: 'Test without API key',
|
||||
model: 'gpt-4o',
|
||||
// No apiKey provided
|
||||
}
|
||||
|
||||
mockGetProviderFromModel.mockReturnValue('openai')
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
/API key is required/
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle Vertex AI models with OAuth credential', async () => {
|
||||
const inputs = {
|
||||
prompt: 'Choose the best option.',
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
vertexCredential: 'test-vertex-credential-id',
|
||||
vertexProject: 'test-gcp-project',
|
||||
vertexLocation: 'us-central1',
|
||||
}
|
||||
|
||||
mockGetProviderFromModel.mockReturnValue('vertex')
|
||||
|
||||
// Mock the database query for Vertex credential
|
||||
const mockDb = await import('@sim/db')
|
||||
const mockAccount = {
|
||||
id: 'test-vertex-credential-id',
|
||||
accessToken: 'mock-access-token',
|
||||
refreshToken: 'mock-refresh-token',
|
||||
expiresAt: new Date(Date.now() + 3600000), // 1 hour from now
|
||||
}
|
||||
vi.spyOn(mockDb.db.query.account, 'findFirst').mockResolvedValue(mockAccount as any)
|
||||
|
||||
await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
const fetchCallArgs = mockFetch.mock.calls[0]
|
||||
const requestBody = JSON.parse(fetchCallArgs[1].body)
|
||||
|
||||
expect(requestBody).toMatchObject({
|
||||
provider: 'vertex',
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
vertexProject: 'test-gcp-project',
|
||||
vertexLocation: 'us-central1',
|
||||
})
|
||||
expect(requestBody.apiKey).toBe('mock-access-token')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,7 +8,7 @@ import { generateRouterPrompt } from '@/blocks/blocks/router'
|
||||
import type { BlockOutput } from '@/blocks/types'
|
||||
import { BlockType, DEFAULTS, HTTP, isAgentBlockType, ROUTER } from '@/executor/constants'
|
||||
import type { BlockHandler, ExecutionContext } from '@/executor/types'
|
||||
import { calculateCost, getProviderFromModel } from '@/providers/utils'
|
||||
import { calculateCost, getApiKey, getProviderFromModel } from '@/providers/utils'
|
||||
import type { SerializedBlock } from '@/serializer/types'
|
||||
|
||||
const logger = createLogger('RouterBlockHandler')
|
||||
@@ -47,9 +47,11 @@ export class RouterBlockHandler implements BlockHandler {
|
||||
const messages = [{ role: 'user', content: routerConfig.prompt }]
|
||||
const systemPrompt = generateRouterPrompt(routerConfig.prompt, targetBlocks)
|
||||
|
||||
let finalApiKey = routerConfig.apiKey
|
||||
let finalApiKey: string
|
||||
if (providerId === 'vertex' && routerConfig.vertexCredential) {
|
||||
finalApiKey = await this.resolveVertexCredential(routerConfig.vertexCredential)
|
||||
} else {
|
||||
finalApiKey = this.getApiKey(providerId, routerConfig.model, routerConfig.apiKey)
|
||||
}
|
||||
|
||||
const providerRequest: Record<string, any> = {
|
||||
@@ -67,6 +69,11 @@ export class RouterBlockHandler implements BlockHandler {
|
||||
providerRequest.vertexLocation = routerConfig.vertexLocation
|
||||
}
|
||||
|
||||
if (providerId === 'azure-openai') {
|
||||
providerRequest.azureEndpoint = inputs.azureEndpoint
|
||||
providerRequest.azureApiVersion = inputs.azureApiVersion
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -171,6 +178,20 @@ export class RouterBlockHandler implements BlockHandler {
|
||||
})
|
||||
}
|
||||
|
||||
private getApiKey(providerId: string, model: string, inputApiKey: string): string {
|
||||
try {
|
||||
return getApiKey(providerId, model, inputApiKey)
|
||||
} catch (error) {
|
||||
logger.error('Failed to get API key:', {
|
||||
provider: providerId,
|
||||
model,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
hasProvidedApiKey: !!inputApiKey,
|
||||
})
|
||||
throw new Error(error instanceof Error ? error.message : 'API key error')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a Vertex AI OAuth credential to an access token
|
||||
*/
|
||||
|
||||
294
apps/sim/executor/handlers/wait/wait-handler.test.ts
Normal file
294
apps/sim/executor/handlers/wait/wait-handler.test.ts
Normal file
@@ -0,0 +1,294 @@
|
||||
import '@/executor/__test-utils__/mock-dependencies'
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { BlockType } from '@/executor/constants'
|
||||
import { WaitBlockHandler } from '@/executor/handlers/wait/wait-handler'
|
||||
import type { ExecutionContext } from '@/executor/types'
|
||||
import type { SerializedBlock } from '@/serializer/types'
|
||||
|
||||
describe('WaitBlockHandler', () => {
|
||||
let handler: WaitBlockHandler
|
||||
let mockBlock: SerializedBlock
|
||||
let mockContext: ExecutionContext
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
|
||||
handler = new WaitBlockHandler()
|
||||
|
||||
mockBlock = {
|
||||
id: 'wait-block-1',
|
||||
metadata: { id: BlockType.WAIT, name: 'Test Wait' },
|
||||
position: { x: 50, y: 50 },
|
||||
config: { tool: BlockType.WAIT, params: {} },
|
||||
inputs: { timeValue: 'string', timeUnit: 'string' },
|
||||
outputs: {},
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
mockContext = {
|
||||
workflowId: 'test-workflow-id',
|
||||
blockStates: new Map(),
|
||||
blockLogs: [],
|
||||
metadata: { duration: 0 },
|
||||
environmentVariables: {},
|
||||
decisions: { router: new Map(), condition: new Map() },
|
||||
loopExecutions: new Map(),
|
||||
completedLoops: new Set(),
|
||||
executedBlocks: new Set(),
|
||||
activeExecutionPath: new Set(),
|
||||
}
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should handle wait blocks', () => {
|
||||
expect(handler.canHandle(mockBlock)).toBe(true)
|
||||
const nonWaitBlock: SerializedBlock = { ...mockBlock, metadata: { id: 'other' } }
|
||||
expect(handler.canHandle(nonWaitBlock)).toBe(false)
|
||||
})
|
||||
|
||||
it('should wait for specified seconds', async () => {
|
||||
const inputs = {
|
||||
timeValue: '5',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(5000)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 5000,
|
||||
status: 'completed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should wait for specified minutes', async () => {
|
||||
const inputs = {
|
||||
timeValue: '2',
|
||||
timeUnit: 'minutes',
|
||||
}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(120000)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 120000,
|
||||
status: 'completed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should use default values when not provided', async () => {
|
||||
const inputs = {}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(10000)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 10000,
|
||||
status: 'completed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw error for negative wait times', async () => {
|
||||
const inputs = {
|
||||
timeValue: '-5',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
'Wait amount must be a positive number'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error for zero wait time', async () => {
|
||||
const inputs = {
|
||||
timeValue: '0',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
'Wait amount must be a positive number'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error for non-numeric wait times', async () => {
|
||||
const inputs = {
|
||||
timeValue: 'abc',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
'Wait amount must be a positive number'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error when wait time exceeds maximum (seconds)', async () => {
|
||||
const inputs = {
|
||||
timeValue: '601',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
'Wait time exceeds maximum of 600 seconds'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error when wait time exceeds maximum (minutes)', async () => {
|
||||
const inputs = {
|
||||
timeValue: '11',
|
||||
timeUnit: 'minutes',
|
||||
}
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
'Wait time exceeds maximum of 10 minutes'
|
||||
)
|
||||
})
|
||||
|
||||
it('should allow maximum wait time of exactly 10 minutes', async () => {
|
||||
const inputs = {
|
||||
timeValue: '10',
|
||||
timeUnit: 'minutes',
|
||||
}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(600000)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 600000,
|
||||
status: 'completed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should allow maximum wait time of exactly 600 seconds', async () => {
|
||||
const inputs = {
|
||||
timeValue: '600',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(600000)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 600000,
|
||||
status: 'completed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle cancellation via AbortSignal', async () => {
|
||||
const abortController = new AbortController()
|
||||
mockContext.abortSignal = abortController.signal
|
||||
|
||||
const inputs = {
|
||||
timeValue: '30',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(10000)
|
||||
abortController.abort()
|
||||
await vi.advanceTimersByTimeAsync(1)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 30000,
|
||||
status: 'cancelled',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return cancelled immediately if signal is already aborted', async () => {
|
||||
const abortController = new AbortController()
|
||||
abortController.abort()
|
||||
mockContext.abortSignal = abortController.signal
|
||||
|
||||
const inputs = {
|
||||
timeValue: '10',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
const result = await handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 10000,
|
||||
status: 'cancelled',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle partial completion before cancellation', async () => {
|
||||
const abortController = new AbortController()
|
||||
mockContext.abortSignal = abortController.signal
|
||||
|
||||
const inputs = {
|
||||
timeValue: '100',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(50000)
|
||||
abortController.abort()
|
||||
await vi.advanceTimersByTimeAsync(1)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 100000,
|
||||
status: 'cancelled',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle fractional seconds by converting to integers', async () => {
|
||||
const inputs = {
|
||||
timeValue: '5.7',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(5000)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 5000,
|
||||
status: 'completed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle very short wait times', async () => {
|
||||
const inputs = {
|
||||
timeValue: '1',
|
||||
timeUnit: 'seconds',
|
||||
}
|
||||
|
||||
const executePromise = handler.execute(mockContext, mockBlock, inputs)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(1000)
|
||||
|
||||
const result = await executePromise
|
||||
|
||||
expect(result).toEqual({
|
||||
waitDuration: 1000,
|
||||
status: 'completed',
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -15,9 +15,7 @@ const sleep = async (ms: number, signal?: AbortSignal): Promise<boolean> => {
|
||||
let timeoutId: NodeJS.Timeout | undefined
|
||||
|
||||
const onAbort = () => {
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId)
|
||||
}
|
||||
if (timeoutId) clearTimeout(timeoutId)
|
||||
resolve(false)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user