feat(oauth): clean up oauth authorization flow

This commit is contained in:
Waleed Latif
2025-03-05 16:08:23 -08:00
parent 3df326f8ba
commit 2cb724eacc
6 changed files with 283 additions and 474 deletions

View File

@@ -1,9 +1,6 @@
import { useCallback, useEffect, useState } from 'react'
import { and, eq } from 'drizzle-orm'
import { getSession } from '@/lib/auth'
import { useSession } from '@/lib/auth-client'
import { db } from '@/db'
import { account } from '@/db/schema'
'use client'
import { useCallback, useState } from 'react'
import { OAuthProvider } from '@/tools/types'
/**
@@ -18,382 +15,70 @@ export interface OAuthRequiredError {
}
/**
* Check if the user has authorized the required OAuth provider with necessary scopes (server-side)
*
* @param provider - The OAuth provider to check
* @param requiredScopes - Optional scopes to check
* @returns Boolean indicating if the provider is authorized with required scopes
*/
export async function hasAuthorizedProviderServer(
provider: OAuthProvider,
requiredScopes?: string[]
): Promise<boolean> {
try {
// Get the session
const session = await getSession()
// If not authenticated, return false
if (!session?.user?.id) {
return false
}
// Determine the appropriate feature type based on the scopes
let featureType = 'default'
if (requiredScopes && requiredScopes.length > 0) {
if (requiredScopes.some((scope) => scope.includes('repo'))) {
featureType = 'repo'
} else if (requiredScopes.some((scope) => scope.includes('workflow'))) {
featureType = 'workflow'
} else if (
requiredScopes.some((scope) => scope.includes('gmail') || scope.includes('mail'))
) {
featureType = 'email'
} else if (requiredScopes.some((scope) => scope.includes('calendar'))) {
featureType = 'calendar'
} else if (requiredScopes.some((scope) => scope.includes('drive'))) {
featureType = 'drive'
} else if (requiredScopes.some((scope) => scope.includes('write'))) {
featureType = 'write'
} else if (requiredScopes.some((scope) => scope.includes('read'))) {
featureType = 'read'
}
}
// We check the appropriate provider ID based on the feature type
const providerId = `${provider}-${featureType}`
// Check if there's an account for this provider for the user
const accounts = await db
.select()
.from(account)
.where(and(eq(account.userId, session.user.id), eq(account.providerId, providerId)))
return accounts.length > 0
} catch (error) {
console.error('Error checking provider authorization:', error)
return false
}
}
/**
* Check if a tool requires OAuth and if the user has authorized it (server-side)
*
* @param tool - The tool configuration
* @returns Object indicating if OAuth is required and if the user has authorized it
*/
export async function checkOAuthRequirementServer(tool: any): Promise<{
requiresAuth: boolean
isAuthorized: boolean
provider?: OAuthProvider
requiredScopes?: string[]
}> {
// Skip if no OAuth config or not required
if (!tool.oauth || !tool.oauth.required) {
return { requiresAuth: false, isAuthorized: false }
}
const provider = tool.oauth.provider
const additionalScopes = tool.oauth.additionalScopes || []
// Check if the user has authorized this provider with required scopes
const isAuthorized = await hasAuthorizedProviderServer(provider, additionalScopes)
return {
requiresAuth: true,
isAuthorized,
provider,
requiredScopes: additionalScopes,
}
}
/**
* Verify OAuth requirements before executing a tool (server-side)
* Throws an error if OAuth is required but not authorized
*
* @param tool - The tool configuration
* @throws Error with JSON.stringify(OAuthRequiredError)
*/
export async function verifyOAuthBeforeExecutionServer(tool: any): Promise<void> {
const { requiresAuth, isAuthorized, provider, requiredScopes } =
await checkOAuthRequirementServer(tool)
if (requiresAuth && !isAuthorized && provider) {
// Throw a structured error that can be caught and handled
throw new Error(
JSON.stringify({
type: 'oauth_required',
provider,
toolId: tool.id,
toolName: tool.name,
requiredScopes,
})
)
}
}
/**
* Get OAuth tokens for a provider if the user has authorized it
*
* @param userId - The user's ID
* @param provider - The OAuth provider to get tokens for
* @returns The OAuth tokens or null if not authorized
*/
export async function getOAuthTokens(
userId: string,
provider: OAuthProvider
): Promise<{
accessToken: string
refreshToken?: string
expiresAt?: Date
} | null> {
try {
// Query the account table for this user and provider
const accounts = await db
.select()
.from(account)
.where(and(eq(account.userId, userId), eq(account.providerId, provider)))
.limit(1)
if (!accounts.length || !accounts[0].accessToken) {
return null
}
const userAccount = accounts[0]
// Check if the token is expired
if (
userAccount.accessTokenExpiresAt &&
new Date(userAccount.accessTokenExpiresAt) < new Date()
) {
// In a production app, we would use the refresh token to get a new access token here
// But for simplicity, we'll just return null for expired tokens
console.warn(`Token for ${provider} is expired and needs refresh`)
return null
}
// Ensure accessToken is not null using the type guard we did earlier
const accessToken = userAccount.accessToken as string
return {
accessToken,
refreshToken: userAccount.refreshToken || undefined,
expiresAt: userAccount.accessTokenExpiresAt
? new Date(userAccount.accessTokenExpiresAt)
: undefined,
}
} catch (error) {
console.error('Error getting OAuth tokens:', error)
return null
}
}
/**
* Get OAuth tokens for a specific tool if the user has authorized it
*
* @param userId - The user's ID
* @param tool - The tool configuration
* @returns The OAuth tokens or null if not required or not authorized
*/
export async function getOAuthTokensForTool(
userId: string,
tool: any
): Promise<{
accessToken: string
refreshToken?: string
expiresAt?: Date
} | null> {
// Skip if no OAuth config or not required
if (!tool.oauth || !tool.oauth.required) {
return null
}
// Get tokens for the provider
return getOAuthTokens(userId, tool.oauth.provider)
}
/**
* Custom hook to check if a user has authorized an OAuth provider
*
* @param provider - The OAuth provider to check
* @param requiredScopes - Optional array of scopes required for the operation
* @returns An object with authorization status and loading state
*/
export function useProviderAuthorization(provider: OAuthProvider, requiredScopes?: string[]) {
const { data: session, isPending } = useSession()
const [isAuthorized, setIsAuthorized] = useState<boolean>(false)
useEffect(() => {
if (isPending || !session?.user) {
setIsAuthorized(false)
return
}
// Check if the user has provider accounts in their session
// This is a client-side check, so it may not be as accurate as the server-side check
const checkAuthorization = async () => {
try {
// We'll use an API endpoint to check authorization status
const response = await fetch(
`/api/auth/oauth/check?provider=${provider}${
requiredScopes ? `&scopes=${requiredScopes.join(',')}` : ''
}`
)
if (response.ok) {
const data = await response.json()
setIsAuthorized(data.isAuthorized || false)
} else {
setIsAuthorized(false)
}
} catch (error) {
console.error('Error checking OAuth authorization:', error)
setIsAuthorized(false)
}
}
checkAuthorization()
}, [session, isPending, provider, requiredScopes])
return {
isAuthorized,
isLoading: isPending,
isLoggedIn: !!session?.user,
}
}
/**
* Check if a tool requires OAuth and if the user has authorized it
* This function must be used in a client component
*
* @param tool - The tool configuration
* @returns Object indicating if OAuth is required and if the user has the necessary authorization
*/
export function useToolOAuthRequirement(tool: any) {
// Skip if no OAuth config or not required
if (!tool.oauth || !tool.oauth.required) {
return {
requiresAuth: false,
isAuthorized: true,
isLoading: false,
}
}
const provider = tool.oauth.provider
const additionalScopes = tool.oauth.additionalScopes || []
// Use the provider authorization hook
const { isAuthorized, isLoading, isLoggedIn } = useProviderAuthorization(
provider,
additionalScopes
)
return {
requiresAuth: true,
isAuthorized: isAuthorized,
isLoading,
isLoggedIn,
provider,
requiredScopes: additionalScopes,
}
}
/**
* Verify OAuth requirements before executing a tool
* Throws an error if OAuth is required but not authorized
* This function must be used in a client component
*
* @param toolOAuthStatus - The result from useToolOAuthRequirement
* @param tool - The tool configuration
* @throws Error with JSON stringified OAuthRequiredError
*/
export function verifyOAuthBeforeExecution(
toolOAuthStatus: ReturnType<typeof useToolOAuthRequirement>,
tool: any
): void {
const { requiresAuth, isAuthorized, isLoading, provider, requiredScopes } = toolOAuthStatus
// Don't verify while loading
if (isLoading) {
return
}
if (requiresAuth && !isAuthorized && provider) {
// Throw a structured error that the frontend can catch and handle
throw new Error(
JSON.stringify({
type: 'oauth_required',
provider,
toolId: tool.id,
toolName: tool.name,
requiredScopes,
} as OAuthRequiredError)
)
}
}
/**
* Hook for handling OAuth errors during tool execution
* Provides a modal state and error handler function
* Custom hook to handle OAuth errors during workflow execution
*/
export function useOAuthErrorHandler() {
const [oauthModalState, setOAuthModalState] = useState<{
const [modalState, setModalState] = useState<{
isOpen: boolean
provider: OAuthProvider
provider: OAuthProvider | null
toolName: string
requiredScopes?: string[]
}>({
isOpen: false,
provider: 'github',
provider: null,
toolName: '',
})
/**
* Handle an error that might be an OAuth required error
* Returns true if it was handled as an OAuth error, false otherwise
*/
const handleError = useCallback((error: any): boolean => {
if (!error) return false
const handleOAuthError = useCallback((error: any) => {
// Check if the error is an OAuth required error
try {
// Try to parse error message as JSON
let errorObj
if (typeof error === 'string' && error.includes('oauth_required')) {
const errorData: OAuthRequiredError = JSON.parse(error)
if (typeof error === 'string') {
errorObj = JSON.parse(error)
} else if (error instanceof Error && error.message) {
try {
errorObj = JSON.parse(error.message)
} catch {
return false
if (errorData.type === 'oauth_required' && errorData.provider) {
setModalState({
isOpen: true,
provider: errorData.provider,
toolName: errorData.toolName || 'this tool',
requiredScopes: errorData.requiredScopes,
})
return true
}
} else {
return false
}
} else if (
error?.message &&
typeof error.message === 'string' &&
error.message.includes('oauth_required')
) {
try {
const errorData: OAuthRequiredError = JSON.parse(error.message)
// Check if it's an OAuth required error
if (errorObj?.type === 'oauth_required') {
setOAuthModalState({
isOpen: true,
provider: errorObj.provider,
toolName: errorObj.toolName,
})
return true
if (errorData.type === 'oauth_required' && errorData.provider) {
setModalState({
isOpen: true,
provider: errorData.provider,
toolName: errorData.toolName || 'this tool',
requiredScopes: errorData.requiredScopes,
})
return true
}
} catch (parseError) {
console.error('Error parsing OAuth error message:', parseError)
}
}
} catch (parseError) {
// Not a JSON error or not an OAuth error
return false
} catch (e) {
console.error('Error handling OAuth error:', e)
}
return false
}, [])
const closeModal = useCallback(() => {
setOAuthModalState((prev) => ({ ...prev, isOpen: false }))
setModalState((prev) => ({ ...prev, isOpen: false }))
}, [])
return {
oauthModalState,
handleOAuthError: handleError,
closeOAuthModal: closeModal,
modalState,
handleOAuthError,
closeModal,
}
}