mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-15 01:47:59 -05:00
fix(oauth): fixed oauth server-side vs client-side checking, fixed x_search tool failing, remove unused route (#251)
* fix(oauth): fixed oauth server-side vs client-side checking, fixed x_search tool failing, remove unused route * added tests * removed unnecessary checks for nil request id * acknowledged PR comments, fixed X OAuth token refresh
This commit is contained in:
337
sim/app/api/auth/oauth/token/route.test.ts
Normal file
337
sim/app/api/auth/oauth/token/route.test.ts
Normal file
@@ -0,0 +1,337 @@
|
||||
/**
|
||||
* Tests for OAuth token API routes
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createMockRequest } from '@/app/api/__test-utils__/utils'
|
||||
|
||||
describe('OAuth Token API Routes', () => {
|
||||
const mockGetUserId = vi.fn()
|
||||
const mockGetCredential = vi.fn()
|
||||
const mockRefreshTokenIfNeeded = vi.fn()
|
||||
|
||||
const mockLogger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
}
|
||||
|
||||
const mockUUID = 'mock-uuid-12345678-90ab-cdef-1234-567890abcdef'
|
||||
const mockRequestId = mockUUID.slice(0, 8)
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue(mockUUID),
|
||||
})
|
||||
|
||||
vi.doMock('../utils', () => ({
|
||||
getUserId: mockGetUserId,
|
||||
getCredential: mockGetCredential,
|
||||
refreshTokenIfNeeded: mockRefreshTokenIfNeeded,
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/logs/console-logger', () => ({
|
||||
createLogger: vi.fn().mockReturnValue(mockLogger),
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
/**
|
||||
* POST route tests
|
||||
*/
|
||||
describe('POST handler', () => {
|
||||
it('should return access token successfully', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce('test-user-id')
|
||||
mockGetCredential.mockResolvedValueOnce({
|
||||
id: 'credential-id',
|
||||
accessToken: 'test-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000),
|
||||
providerId: 'google',
|
||||
})
|
||||
mockRefreshTokenIfNeeded.mockResolvedValueOnce({
|
||||
accessToken: 'fresh-token',
|
||||
refreshed: false,
|
||||
})
|
||||
|
||||
// Create mock request
|
||||
const req = createMockRequest('POST', {
|
||||
credentialId: 'credential-id',
|
||||
})
|
||||
|
||||
// Import handler after setting up mocks
|
||||
const { POST } = await import('./route')
|
||||
|
||||
// Call handler
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
// Verify request was handled correctly
|
||||
expect(response.status).toBe(200)
|
||||
expect(data).toHaveProperty('accessToken', 'fresh-token')
|
||||
|
||||
// Verify mocks were called correctly
|
||||
expect(mockGetUserId).toHaveBeenCalledWith(mockRequestId, undefined)
|
||||
expect(mockGetCredential).toHaveBeenCalledWith(mockRequestId, 'credential-id', 'test-user-id')
|
||||
expect(mockRefreshTokenIfNeeded).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle workflowId for server-side authentication', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce('workflow-owner-id')
|
||||
mockGetCredential.mockResolvedValueOnce({
|
||||
id: 'credential-id',
|
||||
accessToken: 'test-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000),
|
||||
providerId: 'google',
|
||||
})
|
||||
mockRefreshTokenIfNeeded.mockResolvedValueOnce({
|
||||
accessToken: 'fresh-token',
|
||||
refreshed: false,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
credentialId: 'credential-id',
|
||||
workflowId: 'workflow-id',
|
||||
})
|
||||
|
||||
const { POST } = await import('./route')
|
||||
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data).toHaveProperty('accessToken', 'fresh-token')
|
||||
|
||||
expect(mockGetUserId).toHaveBeenCalledWith(mockRequestId, 'workflow-id')
|
||||
expect(mockGetCredential).toHaveBeenCalledWith(
|
||||
mockRequestId,
|
||||
'credential-id',
|
||||
'workflow-owner-id'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle missing credentialId', async () => {
|
||||
const req = createMockRequest('POST', {})
|
||||
|
||||
const { POST } = await import('./route')
|
||||
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data).toHaveProperty('error', 'Credential ID is required')
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle authentication failure', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
credentialId: 'credential-id',
|
||||
})
|
||||
|
||||
const { POST } = await import('./route')
|
||||
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
expect(data).toHaveProperty('error', 'User not authenticated')
|
||||
})
|
||||
|
||||
it('should handle workflow not found', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
credentialId: 'credential-id',
|
||||
workflowId: 'nonexistent-workflow-id',
|
||||
})
|
||||
|
||||
const { POST } = await import('./route')
|
||||
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data).toHaveProperty('error', 'Workflow not found')
|
||||
})
|
||||
|
||||
it('should handle credential not found', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce('test-user-id')
|
||||
mockGetCredential.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
credentialId: 'nonexistent-credential-id',
|
||||
})
|
||||
|
||||
const { POST } = await import('./route')
|
||||
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data).toHaveProperty('error', 'Credential not found')
|
||||
})
|
||||
|
||||
it('should handle token refresh failure', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce('test-user-id')
|
||||
mockGetCredential.mockResolvedValueOnce({
|
||||
id: 'credential-id',
|
||||
accessToken: 'test-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // Expired
|
||||
providerId: 'google',
|
||||
})
|
||||
mockRefreshTokenIfNeeded.mockRejectedValueOnce(new Error('Refresh failure'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
credentialId: 'credential-id',
|
||||
})
|
||||
|
||||
const { POST } = await import('./route')
|
||||
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
expect(data).toHaveProperty('error', 'Failed to refresh access token')
|
||||
})
|
||||
})
|
||||
|
||||
/**
|
||||
* GET route tests
|
||||
*/
|
||||
describe('GET handler', () => {
|
||||
it('should return access token successfully', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce('test-user-id')
|
||||
mockGetCredential.mockResolvedValueOnce({
|
||||
id: 'credential-id',
|
||||
accessToken: 'test-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000),
|
||||
providerId: 'google',
|
||||
})
|
||||
mockRefreshTokenIfNeeded.mockResolvedValueOnce({
|
||||
accessToken: 'fresh-token',
|
||||
refreshed: false,
|
||||
})
|
||||
|
||||
const req = new Request(
|
||||
'http://localhost:3000/api/auth/oauth/token?credentialId=credential-id'
|
||||
)
|
||||
|
||||
const { GET } = await import('./route')
|
||||
|
||||
const response = await GET(req as any)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data).toHaveProperty('accessToken', 'fresh-token')
|
||||
|
||||
expect(mockGetUserId).toHaveBeenCalledWith(mockRequestId)
|
||||
expect(mockGetCredential).toHaveBeenCalledWith(mockRequestId, 'credential-id', 'test-user-id')
|
||||
expect(mockRefreshTokenIfNeeded).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle missing credentialId', async () => {
|
||||
const req = new Request('http://localhost:3000/api/auth/oauth/token')
|
||||
|
||||
const { GET } = await import('./route')
|
||||
|
||||
const response = await GET(req as any)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data).toHaveProperty('error', 'Credential ID is required')
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle authentication failure', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = new Request(
|
||||
'http://localhost:3000/api/auth/oauth/token?credentialId=credential-id'
|
||||
)
|
||||
|
||||
const { GET } = await import('./route')
|
||||
|
||||
const response = await GET(req as any)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
expect(data).toHaveProperty('error', 'User not authenticated')
|
||||
})
|
||||
|
||||
it('should handle credential not found', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce('test-user-id')
|
||||
mockGetCredential.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = new Request(
|
||||
'http://localhost:3000/api/auth/oauth/token?credentialId=nonexistent-credential-id'
|
||||
)
|
||||
|
||||
const { GET } = await import('./route')
|
||||
|
||||
const response = await GET(req as any)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data).toHaveProperty('error', 'Credential not found')
|
||||
})
|
||||
|
||||
it('should handle missing access token', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce('test-user-id')
|
||||
mockGetCredential.mockResolvedValueOnce({
|
||||
id: 'credential-id',
|
||||
accessToken: null,
|
||||
refreshToken: 'refresh-token',
|
||||
providerId: 'google',
|
||||
})
|
||||
|
||||
const req = new Request(
|
||||
'http://localhost:3000/api/auth/oauth/token?credentialId=credential-id'
|
||||
)
|
||||
|
||||
const { GET } = await import('./route')
|
||||
|
||||
const response = await GET(req as any)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data).toHaveProperty('error', 'No access token available')
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle token refresh failure', async () => {
|
||||
mockGetUserId.mockResolvedValueOnce('test-user-id')
|
||||
mockGetCredential.mockResolvedValueOnce({
|
||||
id: 'credential-id',
|
||||
accessToken: 'test-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // Expired
|
||||
providerId: 'google',
|
||||
})
|
||||
mockRefreshTokenIfNeeded.mockRejectedValueOnce(new Error('Refresh failure'))
|
||||
|
||||
const req = new Request(
|
||||
'http://localhost:3000/api/auth/oauth/token?credentialId=credential-id'
|
||||
)
|
||||
|
||||
const { GET } = await import('./route')
|
||||
|
||||
const response = await GET(req as any)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
expect(data).toHaveProperty('error', 'Failed to refresh access token')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,10 +1,6 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console-logger'
|
||||
import { refreshOAuthToken } from '@/lib/oauth'
|
||||
import { db } from '@/db'
|
||||
import { account, workflow } from '@/db/schema'
|
||||
import { getCredential, getUserId, refreshTokenIfNeeded } from '../utils'
|
||||
|
||||
const logger = createLogger('OAuthTokenAPI')
|
||||
|
||||
@@ -27,97 +23,29 @@ export async function POST(request: NextRequest) {
|
||||
}
|
||||
|
||||
// Determine the user ID based on the context
|
||||
let userId: string | undefined
|
||||
const userId = await getUserId(requestId, workflowId)
|
||||
|
||||
// If workflowId is provided, this is a server-side request
|
||||
if (workflowId) {
|
||||
// Get the workflow to verify the user ID
|
||||
const workflows = await db
|
||||
.select({ userId: workflow.userId })
|
||||
.from(workflow)
|
||||
.where(eq(workflow.id, workflowId))
|
||||
.limit(1)
|
||||
|
||||
if (!workflows.length) {
|
||||
logger.warn(`[${requestId}] Workflow not found`)
|
||||
return NextResponse.json({ error: 'Workflow not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
userId = workflows[0].userId
|
||||
} else {
|
||||
// This is a client-side request, use the session
|
||||
const session = await getSession()
|
||||
|
||||
// Check if the user is authenticated
|
||||
if (!session?.user?.id) {
|
||||
logger.warn(`[${requestId}] Unauthenticated token request rejected`)
|
||||
return NextResponse.json({ error: 'User not authenticated' }, { status: 401 })
|
||||
}
|
||||
|
||||
userId = session.user.id
|
||||
if (!userId) {
|
||||
return NextResponse.json(
|
||||
{ error: workflowId ? 'Workflow not found' : 'User not authenticated' },
|
||||
{ status: workflowId ? 404 : 401 }
|
||||
)
|
||||
}
|
||||
|
||||
// Get the credential from the database
|
||||
const credentials = await db
|
||||
.select()
|
||||
.from(account)
|
||||
.where(and(eq(account.id, credentialId), eq(account.userId, userId)))
|
||||
.limit(1)
|
||||
const credential = await getCredential(requestId, credentialId, userId)
|
||||
|
||||
if (!credentials.length) {
|
||||
logger.warn(`[${requestId}] Credential not found`)
|
||||
if (!credential) {
|
||||
return NextResponse.json({ error: 'Credential not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
const credential = credentials[0]
|
||||
|
||||
// Check if we need to refresh the token
|
||||
const expiresAt = credential.accessTokenExpiresAt
|
||||
const now = new Date()
|
||||
const needsRefresh = !expiresAt || expiresAt <= now
|
||||
|
||||
if (needsRefresh && credential.refreshToken) {
|
||||
try {
|
||||
const refreshResult = await refreshOAuthToken(
|
||||
credential.providerId,
|
||||
credential.refreshToken
|
||||
)
|
||||
|
||||
if (!refreshResult) {
|
||||
throw new Error('Failed to refresh token')
|
||||
}
|
||||
|
||||
const {
|
||||
accessToken: refreshedToken,
|
||||
expiresIn,
|
||||
refreshToken: newRefreshToken,
|
||||
} = refreshResult
|
||||
|
||||
// Prepare update data
|
||||
const updateData: any = {
|
||||
accessToken: refreshedToken,
|
||||
accessTokenExpiresAt: new Date(Date.now() + expiresIn * 1000), // Use provider's expiry
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
// If we received a new refresh token, update it
|
||||
if (newRefreshToken && newRefreshToken !== credential.refreshToken) {
|
||||
logger.info(`[${requestId}] Updating refresh token for credential: ${credentialId}`)
|
||||
updateData.refreshToken = newRefreshToken
|
||||
}
|
||||
|
||||
await db.update(account).set(updateData).where(eq(account.id, credentialId))
|
||||
|
||||
logger.info(`[${requestId}] Successfully refreshed access token`)
|
||||
return NextResponse.json({ accessToken: refreshedToken }, { status: 200 })
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error refreshing token`, error)
|
||||
return NextResponse.json({ error: 'Failed to refresh access token' }, { status: 500 })
|
||||
}
|
||||
try {
|
||||
// Refresh the token if needed
|
||||
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
|
||||
return NextResponse.json({ accessToken }, { status: 200 })
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to refresh access token' }, { status: 401 })
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Access token is valid`)
|
||||
return NextResponse.json({ accessToken: credential.accessToken }, { status: 200 })
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error getting access token`, error)
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
|
||||
@@ -131,15 +59,6 @@ export async function GET(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8) // Short request ID for correlation
|
||||
|
||||
try {
|
||||
// Get the session
|
||||
const session = await getSession()
|
||||
|
||||
// Check if the user is authenticated
|
||||
if (!session?.user?.id) {
|
||||
logger.warn(`[${requestId}] Unauthenticated request rejected`)
|
||||
return NextResponse.json({ error: 'User not authenticated' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Get the credential ID from the query params
|
||||
const { searchParams } = new URL(request.url)
|
||||
const credentialId = searchParams.get('credentialId')
|
||||
@@ -149,20 +68,18 @@ export async function GET(request: NextRequest) {
|
||||
return NextResponse.json({ error: 'Credential ID is required' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Get the credential from the database
|
||||
const credentials = await db.select().from(account).where(eq(account.id, credentialId)).limit(1)
|
||||
// For GET requests, we only support session-based authentication
|
||||
const userId = await getUserId(requestId)
|
||||
|
||||
if (!credentials.length) {
|
||||
logger.warn(`[${requestId}] Credential not found`, { credentialId })
|
||||
return NextResponse.json({ error: 'Credential not found' }, { status: 404 })
|
||||
if (!userId) {
|
||||
return NextResponse.json({ error: 'User not authenticated' }, { status: 401 })
|
||||
}
|
||||
|
||||
const credential = credentials[0]
|
||||
// Get the credential from the database
|
||||
const credential = await getCredential(requestId, credentialId, userId)
|
||||
|
||||
// Check if the credential belongs to the user
|
||||
if (credential.userId !== session.user.id) {
|
||||
logger.warn(`[${requestId}] Unauthorized credential access attempt`)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 403 })
|
||||
if (!credential) {
|
||||
return NextResponse.json({ error: 'Credential not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Check if the access token is valid
|
||||
@@ -171,57 +88,13 @@ export async function GET(request: NextRequest) {
|
||||
return NextResponse.json({ error: 'No access token available' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Check if the token is expired and refresh if needed
|
||||
const now = new Date()
|
||||
const tokenExpiry = credential.accessTokenExpiresAt
|
||||
let accessToken = credential.accessToken
|
||||
|
||||
if (tokenExpiry && tokenExpiry < now && credential.refreshToken) {
|
||||
logger.info(`[${requestId}] Access token expired, attempting to refresh`)
|
||||
|
||||
try {
|
||||
// Refresh the token using the centralized utility
|
||||
const refreshResult = await refreshOAuthToken(
|
||||
credential.providerId,
|
||||
credential.refreshToken
|
||||
)
|
||||
|
||||
if (!refreshResult) {
|
||||
throw new Error('Failed to refresh token')
|
||||
}
|
||||
|
||||
const {
|
||||
accessToken: refreshedToken,
|
||||
expiresIn,
|
||||
refreshToken: newRefreshToken,
|
||||
} = refreshResult
|
||||
logger.info(`[${requestId}] Token refreshed successfully`)
|
||||
|
||||
// Prepare update data
|
||||
const updateData: any = {
|
||||
accessToken: refreshedToken,
|
||||
accessTokenExpiresAt: new Date(Date.now() + expiresIn * 1000), // Use provider's expiry
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
// If we received a new refresh token, update it
|
||||
if (newRefreshToken && newRefreshToken !== credential.refreshToken) {
|
||||
logger.info(`[${requestId}] Updating refresh token for credential: ${credentialId}`)
|
||||
updateData.refreshToken = newRefreshToken
|
||||
}
|
||||
|
||||
// Update the token in the database with the correct expiration time
|
||||
await db.update(account).set(updateData).where(eq(account.id, credentialId))
|
||||
|
||||
accessToken = refreshedToken
|
||||
} catch (refreshError) {
|
||||
logger.error(`[${requestId}] Error refreshing token`, refreshError)
|
||||
return NextResponse.json({ error: 'Failed to refresh access token' }, { status: 401 })
|
||||
}
|
||||
try {
|
||||
// Refresh the token if needed
|
||||
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
|
||||
return NextResponse.json({ accessToken }, { status: 200 })
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to refresh access token' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Return the access token
|
||||
return NextResponse.json({ accessToken }, { status: 200 })
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error fetching access token`, error)
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
|
||||
|
||||
292
sim/app/api/auth/oauth/utils.test.ts
Normal file
292
sim/app/api/auth/oauth/utils.test.ts
Normal file
@@ -0,0 +1,292 @@
|
||||
/**
|
||||
* Tests for OAuth utility functions
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
describe('OAuth Utils', () => {
|
||||
const mockSession = { user: { id: 'test-user-id' } }
|
||||
const mockDb = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockReturnValue([]),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
}
|
||||
const mockRefreshOAuthToken = vi.fn()
|
||||
const mockLogger = {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
|
||||
vi.doMock('@/lib/auth', () => ({
|
||||
getSession: vi.fn().mockResolvedValue(mockSession),
|
||||
}))
|
||||
|
||||
vi.doMock('@/db', () => ({
|
||||
db: mockDb,
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/oauth', () => ({
|
||||
refreshOAuthToken: mockRefreshOAuthToken,
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/logs/console-logger', () => ({
|
||||
createLogger: vi.fn().mockReturnValue(mockLogger),
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getUserId', () => {
|
||||
it('should get user ID from session when no workflowId is provided', async () => {
|
||||
const { getUserId } = await import('./utils')
|
||||
|
||||
const userId = await getUserId('request-id')
|
||||
|
||||
expect(userId).toBe('test-user-id')
|
||||
})
|
||||
|
||||
it('should get user ID from workflow when workflowId is provided', async () => {
|
||||
mockDb.limit.mockReturnValueOnce([{ userId: 'workflow-owner-id' }])
|
||||
|
||||
const { getUserId } = await import('./utils')
|
||||
|
||||
const userId = await getUserId('request-id', 'workflow-id')
|
||||
|
||||
expect(mockDb.select).toHaveBeenCalled()
|
||||
expect(mockDb.from).toHaveBeenCalled()
|
||||
expect(mockDb.where).toHaveBeenCalled()
|
||||
expect(mockDb.limit).toHaveBeenCalledWith(1)
|
||||
expect(userId).toBe('workflow-owner-id')
|
||||
})
|
||||
|
||||
it('should return undefined if no session is found', async () => {
|
||||
vi.doMock('@/lib/auth', () => ({
|
||||
getSession: vi.fn().mockResolvedValue(null),
|
||||
}))
|
||||
|
||||
const { getUserId } = await import('./utils')
|
||||
|
||||
const userId = await getUserId('request-id')
|
||||
|
||||
expect(userId).toBeUndefined()
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return undefined if workflow is not found', async () => {
|
||||
mockDb.limit.mockReturnValueOnce([])
|
||||
|
||||
const { getUserId } = await import('./utils')
|
||||
|
||||
const userId = await getUserId('request-id', 'nonexistent-workflow-id')
|
||||
|
||||
expect(userId).toBeUndefined()
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getCredential', () => {
|
||||
it('should return credential when found', async () => {
|
||||
const mockCredential = { id: 'credential-id', userId: 'test-user-id' }
|
||||
mockDb.limit.mockReturnValueOnce([mockCredential])
|
||||
|
||||
const { getCredential } = await import('./utils')
|
||||
|
||||
const credential = await getCredential('request-id', 'credential-id', 'test-user-id')
|
||||
|
||||
expect(mockDb.select).toHaveBeenCalled()
|
||||
expect(mockDb.from).toHaveBeenCalled()
|
||||
expect(mockDb.where).toHaveBeenCalled()
|
||||
expect(mockDb.limit).toHaveBeenCalledWith(1)
|
||||
|
||||
expect(credential).toEqual(mockCredential)
|
||||
})
|
||||
|
||||
it('should return undefined when credential is not found', async () => {
|
||||
mockDb.limit.mockReturnValueOnce([])
|
||||
|
||||
const { getCredential } = await import('./utils')
|
||||
|
||||
const credential = await getCredential('request-id', 'nonexistent-id', 'test-user-id')
|
||||
|
||||
expect(credential).toBeUndefined()
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('refreshTokenIfNeeded', () => {
|
||||
it('should return valid token without refresh if not expired', async () => {
|
||||
const mockCredential = {
|
||||
id: 'credential-id',
|
||||
accessToken: 'valid-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000), // 1 hour in the future
|
||||
providerId: 'google',
|
||||
}
|
||||
|
||||
const { refreshTokenIfNeeded } = await import('./utils')
|
||||
|
||||
const result = await refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
|
||||
|
||||
expect(mockRefreshOAuthToken).not.toHaveBeenCalled()
|
||||
expect(result).toEqual({ accessToken: 'valid-token', refreshed: false })
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('Access token is valid'))
|
||||
})
|
||||
|
||||
it('should refresh token when expired', async () => {
|
||||
const mockCredential = {
|
||||
id: 'credential-id',
|
||||
accessToken: 'expired-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
|
||||
providerId: 'google',
|
||||
}
|
||||
|
||||
mockRefreshOAuthToken.mockResolvedValueOnce({
|
||||
accessToken: 'new-token',
|
||||
expiresIn: 3600,
|
||||
refreshToken: 'new-refresh-token',
|
||||
})
|
||||
|
||||
const { refreshTokenIfNeeded } = await import('./utils')
|
||||
|
||||
const result = await refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
|
||||
|
||||
expect(mockRefreshOAuthToken).toHaveBeenCalledWith('google', 'refresh-token')
|
||||
expect(mockDb.update).toHaveBeenCalled()
|
||||
expect(mockDb.set).toHaveBeenCalled()
|
||||
expect(result).toEqual({ accessToken: 'new-token', refreshed: true })
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Successfully refreshed')
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle refresh token error', async () => {
|
||||
const mockCredential = {
|
||||
id: 'credential-id',
|
||||
accessToken: 'expired-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
|
||||
providerId: 'google',
|
||||
}
|
||||
|
||||
mockRefreshOAuthToken.mockRejectedValueOnce(new Error('Refresh failed'))
|
||||
|
||||
const { refreshTokenIfNeeded } = await import('./utils')
|
||||
|
||||
await expect(
|
||||
refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
|
||||
).rejects.toThrow()
|
||||
|
||||
expect(mockLogger.error).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not attempt refresh if no refresh token', async () => {
|
||||
const mockCredential = {
|
||||
id: 'credential-id',
|
||||
accessToken: 'token',
|
||||
refreshToken: null,
|
||||
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
|
||||
providerId: 'google',
|
||||
}
|
||||
|
||||
const { refreshTokenIfNeeded } = await import('./utils')
|
||||
|
||||
const result = await refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
|
||||
|
||||
expect(mockRefreshOAuthToken).not.toHaveBeenCalled()
|
||||
expect(result).toEqual({ accessToken: 'token', refreshed: false })
|
||||
})
|
||||
})
|
||||
|
||||
describe('refreshAccessTokenIfNeeded', () => {
|
||||
it('should return valid access token without refresh if not expired', async () => {
|
||||
const mockCredential = {
|
||||
id: 'credential-id',
|
||||
accessToken: 'valid-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000), // 1 hour in the future
|
||||
providerId: 'google',
|
||||
userId: 'test-user-id',
|
||||
}
|
||||
mockDb.limit.mockReturnValueOnce([mockCredential])
|
||||
|
||||
const { refreshAccessTokenIfNeeded } = await import('./utils')
|
||||
|
||||
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')
|
||||
|
||||
expect(mockRefreshOAuthToken).not.toHaveBeenCalled()
|
||||
expect(token).toBe('valid-token')
|
||||
})
|
||||
|
||||
it('should refresh token when expired', async () => {
|
||||
const mockCredential = {
|
||||
id: 'credential-id',
|
||||
accessToken: 'expired-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
|
||||
providerId: 'google',
|
||||
userId: 'test-user-id',
|
||||
}
|
||||
mockDb.limit.mockReturnValueOnce([mockCredential])
|
||||
|
||||
mockRefreshOAuthToken.mockResolvedValueOnce({
|
||||
accessToken: 'new-token',
|
||||
expiresIn: 3600,
|
||||
refreshToken: 'new-refresh-token',
|
||||
})
|
||||
|
||||
const { refreshAccessTokenIfNeeded } = await import('./utils')
|
||||
|
||||
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')
|
||||
|
||||
expect(mockRefreshOAuthToken).toHaveBeenCalledWith('google', 'refresh-token')
|
||||
expect(mockDb.update).toHaveBeenCalled()
|
||||
expect(mockDb.set).toHaveBeenCalled()
|
||||
expect(token).toBe('new-token')
|
||||
})
|
||||
|
||||
it('should return null if credential not found', async () => {
|
||||
mockDb.limit.mockReturnValueOnce([])
|
||||
|
||||
const { refreshAccessTokenIfNeeded } = await import('./utils')
|
||||
|
||||
const token = await refreshAccessTokenIfNeeded('nonexistent-id', 'test-user-id', 'request-id')
|
||||
|
||||
expect(token).toBeNull()
|
||||
expect(mockLogger.warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return null if refresh fails', async () => {
|
||||
const mockCredential = {
|
||||
id: 'credential-id',
|
||||
accessToken: 'expired-token',
|
||||
refreshToken: 'refresh-token',
|
||||
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
|
||||
providerId: 'google',
|
||||
userId: 'test-user-id',
|
||||
}
|
||||
mockDb.limit.mockReturnValueOnce([mockCredential])
|
||||
|
||||
mockRefreshOAuthToken.mockResolvedValueOnce(null)
|
||||
|
||||
const { refreshAccessTokenIfNeeded } = await import('./utils')
|
||||
|
||||
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')
|
||||
|
||||
expect(token).toBeNull()
|
||||
expect(mockLogger.error).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,11 +1,66 @@
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console-logger'
|
||||
import { refreshOAuthToken } from '@/lib/oauth'
|
||||
import { db } from '@/db'
|
||||
import { account } from '@/db/schema'
|
||||
import { account, workflow } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('OAuthUtils')
|
||||
|
||||
/**
|
||||
* Get the user ID based on either a session or a workflow ID
|
||||
*/
|
||||
export async function getUserId(
|
||||
requestId: string,
|
||||
workflowId?: string
|
||||
): Promise<string | undefined> {
|
||||
// If workflowId is provided, this is a server-side request
|
||||
if (workflowId) {
|
||||
// Get the workflow to verify the user ID
|
||||
const workflows = await db
|
||||
.select({ userId: workflow.userId })
|
||||
.from(workflow)
|
||||
.where(eq(workflow.id, workflowId))
|
||||
.limit(1)
|
||||
|
||||
if (!workflows.length) {
|
||||
logger.warn(`[${requestId}] Workflow not found`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
return workflows[0].userId
|
||||
} else {
|
||||
// This is a client-side request, use the session
|
||||
const session = await getSession()
|
||||
|
||||
// Check if the user is authenticated
|
||||
if (!session?.user?.id) {
|
||||
logger.warn(`[${requestId}] Unauthenticated request rejected`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
return session.user.id
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a credential by ID and verify it belongs to the user
|
||||
*/
|
||||
export async function getCredential(requestId: string, credentialId: string, userId: string) {
|
||||
const credentials = await db
|
||||
.select()
|
||||
.from(account)
|
||||
.where(and(eq(account.id, credentialId), eq(account.userId, userId)))
|
||||
.limit(1)
|
||||
|
||||
if (!credentials.length) {
|
||||
logger.warn(`[${requestId}] Credential not found`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
return credentials[0]
|
||||
}
|
||||
|
||||
export async function getOAuthToken(userId: string, providerId: string): Promise<string | null> {
|
||||
const connections = await db
|
||||
.select({
|
||||
@@ -94,28 +149,21 @@ export async function getOAuthToken(userId: string, providerId: string): Promise
|
||||
* Refreshes an OAuth token if needed based on credential information
|
||||
* @param credentialId The ID of the credential to check and potentially refresh
|
||||
* @param userId The user ID who owns the credential (for security verification)
|
||||
* @param requestId Optional request ID for log correlation
|
||||
* @param requestId Request ID for log correlation
|
||||
* @returns The valid access token or null if refresh fails
|
||||
*/
|
||||
export async function refreshAccessTokenIfNeeded(
|
||||
credentialId: string,
|
||||
userId: string,
|
||||
requestId?: string
|
||||
requestId: string
|
||||
): Promise<string | null> {
|
||||
// Get the credential from the database
|
||||
const credentials = await db
|
||||
.select()
|
||||
.from(account)
|
||||
.where(and(eq(account.id, credentialId), eq(account.userId, userId)))
|
||||
.limit(1)
|
||||
// Get the credential directly using the getCredential helper
|
||||
const credential = await getCredential(requestId, credentialId, userId)
|
||||
|
||||
if (!credentials.length) {
|
||||
logger.warn(`[${requestId || ''}] Credential not found: ${credentialId}`)
|
||||
if (!credential) {
|
||||
return null
|
||||
}
|
||||
|
||||
const credential = credentials[0]
|
||||
|
||||
// Check if we need to refresh the token
|
||||
const expiresAt = credential.accessTokenExpiresAt
|
||||
const now = new Date()
|
||||
@@ -124,22 +172,17 @@ export async function refreshAccessTokenIfNeeded(
|
||||
let accessToken = credential.accessToken
|
||||
|
||||
if (needsRefresh && credential.refreshToken) {
|
||||
logger.info(
|
||||
`[${requestId || ''}] Token expired, attempting to refresh for credential: ${credentialId}`
|
||||
)
|
||||
logger.info(`[${requestId}] Token expired, attempting to refresh for credential`)
|
||||
try {
|
||||
const refreshedToken = await refreshOAuthToken(credential.providerId, credential.refreshToken)
|
||||
|
||||
if (!refreshedToken) {
|
||||
logger.error(
|
||||
`[${requestId || ''}] Failed to refresh token for credential: ${credentialId}`,
|
||||
{
|
||||
credentialId,
|
||||
providerId: credential.providerId,
|
||||
userId: credential.userId,
|
||||
hasRefreshToken: !!credential.refreshToken,
|
||||
}
|
||||
)
|
||||
logger.error(`[${requestId}] Failed to refresh token for credential: ${credentialId}`, {
|
||||
credentialId,
|
||||
providerId: credential.providerId,
|
||||
userId: credential.userId,
|
||||
hasRefreshToken: !!credential.refreshToken,
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
@@ -152,19 +195,17 @@ export async function refreshAccessTokenIfNeeded(
|
||||
|
||||
// If we received a new refresh token, update it
|
||||
if (refreshedToken.refreshToken && refreshedToken.refreshToken !== credential.refreshToken) {
|
||||
logger.info(`[${requestId || ''}] Updating refresh token for credential: ${credentialId}`)
|
||||
logger.info(`[${requestId}] Updating refresh token for credential`)
|
||||
updateData.refreshToken = refreshedToken.refreshToken
|
||||
}
|
||||
|
||||
// Update the token in the database
|
||||
await db.update(account).set(updateData).where(eq(account.id, credentialId))
|
||||
|
||||
logger.info(
|
||||
`[${requestId || ''}] Successfully refreshed access token for credential: ${credentialId}`
|
||||
)
|
||||
logger.info(`[${requestId}] Successfully refreshed access token for credential`)
|
||||
return refreshedToken.accessToken
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId || ''}] Error refreshing token for credential: ${credentialId}`, {
|
||||
logger.error(`[${requestId}] Error refreshing token for credential`, {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
providerId: credential.providerId,
|
||||
@@ -174,9 +215,62 @@ export async function refreshAccessTokenIfNeeded(
|
||||
return null
|
||||
}
|
||||
} else if (!accessToken) {
|
||||
logger.error(`[${requestId || ''}] Missing access token for credential: ${credential.id}`)
|
||||
logger.error(`[${requestId}] Missing access token for credential`)
|
||||
return null
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Access token is valid for credential`)
|
||||
return accessToken
|
||||
}
|
||||
|
||||
/**
|
||||
* Enhanced version that returns additional information about the refresh operation
|
||||
*/
|
||||
export async function refreshTokenIfNeeded(
|
||||
requestId: string,
|
||||
credential: any,
|
||||
credentialId: string
|
||||
): Promise<{ accessToken: string; refreshed: boolean }> {
|
||||
// Check if we need to refresh the token
|
||||
const expiresAt = credential.accessTokenExpiresAt
|
||||
const now = new Date()
|
||||
const needsRefresh = !expiresAt || expiresAt <= now
|
||||
|
||||
// If token is still valid, return it directly
|
||||
if (!needsRefresh || !credential.refreshToken) {
|
||||
logger.info(`[${requestId}] Access token is valid`)
|
||||
return { accessToken: credential.accessToken, refreshed: false }
|
||||
}
|
||||
|
||||
try {
|
||||
const refreshResult = await refreshOAuthToken(credential.providerId, credential.refreshToken)
|
||||
|
||||
if (!refreshResult) {
|
||||
logger.error(`[${requestId}] Failed to refresh token for credential`)
|
||||
throw new Error('Failed to refresh token')
|
||||
}
|
||||
|
||||
const { accessToken: refreshedToken, expiresIn, refreshToken: newRefreshToken } = refreshResult
|
||||
|
||||
// Prepare update data
|
||||
const updateData: any = {
|
||||
accessToken: refreshedToken,
|
||||
accessTokenExpiresAt: new Date(Date.now() + expiresIn * 1000), // Use provider's expiry
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
// If we received a new refresh token, update it
|
||||
if (newRefreshToken && newRefreshToken !== credential.refreshToken) {
|
||||
logger.info(`[${requestId}] Updating refresh token`)
|
||||
updateData.refreshToken = newRefreshToken
|
||||
}
|
||||
|
||||
await db.update(account).set(updateData).where(eq(account.id, credentialId))
|
||||
|
||||
logger.info(`[${requestId}] Successfully refreshed access token`)
|
||||
return { accessToken: refreshedToken, refreshed: true }
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error refreshing token`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
import { NextRequest, NextResponse } from 'next/server'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { db } from '@/db'
|
||||
import { session, user } from '@/db/schema'
|
||||
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const token = request.nextUrl.searchParams.get('token')
|
||||
|
||||
if (!token) {
|
||||
return NextResponse.json({ error: 'Token is required' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Get session by token
|
||||
const sessionRecord = await db
|
||||
.select()
|
||||
.from(session)
|
||||
.where(eq(session.id, token))
|
||||
.limit(1)
|
||||
.then((rows) => rows[0])
|
||||
|
||||
if (!sessionRecord) {
|
||||
return NextResponse.json({ error: 'Invalid session' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Get user from session
|
||||
const userRecord = await db
|
||||
.select()
|
||||
.from(user)
|
||||
.where(eq(user.id, sessionRecord.userId))
|
||||
.limit(1)
|
||||
.then((rows) => rows[0])
|
||||
|
||||
if (!userRecord) {
|
||||
return NextResponse.json({ error: 'User not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Return minimal user info (only what's needed)
|
||||
return NextResponse.json({
|
||||
user: {
|
||||
id: userRecord.id,
|
||||
email: userRecord.email,
|
||||
name: userRecord.name,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Session API error:', error)
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -23,14 +23,12 @@ export async function POST(request: NextRequest) {
|
||||
maxTokens,
|
||||
apiKey,
|
||||
responseFormat,
|
||||
workflowId,
|
||||
} = body
|
||||
|
||||
logger.info(`Provider request received for ${provider} model: ${model}`)
|
||||
|
||||
let finalApiKey: string
|
||||
try {
|
||||
finalApiKey = getApiKey(provider, model, apiKey)
|
||||
logger.info(`API key obtained for ${provider} ${model}`)
|
||||
} catch (error) {
|
||||
logger.error('Failed to get API key:', error)
|
||||
return NextResponse.json(
|
||||
@@ -49,6 +47,7 @@ export async function POST(request: NextRequest) {
|
||||
maxTokens,
|
||||
apiKey: finalApiKey,
|
||||
responseFormat,
|
||||
workflowId,
|
||||
})
|
||||
|
||||
return NextResponse.json(response)
|
||||
|
||||
@@ -156,7 +156,7 @@ export const XBlock: BlockConfig<XResponse> = {
|
||||
|
||||
// Convert string values to appropriate types
|
||||
const parsedParams: Record<string, any> = {
|
||||
accessToken: credential,
|
||||
credential: credential,
|
||||
}
|
||||
|
||||
// Add other params
|
||||
|
||||
@@ -150,6 +150,7 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
maxTokens: inputs.maxTokens,
|
||||
apiKey: inputs.apiKey,
|
||||
responseFormat,
|
||||
workflowId: context.workflowId,
|
||||
}
|
||||
|
||||
logger.info(`Provider request prepared`, {
|
||||
@@ -158,6 +159,7 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
hasContext: !!providerRequest.context,
|
||||
hasTools: !!providerRequest.tools,
|
||||
hasApiKey: !!providerRequest.apiKey,
|
||||
workflowId: providerRequest.workflowId,
|
||||
})
|
||||
|
||||
// Get the app URL from environment variable or use default
|
||||
|
||||
@@ -367,6 +367,7 @@ export const auth = betterAuth({
|
||||
pkce: true,
|
||||
responseType: 'code',
|
||||
prompt: 'consent',
|
||||
authentication: 'basic',
|
||||
redirectURI: `${process.env.NEXT_PUBLIC_APP_URL}/api/auth/oauth2/callback/x`,
|
||||
getUserInfo: async (tokens) => {
|
||||
try {
|
||||
|
||||
@@ -353,6 +353,7 @@ export async function refreshOAuthToken(
|
||||
tokenEndpoint = 'https://api.x.com/2/oauth2/token'
|
||||
clientId = process.env.X_CLIENT_ID
|
||||
clientSecret = process.env.X_CLIENT_SECRET
|
||||
useBasicAuth = true
|
||||
break
|
||||
case 'confluence':
|
||||
tokenEndpoint = 'https://auth.atlassian.com/oauth/token'
|
||||
@@ -407,6 +408,15 @@ export async function refreshOAuthToken(
|
||||
} else {
|
||||
throw new Error('Both client ID and client secret are required for Airtable OAuth')
|
||||
}
|
||||
} else if (provider === 'x') {
|
||||
// Handle X differently
|
||||
// Confidential client - use Basic Auth
|
||||
const authString = `${clientId}:${clientSecret}`
|
||||
const basicAuth = Buffer.from(authString).toString('base64')
|
||||
headers['Authorization'] = `Basic ${basicAuth}`
|
||||
|
||||
// When using Basic Auth, don't include client_id in body
|
||||
delete bodyParams.client_id
|
||||
} else {
|
||||
// For other providers, use the general approach
|
||||
if (useBasicAuth) {
|
||||
@@ -429,15 +439,20 @@ export async function refreshOAuthToken(
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
let errorData = errorText
|
||||
|
||||
// Try to parse the error as JSON for better diagnostics
|
||||
try {
|
||||
errorData = JSON.parse(errorText)
|
||||
} catch (e) {
|
||||
// Not JSON, keep as text
|
||||
}
|
||||
|
||||
logger.error('Token refresh failed:', {
|
||||
status: response.status,
|
||||
error: errorText,
|
||||
parsedError: errorData,
|
||||
provider,
|
||||
headers: JSON.stringify(headers, null, 2).replace(
|
||||
/"Authorization":"[^"]*"/,
|
||||
'"Authorization":"[REDACTED]"'
|
||||
),
|
||||
bodyParams: JSON.stringify(bodyParams),
|
||||
})
|
||||
throw new Error(`Failed to refresh token: ${response.status} ${errorText}`)
|
||||
}
|
||||
|
||||
@@ -254,7 +254,11 @@ ${fieldDescriptions}
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const mergedArgs = {
|
||||
...tool.params,
|
||||
...toolArgs,
|
||||
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
|
||||
}
|
||||
const result = await executeTool(toolName, mergedArgs)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
@@ -170,7 +170,11 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const mergedArgs = {
|
||||
...tool.params,
|
||||
...toolArgs,
|
||||
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
|
||||
}
|
||||
const result = await executeTool(toolName, mergedArgs)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
@@ -144,7 +144,11 @@ export const deepseekProvider: ProviderConfig = {
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const mergedArgs = {
|
||||
...tool.params,
|
||||
...toolArgs,
|
||||
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
|
||||
}
|
||||
const result = await executeTool(toolName, mergedArgs)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
@@ -214,7 +214,11 @@ export const googleProvider: ProviderConfig = {
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const mergedArgs = {
|
||||
...tool.params,
|
||||
...toolArgs,
|
||||
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
|
||||
}
|
||||
const result = await executeTool(toolName, mergedArgs)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
@@ -145,7 +145,11 @@ export const groqProvider: ProviderConfig = {
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const mergedArgs = {
|
||||
...tool.params,
|
||||
...toolArgs,
|
||||
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
|
||||
}
|
||||
const result = await executeTool(toolName, mergedArgs)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
@@ -163,7 +163,11 @@ export const ollamaProvider: ProviderConfig = {
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const mergedArgs = {
|
||||
...tool.params,
|
||||
...toolArgs,
|
||||
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
|
||||
}
|
||||
const result = await executeTool(toolName, mergedArgs)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
@@ -160,7 +160,11 @@ export const openaiProvider: ProviderConfig = {
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const mergedArgs = {
|
||||
...tool.params,
|
||||
...toolArgs,
|
||||
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
|
||||
}
|
||||
const result = await executeTool(toolName, mergedArgs)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
@@ -138,6 +138,7 @@ export interface ProviderRequest {
|
||||
strict?: boolean
|
||||
}
|
||||
local_execution?: boolean
|
||||
workflowId?: string // Optional workflow ID for authentication context
|
||||
}
|
||||
|
||||
// Map of provider IDs to their configurations
|
||||
|
||||
@@ -143,7 +143,11 @@ export const xAIProvider: ProviderConfig = {
|
||||
if (!tool) continue
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const mergedArgs = {
|
||||
...tool.params,
|
||||
...toolArgs,
|
||||
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
|
||||
}
|
||||
const result = await executeTool(toolName, mergedArgs)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
@@ -332,9 +332,11 @@ export async function executeTool(
|
||||
|
||||
try {
|
||||
const tool = getTool(toolId)
|
||||
// Ensure context is preserved if it exists
|
||||
const contextParams = { ...params }
|
||||
|
||||
// Validate the tool and its parameters
|
||||
validateToolRequest(toolId, tool, params)
|
||||
validateToolRequest(toolId, tool, contextParams)
|
||||
|
||||
// After validation, we know tool exists
|
||||
if (!tool) {
|
||||
@@ -344,7 +346,7 @@ export async function executeTool(
|
||||
// For any tool with direct execution capability, try it first
|
||||
if (tool.directExecution) {
|
||||
try {
|
||||
const directResult = await tool.directExecution(params)
|
||||
const directResult = await tool.directExecution(contextParams)
|
||||
if (directResult) {
|
||||
// Add timing data to the result
|
||||
const endTime = new Date()
|
||||
@@ -354,7 +356,11 @@ export async function executeTool(
|
||||
// Apply post-processing if available and not skipped
|
||||
if (tool.postProcess && directResult.success && !skipPostProcess) {
|
||||
try {
|
||||
const postProcessResult = await tool.postProcess(directResult, params, executeTool)
|
||||
const postProcessResult = await tool.postProcess(
|
||||
directResult,
|
||||
contextParams,
|
||||
executeTool
|
||||
)
|
||||
return {
|
||||
...postProcessResult,
|
||||
timing: {
|
||||
@@ -394,12 +400,12 @@ export async function executeTool(
|
||||
|
||||
// For internal routes or when skipProxy is true, call the API directly
|
||||
if (tool.request.isInternalRoute || skipProxy) {
|
||||
const result = await handleInternalRequest(toolId, tool, params)
|
||||
const result = await handleInternalRequest(toolId, tool, contextParams)
|
||||
|
||||
// Apply post-processing if available and not skipped
|
||||
if (tool.postProcess && result.success && !skipPostProcess) {
|
||||
try {
|
||||
const postProcessResult = await tool.postProcess(result, params, executeTool)
|
||||
const postProcessResult = await tool.postProcess(result, contextParams, executeTool)
|
||||
|
||||
// Add timing data to the post-processed result
|
||||
const endTime = new Date()
|
||||
@@ -446,12 +452,12 @@ export async function executeTool(
|
||||
}
|
||||
|
||||
// For external APIs, use the proxy
|
||||
const result = await handleProxyRequest(toolId, params)
|
||||
const result = await handleProxyRequest(toolId, contextParams)
|
||||
|
||||
// Apply post-processing if available and not skipped
|
||||
if (tool.postProcess && result.success && !skipPostProcess) {
|
||||
try {
|
||||
const postProcessResult = await tool.postProcess(result, params, executeTool)
|
||||
const postProcessResult = await tool.postProcess(result, contextParams, executeTool)
|
||||
|
||||
// Add timing data to the post-processed result
|
||||
const endTime = new Date()
|
||||
|
||||
@@ -63,7 +63,11 @@ export const searchTool: ToolConfig<XSearchParams, XSearchResponse> = {
|
||||
'user.fields': 'name,username,description,profile_image_url,verified,public_metrics',
|
||||
})
|
||||
|
||||
if (params.maxResults) queryParams.append('max_results', params.maxResults.toString())
|
||||
if (params.maxResults && params.maxResults < 10) {
|
||||
queryParams.append('max_results', '10')
|
||||
} else if (params.maxResults) {
|
||||
queryParams.append('max_results', params.maxResults.toString())
|
||||
}
|
||||
if (params.startTime) queryParams.append('start_time', params.startTime)
|
||||
if (params.endTime) queryParams.append('end_time', params.endTime)
|
||||
if (params.sortOrder) queryParams.append('sort_order', params.sortOrder)
|
||||
@@ -127,12 +131,15 @@ export const searchTool: ToolConfig<XSearchParams, XSearchResponse> = {
|
||||
},
|
||||
|
||||
transformError: (error) => {
|
||||
// Log the full error object for debugging
|
||||
console.error('X Search API Error:', JSON.stringify(error, null, 2))
|
||||
|
||||
if (error.title === 'Unauthorized') {
|
||||
return 'Invalid or expired access token. Please reconnect your X account.'
|
||||
}
|
||||
if (error.title === 'Invalid Request') {
|
||||
return 'Invalid search query. Please check your search parameters.'
|
||||
}
|
||||
return error.detail || 'An error occurred while searching X'
|
||||
return error.detail || `An error occurred while searching X`
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user