mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
improvement(copilot): code hygiene + tests (#856)
* improvement(copilot): code hygiene + tests * add remaining copilot tests * fix typ
This commit is contained in:
617
apps/sim/app/api/copilot/chat/route.test.ts
Normal file
617
apps/sim/app/api/copilot/chat/route.test.ts
Normal file
@@ -0,0 +1,617 @@
|
||||
/**
|
||||
* Tests for copilot chat API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockAuth,
|
||||
mockCryptoUuid,
|
||||
setupCommonApiMocks,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
describe('Copilot Chat API Route', () => {
|
||||
const mockSelect = vi.fn()
|
||||
const mockFrom = vi.fn()
|
||||
const mockWhere = vi.fn()
|
||||
const mockLimit = vi.fn()
|
||||
const mockOrderBy = vi.fn()
|
||||
const mockInsert = vi.fn()
|
||||
const mockValues = vi.fn()
|
||||
const mockReturning = vi.fn()
|
||||
const mockUpdate = vi.fn()
|
||||
const mockSet = vi.fn()
|
||||
|
||||
const mockExecuteProviderRequest = vi.fn()
|
||||
const mockGetCopilotModel = vi.fn()
|
||||
const mockGetRotatingApiKey = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
setupCommonApiMocks()
|
||||
mockCryptoUuid()
|
||||
|
||||
mockSelect.mockReturnValue({ from: mockFrom })
|
||||
mockFrom.mockReturnValue({ where: mockWhere })
|
||||
mockWhere.mockReturnValue({
|
||||
orderBy: mockOrderBy,
|
||||
limit: mockLimit,
|
||||
})
|
||||
mockOrderBy.mockResolvedValue([])
|
||||
mockLimit.mockResolvedValue([])
|
||||
mockInsert.mockReturnValue({ values: mockValues })
|
||||
mockValues.mockReturnValue({ returning: mockReturning })
|
||||
mockUpdate.mockReturnValue({ set: mockSet })
|
||||
mockSet.mockReturnValue({ where: mockWhere })
|
||||
|
||||
vi.doMock('@/db', () => ({
|
||||
db: {
|
||||
select: mockSelect,
|
||||
insert: mockInsert,
|
||||
update: mockUpdate,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('@/db/schema', () => ({
|
||||
copilotChats: {
|
||||
id: 'id',
|
||||
userId: 'userId',
|
||||
messages: 'messages',
|
||||
title: 'title',
|
||||
model: 'model',
|
||||
workflowId: 'workflowId',
|
||||
createdAt: 'createdAt',
|
||||
updatedAt: 'updatedAt',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('drizzle-orm', () => ({
|
||||
and: vi.fn((...conditions) => ({ conditions, type: 'and' })),
|
||||
eq: vi.fn((field, value) => ({ field, value, type: 'eq' })),
|
||||
desc: vi.fn((field) => ({ field, type: 'desc' })),
|
||||
}))
|
||||
|
||||
mockGetCopilotModel.mockReturnValue({
|
||||
provider: 'anthropic',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
})
|
||||
|
||||
vi.doMock('@/lib/copilot/config', () => ({
|
||||
getCopilotModel: mockGetCopilotModel,
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/copilot/prompts', () => ({
|
||||
TITLE_GENERATION_SYSTEM_PROMPT: 'Generate a title',
|
||||
TITLE_GENERATION_USER_PROMPT: vi.fn((msg) => `Generate title for: ${msg}`),
|
||||
}))
|
||||
|
||||
mockExecuteProviderRequest.mockResolvedValue({
|
||||
content: 'Generated Title',
|
||||
})
|
||||
|
||||
vi.doMock('@/providers', () => ({
|
||||
executeProviderRequest: mockExecuteProviderRequest,
|
||||
}))
|
||||
|
||||
mockGetRotatingApiKey.mockReturnValue('test-api-key')
|
||||
|
||||
vi.doMock('@/lib/utils', () => ({
|
||||
getRotatingApiKey: mockGetRotatingApiKey,
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
SIM_AGENT_API_URL: 'http://localhost:8000',
|
||||
SIM_AGENT_API_KEY: 'test-sim-agent-key',
|
||||
},
|
||||
}))
|
||||
|
||||
global.fetch = vi.fn()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 400 for invalid request body', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
// Missing required fields
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Invalid request data')
|
||||
expect(responseData.details).toBeDefined()
|
||||
})
|
||||
|
||||
it('should handle new chat creation and forward to sim agent', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock successful chat creation
|
||||
const newChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
title: null,
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [],
|
||||
}
|
||||
mockReturning.mockResolvedValue([newChat])
|
||||
|
||||
// Mock successful sim agent response
|
||||
const mockReadableStream = new ReadableStream({
|
||||
start(controller) {
|
||||
const encoder = new TextEncoder()
|
||||
controller.enqueue(
|
||||
encoder.encode('data: {"type": "assistant_message", "content": "Hello response"}\\n\\n')
|
||||
)
|
||||
controller.close()
|
||||
},
|
||||
})
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: mockReadableStream,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
stream: true,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(mockInsert).toHaveBeenCalled()
|
||||
expect(mockValues).toHaveBeenCalledWith({
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
title: null,
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [],
|
||||
})
|
||||
|
||||
// Verify sim agent was called
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:8000/api/chat-completion-streaming',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-sim-agent-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
workflowId: 'workflow-123',
|
||||
userId: 'user-123',
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
mode: 'agent',
|
||||
}),
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should load existing chat and include conversation history', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock existing chat with history
|
||||
const existingChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
title: 'Existing Chat',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Previous message' },
|
||||
{ role: 'assistant', content: 'Previous response' },
|
||||
],
|
||||
}
|
||||
// For POST route, the select query uses limit not orderBy
|
||||
mockLimit.mockResolvedValue([existingChat])
|
||||
|
||||
// Mock sim agent response
|
||||
const mockReadableStream = new ReadableStream({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
},
|
||||
})
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: mockReadableStream,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'New message',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
|
||||
// Verify conversation history was included
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:8000/api/chat-completion-streaming',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Previous message' },
|
||||
{ role: 'assistant', content: 'Previous response' },
|
||||
{ role: 'user', content: 'New message' },
|
||||
],
|
||||
workflowId: 'workflow-123',
|
||||
userId: 'user-123',
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
mode: 'agent',
|
||||
}),
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should include implicit feedback in messages', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock new chat creation
|
||||
const newChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
messages: [],
|
||||
}
|
||||
mockReturning.mockResolvedValue([newChat])
|
||||
|
||||
// Mock sim agent response
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: new ReadableStream({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
implicitFeedback: 'User seems confused about the workflow',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
await POST(req)
|
||||
|
||||
// Verify implicit feedback was included as system message
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:8000/api/chat-completion-streaming',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
messages: [
|
||||
{ role: 'system', content: 'User seems confused about the workflow' },
|
||||
{ role: 'user', content: 'Hello' },
|
||||
],
|
||||
workflowId: 'workflow-123',
|
||||
userId: 'user-123',
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
mode: 'agent',
|
||||
}),
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle sim agent API errors', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock new chat creation
|
||||
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
|
||||
|
||||
// Mock sim agent error
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
text: () => Promise.resolve('Internal server error'),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toContain('Sim agent API error')
|
||||
})
|
||||
|
||||
it('should handle database errors during chat creation', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error
|
||||
mockReturning.mockRejectedValue(new Error('Database connection failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'Hello',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Database connection failed')
|
||||
})
|
||||
|
||||
it('should use ask mode when specified', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock new chat creation
|
||||
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
|
||||
|
||||
// Mock sim agent response
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: new ReadableStream({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
message: 'What is this workflow?',
|
||||
workflowId: 'workflow-123',
|
||||
createNewChat: true,
|
||||
mode: 'ask',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/route')
|
||||
await POST(req)
|
||||
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:8000/api/chat-completion-streaming',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
messages: [{ role: 'user', content: 'What is this workflow?' }],
|
||||
workflowId: 'workflow-123',
|
||||
userId: 'user-123',
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
mode: 'ask',
|
||||
}),
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('GET', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat?workflowId=workflow-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 400 when workflowId is missing', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('workflowId is required')
|
||||
})
|
||||
|
||||
it('should return chats for authenticated user and workflow', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database response (what comes from DB)
|
||||
const mockDbChats = [
|
||||
{
|
||||
id: 'chat-1',
|
||||
title: 'First Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
{ role: 'user', content: 'Message 2' },
|
||||
{ role: 'assistant', content: 'Response 2' },
|
||||
],
|
||||
createdAt: new Date('2024-01-01'),
|
||||
updatedAt: new Date('2024-01-02'),
|
||||
},
|
||||
{
|
||||
id: 'chat-2',
|
||||
title: 'Second Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
],
|
||||
createdAt: new Date('2024-01-03'),
|
||||
updatedAt: new Date('2024-01-04'),
|
||||
},
|
||||
]
|
||||
|
||||
// Expected transformed response (what the route returns)
|
||||
const expectedChats = [
|
||||
{
|
||||
id: 'chat-1',
|
||||
title: 'First Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
{ role: 'user', content: 'Message 2' },
|
||||
{ role: 'assistant', content: 'Response 2' },
|
||||
],
|
||||
messageCount: 4,
|
||||
previewYaml: null,
|
||||
createdAt: new Date('2024-01-01'),
|
||||
updatedAt: new Date('2024-01-02'),
|
||||
},
|
||||
{
|
||||
id: 'chat-2',
|
||||
title: 'Second Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
],
|
||||
messageCount: 2,
|
||||
previewYaml: null,
|
||||
createdAt: new Date('2024-01-03'),
|
||||
updatedAt: new Date('2024-01-04'),
|
||||
},
|
||||
]
|
||||
|
||||
mockOrderBy.mockResolvedValue(mockDbChats)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat?workflowId=workflow-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
chats: [
|
||||
{
|
||||
id: 'chat-1',
|
||||
title: 'First Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
{ role: 'user', content: 'Message 2' },
|
||||
{ role: 'assistant', content: 'Response 2' },
|
||||
],
|
||||
messageCount: 4,
|
||||
previewYaml: null,
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-02T00:00:00.000Z',
|
||||
},
|
||||
{
|
||||
id: 'chat-2',
|
||||
title: 'Second Chat',
|
||||
model: 'claude-3-haiku-20240307',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
],
|
||||
messageCount: 2,
|
||||
previewYaml: null,
|
||||
createdAt: '2024-01-03T00:00:00.000Z',
|
||||
updatedAt: '2024-01-04T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
// Verify database query was made correctly
|
||||
expect(mockSelect).toHaveBeenCalled()
|
||||
expect(mockWhere).toHaveBeenCalled()
|
||||
expect(mockOrderBy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle database errors when fetching chats', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error
|
||||
mockOrderBy.mockRejectedValue(new Error('Database query failed'))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat?workflowId=workflow-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to fetch chats')
|
||||
})
|
||||
|
||||
it('should return empty array when no chats found', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
mockOrderBy.mockResolvedValue([])
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat?workflowId=workflow-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/chat/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
chats: [],
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,7 +1,13 @@
|
||||
import { and, desc, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createBadRequestResponse,
|
||||
createInternalServerErrorResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { getCopilotModel } from '@/lib/copilot/config'
|
||||
import { TITLE_GENERATION_SYSTEM_PROMPT, TITLE_GENERATION_USER_PROMPT } from '@/lib/copilot/prompts'
|
||||
import { env } from '@/lib/env'
|
||||
@@ -116,16 +122,15 @@ async function generateChatTitleAsync(
|
||||
* Send messages to sim agent and handle chat persistence
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID()
|
||||
const startTime = Date.now()
|
||||
const tracker = createRequestTracker()
|
||||
|
||||
try {
|
||||
// Authenticate user
|
||||
const session = await getSession()
|
||||
const authenticatedUserId: string | null = session?.user?.id || null
|
||||
// Authenticate user using consolidated helper
|
||||
const { userId: authenticatedUserId, isAuthenticated } =
|
||||
await authenticateCopilotRequestSessionOnly()
|
||||
|
||||
if (!authenticatedUserId) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
if (!isAuthenticated || !authenticatedUserId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
@@ -140,7 +145,7 @@ export async function POST(req: NextRequest) {
|
||||
implicitFeedback,
|
||||
} = ChatMessageSchema.parse(body)
|
||||
|
||||
logger.info(`[${requestId}] Processing copilot chat request`, {
|
||||
logger.info(`[${tracker.requestId}] Processing copilot chat request`, {
|
||||
userId: authenticatedUserId,
|
||||
workflowId,
|
||||
chatId,
|
||||
@@ -215,11 +220,11 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
// Start title generation in parallel if this is a new chat with first message
|
||||
if (actualChatId && !currentChat?.title && conversationHistory.length === 0) {
|
||||
logger.info(`[${requestId}] Will start parallel title generation inside stream`)
|
||||
logger.info(`[${tracker.requestId}] Will start parallel title generation inside stream`)
|
||||
}
|
||||
|
||||
// Forward to sim agent API
|
||||
logger.info(`[${requestId}] Sending request to sim agent API`, {
|
||||
logger.info(`[${tracker.requestId}] Sending request to sim agent API`, {
|
||||
messageCount: messages.length,
|
||||
endpoint: `${SIM_AGENT_API_URL}/api/chat-completion-streaming`,
|
||||
})
|
||||
@@ -242,7 +247,7 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
if (!simAgentResponse.ok) {
|
||||
const errorText = await simAgentResponse.text()
|
||||
logger.error(`[${requestId}] Sim agent API error:`, {
|
||||
logger.error(`[${tracker.requestId}] Sim agent API error:`, {
|
||||
status: simAgentResponse.status,
|
||||
error: errorText,
|
||||
})
|
||||
@@ -254,7 +259,7 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
// If streaming is requested, forward the stream and update chat later
|
||||
if (stream && simAgentResponse.body) {
|
||||
logger.info(`[${requestId}] Streaming response from sim agent`)
|
||||
logger.info(`[${tracker.requestId}] Streaming response from sim agent`)
|
||||
|
||||
// Create user message to save
|
||||
const userMessage = {
|
||||
@@ -280,22 +285,24 @@ export async function POST(req: NextRequest) {
|
||||
chatId: actualChatId,
|
||||
})}\n\n`
|
||||
controller.enqueue(encoder.encode(chatIdEvent))
|
||||
logger.debug(`[${requestId}] Sent initial chatId event to client`)
|
||||
logger.debug(`[${tracker.requestId}] Sent initial chatId event to client`)
|
||||
}
|
||||
|
||||
// Start title generation in parallel if needed
|
||||
if (actualChatId && !currentChat?.title && conversationHistory.length === 0) {
|
||||
logger.info(`[${requestId}] Starting title generation with stream updates`, {
|
||||
logger.info(`[${tracker.requestId}] Starting title generation with stream updates`, {
|
||||
chatId: actualChatId,
|
||||
hasTitle: !!currentChat?.title,
|
||||
conversationLength: conversationHistory.length,
|
||||
message: message.substring(0, 100) + (message.length > 100 ? '...' : ''),
|
||||
})
|
||||
generateChatTitleAsync(actualChatId, message, requestId, controller).catch((error) => {
|
||||
logger.error(`[${requestId}] Title generation failed:`, error)
|
||||
})
|
||||
generateChatTitleAsync(actualChatId, message, tracker.requestId, controller).catch(
|
||||
(error) => {
|
||||
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
|
||||
}
|
||||
)
|
||||
} else {
|
||||
logger.debug(`[${requestId}] Skipping title generation`, {
|
||||
logger.debug(`[${tracker.requestId}] Skipping title generation`, {
|
||||
chatId: actualChatId,
|
||||
hasTitle: !!currentChat?.title,
|
||||
conversationLength: conversationHistory.length,
|
||||
@@ -317,7 +324,7 @@ export async function POST(req: NextRequest) {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
logger.info(`[${requestId}] Stream reading completed`)
|
||||
logger.info(`[${tracker.requestId}] Stream reading completed`)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -327,7 +334,9 @@ export async function POST(req: NextRequest) {
|
||||
controller.enqueue(value)
|
||||
} catch (error) {
|
||||
// Client disconnected - stop reading from sim agent
|
||||
logger.info(`[${requestId}] Client disconnected, stopping stream processing`)
|
||||
logger.info(
|
||||
`[${tracker.requestId}] Client disconnected, stopping stream processing`
|
||||
)
|
||||
reader.cancel() // Stop reading from sim agent
|
||||
break
|
||||
}
|
||||
@@ -350,7 +359,7 @@ export async function POST(req: NextRequest) {
|
||||
// Check if the JSON string is unusually large (potential streaming issue)
|
||||
if (jsonStr.length > 50000) {
|
||||
// 50KB limit
|
||||
logger.warn(`[${requestId}] Large SSE event detected`, {
|
||||
logger.warn(`[${tracker.requestId}] Large SSE event detected`, {
|
||||
size: jsonStr.length,
|
||||
preview: `${jsonStr.substring(0, 100)}...`,
|
||||
})
|
||||
@@ -368,7 +377,7 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
case 'tool_call':
|
||||
logger.info(
|
||||
`[${requestId}] Tool call ${event.data?.partial ? '(partial)' : '(complete)'}:`,
|
||||
`[${tracker.requestId}] Tool call ${event.data?.partial ? '(partial)' : '(complete)'}:`,
|
||||
{
|
||||
id: event.data?.id,
|
||||
name: event.data?.name,
|
||||
@@ -382,7 +391,7 @@ export async function POST(req: NextRequest) {
|
||||
break
|
||||
|
||||
case 'tool_execution':
|
||||
logger.info(`[${requestId}] Tool execution started:`, {
|
||||
logger.info(`[${tracker.requestId}] Tool execution started:`, {
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
status: event.status,
|
||||
@@ -390,7 +399,7 @@ export async function POST(req: NextRequest) {
|
||||
break
|
||||
|
||||
case 'tool_result':
|
||||
logger.info(`[${requestId}] Tool result received:`, {
|
||||
logger.info(`[${tracker.requestId}] Tool result received:`, {
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
success: event.success,
|
||||
@@ -400,7 +409,7 @@ export async function POST(req: NextRequest) {
|
||||
break
|
||||
|
||||
case 'tool_error':
|
||||
logger.error(`[${requestId}] Tool error:`, {
|
||||
logger.error(`[${tracker.requestId}] Tool error:`, {
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
error: event.error,
|
||||
@@ -411,20 +420,23 @@ export async function POST(req: NextRequest) {
|
||||
case 'done':
|
||||
if (isFirstDone) {
|
||||
logger.info(
|
||||
`[${requestId}] Initial AI response complete, tool count: ${toolCalls.length}`
|
||||
`[${tracker.requestId}] Initial AI response complete, tool count: ${toolCalls.length}`
|
||||
)
|
||||
isFirstDone = false
|
||||
} else {
|
||||
logger.info(`[${requestId}] Conversation round complete`)
|
||||
logger.info(`[${tracker.requestId}] Conversation round complete`)
|
||||
}
|
||||
break
|
||||
|
||||
case 'error':
|
||||
logger.error(`[${requestId}] Stream error event:`, event.error)
|
||||
logger.error(`[${tracker.requestId}] Stream error event:`, event.error)
|
||||
break
|
||||
|
||||
default:
|
||||
logger.debug(`[${requestId}] Unknown event type: ${event.type}`, event)
|
||||
logger.debug(
|
||||
`[${tracker.requestId}] Unknown event type: ${event.type}`,
|
||||
event
|
||||
)
|
||||
}
|
||||
} catch (e) {
|
||||
// Enhanced error handling for large payloads and parsing issues
|
||||
@@ -433,7 +445,7 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
if (isLargePayload) {
|
||||
logger.error(
|
||||
`[${requestId}] Failed to parse large SSE event (${lineLength} chars)`,
|
||||
`[${tracker.requestId}] Failed to parse large SSE event (${lineLength} chars)`,
|
||||
{
|
||||
error: e,
|
||||
preview: `${line.substring(0, 200)}...`,
|
||||
@@ -442,20 +454,20 @@ export async function POST(req: NextRequest) {
|
||||
)
|
||||
} else {
|
||||
logger.warn(
|
||||
`[${requestId}] Failed to parse SSE event: "${line.substring(0, 200)}..."`,
|
||||
`[${tracker.requestId}] Failed to parse SSE event: "${line.substring(0, 200)}..."`,
|
||||
e
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if (line.trim() && line !== 'data: [DONE]') {
|
||||
logger.debug(`[${requestId}] Non-SSE line from sim agent: "${line}"`)
|
||||
logger.debug(`[${tracker.requestId}] Non-SSE line from sim agent: "${line}"`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process any remaining buffer
|
||||
if (buffer.trim()) {
|
||||
logger.debug(`[${requestId}] Processing remaining buffer: "${buffer}"`)
|
||||
logger.debug(`[${tracker.requestId}] Processing remaining buffer: "${buffer}"`)
|
||||
if (buffer.startsWith('data: ')) {
|
||||
try {
|
||||
const event = JSON.parse(buffer.slice(6))
|
||||
@@ -463,13 +475,13 @@ export async function POST(req: NextRequest) {
|
||||
assistantContent += event.data
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn(`[${requestId}] Failed to parse final buffer: "${buffer}"`)
|
||||
logger.warn(`[${tracker.requestId}] Failed to parse final buffer: "${buffer}"`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log final streaming summary
|
||||
logger.info(`[${requestId}] Streaming complete summary:`, {
|
||||
logger.info(`[${tracker.requestId}] Streaming complete summary:`, {
|
||||
totalContentLength: assistantContent.length,
|
||||
toolCallsCount: toolCalls.length,
|
||||
hasContent: assistantContent.length > 0,
|
||||
@@ -491,11 +503,11 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
updatedMessages.push(assistantMessage)
|
||||
logger.info(
|
||||
`[${requestId}] Saving assistant message with content (${assistantContent.length} chars) and ${toolCalls.length} tool calls`
|
||||
`[${tracker.requestId}] Saving assistant message with content (${assistantContent.length} chars) and ${toolCalls.length} tool calls`
|
||||
)
|
||||
} else {
|
||||
logger.info(
|
||||
`[${requestId}] No assistant content or tool calls to save (aborted before response)`
|
||||
`[${tracker.requestId}] No assistant content or tool calls to save (aborted before response)`
|
||||
)
|
||||
}
|
||||
|
||||
@@ -508,14 +520,14 @@ export async function POST(req: NextRequest) {
|
||||
})
|
||||
.where(eq(copilotChats.id, actualChatId!))
|
||||
|
||||
logger.info(`[${requestId}] Updated chat ${actualChatId} with new messages`, {
|
||||
logger.info(`[${tracker.requestId}] Updated chat ${actualChatId} with new messages`, {
|
||||
messageCount: updatedMessages.length,
|
||||
savedUserMessage: true,
|
||||
savedAssistantMessage: assistantContent.trim().length > 0,
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error processing stream:`, error)
|
||||
logger.error(`[${tracker.requestId}] Error processing stream:`, error)
|
||||
controller.error(error)
|
||||
} finally {
|
||||
controller.close()
|
||||
@@ -532,8 +544,8 @@ export async function POST(req: NextRequest) {
|
||||
},
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Returning streaming response to client`, {
|
||||
duration: Date.now() - startTime,
|
||||
logger.info(`[${tracker.requestId}] Returning streaming response to client`, {
|
||||
duration: tracker.getDuration(),
|
||||
chatId: actualChatId,
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
@@ -547,7 +559,7 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
// For non-streaming responses
|
||||
const responseData = await simAgentResponse.json()
|
||||
logger.info(`[${requestId}] Non-streaming response from sim agent:`, {
|
||||
logger.info(`[${tracker.requestId}] Non-streaming response from sim agent:`, {
|
||||
hasContent: !!responseData.content,
|
||||
contentLength: responseData.content?.length || 0,
|
||||
model: responseData.model,
|
||||
@@ -559,7 +571,7 @@ export async function POST(req: NextRequest) {
|
||||
// Log tool calls if present
|
||||
if (responseData.toolCalls?.length > 0) {
|
||||
responseData.toolCalls.forEach((toolCall: any) => {
|
||||
logger.info(`[${requestId}] Tool call in response:`, {
|
||||
logger.info(`[${tracker.requestId}] Tool call in response:`, {
|
||||
id: toolCall.id,
|
||||
name: toolCall.name,
|
||||
success: toolCall.success,
|
||||
@@ -588,9 +600,9 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
// Start title generation in parallel if this is first message (non-streaming)
|
||||
if (actualChatId && !currentChat.title && conversationHistory.length === 0) {
|
||||
logger.info(`[${requestId}] Starting title generation for non-streaming response`)
|
||||
generateChatTitleAsync(actualChatId, message, requestId).catch((error) => {
|
||||
logger.error(`[${requestId}] Title generation failed:`, error)
|
||||
logger.info(`[${tracker.requestId}] Starting title generation for non-streaming response`)
|
||||
generateChatTitleAsync(actualChatId, message, tracker.requestId).catch((error) => {
|
||||
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -604,8 +616,8 @@ export async function POST(req: NextRequest) {
|
||||
.where(eq(copilotChats.id, actualChatId!))
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Returning non-streaming response`, {
|
||||
duration: Date.now() - startTime,
|
||||
logger.info(`[${tracker.requestId}] Returning non-streaming response`, {
|
||||
duration: tracker.getDuration(),
|
||||
chatId: actualChatId,
|
||||
responseLength: responseData.content?.length || 0,
|
||||
})
|
||||
@@ -615,16 +627,16 @@ export async function POST(req: NextRequest) {
|
||||
response: responseData,
|
||||
chatId: actualChatId,
|
||||
metadata: {
|
||||
requestId,
|
||||
requestId: tracker.requestId,
|
||||
message,
|
||||
duration: Date.now() - startTime,
|
||||
duration: tracker.getDuration(),
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
const duration = tracker.getDuration()
|
||||
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error(`[${requestId}] Validation error:`, {
|
||||
logger.error(`[${tracker.requestId}] Validation error:`, {
|
||||
duration,
|
||||
errors: error.errors,
|
||||
})
|
||||
@@ -634,7 +646,7 @@ export async function POST(req: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
logger.error(`[${requestId}] Error handling copilot chat:`, {
|
||||
logger.error(`[${tracker.requestId}] Error handling copilot chat:`, {
|
||||
duration,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
@@ -653,17 +665,16 @@ export async function GET(req: NextRequest) {
|
||||
const workflowId = searchParams.get('workflowId')
|
||||
|
||||
if (!workflowId) {
|
||||
return NextResponse.json({ error: 'workflowId is required' }, { status: 400 })
|
||||
return createBadRequestResponse('workflowId is required')
|
||||
}
|
||||
|
||||
// Get authenticated user
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
// Get authenticated user using consolidated helper
|
||||
const { userId: authenticatedUserId, isAuthenticated } =
|
||||
await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !authenticatedUserId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const authenticatedUserId = session.user.id
|
||||
|
||||
// Fetch chats for this user and workflow
|
||||
const chats = await db
|
||||
.select({
|
||||
@@ -700,6 +711,6 @@ export async function GET(req: NextRequest) {
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error fetching copilot chats:', error)
|
||||
return NextResponse.json({ error: 'Failed to fetch chats' }, { status: 500 })
|
||||
return createInternalServerErrorResponse('Failed to fetch chats')
|
||||
}
|
||||
}
|
||||
|
||||
561
apps/sim/app/api/copilot/chat/update-messages/route.test.ts
Normal file
561
apps/sim/app/api/copilot/chat/update-messages/route.test.ts
Normal file
@@ -0,0 +1,561 @@
|
||||
/**
|
||||
* Tests for copilot chat update-messages API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockAuth,
|
||||
mockCryptoUuid,
|
||||
setupCommonApiMocks,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
describe('Copilot Chat Update Messages API Route', () => {
|
||||
const mockSelect = vi.fn()
|
||||
const mockFrom = vi.fn()
|
||||
const mockWhere = vi.fn()
|
||||
const mockLimit = vi.fn()
|
||||
const mockUpdate = vi.fn()
|
||||
const mockSet = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
setupCommonApiMocks()
|
||||
mockCryptoUuid()
|
||||
|
||||
mockSelect.mockReturnValue({ from: mockFrom })
|
||||
mockFrom.mockReturnValue({ where: mockWhere })
|
||||
mockWhere.mockReturnValue({ limit: mockLimit })
|
||||
mockLimit.mockResolvedValue([]) // Default: no chat found
|
||||
mockUpdate.mockReturnValue({ set: mockSet })
|
||||
mockSet.mockReturnValue({ where: vi.fn().mockResolvedValue(undefined) }) // Different where for update
|
||||
|
||||
vi.doMock('@/db', () => ({
|
||||
db: {
|
||||
select: mockSelect,
|
||||
update: mockUpdate,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('@/db/schema', () => ({
|
||||
copilotChats: {
|
||||
id: 'id',
|
||||
userId: 'userId',
|
||||
messages: 'messages',
|
||||
updatedAt: 'updatedAt',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('drizzle-orm', () => ({
|
||||
and: vi.fn((...conditions) => ({ conditions, type: 'and' })),
|
||||
eq: vi.fn((field, value) => ({ field, value, type: 'eq' })),
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-123',
|
||||
messages: [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
timestamp: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 400 for invalid request body - missing chatId', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
messages: [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
timestamp: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
// Missing chatId
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update chat messages')
|
||||
})
|
||||
|
||||
it('should return 400 for invalid request body - missing messages', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-123',
|
||||
// Missing messages
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update chat messages')
|
||||
})
|
||||
|
||||
it('should return 400 for invalid message structure - missing required fields', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-123',
|
||||
messages: [
|
||||
{
|
||||
id: 'msg-1',
|
||||
// Missing role, content, timestamp
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update chat messages')
|
||||
})
|
||||
|
||||
it('should return 400 for invalid message role', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-123',
|
||||
messages: [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'invalid-role',
|
||||
content: 'Hello',
|
||||
timestamp: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update chat messages')
|
||||
})
|
||||
|
||||
it('should return 404 when chat is not found', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat not found
|
||||
mockLimit.mockResolvedValueOnce([])
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'non-existent-chat',
|
||||
messages: [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
timestamp: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Chat not found or unauthorized')
|
||||
})
|
||||
|
||||
it('should return 404 when chat belongs to different user', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat not found (due to user mismatch)
|
||||
mockLimit.mockResolvedValueOnce([])
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'other-user-chat',
|
||||
messages: [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
timestamp: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Chat not found or unauthorized')
|
||||
})
|
||||
|
||||
it('should successfully update chat messages', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists - override the default empty array
|
||||
const existingChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
messages: [],
|
||||
}
|
||||
mockLimit.mockResolvedValueOnce([existingChat])
|
||||
|
||||
const messages = [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'Hello, how are you?',
|
||||
timestamp: '2024-01-01T10:00:00.000Z',
|
||||
},
|
||||
{
|
||||
id: 'msg-2',
|
||||
role: 'assistant',
|
||||
content: 'I am doing well, thank you!',
|
||||
timestamp: '2024-01-01T10:01:00.000Z',
|
||||
},
|
||||
]
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-123',
|
||||
messages,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
messageCount: 2,
|
||||
})
|
||||
|
||||
// Verify database operations
|
||||
expect(mockSelect).toHaveBeenCalled()
|
||||
expect(mockUpdate).toHaveBeenCalled()
|
||||
expect(mockSet).toHaveBeenCalledWith({
|
||||
messages,
|
||||
updatedAt: expect.any(Date),
|
||||
})
|
||||
})
|
||||
|
||||
it('should successfully update chat messages with optional fields', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-456',
|
||||
userId: 'user-123',
|
||||
messages: [],
|
||||
}
|
||||
mockLimit.mockResolvedValueOnce([existingChat])
|
||||
|
||||
const messages = [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
timestamp: '2024-01-01T10:00:00.000Z',
|
||||
},
|
||||
{
|
||||
id: 'msg-2',
|
||||
role: 'assistant',
|
||||
content: 'Hi there!',
|
||||
timestamp: '2024-01-01T10:01:00.000Z',
|
||||
toolCalls: [
|
||||
{
|
||||
id: 'tool-1',
|
||||
name: 'get_weather',
|
||||
arguments: { location: 'NYC' },
|
||||
},
|
||||
],
|
||||
contentBlocks: [
|
||||
{
|
||||
type: 'text',
|
||||
content: 'Here is the weather information',
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-456',
|
||||
messages,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
messageCount: 2,
|
||||
})
|
||||
|
||||
expect(mockSet).toHaveBeenCalledWith({
|
||||
messages,
|
||||
updatedAt: expect.any(Date),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle empty messages array', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-789',
|
||||
userId: 'user-123',
|
||||
messages: [],
|
||||
}
|
||||
mockLimit.mockResolvedValueOnce([existingChat])
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-789',
|
||||
messages: [],
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
messageCount: 0,
|
||||
})
|
||||
|
||||
expect(mockSet).toHaveBeenCalledWith({
|
||||
messages: [],
|
||||
updatedAt: expect.any(Date),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle database errors during chat lookup', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error during chat lookup
|
||||
mockLimit.mockRejectedValueOnce(new Error('Database connection failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-123',
|
||||
messages: [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
timestamp: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update chat messages')
|
||||
})
|
||||
|
||||
it('should handle database errors during update operation', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
messages: [],
|
||||
}
|
||||
mockLimit.mockResolvedValueOnce([existingChat])
|
||||
|
||||
// Mock database error during update
|
||||
mockSet.mockReturnValueOnce({
|
||||
where: vi.fn().mockRejectedValue(new Error('Update operation failed')),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-123',
|
||||
messages: [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
timestamp: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update chat messages')
|
||||
})
|
||||
|
||||
it('should handle JSON parsing errors in request body', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Create a request with invalid JSON
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat/update-messages', {
|
||||
method: 'POST',
|
||||
body: '{invalid-json',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update chat messages')
|
||||
})
|
||||
|
||||
it('should handle large message arrays', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-large',
|
||||
userId: 'user-123',
|
||||
messages: [],
|
||||
}
|
||||
mockLimit.mockResolvedValueOnce([existingChat])
|
||||
|
||||
// Create a large array of messages
|
||||
const messages = Array.from({ length: 100 }, (_, i) => ({
|
||||
id: `msg-${i + 1}`,
|
||||
role: i % 2 === 0 ? 'user' : 'assistant',
|
||||
content: `Message ${i + 1}`,
|
||||
timestamp: new Date(2024, 0, 1, 10, i).toISOString(),
|
||||
}))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-large',
|
||||
messages,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
messageCount: 100,
|
||||
})
|
||||
|
||||
expect(mockSet).toHaveBeenCalledWith({
|
||||
messages,
|
||||
updatedAt: expect.any(Date),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle messages with both user and assistant roles', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-mixed',
|
||||
userId: 'user-123',
|
||||
messages: [],
|
||||
}
|
||||
mockLimit.mockResolvedValueOnce([existingChat])
|
||||
|
||||
const messages = [
|
||||
{
|
||||
id: 'msg-1',
|
||||
role: 'user',
|
||||
content: 'What is the weather like?',
|
||||
timestamp: '2024-01-01T10:00:00.000Z',
|
||||
},
|
||||
{
|
||||
id: 'msg-2',
|
||||
role: 'assistant',
|
||||
content: 'Let me check the weather for you.',
|
||||
timestamp: '2024-01-01T10:01:00.000Z',
|
||||
toolCalls: [
|
||||
{
|
||||
id: 'tool-weather',
|
||||
name: 'get_weather',
|
||||
arguments: { location: 'current' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 'msg-3',
|
||||
role: 'assistant',
|
||||
content: 'The weather is sunny and 75°F.',
|
||||
timestamp: '2024-01-01T10:02:00.000Z',
|
||||
},
|
||||
{
|
||||
id: 'msg-4',
|
||||
role: 'user',
|
||||
content: 'Thank you!',
|
||||
timestamp: '2024-01-01T10:03:00.000Z',
|
||||
},
|
||||
]
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
chatId: 'chat-mixed',
|
||||
messages,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/chat/update-messages/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
messageCount: 4,
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,7 +1,13 @@
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createInternalServerErrorResponse,
|
||||
createNotFoundResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { copilotChats } from '@/db/schema'
|
||||
@@ -23,19 +29,19 @@ const UpdateMessagesSchema = z.object({
|
||||
})
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const tracker = createRequestTracker()
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
const { chatId, messages } = UpdateMessagesSchema.parse(body)
|
||||
|
||||
logger.info(`[${requestId}] Updating chat messages`, {
|
||||
userId: session.user.id,
|
||||
logger.info(`[${tracker.requestId}] Updating chat messages`, {
|
||||
userId,
|
||||
chatId,
|
||||
messageCount: messages.length,
|
||||
})
|
||||
@@ -44,11 +50,11 @@ export async function POST(req: NextRequest) {
|
||||
const [chat] = await db
|
||||
.select()
|
||||
.from(copilotChats)
|
||||
.where(and(eq(copilotChats.id, chatId), eq(copilotChats.userId, session.user.id)))
|
||||
.where(and(eq(copilotChats.id, chatId), eq(copilotChats.userId, userId)))
|
||||
.limit(1)
|
||||
|
||||
if (!chat) {
|
||||
return NextResponse.json({ error: 'Chat not found or unauthorized' }, { status: 404 })
|
||||
return createNotFoundResponse('Chat not found or unauthorized')
|
||||
}
|
||||
|
||||
// Update chat with new messages
|
||||
@@ -60,7 +66,7 @@ export async function POST(req: NextRequest) {
|
||||
})
|
||||
.where(eq(copilotChats.id, chatId))
|
||||
|
||||
logger.info(`[${requestId}] Successfully updated chat messages`, {
|
||||
logger.info(`[${tracker.requestId}] Successfully updated chat messages`, {
|
||||
chatId,
|
||||
newMessageCount: messages.length,
|
||||
})
|
||||
@@ -70,7 +76,7 @@ export async function POST(req: NextRequest) {
|
||||
messageCount: messages.length,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error updating chat messages:`, error)
|
||||
return NextResponse.json({ error: 'Failed to update chat messages' }, { status: 500 })
|
||||
logger.error(`[${tracker.requestId}] Error updating chat messages:`, error)
|
||||
return createInternalServerErrorResponse('Failed to update chat messages')
|
||||
}
|
||||
}
|
||||
|
||||
778
apps/sim/app/api/copilot/checkpoints/revert/route.test.ts
Normal file
778
apps/sim/app/api/copilot/checkpoints/revert/route.test.ts
Normal file
@@ -0,0 +1,778 @@
|
||||
/**
|
||||
* Tests for copilot checkpoints revert API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockAuth,
|
||||
mockCryptoUuid,
|
||||
setupCommonApiMocks,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
describe('Copilot Checkpoints Revert API Route', () => {
|
||||
const mockSelect = vi.fn()
|
||||
const mockFrom = vi.fn()
|
||||
const mockWhere = vi.fn()
|
||||
const mockThen = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
setupCommonApiMocks()
|
||||
mockCryptoUuid()
|
||||
|
||||
mockSelect.mockReturnValue({ from: mockFrom })
|
||||
mockFrom.mockReturnValue({ where: mockWhere })
|
||||
mockWhere.mockReturnValue({ then: mockThen })
|
||||
mockThen.mockResolvedValue(null) // Default: no data found
|
||||
|
||||
vi.doMock('@/db', () => ({
|
||||
db: {
|
||||
select: mockSelect,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('@/db/schema', () => ({
|
||||
workflowCheckpoints: {
|
||||
id: 'id',
|
||||
userId: 'userId',
|
||||
workflowId: 'workflowId',
|
||||
workflowState: 'workflowState',
|
||||
},
|
||||
workflow: {
|
||||
id: 'id',
|
||||
userId: 'userId',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('drizzle-orm', () => ({
|
||||
and: vi.fn((...conditions) => ({ conditions, type: 'and' })),
|
||||
eq: vi.fn((field, value) => ({ field, value, type: 'eq' })),
|
||||
}))
|
||||
|
||||
global.fetch = vi.fn()
|
||||
|
||||
vi.spyOn(Date, 'now').mockReturnValue(1640995200000)
|
||||
|
||||
const originalDate = Date
|
||||
vi.spyOn(global, 'Date').mockImplementation(((...args: any[]) => {
|
||||
if (args.length === 0) {
|
||||
const mockDate = new originalDate('2024-01-01T00:00:00.000Z')
|
||||
return mockDate
|
||||
}
|
||||
if (args.length === 1) {
|
||||
return new originalDate(args[0])
|
||||
}
|
||||
return new originalDate(args[0], args[1], args[2], args[3], args[4], args[5], args[6])
|
||||
}) as any)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 500 for invalid request body - missing checkpointId', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
// Missing checkpointId
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to revert to checkpoint')
|
||||
})
|
||||
|
||||
it('should return 500 for empty checkpointId', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: '',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to revert to checkpoint')
|
||||
})
|
||||
|
||||
it('should return 404 when checkpoint is not found', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock checkpoint not found
|
||||
mockThen.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'non-existent-checkpoint',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Checkpoint not found or access denied')
|
||||
})
|
||||
|
||||
it('should return 404 when checkpoint belongs to different user', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock checkpoint not found (due to user mismatch in query)
|
||||
mockThen.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'other-user-checkpoint',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Checkpoint not found or access denied')
|
||||
})
|
||||
|
||||
it('should return 404 when workflow is not found', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock checkpoint found but workflow not found
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-123',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: { blocks: {}, edges: [] },
|
||||
}
|
||||
|
||||
mockThen
|
||||
.mockResolvedValueOnce(mockCheckpoint) // Checkpoint found
|
||||
.mockResolvedValueOnce(undefined) // Workflow not found
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Workflow not found')
|
||||
})
|
||||
|
||||
it('should return 401 when workflow belongs to different user', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock checkpoint found but workflow belongs to different user
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-123',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: { blocks: {}, edges: [] },
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'different-user',
|
||||
}
|
||||
|
||||
mockThen
|
||||
.mockResolvedValueOnce(mockCheckpoint) // Checkpoint found
|
||||
.mockResolvedValueOnce(mockWorkflow) // Workflow found but different user
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should successfully revert checkpoint with basic workflow state', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-123',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: {
|
||||
blocks: { block1: { type: 'start' } },
|
||||
edges: [{ from: 'block1', to: 'block2' }],
|
||||
loops: {},
|
||||
parallels: {},
|
||||
isDeployed: true,
|
||||
deploymentStatuses: { production: 'deployed' },
|
||||
hasActiveWebhook: false,
|
||||
},
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen
|
||||
.mockResolvedValueOnce(mockCheckpoint) // Checkpoint found
|
||||
.mockResolvedValueOnce(mockWorkflow) // Workflow found
|
||||
|
||||
// Mock successful state API call
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints/revert', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Cookie: 'session=test-session',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
checkpointId: 'checkpoint-123',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
workflowId: 'workflow-456',
|
||||
checkpointId: 'checkpoint-123',
|
||||
revertedAt: '2024-01-01T00:00:00.000Z',
|
||||
checkpoint: {
|
||||
id: 'checkpoint-123',
|
||||
workflowState: {
|
||||
blocks: { block1: { type: 'start' } },
|
||||
edges: [{ from: 'block1', to: 'block2' }],
|
||||
loops: {},
|
||||
parallels: {},
|
||||
isDeployed: true,
|
||||
deploymentStatuses: { production: 'deployed' },
|
||||
hasActiveWebhook: false,
|
||||
lastSaved: 1640995200000,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Verify fetch was called with correct parameters
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:3000/api/workflows/workflow-456/state',
|
||||
{
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Cookie: 'session=test-session',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
blocks: { block1: { type: 'start' } },
|
||||
edges: [{ from: 'block1', to: 'block2' }],
|
||||
loops: {},
|
||||
parallels: {},
|
||||
isDeployed: true,
|
||||
deploymentStatuses: { production: 'deployed' },
|
||||
hasActiveWebhook: false,
|
||||
lastSaved: 1640995200000,
|
||||
}),
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle checkpoint state with valid deployedAt date', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-with-date',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: {
|
||||
blocks: {},
|
||||
edges: [],
|
||||
deployedAt: '2024-01-01T12:00:00.000Z',
|
||||
isDeployed: true,
|
||||
},
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen.mockResolvedValueOnce(mockCheckpoint).mockResolvedValueOnce(mockWorkflow)
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-with-date',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.checkpoint.workflowState.deployedAt).toBeDefined()
|
||||
expect(responseData.checkpoint.workflowState.deployedAt).toEqual('2024-01-01T12:00:00.000Z')
|
||||
})
|
||||
|
||||
it('should handle checkpoint state with invalid deployedAt date', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-invalid-date',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: {
|
||||
blocks: {},
|
||||
edges: [],
|
||||
deployedAt: 'invalid-date',
|
||||
isDeployed: true,
|
||||
},
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen.mockResolvedValueOnce(mockCheckpoint).mockResolvedValueOnce(mockWorkflow)
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-invalid-date',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
// Invalid date should be filtered out
|
||||
expect(responseData.checkpoint.workflowState.deployedAt).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle checkpoint state with null/undefined values', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-null-values',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: {
|
||||
blocks: null,
|
||||
edges: undefined,
|
||||
loops: null,
|
||||
parallels: undefined,
|
||||
deploymentStatuses: null,
|
||||
},
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen.mockResolvedValueOnce(mockCheckpoint).mockResolvedValueOnce(mockWorkflow)
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-null-values',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
|
||||
// Null/undefined values should be replaced with defaults
|
||||
expect(responseData.checkpoint.workflowState).toEqual({
|
||||
blocks: {},
|
||||
edges: [],
|
||||
loops: {},
|
||||
parallels: {},
|
||||
isDeployed: false,
|
||||
deploymentStatuses: {},
|
||||
hasActiveWebhook: false,
|
||||
lastSaved: 1640995200000,
|
||||
})
|
||||
})
|
||||
|
||||
it('should return 500 when state API call fails', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-123',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: { blocks: {}, edges: [] },
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen
|
||||
.mockResolvedValueOnce(mockCheckpoint)
|
||||
.mockResolvedValueOnce(mockWorkflow)
|
||||
|
||||
// Mock failed state API call
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: false,
|
||||
text: () => Promise.resolve('State validation failed'),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to revert workflow to checkpoint')
|
||||
})
|
||||
|
||||
it('should handle database errors during checkpoint lookup', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error
|
||||
mockThen.mockRejectedValueOnce(new Error('Database connection failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to revert to checkpoint')
|
||||
})
|
||||
|
||||
it('should handle database errors during workflow lookup', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-123',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: { blocks: {}, edges: [] },
|
||||
}
|
||||
|
||||
mockThen
|
||||
.mockResolvedValueOnce(mockCheckpoint) // Checkpoint found
|
||||
.mockRejectedValueOnce(new Error('Database error during workflow lookup')) // Workflow lookup fails
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to revert to checkpoint')
|
||||
})
|
||||
|
||||
it('should handle fetch network errors', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-123',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: { blocks: {}, edges: [] },
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen
|
||||
.mockResolvedValueOnce(mockCheckpoint)
|
||||
.mockResolvedValueOnce(mockWorkflow)
|
||||
|
||||
// Mock fetch network error
|
||||
|
||||
;(global.fetch as any).mockRejectedValue(new Error('Network error'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-123',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to revert to checkpoint')
|
||||
})
|
||||
|
||||
it('should handle JSON parsing errors in request body', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Create a request with invalid JSON
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints/revert', {
|
||||
method: 'POST',
|
||||
body: '{invalid-json',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to revert to checkpoint')
|
||||
})
|
||||
|
||||
it('should forward cookies to state API call', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-123',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: { blocks: {}, edges: [] },
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen.mockResolvedValueOnce(mockCheckpoint).mockResolvedValueOnce(mockWorkflow)
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints/revert', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Cookie: 'session=test-session; auth=token123',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
checkpointId: 'checkpoint-123',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
await POST(req)
|
||||
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:3000/api/workflows/workflow-456/state',
|
||||
{
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Cookie: 'session=test-session; auth=token123',
|
||||
},
|
||||
body: expect.any(String),
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle missing cookies gracefully', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-123',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: { blocks: {}, edges: [] },
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen.mockResolvedValueOnce(mockCheckpoint).mockResolvedValueOnce(mockWorkflow)
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints/revert', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
// No Cookie header
|
||||
},
|
||||
body: JSON.stringify({
|
||||
checkpointId: 'checkpoint-123',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'http://localhost:3000/api/workflows/workflow-456/state',
|
||||
{
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Cookie: '', // Empty string when no cookies
|
||||
},
|
||||
body: expect.any(String),
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle complex checkpoint state with all fields', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoint = {
|
||||
id: 'checkpoint-complex',
|
||||
workflowId: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
workflowState: {
|
||||
blocks: {
|
||||
start: { type: 'start', config: {} },
|
||||
http: { type: 'http', config: { url: 'https://api.example.com' } },
|
||||
end: { type: 'end', config: {} },
|
||||
},
|
||||
edges: [
|
||||
{ from: 'start', to: 'http' },
|
||||
{ from: 'http', to: 'end' },
|
||||
],
|
||||
loops: {
|
||||
loop1: { condition: 'true', iterations: 3 },
|
||||
},
|
||||
parallels: {
|
||||
parallel1: { branches: ['branch1', 'branch2'] },
|
||||
},
|
||||
isDeployed: true,
|
||||
deploymentStatuses: {
|
||||
production: 'deployed',
|
||||
staging: 'pending',
|
||||
},
|
||||
hasActiveWebhook: true,
|
||||
deployedAt: '2024-01-01T10:00:00.000Z',
|
||||
},
|
||||
}
|
||||
|
||||
const mockWorkflow = {
|
||||
id: 'workflow-456',
|
||||
userId: 'user-123',
|
||||
}
|
||||
|
||||
mockThen.mockResolvedValueOnce(mockCheckpoint).mockResolvedValueOnce(mockWorkflow)
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ success: true }),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
checkpointId: 'checkpoint-complex',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/revert/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.checkpoint.workflowState).toEqual({
|
||||
blocks: {
|
||||
start: { type: 'start', config: {} },
|
||||
http: { type: 'http', config: { url: 'https://api.example.com' } },
|
||||
end: { type: 'end', config: {} },
|
||||
},
|
||||
edges: [
|
||||
{ from: 'start', to: 'http' },
|
||||
{ from: 'http', to: 'end' },
|
||||
],
|
||||
loops: {
|
||||
loop1: { condition: 'true', iterations: 3 },
|
||||
},
|
||||
parallels: {
|
||||
parallel1: { branches: ['branch1', 'branch2'] },
|
||||
},
|
||||
isDeployed: true,
|
||||
deploymentStatuses: {
|
||||
production: 'deployed',
|
||||
staging: 'pending',
|
||||
},
|
||||
hasActiveWebhook: true,
|
||||
deployedAt: '2024-01-01T10:00:00.000Z',
|
||||
lastSaved: 1640995200000,
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,7 +1,13 @@
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createInternalServerErrorResponse,
|
||||
createNotFoundResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { workflowCheckpoints, workflow as workflowTable } from '@/db/schema'
|
||||
@@ -17,33 +23,28 @@ const RevertCheckpointSchema = z.object({
|
||||
* Revert workflow to a specific checkpoint state
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const tracker = createRequestTracker()
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const body = await request.json()
|
||||
const { checkpointId } = RevertCheckpointSchema.parse(body)
|
||||
|
||||
logger.info(`[${requestId}] Reverting to checkpoint ${checkpointId}`)
|
||||
logger.info(`[${tracker.requestId}] Reverting to checkpoint ${checkpointId}`)
|
||||
|
||||
// Get the checkpoint and verify ownership
|
||||
const checkpoint = await db
|
||||
.select()
|
||||
.from(workflowCheckpoints)
|
||||
.where(
|
||||
and(
|
||||
eq(workflowCheckpoints.id, checkpointId),
|
||||
eq(workflowCheckpoints.userId, session.user.id)
|
||||
)
|
||||
)
|
||||
.where(and(eq(workflowCheckpoints.id, checkpointId), eq(workflowCheckpoints.userId, userId)))
|
||||
.then((rows) => rows[0])
|
||||
|
||||
if (!checkpoint) {
|
||||
return NextResponse.json({ error: 'Checkpoint not found or access denied' }, { status: 404 })
|
||||
return createNotFoundResponse('Checkpoint not found or access denied')
|
||||
}
|
||||
|
||||
// Verify user still has access to the workflow
|
||||
@@ -54,11 +55,11 @@ export async function POST(request: NextRequest) {
|
||||
.then((rows) => rows[0])
|
||||
|
||||
if (!workflowData) {
|
||||
return NextResponse.json({ error: 'Workflow not found' }, { status: 404 })
|
||||
return createNotFoundResponse('Workflow not found')
|
||||
}
|
||||
|
||||
if (workflowData.userId !== session.user.id) {
|
||||
return NextResponse.json({ error: 'Access denied to workflow' }, { status: 403 })
|
||||
if (workflowData.userId !== userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
// Apply the checkpoint state to the workflow using the existing state endpoint
|
||||
@@ -83,7 +84,7 @@ export async function POST(request: NextRequest) {
|
||||
: {}),
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Applying cleaned checkpoint state`, {
|
||||
logger.info(`[${tracker.requestId}] Applying cleaned checkpoint state`, {
|
||||
blocksCount: Object.keys(cleanedState.blocks).length,
|
||||
edgesCount: cleanedState.edges.length,
|
||||
hasDeployedAt: !!cleanedState.deployedAt,
|
||||
@@ -104,7 +105,7 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
if (!stateResponse.ok) {
|
||||
const errorData = await stateResponse.text()
|
||||
logger.error(`[${requestId}] Failed to apply checkpoint state: ${errorData}`)
|
||||
logger.error(`[${tracker.requestId}] Failed to apply checkpoint state: ${errorData}`)
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to revert workflow to checkpoint' },
|
||||
{ status: 500 }
|
||||
@@ -113,7 +114,7 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
const result = await stateResponse.json()
|
||||
logger.info(
|
||||
`[${requestId}] Successfully reverted workflow ${checkpoint.workflowId} to checkpoint ${checkpointId}`
|
||||
`[${tracker.requestId}] Successfully reverted workflow ${checkpoint.workflowId} to checkpoint ${checkpointId}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
@@ -127,7 +128,7 @@ export async function POST(request: NextRequest) {
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error reverting to checkpoint:`, error)
|
||||
return NextResponse.json({ error: 'Failed to revert to checkpoint' }, { status: 500 })
|
||||
logger.error(`[${tracker.requestId}] Error reverting to checkpoint:`, error)
|
||||
return createInternalServerErrorResponse('Failed to revert to checkpoint')
|
||||
}
|
||||
}
|
||||
|
||||
438
apps/sim/app/api/copilot/checkpoints/route.test.ts
Normal file
438
apps/sim/app/api/copilot/checkpoints/route.test.ts
Normal file
@@ -0,0 +1,438 @@
|
||||
/**
|
||||
* Tests for copilot checkpoints API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockAuth,
|
||||
mockCryptoUuid,
|
||||
setupCommonApiMocks,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
describe('Copilot Checkpoints API Route', () => {
|
||||
const mockSelect = vi.fn()
|
||||
const mockFrom = vi.fn()
|
||||
const mockWhere = vi.fn()
|
||||
const mockLimit = vi.fn()
|
||||
const mockOrderBy = vi.fn()
|
||||
const mockInsert = vi.fn()
|
||||
const mockValues = vi.fn()
|
||||
const mockReturning = vi.fn()
|
||||
|
||||
const mockCopilotChats = { id: 'id', userId: 'userId' }
|
||||
const mockWorkflowCheckpoints = {
|
||||
id: 'id',
|
||||
userId: 'userId',
|
||||
workflowId: 'workflowId',
|
||||
chatId: 'chatId',
|
||||
messageId: 'messageId',
|
||||
createdAt: 'createdAt',
|
||||
updatedAt: 'updatedAt',
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
setupCommonApiMocks()
|
||||
mockCryptoUuid()
|
||||
|
||||
mockSelect.mockReturnValue({ from: mockFrom })
|
||||
mockFrom.mockReturnValue({ where: mockWhere })
|
||||
mockWhere.mockReturnValue({
|
||||
orderBy: mockOrderBy,
|
||||
limit: mockLimit,
|
||||
})
|
||||
mockOrderBy.mockResolvedValue([])
|
||||
mockLimit.mockResolvedValue([])
|
||||
mockInsert.mockReturnValue({ values: mockValues })
|
||||
mockValues.mockReturnValue({ returning: mockReturning })
|
||||
|
||||
vi.doMock('@/db', () => ({
|
||||
db: {
|
||||
select: mockSelect,
|
||||
insert: mockInsert,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('@/db/schema', () => ({
|
||||
copilotChats: mockCopilotChats,
|
||||
workflowCheckpoints: mockWorkflowCheckpoints,
|
||||
}))
|
||||
|
||||
vi.doMock('drizzle-orm', () => ({
|
||||
and: vi.fn((...conditions) => ({ conditions, type: 'and' })),
|
||||
eq: vi.fn((field, value) => ({ field, value, type: 'eq' })),
|
||||
desc: vi.fn((field) => ({ field, type: 'desc' })),
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
workflowState: '{"blocks": []}',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 500 for invalid request body', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
// Missing required fields
|
||||
workflowId: 'workflow-123',
|
||||
// Missing chatId and workflowState
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to create checkpoint')
|
||||
})
|
||||
|
||||
it('should return 400 when chat not found or unauthorized', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat not found
|
||||
mockLimit.mockResolvedValue([])
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
workflowState: '{"blocks": []}',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Chat not found or unauthorized')
|
||||
})
|
||||
|
||||
it('should return 400 for invalid workflow state JSON', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const chat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
}
|
||||
mockLimit.mockResolvedValue([chat])
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
workflowState: 'invalid-json', // Invalid JSON
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Invalid workflow state JSON')
|
||||
})
|
||||
|
||||
it('should successfully create a checkpoint', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const chat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
}
|
||||
mockLimit.mockResolvedValue([chat])
|
||||
|
||||
// Mock successful checkpoint creation
|
||||
const checkpoint = {
|
||||
id: 'checkpoint-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: 'message-123',
|
||||
createdAt: new Date('2024-01-01'),
|
||||
updatedAt: new Date('2024-01-01'),
|
||||
}
|
||||
mockReturning.mockResolvedValue([checkpoint])
|
||||
|
||||
const workflowState = { blocks: [], connections: [] }
|
||||
const req = createMockRequest('POST', {
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: 'message-123',
|
||||
workflowState: JSON.stringify(workflowState),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
checkpoint: {
|
||||
id: 'checkpoint-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: 'message-123',
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
})
|
||||
|
||||
// Verify database operations
|
||||
expect(mockInsert).toHaveBeenCalled()
|
||||
expect(mockValues).toHaveBeenCalledWith({
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: 'message-123',
|
||||
workflowState: workflowState, // Should be parsed JSON object
|
||||
})
|
||||
})
|
||||
|
||||
it('should create checkpoint without messageId', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const chat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
}
|
||||
mockLimit.mockResolvedValue([chat])
|
||||
|
||||
// Mock successful checkpoint creation
|
||||
const checkpoint = {
|
||||
id: 'checkpoint-123',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: undefined,
|
||||
createdAt: new Date('2024-01-01'),
|
||||
updatedAt: new Date('2024-01-01'),
|
||||
}
|
||||
mockReturning.mockResolvedValue([checkpoint])
|
||||
|
||||
const workflowState = { blocks: [] }
|
||||
const req = createMockRequest('POST', {
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
// No messageId provided
|
||||
workflowState: JSON.stringify(workflowState),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(true)
|
||||
expect(responseData.checkpoint.messageId).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle database errors during checkpoint creation', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const chat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
}
|
||||
mockLimit.mockResolvedValue([chat])
|
||||
|
||||
// Mock database error
|
||||
mockReturning.mockRejectedValue(new Error('Database insert failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
workflowState: '{"blocks": []}',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to create checkpoint')
|
||||
})
|
||||
|
||||
it('should handle database errors during chat lookup', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error during chat lookup
|
||||
mockLimit.mockRejectedValue(new Error('Database query failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
workflowState: '{"blocks": []}',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to create checkpoint')
|
||||
})
|
||||
})
|
||||
|
||||
describe('GET', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints?chatId=chat-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 400 when chatId is missing', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('chatId is required')
|
||||
})
|
||||
|
||||
it('should return checkpoints for authenticated user and chat', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const mockCheckpoints = [
|
||||
{
|
||||
id: 'checkpoint-1',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: 'message-1',
|
||||
createdAt: new Date('2024-01-01'),
|
||||
updatedAt: new Date('2024-01-01'),
|
||||
},
|
||||
{
|
||||
id: 'checkpoint-2',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: 'message-2',
|
||||
createdAt: new Date('2024-01-02'),
|
||||
updatedAt: new Date('2024-01-02'),
|
||||
},
|
||||
]
|
||||
|
||||
mockOrderBy.mockResolvedValue(mockCheckpoints)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints?chatId=chat-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
checkpoints: [
|
||||
{
|
||||
id: 'checkpoint-1',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: 'message-1',
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
},
|
||||
{
|
||||
id: 'checkpoint-2',
|
||||
userId: 'user-123',
|
||||
workflowId: 'workflow-123',
|
||||
chatId: 'chat-123',
|
||||
messageId: 'message-2',
|
||||
createdAt: '2024-01-02T00:00:00.000Z',
|
||||
updatedAt: '2024-01-02T00:00:00.000Z',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
// Verify database query was made correctly
|
||||
expect(mockSelect).toHaveBeenCalled()
|
||||
expect(mockWhere).toHaveBeenCalled()
|
||||
expect(mockOrderBy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle database errors when fetching checkpoints', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error
|
||||
mockOrderBy.mockRejectedValue(new Error('Database query failed'))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints?chatId=chat-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to fetch checkpoints')
|
||||
})
|
||||
|
||||
it('should return empty array when no checkpoints found', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
mockOrderBy.mockResolvedValue([])
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/checkpoints?chatId=chat-123')
|
||||
|
||||
const { GET } = await import('@/app/api/copilot/checkpoints/route')
|
||||
const response = await GET(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
checkpoints: [],
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,7 +1,13 @@
|
||||
import { and, desc, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createBadRequestResponse,
|
||||
createInternalServerErrorResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { copilotChats, workflowCheckpoints } from '@/db/schema'
|
||||
@@ -20,19 +26,19 @@ const CreateCheckpointSchema = z.object({
|
||||
* Create a new checkpoint with JSON workflow state
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const tracker = createRequestTracker()
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
const { workflowId, chatId, messageId, workflowState } = CreateCheckpointSchema.parse(body)
|
||||
|
||||
logger.info(`[${requestId}] Creating workflow checkpoint`, {
|
||||
userId: session.user.id,
|
||||
logger.info(`[${tracker.requestId}] Creating workflow checkpoint`, {
|
||||
userId,
|
||||
workflowId,
|
||||
chatId,
|
||||
messageId,
|
||||
@@ -46,11 +52,11 @@ export async function POST(req: NextRequest) {
|
||||
const [chat] = await db
|
||||
.select()
|
||||
.from(copilotChats)
|
||||
.where(and(eq(copilotChats.id, chatId), eq(copilotChats.userId, session.user.id)))
|
||||
.where(and(eq(copilotChats.id, chatId), eq(copilotChats.userId, userId)))
|
||||
.limit(1)
|
||||
|
||||
if (!chat) {
|
||||
return NextResponse.json({ error: 'Chat not found or unauthorized' }, { status: 404 })
|
||||
return createBadRequestResponse('Chat not found or unauthorized')
|
||||
}
|
||||
|
||||
// Parse the workflow state to validate it's valid JSON
|
||||
@@ -58,14 +64,14 @@ export async function POST(req: NextRequest) {
|
||||
try {
|
||||
parsedWorkflowState = JSON.parse(workflowState)
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Invalid workflow state JSON' }, { status: 400 })
|
||||
return createBadRequestResponse('Invalid workflow state JSON')
|
||||
}
|
||||
|
||||
// Create checkpoint with JSON workflow state
|
||||
const [checkpoint] = await db
|
||||
.insert(workflowCheckpoints)
|
||||
.values({
|
||||
userId: session.user.id,
|
||||
userId,
|
||||
workflowId,
|
||||
chatId,
|
||||
messageId,
|
||||
@@ -73,7 +79,7 @@ export async function POST(req: NextRequest) {
|
||||
})
|
||||
.returning()
|
||||
|
||||
logger.info(`[${requestId}] Workflow checkpoint created successfully`, {
|
||||
logger.info(`[${tracker.requestId}] Workflow checkpoint created successfully`, {
|
||||
checkpointId: checkpoint.id,
|
||||
savedData: {
|
||||
checkpointId: checkpoint.id,
|
||||
@@ -98,8 +104,8 @@ export async function POST(req: NextRequest) {
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to create workflow checkpoint:`, error)
|
||||
return NextResponse.json({ error: 'Failed to create checkpoint' }, { status: 500 })
|
||||
logger.error(`[${tracker.requestId}] Failed to create workflow checkpoint:`, error)
|
||||
return createInternalServerErrorResponse('Failed to create checkpoint')
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,23 +114,23 @@ export async function POST(req: NextRequest) {
|
||||
* Retrieve workflow checkpoints for a chat
|
||||
*/
|
||||
export async function GET(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const tracker = createRequestTracker()
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const { searchParams } = new URL(req.url)
|
||||
const chatId = searchParams.get('chatId')
|
||||
|
||||
if (!chatId) {
|
||||
return NextResponse.json({ error: 'chatId is required' }, { status: 400 })
|
||||
return createBadRequestResponse('chatId is required')
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Fetching workflow checkpoints for chat`, {
|
||||
userId: session.user.id,
|
||||
logger.info(`[${tracker.requestId}] Fetching workflow checkpoints for chat`, {
|
||||
userId,
|
||||
chatId,
|
||||
})
|
||||
|
||||
@@ -140,19 +146,17 @@ export async function GET(req: NextRequest) {
|
||||
updatedAt: workflowCheckpoints.updatedAt,
|
||||
})
|
||||
.from(workflowCheckpoints)
|
||||
.where(
|
||||
and(eq(workflowCheckpoints.chatId, chatId), eq(workflowCheckpoints.userId, session.user.id))
|
||||
)
|
||||
.where(and(eq(workflowCheckpoints.chatId, chatId), eq(workflowCheckpoints.userId, userId)))
|
||||
.orderBy(desc(workflowCheckpoints.createdAt))
|
||||
|
||||
logger.info(`[${requestId}] Retrieved ${checkpoints.length} workflow checkpoints`)
|
||||
logger.info(`[${tracker.requestId}] Retrieved ${checkpoints.length} workflow checkpoints`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
checkpoints,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to fetch workflow checkpoints:`, error)
|
||||
return NextResponse.json({ error: 'Failed to fetch checkpoints' }, { status: 500 })
|
||||
logger.error(`[${tracker.requestId}] Failed to fetch workflow checkpoints:`, error)
|
||||
return createInternalServerErrorResponse('Failed to fetch checkpoints')
|
||||
}
|
||||
}
|
||||
|
||||
393
apps/sim/app/api/copilot/confirm/route.test.ts
Normal file
393
apps/sim/app/api/copilot/confirm/route.test.ts
Normal file
@@ -0,0 +1,393 @@
|
||||
/**
|
||||
* Tests for copilot confirm API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockAuth,
|
||||
mockCryptoUuid,
|
||||
setupCommonApiMocks,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
describe('Copilot Confirm API Route', () => {
|
||||
const mockRedisExists = vi.fn()
|
||||
const mockRedisSet = vi.fn()
|
||||
const mockGetRedisClient = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
setupCommonApiMocks()
|
||||
mockCryptoUuid()
|
||||
|
||||
const mockRedisClient = {
|
||||
exists: mockRedisExists,
|
||||
set: mockRedisSet,
|
||||
}
|
||||
|
||||
mockGetRedisClient.mockReturnValue(mockRedisClient)
|
||||
mockRedisExists.mockResolvedValue(1) // Tool call exists by default
|
||||
mockRedisSet.mockResolvedValue('OK')
|
||||
|
||||
vi.doMock('@/lib/redis', () => ({
|
||||
getRedisClient: mockGetRedisClient,
|
||||
}))
|
||||
|
||||
// Mock setTimeout to control polling behavior
|
||||
vi.spyOn(global, 'setTimeout').mockImplementation((callback, _delay) => {
|
||||
// Immediately call callback to avoid delays
|
||||
if (typeof callback === 'function') {
|
||||
setImmediate(callback)
|
||||
}
|
||||
return setTimeout(() => {}, 0) as any
|
||||
})
|
||||
|
||||
// Mock Date.now to control timeout behavior
|
||||
let mockTime = 1640995200000
|
||||
vi.spyOn(Date, 'now').mockImplementation(() => {
|
||||
// Increment time rapidly to trigger timeout for non-existent keys
|
||||
mockTime += 10000 // Add 10 seconds each call
|
||||
return mockTime
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
it('should return 401 when user is not authenticated', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setUnauthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-123',
|
||||
status: 'success',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({ error: 'Unauthorized' })
|
||||
})
|
||||
|
||||
it('should return 400 for invalid request body - missing toolCallId', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
status: 'success',
|
||||
// Missing toolCallId
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toContain('Required')
|
||||
})
|
||||
|
||||
it('should return 400 for invalid request body - missing status', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-123',
|
||||
// Missing status
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toContain('Invalid request data')
|
||||
})
|
||||
|
||||
it('should return 400 for invalid status value', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-123',
|
||||
status: 'invalid-status',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toContain('Invalid notification status')
|
||||
})
|
||||
|
||||
it('should successfully confirm tool call with success status', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-123',
|
||||
status: 'success',
|
||||
message: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
message: 'Tool executed successfully',
|
||||
toolCallId: 'tool-call-123',
|
||||
status: 'success',
|
||||
})
|
||||
|
||||
// Verify Redis operations were called
|
||||
expect(mockRedisExists).toHaveBeenCalled()
|
||||
expect(mockRedisSet).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should successfully confirm tool call with error status', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-456',
|
||||
status: 'error',
|
||||
message: 'Tool execution failed',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
message: 'Tool execution failed',
|
||||
toolCallId: 'tool-call-456',
|
||||
status: 'error',
|
||||
})
|
||||
|
||||
expect(mockRedisSet).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should successfully confirm tool call with accepted status', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-789',
|
||||
status: 'accepted',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
message: 'Tool call tool-call-789 has been accepted',
|
||||
toolCallId: 'tool-call-789',
|
||||
status: 'accepted',
|
||||
})
|
||||
|
||||
expect(mockRedisSet).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should successfully confirm tool call with rejected status', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-101',
|
||||
status: 'rejected',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
message: 'Tool call tool-call-101 has been rejected',
|
||||
toolCallId: 'tool-call-101',
|
||||
status: 'rejected',
|
||||
})
|
||||
})
|
||||
|
||||
it('should successfully confirm tool call with background status', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-bg',
|
||||
status: 'background',
|
||||
message: 'Moved to background execution',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
message: 'Moved to background execution',
|
||||
toolCallId: 'tool-call-bg',
|
||||
status: 'background',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return 400 when Redis client is not available', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock Redis client as unavailable
|
||||
mockGetRedisClient.mockReturnValue(null)
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-123',
|
||||
status: 'success',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update tool call status or tool call not found')
|
||||
})
|
||||
|
||||
it('should return 400 when tool call is not found in Redis', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock tool call as not existing in Redis
|
||||
mockRedisExists.mockResolvedValue(0)
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'non-existent-tool',
|
||||
status: 'success',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update tool call status or tool call not found')
|
||||
}, 10000) // 10 second timeout for this specific test
|
||||
|
||||
it('should handle Redis errors gracefully', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock Redis operations to throw an error
|
||||
mockRedisExists.mockRejectedValue(new Error('Redis connection failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-123',
|
||||
status: 'success',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update tool call status or tool call not found')
|
||||
})
|
||||
|
||||
it('should handle Redis set operation failure', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Tool call exists but set operation fails
|
||||
mockRedisExists.mockResolvedValue(1)
|
||||
mockRedisSet.mockRejectedValue(new Error('Redis set failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: 'tool-call-123',
|
||||
status: 'success',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toBe('Failed to update tool call status or tool call not found')
|
||||
})
|
||||
|
||||
it('should handle JSON parsing errors in request body', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Create a request with invalid JSON
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/confirm', {
|
||||
method: 'POST',
|
||||
body: '{invalid-json',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toContain('JSON')
|
||||
})
|
||||
|
||||
it('should validate empty toolCallId', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: '',
|
||||
status: 'success',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.error).toContain('Tool call ID is required')
|
||||
})
|
||||
|
||||
it('should handle all valid status types', async () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
const validStatuses = ['success', 'error', 'accepted', 'rejected', 'background']
|
||||
|
||||
for (const status of validStatuses) {
|
||||
const req = createMockRequest('POST', {
|
||||
toolCallId: `tool-call-${status}`,
|
||||
status,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/confirm/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(true)
|
||||
expect(responseData.status).toBe(status)
|
||||
expect(responseData.toolCallId).toBe(`tool-call-${status}`)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,9 +1,15 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createBadRequestResponse,
|
||||
createInternalServerErrorResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
type NotificationStatus,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getRedisClient } from '@/lib/redis'
|
||||
import type { NotificationStatus } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/copilot/lib/tools/types'
|
||||
|
||||
const logger = createLogger('CopilotConfirmAPI')
|
||||
|
||||
@@ -98,22 +104,21 @@ async function updateToolCallStatus(
|
||||
* Update tool call status (Accept/Reject)
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID()
|
||||
const startTime = Date.now()
|
||||
const tracker = createRequestTracker()
|
||||
|
||||
try {
|
||||
// Authenticate user (same pattern as copilot chat)
|
||||
const session = await getSession()
|
||||
const authenticatedUserId: string | null = session?.user?.id || null
|
||||
// Authenticate user using consolidated helper
|
||||
const { userId: authenticatedUserId, isAuthenticated } =
|
||||
await authenticateCopilotRequestSessionOnly()
|
||||
|
||||
if (!authenticatedUserId) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
if (!isAuthenticated) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
const { toolCallId, status, message } = ConfirmationSchema.parse(body)
|
||||
|
||||
logger.info(`[${requestId}] Tool call confirmation request`, {
|
||||
logger.info(`[${tracker.requestId}] Tool call confirmation request`, {
|
||||
userId: authenticatedUserId,
|
||||
toolCallId,
|
||||
status,
|
||||
@@ -124,21 +129,18 @@ export async function POST(req: NextRequest) {
|
||||
const updated = await updateToolCallStatus(toolCallId, status, message)
|
||||
|
||||
if (!updated) {
|
||||
logger.error(`[${requestId}] Failed to update tool call status`, {
|
||||
logger.error(`[${tracker.requestId}] Failed to update tool call status`, {
|
||||
userId: authenticatedUserId,
|
||||
toolCallId,
|
||||
status,
|
||||
internalStatus: status,
|
||||
message,
|
||||
})
|
||||
return NextResponse.json(
|
||||
{ success: false, error: 'Failed to update tool call status or tool call not found' },
|
||||
{ status: 400 }
|
||||
)
|
||||
return createBadRequestResponse('Failed to update tool call status or tool call not found')
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`[${requestId}] Tool call confirmation completed`, {
|
||||
const duration = tracker.getDuration()
|
||||
logger.info(`[${tracker.requestId}] Tool call confirmation completed`, {
|
||||
userId: authenticatedUserId,
|
||||
toolCallId,
|
||||
status,
|
||||
@@ -148,36 +150,31 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: message || `Tool call ${toolCallId} has been ${status.toLowerCase()}ed`,
|
||||
message: message || `Tool call ${toolCallId} has been ${status.toLowerCase()}`,
|
||||
toolCallId,
|
||||
status,
|
||||
})
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
const duration = tracker.getDuration()
|
||||
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error(`[${requestId}] Request validation error:`, {
|
||||
logger.error(`[${tracker.requestId}] Request validation error:`, {
|
||||
duration,
|
||||
errors: error.errors,
|
||||
})
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: `Invalid request data: ${error.errors.map((e) => e.message).join(', ')}`,
|
||||
},
|
||||
{ status: 400 }
|
||||
return createBadRequestResponse(
|
||||
`Invalid request data: ${error.errors.map((e) => e.message).join(', ')}`
|
||||
)
|
||||
}
|
||||
|
||||
logger.error(`[${requestId}] Unexpected error:`, {
|
||||
logger.error(`[${tracker.requestId}] Unexpected error:`, {
|
||||
duration,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
})
|
||||
|
||||
return NextResponse.json(
|
||||
{ success: false, error: error instanceof Error ? error.message : 'Internal server error' },
|
||||
{ status: 500 }
|
||||
return createInternalServerErrorResponse(
|
||||
error instanceof Error ? error.message : 'Internal server error'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
751
apps/sim/app/api/copilot/methods/route.test.ts
Normal file
751
apps/sim/app/api/copilot/methods/route.test.ts
Normal file
@@ -0,0 +1,751 @@
|
||||
/**
|
||||
* Tests for copilot methods API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockCryptoUuid,
|
||||
setupCommonApiMocks,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
describe('Copilot Methods API Route', () => {
|
||||
const mockRedisGet = vi.fn()
|
||||
const mockRedisSet = vi.fn()
|
||||
const mockGetRedisClient = vi.fn()
|
||||
const mockToolRegistryHas = vi.fn()
|
||||
const mockToolRegistryGet = vi.fn()
|
||||
const mockToolRegistryExecute = vi.fn()
|
||||
const mockToolRegistryGetAvailableIds = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
setupCommonApiMocks()
|
||||
mockCryptoUuid()
|
||||
|
||||
// Mock Redis client
|
||||
const mockRedisClient = {
|
||||
get: mockRedisGet,
|
||||
set: mockRedisSet,
|
||||
}
|
||||
|
||||
mockGetRedisClient.mockReturnValue(mockRedisClient)
|
||||
mockRedisGet.mockResolvedValue(null)
|
||||
mockRedisSet.mockResolvedValue('OK')
|
||||
|
||||
vi.doMock('@/lib/redis', () => ({
|
||||
getRedisClient: mockGetRedisClient,
|
||||
}))
|
||||
|
||||
// Mock tool registry
|
||||
const mockToolRegistry = {
|
||||
has: mockToolRegistryHas,
|
||||
get: mockToolRegistryGet,
|
||||
execute: mockToolRegistryExecute,
|
||||
getAvailableIds: mockToolRegistryGetAvailableIds,
|
||||
}
|
||||
|
||||
mockToolRegistryHas.mockReturnValue(true)
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: false })
|
||||
mockToolRegistryExecute.mockResolvedValue({ success: true, data: 'Tool executed successfully' })
|
||||
mockToolRegistryGetAvailableIds.mockReturnValue(['test-tool', 'another-tool'])
|
||||
|
||||
vi.doMock('@/lib/copilot/tools/server-tools/registry', () => ({
|
||||
copilotToolRegistry: mockToolRegistry,
|
||||
}))
|
||||
|
||||
// Mock environment variables
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
INTERNAL_API_SECRET: 'test-secret-key',
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock setTimeout for polling
|
||||
vi.spyOn(global, 'setTimeout').mockImplementation((callback, _delay) => {
|
||||
if (typeof callback === 'function') {
|
||||
setImmediate(callback)
|
||||
}
|
||||
return setTimeout(() => {}, 0) as any
|
||||
})
|
||||
|
||||
// Mock Date.now for timeout control
|
||||
let mockTime = 1640995200000
|
||||
vi.spyOn(Date, 'now').mockImplementation(() => {
|
||||
mockTime += 1000 // Add 1 second each call
|
||||
return mockTime
|
||||
})
|
||||
|
||||
// Mock crypto.randomUUID for request IDs
|
||||
vi.spyOn(crypto, 'randomUUID').mockReturnValue('test-request-id')
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
it('should return 401 when API key is missing', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
methodId: 'test-tool',
|
||||
params: {},
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: false,
|
||||
error: 'API key required',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return 401 when API key is invalid', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'invalid-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'test-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: false,
|
||||
error: 'Invalid API key',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return 401 when internal API key is not configured', async () => {
|
||||
// Mock environment with no API key
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
INTERNAL_API_SECRET: undefined,
|
||||
},
|
||||
}))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'any-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'test-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: false,
|
||||
error: 'Internal API key not configured',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return 400 for invalid request body - missing methodId', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
params: {},
|
||||
// Missing methodId
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toContain('Required')
|
||||
})
|
||||
|
||||
it('should return 400 for empty methodId', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: '',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toContain('Method ID is required')
|
||||
})
|
||||
|
||||
it('should return 400 when tool is not found in registry', async () => {
|
||||
mockToolRegistryHas.mockReturnValue(false)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'unknown-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toContain('Unknown method: unknown-tool')
|
||||
expect(responseData.error).toContain('Available methods: test-tool, another-tool')
|
||||
})
|
||||
|
||||
it('should successfully execute a tool without interruption', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'test-tool',
|
||||
params: { key: 'value' },
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalledWith('test-tool', { key: 'value' })
|
||||
})
|
||||
|
||||
it('should handle tool execution with default empty params', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'test-tool',
|
||||
// No params provided
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalledWith('test-tool', {})
|
||||
})
|
||||
|
||||
it('should return 400 when tool requires interrupt but no toolCallId provided', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
// No toolCallId provided
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe(
|
||||
'This tool requires approval but no tool call ID was provided'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - user approval', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return accepted status immediately (simulate quick approval)
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'accepted', message: 'User approved' })
|
||||
)
|
||||
|
||||
// Reset Date.now mock to not trigger timeout
|
||||
let mockTime = 1640995200000
|
||||
vi.spyOn(Date, 'now').mockImplementation(() => {
|
||||
mockTime += 100 // Small increment to avoid timeout
|
||||
return mockTime
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: { key: 'value' },
|
||||
toolCallId: 'tool-call-123',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
// Verify Redis operations
|
||||
expect(mockRedisSet).toHaveBeenCalledWith(
|
||||
'tool_call:tool-call-123',
|
||||
expect.stringContaining('"status":"pending"'),
|
||||
'EX',
|
||||
86400
|
||||
)
|
||||
expect(mockRedisGet).toHaveBeenCalledWith('tool_call:tool-call-123')
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalledWith('interrupt-tool', { key: 'value' })
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - user rejection', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return rejected status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'rejected', message: 'User rejected' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-456',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200) // User rejection returns 200
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe(
|
||||
'The user decided to skip running this tool. This was a user decision.'
|
||||
)
|
||||
|
||||
// Tool should not be executed when rejected
|
||||
expect(mockToolRegistryExecute).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - error status', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return error status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'error', message: 'Tool execution failed' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-error',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Tool execution failed')
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - background status', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return background status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'background', message: 'Running in background' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-bg',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - success status', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return success status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'success', message: 'Completed successfully' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-success',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - timeout', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to never return a status (timeout scenario)
|
||||
mockRedisGet.mockResolvedValue(null)
|
||||
|
||||
// Mock Date.now to trigger timeout quickly
|
||||
let mockTime = 1640995200000
|
||||
vi.spyOn(Date, 'now').mockImplementation(() => {
|
||||
mockTime += 100000 // Add 100 seconds each call to trigger timeout
|
||||
return mockTime
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-timeout',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(408) // Request Timeout
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Tool execution request timed out')
|
||||
|
||||
expect(mockToolRegistryExecute).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle unexpected status in interrupt flow', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return unexpected status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'unknown-status', message: 'Unknown' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-unknown',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Unexpected tool call status: unknown-status')
|
||||
})
|
||||
|
||||
it('should handle Redis client unavailable for interrupt flow', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
mockGetRedisClient.mockReturnValue(null)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-no-redis',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(408) // Timeout due to Redis unavailable
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Tool execution request timed out')
|
||||
})
|
||||
|
||||
it('should handle no_op tool with confirmation message', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return accepted status with message
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'accepted', message: 'Confirmation message' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'no_op',
|
||||
params: { existing: 'param' },
|
||||
toolCallId: 'tool-call-noop',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
|
||||
// Verify confirmation message was added to params
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalledWith('no_op', {
|
||||
existing: 'param',
|
||||
confirmationMessage: 'Confirmation message',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle Redis errors in interrupt flow', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to throw an error
|
||||
mockRedisGet.mockRejectedValue(new Error('Redis connection failed'))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-redis-error',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(408) // Timeout due to Redis error
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Tool execution request timed out')
|
||||
})
|
||||
|
||||
it('should handle tool execution failure', async () => {
|
||||
mockToolRegistryExecute.mockResolvedValue({
|
||||
success: false,
|
||||
error: 'Tool execution failed',
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'failing-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200) // Still returns 200, but with success: false
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: false,
|
||||
error: 'Tool execution failed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle JSON parsing errors in request body', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: '{invalid-json',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toContain('JSON')
|
||||
})
|
||||
|
||||
it('should handle tool registry execution throwing an error', async () => {
|
||||
mockToolRegistryExecute.mockRejectedValue(new Error('Registry execution failed'))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'error-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Registry execution failed')
|
||||
})
|
||||
|
||||
it('should handle old format Redis status (string instead of JSON)', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return old format (direct status string)
|
||||
mockRedisGet.mockResolvedValue('accepted')
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-old-format',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,16 +1,14 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { copilotToolRegistry } from '@/lib/copilot/tools/server-tools/registry'
|
||||
import type { NotificationStatus } from '@/lib/copilot/types'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getRedisClient } from '@/lib/redis'
|
||||
import { copilotToolRegistry } from '../tools/registry'
|
||||
import { createErrorResponse } from './utils'
|
||||
import { createErrorResponse } from '@/app/api/copilot/methods/utils'
|
||||
|
||||
const logger = createLogger('CopilotMethodsAPI')
|
||||
|
||||
// Tool call status types - should match NotificationStatus from frontend
|
||||
type ToolCallStatus = 'pending' | 'accepted' | 'rejected' | 'error' | 'success' | 'background'
|
||||
|
||||
/**
|
||||
* Add a tool call to Redis with 'pending' status
|
||||
*/
|
||||
@@ -28,7 +26,7 @@ async function addToolToRedis(toolCallId: string): Promise<void> {
|
||||
|
||||
try {
|
||||
const key = `tool_call:${toolCallId}`
|
||||
const status: ToolCallStatus = 'pending'
|
||||
const status: NotificationStatus = 'pending'
|
||||
|
||||
// Store as JSON object for consistency with confirm API
|
||||
const toolCallData = {
|
||||
@@ -59,7 +57,7 @@ async function addToolToRedis(toolCallId: string): Promise<void> {
|
||||
*/
|
||||
async function pollRedisForTool(
|
||||
toolCallId: string
|
||||
): Promise<{ status: ToolCallStatus; message?: string } | null> {
|
||||
): Promise<{ status: NotificationStatus; message?: string } | null> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) {
|
||||
logger.warn('pollRedisForTool: Redis client not available')
|
||||
@@ -86,17 +84,17 @@ async function pollRedisForTool(
|
||||
continue
|
||||
}
|
||||
|
||||
let status: ToolCallStatus | null = null
|
||||
let status: NotificationStatus | null = null
|
||||
let message: string | undefined
|
||||
|
||||
// Try to parse as JSON (new format), fallback to string (old format)
|
||||
try {
|
||||
const parsedData = JSON.parse(redisValue)
|
||||
status = parsedData.status as ToolCallStatus
|
||||
status = parsedData.status as NotificationStatus
|
||||
message = parsedData.message || undefined
|
||||
} catch {
|
||||
// Fallback to old format (direct status string)
|
||||
status = redisValue as ToolCallStatus
|
||||
status = redisValue as NotificationStatus
|
||||
}
|
||||
|
||||
if (status !== 'pending') {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { CopilotToolResponse } from '@/lib/copilot/tools/server-tools/base'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { CopilotToolResponse } from '../tools/base'
|
||||
|
||||
const logger = createLogger('CopilotMethodsUtils')
|
||||
|
||||
@@ -12,13 +12,3 @@ export function createErrorResponse(error: string): CopilotToolResponse {
|
||||
error,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a standardized success response
|
||||
*/
|
||||
export function createSuccessResponse(data: any): CopilotToolResponse {
|
||||
return {
|
||||
success: true,
|
||||
data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,17 +6,16 @@ import ReactMarkdown from 'react-markdown'
|
||||
import remarkGfm from 'remark-gfm'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'
|
||||
import { COPILOT_TOOL_DISPLAY_NAMES } from '@/stores/constants'
|
||||
import { InlineToolCall } from '@/lib/copilot/tools/inline-tool-call'
|
||||
import { useCopilotStore } from '@/stores/copilot/store'
|
||||
import type { CopilotMessage } from '@/stores/copilot/types'
|
||||
import { InlineToolCall } from '../../lib/tools/inline-tool-call'
|
||||
import type { CopilotMessage as CopilotMessageType } from '@/stores/copilot/types'
|
||||
|
||||
interface ProfessionalMessageProps {
|
||||
message: CopilotMessage
|
||||
interface CopilotMessageProps {
|
||||
message: CopilotMessageType
|
||||
isStreaming?: boolean
|
||||
}
|
||||
|
||||
// Link component with preview (from CopilotMarkdownRenderer)
|
||||
// Link component with preview
|
||||
function LinkWithPreview({ href, children }: { href: string; children: React.ReactNode }) {
|
||||
return (
|
||||
<Tooltip delayDuration={300}>
|
||||
@@ -201,12 +200,7 @@ const WordWrap = ({ text }: { text: string }) => {
|
||||
)
|
||||
}
|
||||
|
||||
// Helper function to get tool display name based on state
|
||||
function getToolDisplayName(toolName: string): string {
|
||||
return COPILOT_TOOL_DISPLAY_NAMES[toolName] || toolName
|
||||
}
|
||||
|
||||
const ProfessionalMessage: FC<ProfessionalMessageProps> = memo(
|
||||
const CopilotMessage: FC<CopilotMessageProps> = memo(
|
||||
({ message, isStreaming }) => {
|
||||
const isUser = message.role === 'user'
|
||||
const isAssistant = message.role === 'assistant'
|
||||
@@ -293,10 +287,6 @@ const ProfessionalMessage: FC<ProfessionalMessageProps> = memo(
|
||||
}
|
||||
}, [showDownvoteSuccess])
|
||||
|
||||
const formatTimestamp = (timestamp: string) => {
|
||||
return new Date(timestamp).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })
|
||||
}
|
||||
|
||||
// Get clean text content with double newline parsing
|
||||
const cleanTextContent = useMemo(() => {
|
||||
if (!message.content) return ''
|
||||
@@ -357,7 +347,6 @@ const ProfessionalMessage: FC<ProfessionalMessageProps> = memo(
|
||||
li: ({
|
||||
children,
|
||||
ordered,
|
||||
...props
|
||||
}: React.LiHTMLAttributes<HTMLLIElement> & { ordered?: boolean }) => (
|
||||
<li
|
||||
className='font-geist-sans text-gray-800 dark:text-gray-200'
|
||||
@@ -369,7 +358,6 @@ const ProfessionalMessage: FC<ProfessionalMessageProps> = memo(
|
||||
|
||||
// Code blocks
|
||||
pre: ({ children }: React.HTMLAttributes<HTMLPreElement>) => {
|
||||
let codeProps: React.HTMLAttributes<HTMLElement> = {}
|
||||
let codeContent: React.ReactNode = children
|
||||
let language = 'code'
|
||||
|
||||
@@ -381,7 +369,6 @@ const ProfessionalMessage: FC<ProfessionalMessageProps> = memo(
|
||||
className?: string
|
||||
children?: React.ReactNode
|
||||
}>
|
||||
codeProps = { className: childElement.props.className }
|
||||
codeContent = childElement.props.children
|
||||
language = childElement.props.className?.replace('language-', '') || 'code'
|
||||
}
|
||||
@@ -789,6 +776,6 @@ const ProfessionalMessage: FC<ProfessionalMessageProps> = memo(
|
||||
}
|
||||
)
|
||||
|
||||
ProfessionalMessage.displayName = 'ProfessionalMessage'
|
||||
CopilotMessage.displayName = 'CopilotMessage'
|
||||
|
||||
export { ProfessionalMessage }
|
||||
export { CopilotMessage }
|
||||
@@ -0,0 +1,4 @@
|
||||
export * from './checkpoint-panel/checkpoint-panel'
|
||||
export * from './copilot-message/copilot-message'
|
||||
export * from './user-input/user-input'
|
||||
export * from './welcome/welcome'
|
||||
@@ -1,234 +0,0 @@
|
||||
import React, { type HTMLAttributes, type ReactNode } from 'react'
|
||||
import { Copy } from 'lucide-react'
|
||||
import ReactMarkdown from 'react-markdown'
|
||||
import remarkGfm from 'remark-gfm'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'
|
||||
|
||||
export function LinkWithPreview({ href, children }: { href: string; children: React.ReactNode }) {
|
||||
return (
|
||||
<Tooltip delayDuration={300}>
|
||||
<TooltipTrigger asChild>
|
||||
<a
|
||||
href={href}
|
||||
className='text-blue-600 hover:underline dark:text-blue-400'
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
>
|
||||
{children}
|
||||
</a>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side='top' align='center' sideOffset={5} className='max-w-sm p-3'>
|
||||
<span className='truncate font-medium text-xs'>{href}</span>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
export default function CopilotMarkdownRenderer({
|
||||
content,
|
||||
customLinkComponent,
|
||||
}: {
|
||||
content: string
|
||||
customLinkComponent?: typeof LinkWithPreview
|
||||
}) {
|
||||
const LinkComponent = customLinkComponent || LinkWithPreview
|
||||
|
||||
const customComponents = {
|
||||
// Paragraph
|
||||
p: ({ children }: React.HTMLAttributes<HTMLParagraphElement>) => (
|
||||
<p className='mb-1 font-geist-sans text-base text-gray-800 leading-relaxed last:mb-0 dark:text-gray-200'>
|
||||
{children}
|
||||
</p>
|
||||
),
|
||||
|
||||
// Headings
|
||||
h1: ({ children }: React.HTMLAttributes<HTMLHeadingElement>) => (
|
||||
<h1 className='mt-10 mb-5 font-geist-sans font-semibold text-2xl text-gray-900 dark:text-gray-100'>
|
||||
{children}
|
||||
</h1>
|
||||
),
|
||||
h2: ({ children }: React.HTMLAttributes<HTMLHeadingElement>) => (
|
||||
<h2 className='mt-8 mb-4 font-geist-sans font-semibold text-gray-900 text-xl dark:text-gray-100'>
|
||||
{children}
|
||||
</h2>
|
||||
),
|
||||
h3: ({ children }: React.HTMLAttributes<HTMLHeadingElement>) => (
|
||||
<h3 className='mt-7 mb-3 font-geist-sans font-semibold text-gray-900 text-lg dark:text-gray-100'>
|
||||
{children}
|
||||
</h3>
|
||||
),
|
||||
h4: ({ children }: React.HTMLAttributes<HTMLHeadingElement>) => (
|
||||
<h4 className='mt-5 mb-2 font-geist-sans font-semibold text-base text-gray-900 dark:text-gray-100'>
|
||||
{children}
|
||||
</h4>
|
||||
),
|
||||
|
||||
// Lists
|
||||
ul: ({ children }: React.HTMLAttributes<HTMLUListElement>) => (
|
||||
<ul
|
||||
className='mt-1 mb-1 space-y-1 pl-6 font-geist-sans text-gray-800 dark:text-gray-200'
|
||||
style={{ listStyleType: 'disc' }}
|
||||
>
|
||||
{children}
|
||||
</ul>
|
||||
),
|
||||
ol: ({ children }: React.HTMLAttributes<HTMLOListElement>) => (
|
||||
<ol
|
||||
className='mt-1 mb-1 space-y-1 pl-6 font-geist-sans text-gray-800 dark:text-gray-200'
|
||||
style={{ listStyleType: 'decimal' }}
|
||||
>
|
||||
{children}
|
||||
</ol>
|
||||
),
|
||||
li: ({
|
||||
children,
|
||||
ordered,
|
||||
...props
|
||||
}: React.LiHTMLAttributes<HTMLLIElement> & { ordered?: boolean }) => (
|
||||
<li
|
||||
className='font-geist-sans text-gray-800 dark:text-gray-200'
|
||||
style={{ display: 'list-item' }}
|
||||
>
|
||||
{children}
|
||||
</li>
|
||||
),
|
||||
|
||||
// Code blocks
|
||||
pre: ({ children }: HTMLAttributes<HTMLPreElement>) => {
|
||||
let codeProps: HTMLAttributes<HTMLElement> = {}
|
||||
let codeContent: ReactNode = children
|
||||
let language = 'code'
|
||||
|
||||
if (
|
||||
React.isValidElement<{ className?: string; children?: ReactNode }>(children) &&
|
||||
children.type === 'code'
|
||||
) {
|
||||
const childElement = children as React.ReactElement<{
|
||||
className?: string
|
||||
children?: ReactNode
|
||||
}>
|
||||
codeProps = { className: childElement.props.className }
|
||||
codeContent = childElement.props.children
|
||||
language = childElement.props.className?.replace('language-', '') || 'code'
|
||||
}
|
||||
|
||||
return (
|
||||
<div className='my-6 rounded-md bg-gray-900 text-sm dark:bg-black'>
|
||||
<div className='flex items-center justify-between border-gray-700 border-b px-4 py-1.5 dark:border-gray-800'>
|
||||
<span className='font-geist-sans text-gray-400 text-xs'>{language}</span>
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='sm'
|
||||
className='h-4 w-4 p-0 opacity-70 hover:opacity-100'
|
||||
onClick={() => {
|
||||
if (typeof codeContent === 'string') {
|
||||
navigator.clipboard.writeText(codeContent)
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Copy className='h-3 w-3 text-gray-400' />
|
||||
</Button>
|
||||
</div>
|
||||
<pre className='overflow-x-auto p-4 font-mono text-gray-200 dark:text-gray-100'>
|
||||
{codeContent}
|
||||
</pre>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
|
||||
// Inline code
|
||||
code: ({
|
||||
inline,
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.HTMLAttributes<HTMLElement> & { className?: string; inline?: boolean }) => {
|
||||
if (inline) {
|
||||
return (
|
||||
<code
|
||||
className='rounded bg-gray-200 px-1 py-0.5 font-mono text-[0.9em] text-gray-800 dark:bg-gray-700 dark:text-gray-200'
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</code>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<code className={className} {...props}>
|
||||
{children}
|
||||
</code>
|
||||
)
|
||||
},
|
||||
|
||||
// Blockquotes
|
||||
blockquote: ({ children }: React.HTMLAttributes<HTMLQuoteElement>) => (
|
||||
<blockquote className='my-4 border-gray-300 border-l-4 py-1 pl-4 font-geist-sans text-gray-700 italic dark:border-gray-600 dark:text-gray-300'>
|
||||
{children}
|
||||
</blockquote>
|
||||
),
|
||||
|
||||
// Horizontal rule
|
||||
hr: () => <hr className='my-8 border-gray-500/[.07] border-t dark:border-gray-400/[.07]' />,
|
||||
|
||||
// Links
|
||||
a: ({ href, children, ...props }: React.AnchorHTMLAttributes<HTMLAnchorElement>) => (
|
||||
<LinkComponent href={href || '#'} {...props}>
|
||||
{children}
|
||||
</LinkComponent>
|
||||
),
|
||||
|
||||
// Tables
|
||||
table: ({ children }: React.TableHTMLAttributes<HTMLTableElement>) => (
|
||||
<div className='my-4 w-full overflow-x-auto'>
|
||||
<table className='min-w-full table-auto border border-gray-300 font-geist-sans text-sm dark:border-gray-700'>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
),
|
||||
thead: ({ children }: React.HTMLAttributes<HTMLTableSectionElement>) => (
|
||||
<thead className='bg-gray-100 text-left dark:bg-gray-800'>{children}</thead>
|
||||
),
|
||||
tbody: ({ children }: React.HTMLAttributes<HTMLTableSectionElement>) => (
|
||||
<tbody className='divide-y divide-gray-200 bg-white dark:divide-gray-700 dark:bg-gray-900'>
|
||||
{children}
|
||||
</tbody>
|
||||
),
|
||||
tr: ({ children }: React.HTMLAttributes<HTMLTableRowElement>) => (
|
||||
<tr className='border-gray-200 border-b transition-colors hover:bg-gray-50 dark:border-gray-700 dark:hover:bg-gray-800/60'>
|
||||
{children}
|
||||
</tr>
|
||||
),
|
||||
th: ({ children }: React.ThHTMLAttributes<HTMLTableCellElement>) => (
|
||||
<th className='border-gray-300 border-r px-4 py-2 font-medium text-gray-700 last:border-r-0 dark:border-gray-700 dark:text-gray-300'>
|
||||
{children}
|
||||
</th>
|
||||
),
|
||||
td: ({ children }: React.TdHTMLAttributes<HTMLTableCellElement>) => (
|
||||
<td className='break-words border-gray-300 border-r px-4 py-2 text-gray-800 last:border-r-0 dark:border-gray-700 dark:text-gray-200'>
|
||||
{children}
|
||||
</td>
|
||||
),
|
||||
|
||||
// Images
|
||||
img: ({ src, alt, ...props }: React.ImgHTMLAttributes<HTMLImageElement>) => (
|
||||
<img
|
||||
src={src}
|
||||
alt={alt || 'Image'}
|
||||
className='my-3 h-auto max-w-full rounded-md'
|
||||
{...props}
|
||||
/>
|
||||
),
|
||||
}
|
||||
|
||||
// Pre-process content to fix common issues
|
||||
const processedContent = content.trim()
|
||||
|
||||
return (
|
||||
<div className='space-y-4 break-words font-geist-sans text-[#0D0D0D] text-base leading-relaxed dark:text-gray-100'>
|
||||
<ReactMarkdown remarkPlugins={[remarkGfm]} components={customComponents}>
|
||||
{processedContent}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import { Button } from '@/components/ui/button'
|
||||
import { Textarea } from '@/components/ui/textarea'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
interface ProfessionalInputProps {
|
||||
interface UserInputProps {
|
||||
onSubmit: (message: string) => void
|
||||
onAbort?: () => void
|
||||
disabled?: boolean
|
||||
@@ -20,7 +20,7 @@ interface ProfessionalInputProps {
|
||||
onChange?: (value: string) => void // Callback when value changes
|
||||
}
|
||||
|
||||
const ProfessionalInput: FC<ProfessionalInputProps> = ({
|
||||
const UserInput: FC<UserInputProps> = ({
|
||||
onSubmit,
|
||||
onAbort,
|
||||
disabled = false,
|
||||
@@ -167,4 +167,4 @@ const ProfessionalInput: FC<ProfessionalInputProps> = ({
|
||||
)
|
||||
}
|
||||
|
||||
export { ProfessionalInput }
|
||||
export { UserInput }
|
||||
@@ -4,14 +4,16 @@ import { forwardRef, useCallback, useEffect, useImperativeHandle, useRef, useSta
|
||||
import { LoadingAgent } from '@/components/ui/loading-agent'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import {
|
||||
CheckpointPanel,
|
||||
CopilotMessage,
|
||||
CopilotWelcome,
|
||||
UserInput,
|
||||
} from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/copilot/components'
|
||||
import { COPILOT_TOOL_IDS } from '@/stores/copilot/constants'
|
||||
import { usePreviewStore } from '@/stores/copilot/preview-store'
|
||||
import { useCopilotStore } from '@/stores/copilot/store'
|
||||
import { useWorkflowRegistry } from '@/stores/workflows/registry/store'
|
||||
import { CheckpointPanel } from './components/checkpoint-panel'
|
||||
import { ProfessionalInput } from './components/professional-input/professional-input'
|
||||
import { ProfessionalMessage } from './components/professional-message/professional-message'
|
||||
import { CopilotWelcome } from './components/welcome/welcome'
|
||||
|
||||
const logger = createLogger('Copilot')
|
||||
|
||||
@@ -25,8 +27,7 @@ interface CopilotRef {
|
||||
|
||||
export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref) => {
|
||||
const scrollAreaRef = useRef<HTMLDivElement>(null)
|
||||
const [showCheckpoints, setShowCheckpoints] = useState(false)
|
||||
const scannedChatRef = useRef<string | null>(null)
|
||||
const [showCheckpoints] = useState(false)
|
||||
const [isInitialized, setIsInitialized] = useState(false)
|
||||
const lastWorkflowIdRef = useRef<string | null>(null)
|
||||
const hasMountedRef = useRef(false)
|
||||
@@ -34,12 +35,11 @@ export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref
|
||||
const { activeWorkflowId } = useWorkflowRegistry()
|
||||
|
||||
// Use preview store to track seen previews
|
||||
const { scanAndMarkExistingPreviews, isToolCallSeen, markToolCallAsSeen } = usePreviewStore()
|
||||
const { isToolCallSeen, markToolCallAsSeen } = usePreviewStore()
|
||||
|
||||
// Use the new copilot store
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
isLoadingChats,
|
||||
isSendingMessage,
|
||||
isAborting,
|
||||
@@ -48,7 +48,6 @@ export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref
|
||||
sendMessage,
|
||||
abortMessage,
|
||||
createNewChat,
|
||||
clearMessages,
|
||||
setMode,
|
||||
setInputValue,
|
||||
chatsLoadedForWorkflow,
|
||||
@@ -187,14 +186,6 @@ export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref
|
||||
[isSendingMessage, activeWorkflowId, sendMessage]
|
||||
)
|
||||
|
||||
// Handle modal message sending
|
||||
const handleModalSendMessage = useCallback(
|
||||
async (message: string) => {
|
||||
await handleSubmit(message)
|
||||
},
|
||||
[handleSubmit]
|
||||
)
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className='flex h-full flex-col overflow-hidden'>
|
||||
@@ -224,7 +215,7 @@ export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref
|
||||
</div>
|
||||
) : (
|
||||
messages.map((message) => (
|
||||
<ProfessionalMessage
|
||||
<CopilotMessage
|
||||
key={message.id}
|
||||
message={message}
|
||||
isStreaming={
|
||||
@@ -239,7 +230,7 @@ export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref
|
||||
|
||||
{/* Input area with integrated mode selector */}
|
||||
{!showCheckpoints && (
|
||||
<ProfessionalInput
|
||||
<UserInput
|
||||
onSubmit={handleSubmit}
|
||||
onAbort={abortMessage}
|
||||
disabled={!activeWorkflowId}
|
||||
|
||||
@@ -5,8 +5,8 @@ import { CheckCircle, ChevronDown, ChevronRight, Loader2, Settings, XCircle } fr
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from '@/components/ui/collapsible'
|
||||
import type { ToolCallGroup, ToolCallState } from '@/lib/copilot/types'
|
||||
import { cn } from '@/lib/utils'
|
||||
import type { ToolCallGroup, ToolCallState } from '@/types/tool-call'
|
||||
|
||||
interface ToolCallProps {
|
||||
toolCall: ToolCallState
|
||||
|
||||
110
apps/sim/lib/copilot/auth.ts
Normal file
110
apps/sim/lib/copilot/auth.ts
Normal file
@@ -0,0 +1,110 @@
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { db } from '@/db'
|
||||
import { apiKey as apiKeyTable } from '@/db/schema'
|
||||
|
||||
export type { NotificationStatus } from '@/lib/copilot/types'
|
||||
|
||||
/**
|
||||
* Authentication result for copilot API routes
|
||||
*/
|
||||
export interface CopilotAuthResult {
|
||||
userId: string | null
|
||||
isAuthenticated: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Standard error response helpers for copilot API routes
|
||||
*/
|
||||
export function createUnauthorizedResponse(): NextResponse {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
export function createBadRequestResponse(message: string): NextResponse {
|
||||
return NextResponse.json({ error: message }, { status: 400 })
|
||||
}
|
||||
|
||||
export function createNotFoundResponse(message: string): NextResponse {
|
||||
return NextResponse.json({ error: message }, { status: 404 })
|
||||
}
|
||||
|
||||
export function createInternalServerErrorResponse(message: string): NextResponse {
|
||||
return NextResponse.json({ error: message }, { status: 500 })
|
||||
}
|
||||
|
||||
/**
|
||||
* Request tracking helpers for copilot API routes
|
||||
*/
|
||||
export function createRequestId(): string {
|
||||
return crypto.randomUUID()
|
||||
}
|
||||
|
||||
export function createShortRequestId(): string {
|
||||
return crypto.randomUUID().slice(0, 8)
|
||||
}
|
||||
|
||||
export interface RequestTracker {
|
||||
requestId: string
|
||||
startTime: number
|
||||
getDuration(): number
|
||||
}
|
||||
|
||||
export function createRequestTracker(short = true): RequestTracker {
|
||||
const requestId = short ? createShortRequestId() : createRequestId()
|
||||
const startTime = Date.now()
|
||||
|
||||
return {
|
||||
requestId,
|
||||
startTime,
|
||||
getDuration(): number {
|
||||
return Date.now() - startTime
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Authenticate request using session or API key fallback
|
||||
* Returns userId if authenticated, null otherwise
|
||||
*/
|
||||
export async function authenticateCopilotRequest(req: NextRequest): Promise<CopilotAuthResult> {
|
||||
// Try session authentication first
|
||||
const session = await getSession()
|
||||
let userId: string | null = session?.user?.id || null
|
||||
|
||||
// If no session, check for API key auth
|
||||
if (!userId) {
|
||||
const apiKeyHeader = req.headers.get('x-api-key')
|
||||
if (apiKeyHeader) {
|
||||
// Verify API key
|
||||
const [apiKeyRecord] = await db
|
||||
.select({ userId: apiKeyTable.userId })
|
||||
.from(apiKeyTable)
|
||||
.where(eq(apiKeyTable.key, apiKeyHeader))
|
||||
.limit(1)
|
||||
|
||||
if (apiKeyRecord) {
|
||||
userId = apiKeyRecord.userId
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
userId,
|
||||
isAuthenticated: userId !== null,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Authenticate request using session only (no API key fallback)
|
||||
* Returns userId if authenticated, null otherwise
|
||||
*/
|
||||
export async function authenticateCopilotRequestSessionOnly(): Promise<CopilotAuthResult> {
|
||||
const session = await getSession()
|
||||
const userId = session?.user?.id || null
|
||||
|
||||
return {
|
||||
userId,
|
||||
isAuthenticated: userId !== null,
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import type {
|
||||
ToolExecutionOptions,
|
||||
ToolMetadata,
|
||||
ToolState,
|
||||
} from './types'
|
||||
} from '@/lib/copilot/tools/types'
|
||||
|
||||
export abstract class BaseTool implements Tool {
|
||||
// Static property for tool ID - must be overridden by each tool
|
||||
@@ -2,16 +2,16 @@
|
||||
* Run Workflow Tool
|
||||
*/
|
||||
|
||||
import { executeWorkflowWithFullLogging } from '@/app/workspace/[workspaceId]/w/[workflowId]/lib/workflow-execution-utils'
|
||||
import { useExecutionStore } from '@/stores/execution/store'
|
||||
import { useWorkflowRegistry } from '@/stores/workflows/registry/store'
|
||||
import { BaseTool } from '../base-tool'
|
||||
import { BaseTool } from '@/lib/copilot/tools/base-tool'
|
||||
import type {
|
||||
CopilotToolCall,
|
||||
ToolExecuteResult,
|
||||
ToolExecutionOptions,
|
||||
ToolMetadata,
|
||||
} from '../types'
|
||||
} from '@/lib/copilot/tools/types'
|
||||
import { executeWorkflowWithFullLogging } from '@/app/workspace/[workspaceId]/w/[workflowId]/lib/workflow-execution-utils'
|
||||
import { useExecutionStore } from '@/stores/execution/store'
|
||||
import { useWorkflowRegistry } from '@/stores/workflows/registry/store'
|
||||
|
||||
interface RunWorkflowParams {
|
||||
workflowId?: string
|
||||
@@ -8,11 +8,11 @@
|
||||
import { useState } from 'react'
|
||||
import { Loader2 } from 'lucide-react'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { notifyServerTool } from '@/lib/copilot/tools/notification-utils'
|
||||
import { toolRegistry } from '@/lib/copilot/tools/registry'
|
||||
import { renderToolStateIcon, toolRequiresInterrupt } from '@/lib/copilot/tools/utils'
|
||||
import { useCopilotStore } from '@/stores/copilot/store'
|
||||
import type { CopilotToolCall } from '@/stores/copilot/types'
|
||||
import { notifyServerTool } from './notification-utils'
|
||||
import { toolRegistry } from './registry'
|
||||
import { renderToolStateIcon, toolRequiresInterrupt } from './utils'
|
||||
|
||||
interface InlineToolCallProps {
|
||||
toolCall: CopilotToolCall
|
||||
@@ -3,8 +3,8 @@
|
||||
* Handles notifications and state messages for tools
|
||||
*/
|
||||
|
||||
import { toolRegistry } from './registry'
|
||||
import type { NotificationStatus, ToolState } from './types'
|
||||
import { toolRegistry } from '@/lib/copilot/tools/registry'
|
||||
import type { NotificationStatus, ToolState } from '@/lib/copilot/tools/types'
|
||||
|
||||
/**
|
||||
* Send a notification for a tool state change
|
||||
@@ -8,11 +8,9 @@
|
||||
* It also provides metadata for server-side tools for display purposes
|
||||
*/
|
||||
|
||||
// Import client tool implementations
|
||||
import { RunWorkflowTool } from './client-tools/run-workflow'
|
||||
// Import server tool definitions
|
||||
import { SERVER_TOOL_METADATA } from './server-tools/definitions'
|
||||
import type { Tool, ToolMetadata } from './types'
|
||||
import { RunWorkflowTool } from '@/lib/copilot/tools/client-tools/run-workflow'
|
||||
import { SERVER_TOOL_METADATA } from '@/lib/copilot/tools/server-tools/definitions'
|
||||
import type { Tool, ToolMetadata } from '@/lib/copilot/tools/types'
|
||||
|
||||
/**
|
||||
* Tool Registry class that manages all available tools
|
||||
@@ -3,7 +3,7 @@
|
||||
* These tools execute on the server and their results are displayed in the UI
|
||||
*/
|
||||
|
||||
import type { ToolMetadata } from '../types'
|
||||
import type { ToolMetadata } from '@/lib/copilot/tools/types'
|
||||
|
||||
// Tool IDs for server tools
|
||||
export const SERVER_TOOL_IDS = {
|
||||
@@ -8,10 +8,10 @@
|
||||
import { useState } from 'react'
|
||||
import { Loader2 } from 'lucide-react'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { notifyServerTool } from './notification-utils'
|
||||
import { toolRegistry } from './registry'
|
||||
import type { CopilotToolCall } from './types'
|
||||
import { executeToolWithStateManagement } from './utils'
|
||||
import { notifyServerTool } from '@/lib/copilot/tools/notification-utils'
|
||||
import { toolRegistry } from '@/lib/copilot/tools/registry'
|
||||
import type { CopilotToolCall } from '@/lib/copilot/tools/types'
|
||||
import { executeToolWithStateManagement } from '@/lib/copilot/tools/utils'
|
||||
|
||||
interface ToolConfirmationProps {
|
||||
toolCall: CopilotToolCall
|
||||
4
apps/sim/lib/copilot/tools/types.ts
Normal file
4
apps/sim/lib/copilot/tools/types.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
/**
|
||||
* Re-export all copilot types from the consolidated types file
|
||||
*/
|
||||
export * from '@/lib/copilot/types'
|
||||
@@ -34,8 +34,8 @@ import {
|
||||
XCircle,
|
||||
Zap,
|
||||
} from 'lucide-react'
|
||||
import { toolRegistry } from './registry'
|
||||
import type { CopilotToolCall, ToolState } from './types'
|
||||
import { toolRegistry } from '@/lib/copilot/tools/registry'
|
||||
import type { CopilotToolCall, ToolState } from '@/lib/copilot/tools/types'
|
||||
|
||||
/**
|
||||
* Map icon identifiers to Lucide icon components
|
||||
@@ -1,11 +1,72 @@
|
||||
/**
|
||||
* Copilot Tools Type Definitions
|
||||
* Clean architecture for client-side tool management
|
||||
* Copilot Types - Consolidated from various locations
|
||||
* This file contains all copilot-related type definitions
|
||||
*/
|
||||
|
||||
// Tool call state types (from apps/sim/types/tool-call.ts)
|
||||
export interface ToolCallState {
|
||||
id: string
|
||||
name: string
|
||||
displayName?: string
|
||||
parameters?: Record<string, any>
|
||||
state:
|
||||
| 'detecting'
|
||||
| 'pending'
|
||||
| 'executing'
|
||||
| 'completed'
|
||||
| 'error'
|
||||
| 'rejected'
|
||||
| 'applied'
|
||||
| 'ready_for_review'
|
||||
| 'aborted'
|
||||
| 'skipped'
|
||||
| 'background'
|
||||
startTime?: number
|
||||
endTime?: number
|
||||
duration?: number
|
||||
result?: any
|
||||
error?: string
|
||||
progress?: string
|
||||
}
|
||||
|
||||
export interface ToolCallGroup {
|
||||
id: string
|
||||
toolCalls: ToolCallState[]
|
||||
status: 'pending' | 'in_progress' | 'completed' | 'error'
|
||||
startTime?: number
|
||||
endTime?: number
|
||||
summary?: string
|
||||
}
|
||||
|
||||
export interface InlineContent {
|
||||
type: 'text' | 'tool_call'
|
||||
content: string
|
||||
toolCall?: ToolCallState
|
||||
}
|
||||
|
||||
export interface ParsedMessageContent {
|
||||
textContent: string
|
||||
toolCalls: ToolCallState[]
|
||||
toolGroups: ToolCallGroup[]
|
||||
inlineContent?: InlineContent[]
|
||||
}
|
||||
|
||||
export interface ToolCallIndicator {
|
||||
type: 'status' | 'thinking' | 'execution'
|
||||
content: string
|
||||
toolNames?: string[]
|
||||
}
|
||||
|
||||
// Copilot Tools Type Definitions (from workspace copilot lib)
|
||||
import type { CopilotToolCall, ToolState } from '@/stores/copilot/types'
|
||||
|
||||
export type NotificationStatus = 'success' | 'error' | 'accepted' | 'rejected' | 'background'
|
||||
export type NotificationStatus =
|
||||
| 'pending'
|
||||
| 'success'
|
||||
| 'error'
|
||||
| 'accepted'
|
||||
| 'rejected'
|
||||
| 'background'
|
||||
|
||||
// Export the consolidated types
|
||||
export type { CopilotToolCall, ToolState }
|
||||
@@ -3,8 +3,8 @@
|
||||
import { create } from 'zustand'
|
||||
import { devtools } from 'zustand/middleware'
|
||||
import { type CopilotChat, sendStreamingMessage } from '@/lib/copilot/api'
|
||||
import { toolRegistry } from '@/lib/copilot/tools'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { toolRegistry } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/copilot/lib/tools'
|
||||
import { COPILOT_TOOL_DISPLAY_NAMES } from '@/stores/constants'
|
||||
import { COPILOT_TOOL_IDS } from './constants'
|
||||
import type { CopilotMessage, CopilotStore, WorkflowCheckpoint } from './types'
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
export interface ToolCallState {
|
||||
id: string
|
||||
name: string
|
||||
displayName?: string
|
||||
parameters?: Record<string, any>
|
||||
state:
|
||||
| 'detecting'
|
||||
| 'pending'
|
||||
| 'executing'
|
||||
| 'completed'
|
||||
| 'error'
|
||||
| 'rejected'
|
||||
| 'applied'
|
||||
| 'ready_for_review'
|
||||
| 'aborted'
|
||||
| 'skipped'
|
||||
| 'background'
|
||||
startTime?: number
|
||||
endTime?: number
|
||||
duration?: number
|
||||
result?: any
|
||||
error?: string
|
||||
progress?: string
|
||||
}
|
||||
|
||||
export interface ToolCallGroup {
|
||||
id: string
|
||||
toolCalls: ToolCallState[]
|
||||
status: 'pending' | 'in_progress' | 'completed' | 'error'
|
||||
startTime?: number
|
||||
endTime?: number
|
||||
summary?: string
|
||||
}
|
||||
|
||||
export interface InlineContent {
|
||||
type: 'text' | 'tool_call'
|
||||
content: string
|
||||
toolCall?: ToolCallState
|
||||
}
|
||||
|
||||
export interface ParsedMessageContent {
|
||||
textContent: string
|
||||
toolCalls: ToolCallState[]
|
||||
toolGroups: ToolCallGroup[]
|
||||
inlineContent?: InlineContent[]
|
||||
}
|
||||
|
||||
export interface ToolCallIndicator {
|
||||
type: 'status' | 'thinking' | 'execution'
|
||||
content: string
|
||||
toolNames?: string[]
|
||||
}
|
||||
Reference in New Issue
Block a user