diff --git a/apps/sim/app/api/a2a/agents/[agentId]/route.ts b/apps/sim/app/api/a2a/agents/[agentId]/route.ts index 3d08ceba19..9f51b986d8 100644 --- a/apps/sim/app/api/a2a/agents/[agentId]/route.ts +++ b/apps/sim/app/api/a2a/agents/[agentId]/route.ts @@ -45,7 +45,6 @@ export async function GET(request: NextRequest, { params }: { params: Promise }) { const { agentId } = await params - // Try Redis cache first const redis = getRedisClient() const cacheKey = `a2a:agent:${agentId}:card` @@ -113,7 +112,6 @@ export async function GET(request: NextRequest, { params }: { params: Promise = { 'Content-Type': 'application/json' } if (apiKey) headers['X-API-Key'] = apiKey @@ -345,7 +333,6 @@ async function handleMessageSend( logger.info(`Executing workflow ${agent.workflowId} for A2A task ${taskId}`) try { - // Extract text content from the Message for workflow consumption const messageText = extractTextContent(message) const response = await fetch(executeUrl, { @@ -360,10 +347,8 @@ async function handleMessageSend( const executeResult = await response.json() - // Determine final state const finalState: TaskState = response.ok ? 'completed' : 'failed' - // Create agent response message const agentContent = executeResult.output?.content || (typeof executeResult.output === 'object' @@ -371,15 +356,12 @@ async function handleMessageSend( : String(executeResult.output || executeResult.error || 'Task completed')) const agentMessage = createAgentMessage(agentContent) - // Add taskId and contextId to the response message agentMessage.taskId = taskId if (contextId) agentMessage.contextId = contextId history.push(agentMessage) - // Extract artifacts if present const artifacts = executeResult.output?.artifacts || [] - // Update task with result await db .update(a2aTask) .set({ @@ -392,7 +374,6 @@ async function handleMessageSend( }) .where(eq(a2aTask.id, taskId)) - // Trigger push notification (fire and forget) if (isTerminalState(finalState)) { notifyTaskStateChange(taskId, finalState).catch((err) => { logger.error('Failed to trigger push notification', { taskId, error: err }) @@ -412,7 +393,6 @@ async function handleMessageSend( } catch (error) { logger.error(`Error executing workflow for task ${taskId}:`, error) - // Mark task as failed const errorMessage = error instanceof Error ? error.message : 'Workflow execution failed' await db @@ -424,7 +404,6 @@ async function handleMessageSend( }) .where(eq(a2aTask.id, taskId)) - // Trigger push notification for failure (fire and forget) notifyTaskStateChange(taskId, 'failed').catch((err) => { logger.error('Failed to trigger push notification for failure', { taskId, error: err }) }) @@ -463,7 +442,6 @@ async function handleMessageStream( const message = params.message const contextId = message.contextId || uuidv4() // Generate contextId if not provided - // Get existing task or prepare for new one let history: Message[] = [] let existingTask: typeof a2aTask.$inferSelect | null = null @@ -490,7 +468,6 @@ async function handleMessageStream( const taskId = message.taskId || generateTaskId() history.push(message) - // Create or update task record if (existingTask) { await db .update(a2aTask) @@ -513,7 +490,6 @@ async function handleMessageStream( }) } - // Create SSE stream const encoder = new TextEncoder() const stream = new ReadableStream({ @@ -526,7 +502,6 @@ async function handleMessageStream( } } - // Send initial status update (v0.3 format) sendEvent('status', { kind: 'status', taskId, @@ -535,7 +510,6 @@ async function handleMessageStream( }) try { - // Execute workflow with streaming const executeUrl = `${getBaseUrl()}/api/workflows/${agent.workflowId}/execute` const headers: Record = { 'Content-Type': 'application/json', @@ -543,7 +517,6 @@ async function handleMessageStream( } if (apiKey) headers['X-API-Key'] = apiKey - // Extract text content from the Message for workflow consumption const messageText = extractTextContent(message) const response = await fetch(executeUrl, { @@ -568,13 +541,11 @@ async function handleMessageStream( throw new Error(errorMessage) } - // Check content type to determine response handling const contentType = response.headers.get('content-type') || '' const isStreamingResponse = contentType.includes('text/event-stream') || contentType.includes('text/plain') if (response.body && isStreamingResponse) { - // Handle streaming response - forward chunks const reader = response.body.getReader() const decoder = new TextDecoder() let fullContent = '' @@ -586,7 +557,6 @@ async function handleMessageStream( const chunk = decoder.decode(value, { stream: true }) fullContent += chunk - // Forward chunk as message event (v0.3 format) sendEvent('message', { kind: 'message', taskId, @@ -597,13 +567,11 @@ async function handleMessageStream( }) } - // Create final agent message const agentMessage = createAgentMessage(fullContent || 'Task completed') agentMessage.taskId = taskId if (contextId) agentMessage.contextId = contextId history.push(agentMessage) - // Update task await db .update(a2aTask) .set({ @@ -614,7 +582,6 @@ async function handleMessageStream( }) .where(eq(a2aTask.id, taskId)) - // Trigger push notification (fire and forget) notifyTaskStateChange(taskId, 'completed').catch((err) => { logger.error('Failed to trigger push notification', { taskId, error: err }) }) @@ -627,7 +594,6 @@ async function handleMessageStream( final: true, }) } else { - // Handle JSON response (non-streaming workflow) const result = await response.json() const content = @@ -636,7 +602,6 @@ async function handleMessageStream( ? JSON.stringify(result.output) : String(result.output || 'Task completed')) - // Send the complete content as a final message sendEvent('message', { kind: 'message', taskId, @@ -653,7 +618,6 @@ async function handleMessageStream( const artifacts = (result.output?.artifacts as Artifact[]) || [] - // Update task with result await db .update(a2aTask) .set({ @@ -666,7 +630,6 @@ async function handleMessageStream( }) .where(eq(a2aTask.id, taskId)) - // Trigger push notification (fire and forget) notifyTaskStateChange(taskId, 'completed').catch((err) => { logger.error('Failed to trigger push notification', { taskId, error: err }) }) @@ -691,7 +654,6 @@ async function handleMessageStream( }) .where(eq(a2aTask.id, taskId)) - // Trigger push notification for failure (fire and forget) notifyTaskStateChange(taskId, 'failed').catch((err) => { logger.error('Failed to trigger push notification for failure', { taskId, error: err }) }) @@ -725,7 +687,6 @@ async function handleTaskGet(id: string | number, params: TaskIdParams): Promise ) } - // Validate historyLength if provided const historyLength = params.historyLength !== undefined && params.historyLength >= 0 ? params.historyLength @@ -742,7 +703,7 @@ async function handleTaskGet(id: string | number, params: TaskIdParams): Promise const taskResponse: Task = { kind: 'task', id: task.id, - contextId: task.sessionId || task.id, // Use task ID as fallback contextId + contextId: task.sessionId || task.id, status: createTaskStatus(task.status as TaskState), history: task.messages as Message[], artifacts: (task.artifacts as Artifact[]) || [], @@ -779,7 +740,6 @@ async function handleTaskCancel(id: string | number, params: TaskIdParams): Prom ) } - // Cancel running workflow execution if exists if (task.executionId) { try { await markExecutionCancelled(task.executionId) @@ -805,7 +765,6 @@ async function handleTaskCancel(id: string | number, params: TaskIdParams): Prom }) .where(eq(a2aTask.id, params.id)) - // Trigger push notification for cancellation (fire and forget) notifyTaskStateChange(params.id, 'canceled').catch((err) => { logger.error('Failed to trigger push notification for cancellation', { taskId: params.id, @@ -816,7 +775,7 @@ async function handleTaskCancel(id: string | number, params: TaskIdParams): Prom const canceledTask: Task = { kind: 'task', id: task.id, - contextId: task.sessionId || task.id, // Use task ID as fallback contextId + contextId: task.sessionId || task.id, status: createTaskStatus('canceled'), history: task.messages as Message[], artifacts: (task.artifacts as Artifact[]) || [], @@ -849,11 +808,10 @@ async function handleTaskResubscribe( } if (isTerminalState(task.status as TaskState)) { - // Task already completed - return final state as regular response const completedTask: Task = { kind: 'task', id: task.id, - contextId: task.sessionId || task.id, // Use task ID as fallback contextId + contextId: task.sessionId || task.id, status: createTaskStatus(task.status as TaskState), history: task.messages as Message[], artifacts: (task.artifacts as Artifact[]) || [], @@ -861,7 +819,6 @@ async function handleTaskResubscribe( return NextResponse.json(createResponse(id, completedTask)) } - // Create SSE stream for ongoing task updates const encoder = new TextEncoder() let isCancelled = false let pollTimeoutId: ReturnType | null = null @@ -888,7 +845,6 @@ async function handleTaskResubscribe( } } - // Send current status (v0.3 format) if ( !sendEvent('status', { kind: 'status', @@ -901,7 +857,6 @@ async function handleTaskResubscribe( return } - // Poll for updates until task completes const pollInterval = 3000 // 3 seconds (reduced from 1s to lower DB load) const maxPolls = 100 // 5 minutes max (100 * 3s = 300s) @@ -962,7 +917,6 @@ async function handleTaskResubscribe( } if (isTerminalState(updatedTask.status as TaskState)) { - // Send final message if available const messages = updatedTask.messages as Message[] const lastMessage = messages[messages.length - 1] if (lastMessage && lastMessage.role === 'agent') { @@ -1039,7 +993,6 @@ async function handlePushNotificationSet( ) } - // Validate URL is HTTPS (security requirement) try { const url = new URL(params.pushNotificationConfig.url) if (url.protocol !== 'https:') { @@ -1063,7 +1016,6 @@ async function handlePushNotificationSet( }) } - // Check if config already exists const [existingConfig] = await db .select() .from(a2aPushNotificationConfig) diff --git a/apps/sim/app/api/memory/[id]/route.ts b/apps/sim/app/api/memory/[id]/route.ts index 617979ef16..2f5b5ae1cc 100644 --- a/apps/sim/app/api/memory/[id]/route.ts +++ b/apps/sim/app/api/memory/[id]/route.ts @@ -1,11 +1,12 @@ import { db } from '@sim/db' -import { memory, permissions, workspace } from '@sim/db/schema' +import { memory } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { checkHybridAuth } from '@/lib/auth/hybrid' import { generateRequestId } from '@/lib/core/utils/request' +import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils' const logger = createLogger('MemoryByIdAPI') @@ -29,46 +30,6 @@ const memoryPutBodySchema = z.object({ workspaceId: z.string().uuid('Invalid workspace ID format'), }) -async function checkWorkspaceAccess( - workspaceId: string, - userId: string -): Promise<{ hasAccess: boolean; canWrite: boolean }> { - const [workspaceRow] = await db - .select({ ownerId: workspace.ownerId }) - .from(workspace) - .where(eq(workspace.id, workspaceId)) - .limit(1) - - if (!workspaceRow) { - return { hasAccess: false, canWrite: false } - } - - if (workspaceRow.ownerId === userId) { - return { hasAccess: true, canWrite: true } - } - - const [permissionRow] = await db - .select({ permissionType: permissions.permissionType }) - .from(permissions) - .where( - and( - eq(permissions.userId, userId), - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workspaceId) - ) - ) - .limit(1) - - if (!permissionRow) { - return { hasAccess: false, canWrite: false } - } - - return { - hasAccess: true, - canWrite: permissionRow.permissionType === 'write' || permissionRow.permissionType === 'admin', - } -} - async function validateMemoryAccess( request: NextRequest, workspaceId: string, @@ -86,8 +47,8 @@ async function validateMemoryAccess( } } - const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId) - if (!hasAccess) { + const access = await checkWorkspaceAccess(workspaceId, authResult.userId) + if (!access.exists || !access.hasAccess) { return { error: NextResponse.json( { success: false, error: { message: 'Workspace not found' } }, @@ -96,7 +57,7 @@ async function validateMemoryAccess( } } - if (action === 'write' && !canWrite) { + if (action === 'write' && !access.canWrite) { return { error: NextResponse.json( { success: false, error: { message: 'Write access denied' } }, diff --git a/apps/sim/app/api/memory/route.ts b/apps/sim/app/api/memory/route.ts index fe159b9664..072756c7a6 100644 --- a/apps/sim/app/api/memory/route.ts +++ b/apps/sim/app/api/memory/route.ts @@ -1,56 +1,17 @@ import { db } from '@sim/db' -import { memory, permissions, workspace } from '@sim/db/schema' +import { memory } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq, isNull, like } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { checkHybridAuth } from '@/lib/auth/hybrid' import { generateRequestId } from '@/lib/core/utils/request' +import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils' const logger = createLogger('MemoryAPI') export const dynamic = 'force-dynamic' export const runtime = 'nodejs' -async function checkWorkspaceAccess( - workspaceId: string, - userId: string -): Promise<{ hasAccess: boolean; canWrite: boolean }> { - const [workspaceRow] = await db - .select({ ownerId: workspace.ownerId }) - .from(workspace) - .where(eq(workspace.id, workspaceId)) - .limit(1) - - if (!workspaceRow) { - return { hasAccess: false, canWrite: false } - } - - if (workspaceRow.ownerId === userId) { - return { hasAccess: true, canWrite: true } - } - - const [permissionRow] = await db - .select({ permissionType: permissions.permissionType }) - .from(permissions) - .where( - and( - eq(permissions.userId, userId), - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workspaceId) - ) - ) - .limit(1) - - if (!permissionRow) { - return { hasAccess: false, canWrite: false } - } - - return { - hasAccess: true, - canWrite: permissionRow.permissionType === 'write' || permissionRow.permissionType === 'admin', - } -} - export async function GET(request: NextRequest) { const requestId = generateRequestId() @@ -76,8 +37,14 @@ export async function GET(request: NextRequest) { ) } - const { hasAccess } = await checkWorkspaceAccess(workspaceId, authResult.userId) - if (!hasAccess) { + const access = await checkWorkspaceAccess(workspaceId, authResult.userId) + if (!access.exists) { + return NextResponse.json( + { success: false, error: { message: 'Workspace not found' } }, + { status: 404 } + ) + } + if (!access.hasAccess) { return NextResponse.json( { success: false, error: { message: 'Access denied to this workspace' } }, { status: 403 } @@ -155,15 +122,21 @@ export async function POST(request: NextRequest) { ) } - const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId) - if (!hasAccess) { + const access = await checkWorkspaceAccess(workspaceId, authResult.userId) + if (!access.exists) { return NextResponse.json( { success: false, error: { message: 'Workspace not found' } }, { status: 404 } ) } + if (!access.hasAccess) { + return NextResponse.json( + { success: false, error: { message: 'Access denied to this workspace' } }, + { status: 403 } + ) + } - if (!canWrite) { + if (!access.canWrite) { return NextResponse.json( { success: false, error: { message: 'Write access denied to this workspace' } }, { status: 403 } @@ -282,15 +255,21 @@ export async function DELETE(request: NextRequest) { ) } - const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId) - if (!hasAccess) { + const access = await checkWorkspaceAccess(workspaceId, authResult.userId) + if (!access.exists) { return NextResponse.json( { success: false, error: { message: 'Workspace not found' } }, { status: 404 } ) } + if (!access.hasAccess) { + return NextResponse.json( + { success: false, error: { message: 'Access denied to this workspace' } }, + { status: 403 } + ) + } - if (!canWrite) { + if (!access.canWrite) { return NextResponse.json( { success: false, error: { message: 'Write access denied to this workspace' } }, { status: 403 } diff --git a/apps/sim/app/api/workflows/route.ts b/apps/sim/app/api/workflows/route.ts index 7c905ab7e6..81d4c885b9 100644 --- a/apps/sim/app/api/workflows/route.ts +++ b/apps/sim/app/api/workflows/route.ts @@ -1,12 +1,12 @@ import { db } from '@sim/db' -import { workflow, workspace } from '@sim/db/schema' +import { workflow } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' import { generateRequestId } from '@/lib/core/utils/request' -import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' +import { getUserEntityPermissions, workspaceExists } from '@/lib/workspaces/permissions/utils' import { verifyWorkspaceMembership } from '@/app/api/workflows/utils' const logger = createLogger('WorkflowAPI') @@ -36,13 +36,9 @@ export async function GET(request: Request) { const userId = session.user.id if (workspaceId) { - const workspaceExists = await db - .select({ id: workspace.id }) - .from(workspace) - .where(eq(workspace.id, workspaceId)) - .then((rows) => rows.length > 0) + const wsExists = await workspaceExists(workspaceId) - if (!workspaceExists) { + if (!wsExists) { logger.warn( `[${requestId}] Attempt to fetch workflows for non-existent workspace: ${workspaceId}` ) diff --git a/apps/sim/app/api/workspaces/[id]/api-keys/route.ts b/apps/sim/app/api/workspaces/[id]/api-keys/route.ts index 1232272366..c649972140 100644 --- a/apps/sim/app/api/workspaces/[id]/api-keys/route.ts +++ b/apps/sim/app/api/workspaces/[id]/api-keys/route.ts @@ -1,5 +1,5 @@ import { db } from '@sim/db' -import { apiKey, workspace } from '@sim/db/schema' +import { apiKey } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq, inArray } from 'drizzle-orm' import { nanoid } from 'nanoid' @@ -9,7 +9,7 @@ import { createApiKey, getApiKeyDisplayFormat } from '@/lib/api-key/auth' import { getSession } from '@/lib/auth' import { PlatformEvents } from '@/lib/core/telemetry' import { generateRequestId } from '@/lib/core/utils/request' -import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' +import { getUserEntityPermissions, getWorkspaceById } from '@/lib/workspaces/permissions/utils' const logger = createLogger('WorkspaceApiKeysAPI') @@ -34,8 +34,8 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{ const userId = session.user.id - const ws = await db.select().from(workspace).where(eq(workspace.id, workspaceId)).limit(1) - if (!ws.length) { + const ws = await getWorkspaceById(workspaceId) + if (!ws) { return NextResponse.json({ error: 'Workspace not found' }, { status: 404 }) } diff --git a/apps/sim/app/api/workspaces/[id]/byok-keys/route.ts b/apps/sim/app/api/workspaces/[id]/byok-keys/route.ts index 84be273d12..82045c0c50 100644 --- a/apps/sim/app/api/workspaces/[id]/byok-keys/route.ts +++ b/apps/sim/app/api/workspaces/[id]/byok-keys/route.ts @@ -1,5 +1,5 @@ import { db } from '@sim/db' -import { workspace, workspaceBYOKKeys } from '@sim/db/schema' +import { workspaceBYOKKeys } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq } from 'drizzle-orm' import { nanoid } from 'nanoid' @@ -10,7 +10,7 @@ import { isEnterpriseOrgAdminOrOwner } from '@/lib/billing/core/subscription' import { isHosted } from '@/lib/core/config/feature-flags' import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption' import { generateRequestId } from '@/lib/core/utils/request' -import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' +import { getUserEntityPermissions, getWorkspaceById } from '@/lib/workspaces/permissions/utils' const logger = createLogger('WorkspaceBYOKKeysAPI') @@ -48,8 +48,8 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{ const userId = session.user.id - const ws = await db.select().from(workspace).where(eq(workspace.id, workspaceId)).limit(1) - if (!ws.length) { + const ws = await getWorkspaceById(workspaceId) + if (!ws) { return NextResponse.json({ error: 'Workspace not found' }, { status: 404 }) } diff --git a/apps/sim/app/api/workspaces/[id]/environment/route.ts b/apps/sim/app/api/workspaces/[id]/environment/route.ts index 9c1ee4eb04..f11da0ecc9 100644 --- a/apps/sim/app/api/workspaces/[id]/environment/route.ts +++ b/apps/sim/app/api/workspaces/[id]/environment/route.ts @@ -1,5 +1,5 @@ import { db } from '@sim/db' -import { environment, workspace, workspaceEnvironment } from '@sim/db/schema' +import { environment, workspaceEnvironment } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' @@ -7,7 +7,7 @@ import { z } from 'zod' import { getSession } from '@/lib/auth' import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption' import { generateRequestId } from '@/lib/core/utils/request' -import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' +import { getUserEntityPermissions, getWorkspaceById } from '@/lib/workspaces/permissions/utils' const logger = createLogger('WorkspaceEnvironmentAPI') @@ -33,8 +33,8 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{ const userId = session.user.id // Validate workspace exists - const ws = await db.select().from(workspace).where(eq(workspace.id, workspaceId)).limit(1) - if (!ws.length) { + const ws = await getWorkspaceById(workspaceId) + if (!ws) { return NextResponse.json({ error: 'Workspace not found' }, { status: 404 }) } diff --git a/apps/sim/lib/workflows/utils.ts b/apps/sim/lib/workflows/utils.ts index f3bf49ece3..d17744af67 100644 --- a/apps/sim/lib/workflows/utils.ts +++ b/apps/sim/lib/workflows/utils.ts @@ -1,17 +1,14 @@ import { db } from '@sim/db' -import { permissions, userStats, workflow as workflowTable, workspace } from '@sim/db/schema' +import { permissions, userStats, workflow as workflowTable } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import type { InferSelectModel } from 'drizzle-orm' import { and, eq } from 'drizzle-orm' import { NextResponse } from 'next/server' import { getSession } from '@/lib/auth' -import type { PermissionType } from '@/lib/workspaces/permissions/utils' +import { getWorkspaceWithOwner, type PermissionType } from '@/lib/workspaces/permissions/utils' import type { ExecutionResult } from '@/executor/types' const logger = createLogger('WorkflowUtils') -type WorkflowSelection = InferSelectModel - export async function getWorkflowById(id: string) { const rows = await db.select().from(workflowTable).where(eq(workflowTable.id, id)).limit(1) @@ -44,11 +41,7 @@ export async function getWorkflowAccessContext( let workspacePermission: PermissionType | null = null if (workflow.workspaceId) { - const [workspaceRow] = await db - .select({ ownerId: workspace.ownerId }) - .from(workspace) - .where(eq(workspace.id, workflow.workspaceId)) - .limit(1) + const workspaceRow = await getWorkspaceWithOwner(workflow.workspaceId) workspaceOwnerId = workspaceRow?.ownerId ?? null @@ -147,7 +140,6 @@ export const workflowHasResponseBlock = (executionResult: ExecutionResult): bool return responseBlock !== undefined } -// Create a HTTP response from response block export const createHttpResponseFromBlock = (executionResult: ExecutionResult): NextResponse => { const { data = {}, status = 200, headers = {} } = executionResult.output diff --git a/apps/sim/lib/workspaces/permissions/utils.test.ts b/apps/sim/lib/workspaces/permissions/utils.test.ts index 4ec22ce4bf..938937d222 100644 --- a/apps/sim/lib/workspaces/permissions/utils.test.ts +++ b/apps/sim/lib/workspaces/permissions/utils.test.ts @@ -40,11 +40,15 @@ vi.mock('drizzle-orm', () => drizzleOrmMock) import { db } from '@sim/db' import { + checkWorkspaceAccess, getManageableWorkspaces, getUserEntityPermissions, getUsersWithPermissions, + getWorkspaceById, + getWorkspaceWithOwner, hasAdminPermission, hasWorkspaceAdminAccess, + workspaceExists, } from '@/lib/workspaces/permissions/utils' const mockDb = db as any @@ -610,4 +614,209 @@ describe('Permission Utils', () => { expect(result).toEqual([]) }) }) + + describe('getWorkspaceById', () => { + it.concurrent('should return workspace when it exists', async () => { + const chain = createMockChain([{ id: 'workspace123' }]) + mockDb.select.mockReturnValue(chain) + + const result = await getWorkspaceById('workspace123') + + expect(result).toEqual({ id: 'workspace123' }) + }) + + it.concurrent('should return null when workspace does not exist', async () => { + const chain = createMockChain([]) + mockDb.select.mockReturnValue(chain) + + const result = await getWorkspaceById('non-existent') + + expect(result).toBeNull() + }) + + it.concurrent('should handle empty workspace ID', async () => { + const chain = createMockChain([]) + mockDb.select.mockReturnValue(chain) + + const result = await getWorkspaceById('') + + expect(result).toBeNull() + }) + }) + + describe('getWorkspaceWithOwner', () => { + it.concurrent('should return workspace with owner when it exists', async () => { + const chain = createMockChain([{ id: 'workspace123', ownerId: 'owner456' }]) + mockDb.select.mockReturnValue(chain) + + const result = await getWorkspaceWithOwner('workspace123') + + expect(result).toEqual({ id: 'workspace123', ownerId: 'owner456' }) + }) + + it.concurrent('should return null when workspace does not exist', async () => { + const chain = createMockChain([]) + mockDb.select.mockReturnValue(chain) + + const result = await getWorkspaceWithOwner('non-existent') + + expect(result).toBeNull() + }) + + it.concurrent('should handle workspace with null owner ID', async () => { + const chain = createMockChain([{ id: 'workspace123', ownerId: null }]) + mockDb.select.mockReturnValue(chain) + + const result = await getWorkspaceWithOwner('workspace123') + + expect(result).toEqual({ id: 'workspace123', ownerId: null }) + }) + }) + + describe('workspaceExists', () => { + it.concurrent('should return true when workspace exists', async () => { + const chain = createMockChain([{ id: 'workspace123' }]) + mockDb.select.mockReturnValue(chain) + + const result = await workspaceExists('workspace123') + + expect(result).toBe(true) + }) + + it.concurrent('should return false when workspace does not exist', async () => { + const chain = createMockChain([]) + mockDb.select.mockReturnValue(chain) + + const result = await workspaceExists('non-existent') + + expect(result).toBe(false) + }) + + it.concurrent('should handle empty workspace ID', async () => { + const chain = createMockChain([]) + mockDb.select.mockReturnValue(chain) + + const result = await workspaceExists('') + + expect(result).toBe(false) + }) + }) + + describe('checkWorkspaceAccess', () => { + it('should return exists=false when workspace does not exist', async () => { + const chain = createMockChain([]) + mockDb.select.mockReturnValue(chain) + + const result = await checkWorkspaceAccess('non-existent', 'user123') + + expect(result).toEqual({ + exists: false, + hasAccess: false, + canWrite: false, + workspace: null, + }) + }) + + it('should return full access when user is workspace owner', async () => { + const chain = createMockChain([{ id: 'workspace123', ownerId: 'user123' }]) + mockDb.select.mockReturnValue(chain) + + const result = await checkWorkspaceAccess('workspace123', 'user123') + + expect(result).toEqual({ + exists: true, + hasAccess: true, + canWrite: true, + workspace: { id: 'workspace123', ownerId: 'user123' }, + }) + }) + + it('should return hasAccess=false when user has no permissions', async () => { + let callCount = 0 + mockDb.select.mockImplementation(() => { + callCount++ + if (callCount === 1) { + return createMockChain([{ id: 'workspace123', ownerId: 'other-user' }]) + } + return createMockChain([]) // No permissions + }) + + const result = await checkWorkspaceAccess('workspace123', 'user123') + + expect(result.exists).toBe(true) + expect(result.hasAccess).toBe(false) + expect(result.canWrite).toBe(false) + }) + + it('should return canWrite=true when user has admin permission', async () => { + let callCount = 0 + mockDb.select.mockImplementation(() => { + callCount++ + if (callCount === 1) { + return createMockChain([{ id: 'workspace123', ownerId: 'other-user' }]) + } + return createMockChain([{ permissionType: 'admin' }]) + }) + + const result = await checkWorkspaceAccess('workspace123', 'user123') + + expect(result.exists).toBe(true) + expect(result.hasAccess).toBe(true) + expect(result.canWrite).toBe(true) + }) + + it('should return canWrite=true when user has write permission', async () => { + let callCount = 0 + mockDb.select.mockImplementation(() => { + callCount++ + if (callCount === 1) { + return createMockChain([{ id: 'workspace123', ownerId: 'other-user' }]) + } + return createMockChain([{ permissionType: 'write' }]) + }) + + const result = await checkWorkspaceAccess('workspace123', 'user123') + + expect(result.exists).toBe(true) + expect(result.hasAccess).toBe(true) + expect(result.canWrite).toBe(true) + }) + + it('should return canWrite=false when user has read permission', async () => { + let callCount = 0 + mockDb.select.mockImplementation(() => { + callCount++ + if (callCount === 1) { + return createMockChain([{ id: 'workspace123', ownerId: 'other-user' }]) + } + return createMockChain([{ permissionType: 'read' }]) + }) + + const result = await checkWorkspaceAccess('workspace123', 'user123') + + expect(result.exists).toBe(true) + expect(result.hasAccess).toBe(true) + expect(result.canWrite).toBe(false) + }) + + it('should handle empty user ID', async () => { + const chain = createMockChain([]) + mockDb.select.mockReturnValue(chain) + + const result = await checkWorkspaceAccess('workspace123', '') + + expect(result.exists).toBe(false) + expect(result.hasAccess).toBe(false) + }) + + it('should handle empty workspace ID', async () => { + const chain = createMockChain([]) + mockDb.select.mockReturnValue(chain) + + const result = await checkWorkspaceAccess('', 'user123') + + expect(result.exists).toBe(false) + expect(result.hasAccess).toBe(false) + }) + }) }) diff --git a/apps/sim/lib/workspaces/permissions/utils.ts b/apps/sim/lib/workspaces/permissions/utils.ts index c34b7840b0..0a517791f4 100644 --- a/apps/sim/lib/workspaces/permissions/utils.ts +++ b/apps/sim/lib/workspaces/permissions/utils.ts @@ -3,6 +3,112 @@ import { permissions, type permissionTypeEnum, user, workspace } from '@sim/db/s import { and, eq } from 'drizzle-orm' export type PermissionType = (typeof permissionTypeEnum.enumValues)[number] +export interface WorkspaceBasic { + id: string +} + +export interface WorkspaceWithOwner { + id: string + ownerId: string +} + +export interface WorkspaceAccess { + exists: boolean + hasAccess: boolean + canWrite: boolean + workspace: WorkspaceWithOwner | null +} + +/** + * Get a workspace by ID (basic existence check) + * + * @param workspaceId - The workspace ID to look up + * @returns The workspace if found, null otherwise + */ +export async function getWorkspaceById(workspaceId: string): Promise { + const [ws] = await db + .select({ id: workspace.id }) + .from(workspace) + .where(eq(workspace.id, workspaceId)) + .limit(1) + + return ws || null +} + +/** + * Get a workspace with owner info by ID + * + * @param workspaceId - The workspace ID to look up + * @returns The workspace with owner info if found, null otherwise + */ +export async function getWorkspaceWithOwner( + workspaceId: string +): Promise { + const [ws] = await db + .select({ id: workspace.id, ownerId: workspace.ownerId }) + .from(workspace) + .where(eq(workspace.id, workspaceId)) + .limit(1) + + return ws || null +} + +/** + * Check if a workspace exists + * + * @param workspaceId - The workspace ID to check + * @returns True if the workspace exists, false otherwise + */ +export async function workspaceExists(workspaceId: string): Promise { + const ws = await getWorkspaceById(workspaceId) + return ws !== null +} + +/** + * Check workspace access for a user + * + * Verifies the workspace exists and the user has access to it. + * Returns access level (read/write) based on ownership and permissions. + * + * @param workspaceId - The workspace ID to check + * @param userId - The user ID to check access for + * @returns WorkspaceAccess object with exists, hasAccess, canWrite, and workspace data + */ +export async function checkWorkspaceAccess( + workspaceId: string, + userId: string +): Promise { + const ws = await getWorkspaceWithOwner(workspaceId) + + if (!ws) { + return { exists: false, hasAccess: false, canWrite: false, workspace: null } + } + + if (ws.ownerId === userId) { + return { exists: true, hasAccess: true, canWrite: true, workspace: ws } + } + + const [permissionRow] = await db + .select({ permissionType: permissions.permissionType }) + .from(permissions) + .where( + and( + eq(permissions.userId, userId), + eq(permissions.entityType, 'workspace'), + eq(permissions.entityId, workspaceId) + ) + ) + .limit(1) + + if (!permissionRow) { + return { exists: true, hasAccess: false, canWrite: false, workspace: ws } + } + + const canWrite = + permissionRow.permissionType === 'write' || permissionRow.permissionType === 'admin' + + return { exists: true, hasAccess: true, canWrite, workspace: ws } +} /** * Get the highest permission level a user has for a specific entity @@ -111,17 +217,13 @@ export async function hasWorkspaceAdminAccess( userId: string, workspaceId: string ): Promise { - const workspaceResult = await db - .select({ ownerId: workspace.ownerId }) - .from(workspace) - .where(eq(workspace.id, workspaceId)) - .limit(1) + const ws = await getWorkspaceWithOwner(workspaceId) - if (workspaceResult.length === 0) { + if (!ws) { return false } - if (workspaceResult[0].ownerId === userId) { + if (ws.ownerId === userId) { return true }