fix(oauth): fixed oauth server-side vs client-side checking, fixed x_search tool failing, remove unused route (#251)

* fix(oauth): fixed oauth server-side vs client-side checking, fixed x_search tool failing, remove unused route

* added tests

* removed unnecessary checks for nil request id

* acknowledged PR comments, fixed X OAuth token refresh
This commit is contained in:
Waleed Latif
2025-04-11 17:57:34 -07:00
committed by GitHub
parent 0317691a43
commit bace01fe56
21 changed files with 872 additions and 263 deletions

View File

@@ -0,0 +1,337 @@
/**
* Tests for OAuth token API routes
*
* @vitest-environment node
*/
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockRequest } from '@/app/api/__test-utils__/utils'
describe('OAuth Token API Routes', () => {
const mockGetUserId = vi.fn()
const mockGetCredential = vi.fn()
const mockRefreshTokenIfNeeded = vi.fn()
const mockLogger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
}
const mockUUID = 'mock-uuid-12345678-90ab-cdef-1234-567890abcdef'
const mockRequestId = mockUUID.slice(0, 8)
beforeEach(() => {
vi.resetModules()
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue(mockUUID),
})
vi.doMock('../utils', () => ({
getUserId: mockGetUserId,
getCredential: mockGetCredential,
refreshTokenIfNeeded: mockRefreshTokenIfNeeded,
}))
vi.doMock('@/lib/logs/console-logger', () => ({
createLogger: vi.fn().mockReturnValue(mockLogger),
}))
})
afterEach(() => {
vi.clearAllMocks()
})
/**
* POST route tests
*/
describe('POST handler', () => {
it('should return access token successfully', async () => {
mockGetUserId.mockResolvedValueOnce('test-user-id')
mockGetCredential.mockResolvedValueOnce({
id: 'credential-id',
accessToken: 'test-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000),
providerId: 'google',
})
mockRefreshTokenIfNeeded.mockResolvedValueOnce({
accessToken: 'fresh-token',
refreshed: false,
})
// Create mock request
const req = createMockRequest('POST', {
credentialId: 'credential-id',
})
// Import handler after setting up mocks
const { POST } = await import('./route')
// Call handler
const response = await POST(req)
const data = await response.json()
// Verify request was handled correctly
expect(response.status).toBe(200)
expect(data).toHaveProperty('accessToken', 'fresh-token')
// Verify mocks were called correctly
expect(mockGetUserId).toHaveBeenCalledWith(mockRequestId, undefined)
expect(mockGetCredential).toHaveBeenCalledWith(mockRequestId, 'credential-id', 'test-user-id')
expect(mockRefreshTokenIfNeeded).toHaveBeenCalled()
})
it('should handle workflowId for server-side authentication', async () => {
mockGetUserId.mockResolvedValueOnce('workflow-owner-id')
mockGetCredential.mockResolvedValueOnce({
id: 'credential-id',
accessToken: 'test-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000),
providerId: 'google',
})
mockRefreshTokenIfNeeded.mockResolvedValueOnce({
accessToken: 'fresh-token',
refreshed: false,
})
const req = createMockRequest('POST', {
credentialId: 'credential-id',
workflowId: 'workflow-id',
})
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(200)
expect(data).toHaveProperty('accessToken', 'fresh-token')
expect(mockGetUserId).toHaveBeenCalledWith(mockRequestId, 'workflow-id')
expect(mockGetCredential).toHaveBeenCalledWith(
mockRequestId,
'credential-id',
'workflow-owner-id'
)
})
it('should handle missing credentialId', async () => {
const req = createMockRequest('POST', {})
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(400)
expect(data).toHaveProperty('error', 'Credential ID is required')
expect(mockLogger.warn).toHaveBeenCalled()
})
it('should handle authentication failure', async () => {
mockGetUserId.mockResolvedValueOnce(undefined)
const req = createMockRequest('POST', {
credentialId: 'credential-id',
})
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(401)
expect(data).toHaveProperty('error', 'User not authenticated')
})
it('should handle workflow not found', async () => {
mockGetUserId.mockResolvedValueOnce(undefined)
const req = createMockRequest('POST', {
credentialId: 'credential-id',
workflowId: 'nonexistent-workflow-id',
})
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(404)
expect(data).toHaveProperty('error', 'Workflow not found')
})
it('should handle credential not found', async () => {
mockGetUserId.mockResolvedValueOnce('test-user-id')
mockGetCredential.mockResolvedValueOnce(undefined)
const req = createMockRequest('POST', {
credentialId: 'nonexistent-credential-id',
})
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(404)
expect(data).toHaveProperty('error', 'Credential not found')
})
it('should handle token refresh failure', async () => {
mockGetUserId.mockResolvedValueOnce('test-user-id')
mockGetCredential.mockResolvedValueOnce({
id: 'credential-id',
accessToken: 'test-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // Expired
providerId: 'google',
})
mockRefreshTokenIfNeeded.mockRejectedValueOnce(new Error('Refresh failure'))
const req = createMockRequest('POST', {
credentialId: 'credential-id',
})
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(401)
expect(data).toHaveProperty('error', 'Failed to refresh access token')
})
})
/**
* GET route tests
*/
describe('GET handler', () => {
it('should return access token successfully', async () => {
mockGetUserId.mockResolvedValueOnce('test-user-id')
mockGetCredential.mockResolvedValueOnce({
id: 'credential-id',
accessToken: 'test-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000),
providerId: 'google',
})
mockRefreshTokenIfNeeded.mockResolvedValueOnce({
accessToken: 'fresh-token',
refreshed: false,
})
const req = new Request(
'http://localhost:3000/api/auth/oauth/token?credentialId=credential-id'
)
const { GET } = await import('./route')
const response = await GET(req as any)
const data = await response.json()
expect(response.status).toBe(200)
expect(data).toHaveProperty('accessToken', 'fresh-token')
expect(mockGetUserId).toHaveBeenCalledWith(mockRequestId)
expect(mockGetCredential).toHaveBeenCalledWith(mockRequestId, 'credential-id', 'test-user-id')
expect(mockRefreshTokenIfNeeded).toHaveBeenCalled()
})
it('should handle missing credentialId', async () => {
const req = new Request('http://localhost:3000/api/auth/oauth/token')
const { GET } = await import('./route')
const response = await GET(req as any)
const data = await response.json()
expect(response.status).toBe(400)
expect(data).toHaveProperty('error', 'Credential ID is required')
expect(mockLogger.warn).toHaveBeenCalled()
})
it('should handle authentication failure', async () => {
mockGetUserId.mockResolvedValueOnce(undefined)
const req = new Request(
'http://localhost:3000/api/auth/oauth/token?credentialId=credential-id'
)
const { GET } = await import('./route')
const response = await GET(req as any)
const data = await response.json()
expect(response.status).toBe(401)
expect(data).toHaveProperty('error', 'User not authenticated')
})
it('should handle credential not found', async () => {
mockGetUserId.mockResolvedValueOnce('test-user-id')
mockGetCredential.mockResolvedValueOnce(undefined)
const req = new Request(
'http://localhost:3000/api/auth/oauth/token?credentialId=nonexistent-credential-id'
)
const { GET } = await import('./route')
const response = await GET(req as any)
const data = await response.json()
expect(response.status).toBe(404)
expect(data).toHaveProperty('error', 'Credential not found')
})
it('should handle missing access token', async () => {
mockGetUserId.mockResolvedValueOnce('test-user-id')
mockGetCredential.mockResolvedValueOnce({
id: 'credential-id',
accessToken: null,
refreshToken: 'refresh-token',
providerId: 'google',
})
const req = new Request(
'http://localhost:3000/api/auth/oauth/token?credentialId=credential-id'
)
const { GET } = await import('./route')
const response = await GET(req as any)
const data = await response.json()
expect(response.status).toBe(400)
expect(data).toHaveProperty('error', 'No access token available')
expect(mockLogger.warn).toHaveBeenCalled()
})
it('should handle token refresh failure', async () => {
mockGetUserId.mockResolvedValueOnce('test-user-id')
mockGetCredential.mockResolvedValueOnce({
id: 'credential-id',
accessToken: 'test-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // Expired
providerId: 'google',
})
mockRefreshTokenIfNeeded.mockRejectedValueOnce(new Error('Refresh failure'))
const req = new Request(
'http://localhost:3000/api/auth/oauth/token?credentialId=credential-id'
)
const { GET } = await import('./route')
const response = await GET(req as any)
const data = await response.json()
expect(response.status).toBe(401)
expect(data).toHaveProperty('error', 'Failed to refresh access token')
})
})
})

View File

@@ -1,10 +1,6 @@
import { NextRequest, NextResponse } from 'next/server'
import { and, eq } from 'drizzle-orm'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console-logger'
import { refreshOAuthToken } from '@/lib/oauth'
import { db } from '@/db'
import { account, workflow } from '@/db/schema'
import { getCredential, getUserId, refreshTokenIfNeeded } from '../utils'
const logger = createLogger('OAuthTokenAPI')
@@ -27,97 +23,29 @@ export async function POST(request: NextRequest) {
}
// Determine the user ID based on the context
let userId: string | undefined
const userId = await getUserId(requestId, workflowId)
// If workflowId is provided, this is a server-side request
if (workflowId) {
// Get the workflow to verify the user ID
const workflows = await db
.select({ userId: workflow.userId })
.from(workflow)
.where(eq(workflow.id, workflowId))
.limit(1)
if (!workflows.length) {
logger.warn(`[${requestId}] Workflow not found`)
return NextResponse.json({ error: 'Workflow not found' }, { status: 404 })
}
userId = workflows[0].userId
} else {
// This is a client-side request, use the session
const session = await getSession()
// Check if the user is authenticated
if (!session?.user?.id) {
logger.warn(`[${requestId}] Unauthenticated token request rejected`)
return NextResponse.json({ error: 'User not authenticated' }, { status: 401 })
}
userId = session.user.id
if (!userId) {
return NextResponse.json(
{ error: workflowId ? 'Workflow not found' : 'User not authenticated' },
{ status: workflowId ? 404 : 401 }
)
}
// Get the credential from the database
const credentials = await db
.select()
.from(account)
.where(and(eq(account.id, credentialId), eq(account.userId, userId)))
.limit(1)
const credential = await getCredential(requestId, credentialId, userId)
if (!credentials.length) {
logger.warn(`[${requestId}] Credential not found`)
if (!credential) {
return NextResponse.json({ error: 'Credential not found' }, { status: 404 })
}
const credential = credentials[0]
// Check if we need to refresh the token
const expiresAt = credential.accessTokenExpiresAt
const now = new Date()
const needsRefresh = !expiresAt || expiresAt <= now
if (needsRefresh && credential.refreshToken) {
try {
const refreshResult = await refreshOAuthToken(
credential.providerId,
credential.refreshToken
)
if (!refreshResult) {
throw new Error('Failed to refresh token')
}
const {
accessToken: refreshedToken,
expiresIn,
refreshToken: newRefreshToken,
} = refreshResult
// Prepare update data
const updateData: any = {
accessToken: refreshedToken,
accessTokenExpiresAt: new Date(Date.now() + expiresIn * 1000), // Use provider's expiry
updatedAt: new Date(),
}
// If we received a new refresh token, update it
if (newRefreshToken && newRefreshToken !== credential.refreshToken) {
logger.info(`[${requestId}] Updating refresh token for credential: ${credentialId}`)
updateData.refreshToken = newRefreshToken
}
await db.update(account).set(updateData).where(eq(account.id, credentialId))
logger.info(`[${requestId}] Successfully refreshed access token`)
return NextResponse.json({ accessToken: refreshedToken }, { status: 200 })
} catch (error) {
logger.error(`[${requestId}] Error refreshing token`, error)
return NextResponse.json({ error: 'Failed to refresh access token' }, { status: 500 })
}
try {
// Refresh the token if needed
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
return NextResponse.json({ accessToken }, { status: 200 })
} catch (error) {
return NextResponse.json({ error: 'Failed to refresh access token' }, { status: 401 })
}
logger.info(`[${requestId}] Access token is valid`)
return NextResponse.json({ accessToken: credential.accessToken }, { status: 200 })
} catch (error) {
logger.error(`[${requestId}] Error getting access token`, error)
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
@@ -131,15 +59,6 @@ export async function GET(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8) // Short request ID for correlation
try {
// Get the session
const session = await getSession()
// Check if the user is authenticated
if (!session?.user?.id) {
logger.warn(`[${requestId}] Unauthenticated request rejected`)
return NextResponse.json({ error: 'User not authenticated' }, { status: 401 })
}
// Get the credential ID from the query params
const { searchParams } = new URL(request.url)
const credentialId = searchParams.get('credentialId')
@@ -149,20 +68,18 @@ export async function GET(request: NextRequest) {
return NextResponse.json({ error: 'Credential ID is required' }, { status: 400 })
}
// Get the credential from the database
const credentials = await db.select().from(account).where(eq(account.id, credentialId)).limit(1)
// For GET requests, we only support session-based authentication
const userId = await getUserId(requestId)
if (!credentials.length) {
logger.warn(`[${requestId}] Credential not found`, { credentialId })
return NextResponse.json({ error: 'Credential not found' }, { status: 404 })
if (!userId) {
return NextResponse.json({ error: 'User not authenticated' }, { status: 401 })
}
const credential = credentials[0]
// Get the credential from the database
const credential = await getCredential(requestId, credentialId, userId)
// Check if the credential belongs to the user
if (credential.userId !== session.user.id) {
logger.warn(`[${requestId}] Unauthorized credential access attempt`)
return NextResponse.json({ error: 'Unauthorized' }, { status: 403 })
if (!credential) {
return NextResponse.json({ error: 'Credential not found' }, { status: 404 })
}
// Check if the access token is valid
@@ -171,57 +88,13 @@ export async function GET(request: NextRequest) {
return NextResponse.json({ error: 'No access token available' }, { status: 400 })
}
// Check if the token is expired and refresh if needed
const now = new Date()
const tokenExpiry = credential.accessTokenExpiresAt
let accessToken = credential.accessToken
if (tokenExpiry && tokenExpiry < now && credential.refreshToken) {
logger.info(`[${requestId}] Access token expired, attempting to refresh`)
try {
// Refresh the token using the centralized utility
const refreshResult = await refreshOAuthToken(
credential.providerId,
credential.refreshToken
)
if (!refreshResult) {
throw new Error('Failed to refresh token')
}
const {
accessToken: refreshedToken,
expiresIn,
refreshToken: newRefreshToken,
} = refreshResult
logger.info(`[${requestId}] Token refreshed successfully`)
// Prepare update data
const updateData: any = {
accessToken: refreshedToken,
accessTokenExpiresAt: new Date(Date.now() + expiresIn * 1000), // Use provider's expiry
updatedAt: new Date(),
}
// If we received a new refresh token, update it
if (newRefreshToken && newRefreshToken !== credential.refreshToken) {
logger.info(`[${requestId}] Updating refresh token for credential: ${credentialId}`)
updateData.refreshToken = newRefreshToken
}
// Update the token in the database with the correct expiration time
await db.update(account).set(updateData).where(eq(account.id, credentialId))
accessToken = refreshedToken
} catch (refreshError) {
logger.error(`[${requestId}] Error refreshing token`, refreshError)
return NextResponse.json({ error: 'Failed to refresh access token' }, { status: 401 })
}
try {
// Refresh the token if needed
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
return NextResponse.json({ accessToken }, { status: 200 })
} catch (error) {
return NextResponse.json({ error: 'Failed to refresh access token' }, { status: 401 })
}
// Return the access token
return NextResponse.json({ accessToken }, { status: 200 })
} catch (error) {
logger.error(`[${requestId}] Error fetching access token`, error)
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })

View File

@@ -0,0 +1,292 @@
/**
* Tests for OAuth utility functions
*
* @vitest-environment node
*/
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
describe('OAuth Utils', () => {
const mockSession = { user: { id: 'test-user-id' } }
const mockDb = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
limit: vi.fn().mockReturnValue([]),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
}
const mockRefreshOAuthToken = vi.fn()
const mockLogger = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
}
beforeEach(() => {
vi.resetModules()
vi.doMock('@/lib/auth', () => ({
getSession: vi.fn().mockResolvedValue(mockSession),
}))
vi.doMock('@/db', () => ({
db: mockDb,
}))
vi.doMock('@/lib/oauth', () => ({
refreshOAuthToken: mockRefreshOAuthToken,
}))
vi.doMock('@/lib/logs/console-logger', () => ({
createLogger: vi.fn().mockReturnValue(mockLogger),
}))
})
afterEach(() => {
vi.clearAllMocks()
})
describe('getUserId', () => {
it('should get user ID from session when no workflowId is provided', async () => {
const { getUserId } = await import('./utils')
const userId = await getUserId('request-id')
expect(userId).toBe('test-user-id')
})
it('should get user ID from workflow when workflowId is provided', async () => {
mockDb.limit.mockReturnValueOnce([{ userId: 'workflow-owner-id' }])
const { getUserId } = await import('./utils')
const userId = await getUserId('request-id', 'workflow-id')
expect(mockDb.select).toHaveBeenCalled()
expect(mockDb.from).toHaveBeenCalled()
expect(mockDb.where).toHaveBeenCalled()
expect(mockDb.limit).toHaveBeenCalledWith(1)
expect(userId).toBe('workflow-owner-id')
})
it('should return undefined if no session is found', async () => {
vi.doMock('@/lib/auth', () => ({
getSession: vi.fn().mockResolvedValue(null),
}))
const { getUserId } = await import('./utils')
const userId = await getUserId('request-id')
expect(userId).toBeUndefined()
expect(mockLogger.warn).toHaveBeenCalled()
})
it('should return undefined if workflow is not found', async () => {
mockDb.limit.mockReturnValueOnce([])
const { getUserId } = await import('./utils')
const userId = await getUserId('request-id', 'nonexistent-workflow-id')
expect(userId).toBeUndefined()
expect(mockLogger.warn).toHaveBeenCalled()
})
})
describe('getCredential', () => {
it('should return credential when found', async () => {
const mockCredential = { id: 'credential-id', userId: 'test-user-id' }
mockDb.limit.mockReturnValueOnce([mockCredential])
const { getCredential } = await import('./utils')
const credential = await getCredential('request-id', 'credential-id', 'test-user-id')
expect(mockDb.select).toHaveBeenCalled()
expect(mockDb.from).toHaveBeenCalled()
expect(mockDb.where).toHaveBeenCalled()
expect(mockDb.limit).toHaveBeenCalledWith(1)
expect(credential).toEqual(mockCredential)
})
it('should return undefined when credential is not found', async () => {
mockDb.limit.mockReturnValueOnce([])
const { getCredential } = await import('./utils')
const credential = await getCredential('request-id', 'nonexistent-id', 'test-user-id')
expect(credential).toBeUndefined()
expect(mockLogger.warn).toHaveBeenCalled()
})
})
describe('refreshTokenIfNeeded', () => {
it('should return valid token without refresh if not expired', async () => {
const mockCredential = {
id: 'credential-id',
accessToken: 'valid-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000), // 1 hour in the future
providerId: 'google',
}
const { refreshTokenIfNeeded } = await import('./utils')
const result = await refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
expect(mockRefreshOAuthToken).not.toHaveBeenCalled()
expect(result).toEqual({ accessToken: 'valid-token', refreshed: false })
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('Access token is valid'))
})
it('should refresh token when expired', async () => {
const mockCredential = {
id: 'credential-id',
accessToken: 'expired-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
providerId: 'google',
}
mockRefreshOAuthToken.mockResolvedValueOnce({
accessToken: 'new-token',
expiresIn: 3600,
refreshToken: 'new-refresh-token',
})
const { refreshTokenIfNeeded } = await import('./utils')
const result = await refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
expect(mockRefreshOAuthToken).toHaveBeenCalledWith('google', 'refresh-token')
expect(mockDb.update).toHaveBeenCalled()
expect(mockDb.set).toHaveBeenCalled()
expect(result).toEqual({ accessToken: 'new-token', refreshed: true })
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('Successfully refreshed')
)
})
it('should handle refresh token error', async () => {
const mockCredential = {
id: 'credential-id',
accessToken: 'expired-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
providerId: 'google',
}
mockRefreshOAuthToken.mockRejectedValueOnce(new Error('Refresh failed'))
const { refreshTokenIfNeeded } = await import('./utils')
await expect(
refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
).rejects.toThrow()
expect(mockLogger.error).toHaveBeenCalled()
})
it('should not attempt refresh if no refresh token', async () => {
const mockCredential = {
id: 'credential-id',
accessToken: 'token',
refreshToken: null,
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
providerId: 'google',
}
const { refreshTokenIfNeeded } = await import('./utils')
const result = await refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')
expect(mockRefreshOAuthToken).not.toHaveBeenCalled()
expect(result).toEqual({ accessToken: 'token', refreshed: false })
})
})
describe('refreshAccessTokenIfNeeded', () => {
it('should return valid access token without refresh if not expired', async () => {
const mockCredential = {
id: 'credential-id',
accessToken: 'valid-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() + 3600 * 1000), // 1 hour in the future
providerId: 'google',
userId: 'test-user-id',
}
mockDb.limit.mockReturnValueOnce([mockCredential])
const { refreshAccessTokenIfNeeded } = await import('./utils')
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')
expect(mockRefreshOAuthToken).not.toHaveBeenCalled()
expect(token).toBe('valid-token')
})
it('should refresh token when expired', async () => {
const mockCredential = {
id: 'credential-id',
accessToken: 'expired-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
providerId: 'google',
userId: 'test-user-id',
}
mockDb.limit.mockReturnValueOnce([mockCredential])
mockRefreshOAuthToken.mockResolvedValueOnce({
accessToken: 'new-token',
expiresIn: 3600,
refreshToken: 'new-refresh-token',
})
const { refreshAccessTokenIfNeeded } = await import('./utils')
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')
expect(mockRefreshOAuthToken).toHaveBeenCalledWith('google', 'refresh-token')
expect(mockDb.update).toHaveBeenCalled()
expect(mockDb.set).toHaveBeenCalled()
expect(token).toBe('new-token')
})
it('should return null if credential not found', async () => {
mockDb.limit.mockReturnValueOnce([])
const { refreshAccessTokenIfNeeded } = await import('./utils')
const token = await refreshAccessTokenIfNeeded('nonexistent-id', 'test-user-id', 'request-id')
expect(token).toBeNull()
expect(mockLogger.warn).toHaveBeenCalled()
})
it('should return null if refresh fails', async () => {
const mockCredential = {
id: 'credential-id',
accessToken: 'expired-token',
refreshToken: 'refresh-token',
accessTokenExpiresAt: new Date(Date.now() - 3600 * 1000), // 1 hour in the past
providerId: 'google',
userId: 'test-user-id',
}
mockDb.limit.mockReturnValueOnce([mockCredential])
mockRefreshOAuthToken.mockResolvedValueOnce(null)
const { refreshAccessTokenIfNeeded } = await import('./utils')
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')
expect(token).toBeNull()
expect(mockLogger.error).toHaveBeenCalled()
})
})
})

View File

@@ -1,11 +1,66 @@
import { and, eq } from 'drizzle-orm'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console-logger'
import { refreshOAuthToken } from '@/lib/oauth'
import { db } from '@/db'
import { account } from '@/db/schema'
import { account, workflow } from '@/db/schema'
const logger = createLogger('OAuthUtils')
/**
* Get the user ID based on either a session or a workflow ID
*/
export async function getUserId(
requestId: string,
workflowId?: string
): Promise<string | undefined> {
// If workflowId is provided, this is a server-side request
if (workflowId) {
// Get the workflow to verify the user ID
const workflows = await db
.select({ userId: workflow.userId })
.from(workflow)
.where(eq(workflow.id, workflowId))
.limit(1)
if (!workflows.length) {
logger.warn(`[${requestId}] Workflow not found`)
return undefined
}
return workflows[0].userId
} else {
// This is a client-side request, use the session
const session = await getSession()
// Check if the user is authenticated
if (!session?.user?.id) {
logger.warn(`[${requestId}] Unauthenticated request rejected`)
return undefined
}
return session.user.id
}
}
/**
* Get a credential by ID and verify it belongs to the user
*/
export async function getCredential(requestId: string, credentialId: string, userId: string) {
const credentials = await db
.select()
.from(account)
.where(and(eq(account.id, credentialId), eq(account.userId, userId)))
.limit(1)
if (!credentials.length) {
logger.warn(`[${requestId}] Credential not found`)
return undefined
}
return credentials[0]
}
export async function getOAuthToken(userId: string, providerId: string): Promise<string | null> {
const connections = await db
.select({
@@ -94,28 +149,21 @@ export async function getOAuthToken(userId: string, providerId: string): Promise
* Refreshes an OAuth token if needed based on credential information
* @param credentialId The ID of the credential to check and potentially refresh
* @param userId The user ID who owns the credential (for security verification)
* @param requestId Optional request ID for log correlation
* @param requestId Request ID for log correlation
* @returns The valid access token or null if refresh fails
*/
export async function refreshAccessTokenIfNeeded(
credentialId: string,
userId: string,
requestId?: string
requestId: string
): Promise<string | null> {
// Get the credential from the database
const credentials = await db
.select()
.from(account)
.where(and(eq(account.id, credentialId), eq(account.userId, userId)))
.limit(1)
// Get the credential directly using the getCredential helper
const credential = await getCredential(requestId, credentialId, userId)
if (!credentials.length) {
logger.warn(`[${requestId || ''}] Credential not found: ${credentialId}`)
if (!credential) {
return null
}
const credential = credentials[0]
// Check if we need to refresh the token
const expiresAt = credential.accessTokenExpiresAt
const now = new Date()
@@ -124,22 +172,17 @@ export async function refreshAccessTokenIfNeeded(
let accessToken = credential.accessToken
if (needsRefresh && credential.refreshToken) {
logger.info(
`[${requestId || ''}] Token expired, attempting to refresh for credential: ${credentialId}`
)
logger.info(`[${requestId}] Token expired, attempting to refresh for credential`)
try {
const refreshedToken = await refreshOAuthToken(credential.providerId, credential.refreshToken)
if (!refreshedToken) {
logger.error(
`[${requestId || ''}] Failed to refresh token for credential: ${credentialId}`,
{
credentialId,
providerId: credential.providerId,
userId: credential.userId,
hasRefreshToken: !!credential.refreshToken,
}
)
logger.error(`[${requestId}] Failed to refresh token for credential: ${credentialId}`, {
credentialId,
providerId: credential.providerId,
userId: credential.userId,
hasRefreshToken: !!credential.refreshToken,
})
return null
}
@@ -152,19 +195,17 @@ export async function refreshAccessTokenIfNeeded(
// If we received a new refresh token, update it
if (refreshedToken.refreshToken && refreshedToken.refreshToken !== credential.refreshToken) {
logger.info(`[${requestId || ''}] Updating refresh token for credential: ${credentialId}`)
logger.info(`[${requestId}] Updating refresh token for credential`)
updateData.refreshToken = refreshedToken.refreshToken
}
// Update the token in the database
await db.update(account).set(updateData).where(eq(account.id, credentialId))
logger.info(
`[${requestId || ''}] Successfully refreshed access token for credential: ${credentialId}`
)
logger.info(`[${requestId}] Successfully refreshed access token for credential`)
return refreshedToken.accessToken
} catch (error) {
logger.error(`[${requestId || ''}] Error refreshing token for credential: ${credentialId}`, {
logger.error(`[${requestId}] Error refreshing token for credential`, {
error: error instanceof Error ? error.message : String(error),
stack: error instanceof Error ? error.stack : undefined,
providerId: credential.providerId,
@@ -174,9 +215,62 @@ export async function refreshAccessTokenIfNeeded(
return null
}
} else if (!accessToken) {
logger.error(`[${requestId || ''}] Missing access token for credential: ${credential.id}`)
logger.error(`[${requestId}] Missing access token for credential`)
return null
}
logger.info(`[${requestId}] Access token is valid for credential`)
return accessToken
}
/**
* Enhanced version that returns additional information about the refresh operation
*/
export async function refreshTokenIfNeeded(
requestId: string,
credential: any,
credentialId: string
): Promise<{ accessToken: string; refreshed: boolean }> {
// Check if we need to refresh the token
const expiresAt = credential.accessTokenExpiresAt
const now = new Date()
const needsRefresh = !expiresAt || expiresAt <= now
// If token is still valid, return it directly
if (!needsRefresh || !credential.refreshToken) {
logger.info(`[${requestId}] Access token is valid`)
return { accessToken: credential.accessToken, refreshed: false }
}
try {
const refreshResult = await refreshOAuthToken(credential.providerId, credential.refreshToken)
if (!refreshResult) {
logger.error(`[${requestId}] Failed to refresh token for credential`)
throw new Error('Failed to refresh token')
}
const { accessToken: refreshedToken, expiresIn, refreshToken: newRefreshToken } = refreshResult
// Prepare update data
const updateData: any = {
accessToken: refreshedToken,
accessTokenExpiresAt: new Date(Date.now() + expiresIn * 1000), // Use provider's expiry
updatedAt: new Date(),
}
// If we received a new refresh token, update it
if (newRefreshToken && newRefreshToken !== credential.refreshToken) {
logger.info(`[${requestId}] Updating refresh token`)
updateData.refreshToken = newRefreshToken
}
await db.update(account).set(updateData).where(eq(account.id, credentialId))
logger.info(`[${requestId}] Successfully refreshed access token`)
return { accessToken: refreshedToken, refreshed: true }
} catch (error) {
logger.error(`[${requestId}] Error refreshing token`, error)
throw error
}
}

View File

@@ -1,50 +0,0 @@
import { NextRequest, NextResponse } from 'next/server'
import { eq } from 'drizzle-orm'
import { db } from '@/db'
import { session, user } from '@/db/schema'
export async function GET(request: NextRequest) {
try {
const token = request.nextUrl.searchParams.get('token')
if (!token) {
return NextResponse.json({ error: 'Token is required' }, { status: 400 })
}
// Get session by token
const sessionRecord = await db
.select()
.from(session)
.where(eq(session.id, token))
.limit(1)
.then((rows) => rows[0])
if (!sessionRecord) {
return NextResponse.json({ error: 'Invalid session' }, { status: 401 })
}
// Get user from session
const userRecord = await db
.select()
.from(user)
.where(eq(user.id, sessionRecord.userId))
.limit(1)
.then((rows) => rows[0])
if (!userRecord) {
return NextResponse.json({ error: 'User not found' }, { status: 404 })
}
// Return minimal user info (only what's needed)
return NextResponse.json({
user: {
id: userRecord.id,
email: userRecord.email,
name: userRecord.name,
},
})
} catch (error) {
console.error('Session API error:', error)
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
}
}

View File

@@ -23,14 +23,12 @@ export async function POST(request: NextRequest) {
maxTokens,
apiKey,
responseFormat,
workflowId,
} = body
logger.info(`Provider request received for ${provider} model: ${model}`)
let finalApiKey: string
try {
finalApiKey = getApiKey(provider, model, apiKey)
logger.info(`API key obtained for ${provider} ${model}`)
} catch (error) {
logger.error('Failed to get API key:', error)
return NextResponse.json(
@@ -49,6 +47,7 @@ export async function POST(request: NextRequest) {
maxTokens,
apiKey: finalApiKey,
responseFormat,
workflowId,
})
return NextResponse.json(response)

View File

@@ -156,7 +156,7 @@ export const XBlock: BlockConfig<XResponse> = {
// Convert string values to appropriate types
const parsedParams: Record<string, any> = {
accessToken: credential,
credential: credential,
}
// Add other params

View File

@@ -150,6 +150,7 @@ export class AgentBlockHandler implements BlockHandler {
maxTokens: inputs.maxTokens,
apiKey: inputs.apiKey,
responseFormat,
workflowId: context.workflowId,
}
logger.info(`Provider request prepared`, {
@@ -158,6 +159,7 @@ export class AgentBlockHandler implements BlockHandler {
hasContext: !!providerRequest.context,
hasTools: !!providerRequest.tools,
hasApiKey: !!providerRequest.apiKey,
workflowId: providerRequest.workflowId,
})
// Get the app URL from environment variable or use default

View File

@@ -367,6 +367,7 @@ export const auth = betterAuth({
pkce: true,
responseType: 'code',
prompt: 'consent',
authentication: 'basic',
redirectURI: `${process.env.NEXT_PUBLIC_APP_URL}/api/auth/oauth2/callback/x`,
getUserInfo: async (tokens) => {
try {

View File

@@ -353,6 +353,7 @@ export async function refreshOAuthToken(
tokenEndpoint = 'https://api.x.com/2/oauth2/token'
clientId = process.env.X_CLIENT_ID
clientSecret = process.env.X_CLIENT_SECRET
useBasicAuth = true
break
case 'confluence':
tokenEndpoint = 'https://auth.atlassian.com/oauth/token'
@@ -407,6 +408,15 @@ export async function refreshOAuthToken(
} else {
throw new Error('Both client ID and client secret are required for Airtable OAuth')
}
} else if (provider === 'x') {
// Handle X differently
// Confidential client - use Basic Auth
const authString = `${clientId}:${clientSecret}`
const basicAuth = Buffer.from(authString).toString('base64')
headers['Authorization'] = `Basic ${basicAuth}`
// When using Basic Auth, don't include client_id in body
delete bodyParams.client_id
} else {
// For other providers, use the general approach
if (useBasicAuth) {
@@ -429,15 +439,20 @@ export async function refreshOAuthToken(
if (!response.ok) {
const errorText = await response.text()
let errorData = errorText
// Try to parse the error as JSON for better diagnostics
try {
errorData = JSON.parse(errorText)
} catch (e) {
// Not JSON, keep as text
}
logger.error('Token refresh failed:', {
status: response.status,
error: errorText,
parsedError: errorData,
provider,
headers: JSON.stringify(headers, null, 2).replace(
/"Authorization":"[^"]*"/,
'"Authorization":"[REDACTED]"'
),
bodyParams: JSON.stringify(bodyParams),
})
throw new Error(`Failed to refresh token: ${response.status} ${errorText}`)
}

View File

@@ -254,7 +254,11 @@ ${fieldDescriptions}
// Execute the tool
const toolCallStartTime = Date.now()
const mergedArgs = { ...tool.params, ...toolArgs }
const mergedArgs = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
}
const result = await executeTool(toolName, mergedArgs)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime

View File

@@ -170,7 +170,11 @@ export const cerebrasProvider: ProviderConfig = {
// Execute the tool
const toolCallStartTime = Date.now()
const mergedArgs = { ...tool.params, ...toolArgs }
const mergedArgs = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
}
const result = await executeTool(toolName, mergedArgs)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime

View File

@@ -144,7 +144,11 @@ export const deepseekProvider: ProviderConfig = {
// Execute the tool
const toolCallStartTime = Date.now()
const mergedArgs = { ...tool.params, ...toolArgs }
const mergedArgs = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
}
const result = await executeTool(toolName, mergedArgs)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime

View File

@@ -214,7 +214,11 @@ export const googleProvider: ProviderConfig = {
// Execute the tool
const toolCallStartTime = Date.now()
const mergedArgs = { ...tool.params, ...toolArgs }
const mergedArgs = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
}
const result = await executeTool(toolName, mergedArgs)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime

View File

@@ -145,7 +145,11 @@ export const groqProvider: ProviderConfig = {
// Execute the tool
const toolCallStartTime = Date.now()
const mergedArgs = { ...tool.params, ...toolArgs }
const mergedArgs = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
}
const result = await executeTool(toolName, mergedArgs)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime

View File

@@ -163,7 +163,11 @@ export const ollamaProvider: ProviderConfig = {
// Execute the tool
const toolCallStartTime = Date.now()
const mergedArgs = { ...tool.params, ...toolArgs }
const mergedArgs = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
}
const result = await executeTool(toolName, mergedArgs)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime

View File

@@ -160,7 +160,11 @@ export const openaiProvider: ProviderConfig = {
// Execute the tool
const toolCallStartTime = Date.now()
const mergedArgs = { ...tool.params, ...toolArgs }
const mergedArgs = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
}
const result = await executeTool(toolName, mergedArgs)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime

View File

@@ -138,6 +138,7 @@ export interface ProviderRequest {
strict?: boolean
}
local_execution?: boolean
workflowId?: string // Optional workflow ID for authentication context
}
// Map of provider IDs to their configurations

View File

@@ -143,7 +143,11 @@ export const xAIProvider: ProviderConfig = {
if (!tool) continue
const toolCallStartTime = Date.now()
const mergedArgs = { ...tool.params, ...toolArgs }
const mergedArgs = {
...tool.params,
...toolArgs,
...(request.workflowId ? { _context: { workflowId: request.workflowId } } : {}),
}
const result = await executeTool(toolName, mergedArgs)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime

View File

@@ -332,9 +332,11 @@ export async function executeTool(
try {
const tool = getTool(toolId)
// Ensure context is preserved if it exists
const contextParams = { ...params }
// Validate the tool and its parameters
validateToolRequest(toolId, tool, params)
validateToolRequest(toolId, tool, contextParams)
// After validation, we know tool exists
if (!tool) {
@@ -344,7 +346,7 @@ export async function executeTool(
// For any tool with direct execution capability, try it first
if (tool.directExecution) {
try {
const directResult = await tool.directExecution(params)
const directResult = await tool.directExecution(contextParams)
if (directResult) {
// Add timing data to the result
const endTime = new Date()
@@ -354,7 +356,11 @@ export async function executeTool(
// Apply post-processing if available and not skipped
if (tool.postProcess && directResult.success && !skipPostProcess) {
try {
const postProcessResult = await tool.postProcess(directResult, params, executeTool)
const postProcessResult = await tool.postProcess(
directResult,
contextParams,
executeTool
)
return {
...postProcessResult,
timing: {
@@ -394,12 +400,12 @@ export async function executeTool(
// For internal routes or when skipProxy is true, call the API directly
if (tool.request.isInternalRoute || skipProxy) {
const result = await handleInternalRequest(toolId, tool, params)
const result = await handleInternalRequest(toolId, tool, contextParams)
// Apply post-processing if available and not skipped
if (tool.postProcess && result.success && !skipPostProcess) {
try {
const postProcessResult = await tool.postProcess(result, params, executeTool)
const postProcessResult = await tool.postProcess(result, contextParams, executeTool)
// Add timing data to the post-processed result
const endTime = new Date()
@@ -446,12 +452,12 @@ export async function executeTool(
}
// For external APIs, use the proxy
const result = await handleProxyRequest(toolId, params)
const result = await handleProxyRequest(toolId, contextParams)
// Apply post-processing if available and not skipped
if (tool.postProcess && result.success && !skipPostProcess) {
try {
const postProcessResult = await tool.postProcess(result, params, executeTool)
const postProcessResult = await tool.postProcess(result, contextParams, executeTool)
// Add timing data to the post-processed result
const endTime = new Date()

View File

@@ -63,7 +63,11 @@ export const searchTool: ToolConfig<XSearchParams, XSearchResponse> = {
'user.fields': 'name,username,description,profile_image_url,verified,public_metrics',
})
if (params.maxResults) queryParams.append('max_results', params.maxResults.toString())
if (params.maxResults && params.maxResults < 10) {
queryParams.append('max_results', '10')
} else if (params.maxResults) {
queryParams.append('max_results', params.maxResults.toString())
}
if (params.startTime) queryParams.append('start_time', params.startTime)
if (params.endTime) queryParams.append('end_time', params.endTime)
if (params.sortOrder) queryParams.append('sort_order', params.sortOrder)
@@ -127,12 +131,15 @@ export const searchTool: ToolConfig<XSearchParams, XSearchResponse> = {
},
transformError: (error) => {
// Log the full error object for debugging
console.error('X Search API Error:', JSON.stringify(error, null, 2))
if (error.title === 'Unauthorized') {
return 'Invalid or expired access token. Please reconnect your X account.'
}
if (error.title === 'Invalid Request') {
return 'Invalid search query. Please check your search parameters.'
}
return error.detail || 'An error occurred while searching X'
return error.detail || `An error occurred while searching X`
},
}