feat(oauth): add multiple credentials per provider

This commit is contained in:
Waleed Latif
2025-03-06 20:19:03 -08:00
parent 0af9d05b67
commit f244c96f9a
6 changed files with 252 additions and 104 deletions

View File

@@ -1,10 +1,16 @@
import { NextRequest, NextResponse } from 'next/server'
import { and, eq, like } from 'drizzle-orm'
import { eq } from 'drizzle-orm'
import { jwtDecode } from 'jwt-decode'
import { getSession } from '@/lib/auth'
import { db } from '@/db'
import { account } from '@/db/schema'
import { OAuthProvider } from '@/tools/types'
interface GoogleIdToken {
email?: string
sub?: string
}
// Valid OAuth providers
const VALID_PROVIDERS = ['google', 'github', 'twitter']
@@ -32,14 +38,47 @@ export async function GET(request: NextRequest) {
const [provider, featureType = 'default'] = acc.providerId.split('-')
if (provider && VALID_PROVIDERS.includes(provider)) {
connections.push({
provider: provider as OAuthProvider,
featureType,
isConnected: true,
scopes: acc.scope ? acc.scope.split(' ') : [],
lastConnected: acc.updatedAt.toISOString(),
accountId: acc.id,
})
// Get the account name (try to get email for Google accounts)
let name = acc.accountId
if (provider === 'google' && acc.idToken) {
try {
const decoded = jwtDecode<GoogleIdToken>(acc.idToken)
if (decoded.email) {
name = decoded.email
}
} catch (error) {
console.error('Error decoding ID token:', error)
}
}
// Find existing connection for this provider and feature type
const existingConnection = connections.find(
(conn) => conn.provider === provider && conn.featureType === featureType
)
if (existingConnection) {
// Add account to existing connection
existingConnection.accounts = existingConnection.accounts || []
existingConnection.accounts.push({
id: acc.id,
name,
})
} else {
// Create new connection
connections.push({
provider: provider as OAuthProvider,
featureType,
isConnected: true,
scopes: acc.scope ? acc.scope.split(' ') : [],
lastConnected: acc.updatedAt.toISOString(),
accounts: [
{
id: acc.id,
name,
},
],
})
}
}
})

View File

@@ -1,10 +1,16 @@
import { NextRequest, NextResponse } from 'next/server'
import { and, eq, like } from 'drizzle-orm'
import { jwtDecode } from 'jwt-decode'
import { getSession } from '@/lib/auth'
import { db } from '@/db'
import { account } from '@/db/schema'
import { OAuthProvider } from '@/tools/types'
interface GoogleIdToken {
email?: string
sub?: string
}
/**
* Get credentials for a specific provider
*/
@@ -33,18 +39,33 @@ export async function GET(request: NextRequest) {
.where(and(eq(account.userId, session.user.id), like(account.providerId, `${provider}-%`)))
// Transform accounts into credentials
const credentials = accounts.map((acc) => {
// Extract the feature type from providerId (e.g., 'google-default' -> 'default')
const [_, featureType = 'default'] = acc.providerId.split('-')
const credentials = await Promise.all(
accounts.map(async (acc) => {
// Extract the feature type from providerId (e.g., 'google-default' -> 'default')
const [_, featureType = 'default'] = acc.providerId.split('-')
return {
id: acc.id,
name: `${provider.charAt(0).toUpperCase() + provider.slice(1)} ${featureType !== 'default' ? featureType : ''}`.trim(),
provider,
lastUsed: acc.updatedAt.toISOString(),
isDefault: featureType === 'default',
}
})
// For Google accounts, try to get the email from the ID token
let name = acc.accountId
if (provider === 'google' && acc.idToken) {
try {
const decoded = jwtDecode<GoogleIdToken>(acc.idToken)
if (decoded.email) {
name = decoded.email
}
} catch (error) {
console.error('Error decoding ID token:', error)
}
}
return {
id: acc.id,
name,
provider,
lastUsed: acc.updatedAt.toISOString(),
isDefault: featureType === 'default',
}
})
)
return NextResponse.json({ credentials }, { status: 200 })
} catch (error) {

View File

@@ -1,6 +1,6 @@
'use client'
import { useEffect, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { Check, ChevronDown, ExternalLink, Key, RefreshCw } from 'lucide-react'
import { GoogleIcon } from '@/components/icons'
import { Button } from '@/components/ui/button'
@@ -46,47 +46,65 @@ export function CredentialSelector({
}: CredentialSelectorProps) {
const [open, setOpen] = useState(false)
const [credentials, setCredentials] = useState<Credential[]>([])
const [isLoading, setIsLoading] = useState(true)
const [isLoading, setIsLoading] = useState(false)
const [showOAuthModal, setShowOAuthModal] = useState(false)
const [selectedId, setSelectedId] = useState(value)
// Fetch available credentials for this provider
useEffect(() => {
const fetchCredentials = async () => {
setIsLoading(true)
try {
const response = await fetch(`/api/auth/oauth/credentials?provider=${provider}`)
if (response.ok) {
const data = await response.json()
setCredentials(data.credentials)
const fetchCredentials = useCallback(async () => {
if (!open) return
setIsLoading(true)
try {
const response = await fetch(`/api/auth/oauth/credentials?provider=${provider}`)
if (response.ok) {
const data = await response.json()
setCredentials(data.credentials)
// If we have a value but it's not in the credentials, reset it
if (value && !data.credentials.some((cred: Credential) => cred.id === value)) {
onChange('')
}
// If we have a value but it's not in the credentials, reset it
if (selectedId && !data.credentials.some((cred: Credential) => cred.id === selectedId)) {
setSelectedId('')
onChange('')
}
// If we have no value but have a default credential, select it
if (!value && data.credentials.length > 0) {
const defaultCred = data.credentials.find((cred: Credential) => cred.isDefault)
if (defaultCred) {
onChange(defaultCred.id)
} else if (data.credentials.length === 1) {
// If only one credential, select it
onChange(data.credentials[0].id)
}
// If we have no value but have a default credential, select it
if (!selectedId && data.credentials.length > 0) {
const defaultCred = data.credentials.find((cred: Credential) => cred.isDefault)
if (defaultCred) {
setSelectedId(defaultCred.id)
onChange(defaultCred.id)
} else if (data.credentials.length === 1) {
// If only one credential, select it
setSelectedId(data.credentials[0].id)
onChange(data.credentials[0].id)
}
}
} catch (error) {
console.error('Error fetching credentials:', error)
} finally {
setIsLoading(false)
}
} catch (error) {
console.error('Error fetching credentials:', error)
} finally {
setIsLoading(false)
}
}, [open, provider, onChange, selectedId])
// Only fetch credentials when opening the popover
useEffect(() => {
fetchCredentials()
}, [provider, onChange, value])
}, [open, fetchCredentials])
// Update local state when external value changes
useEffect(() => {
setSelectedId(value)
}, [value])
// Get the selected credential
const selectedCredential = credentials.find((cred) => cred.id === value)
const selectedCredential = credentials.find((cred) => cred.id === selectedId)
// Handle selection
const handleSelect = (credentialId: string) => {
setSelectedId(credentialId)
onChange(credentialId)
setOpen(false)
}
// Determine the appropriate service ID based on provider and scopes
const getServiceId = (): string => {
@@ -223,16 +241,13 @@ export function CredentialSelector({
<CommandItem
key={credential.id}
value={credential.id}
onSelect={() => {
onChange(credential.id)
setOpen(false)
}}
onSelect={() => handleSelect(credential.id)}
>
<div className="flex items-center gap-2">
{getProviderIcon(credential.provider)}
<span>{credential.name}</span>
</div>
{credential.id === value && <Check className="ml-auto h-4 w-4" />}
{credential.id === selectedId && <Check className="ml-auto h-4 w-4" />}
</CommandItem>
))}
</CommandGroup>

View File

@@ -2,7 +2,7 @@
import { useEffect, useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import { Check, ExternalLink, RefreshCw } from 'lucide-react'
import { Check, ExternalLink, Plus, RefreshCw } from 'lucide-react'
import { GithubIcon, GoogleDriveIcon, xIcon as XIcon } from '@/components/icons'
import { GmailIcon } from '@/components/icons'
import { Button } from '@/components/ui/button'
@@ -28,6 +28,7 @@ interface ServiceInfo {
isConnected: boolean
scopes: string[]
lastConnected?: string
accounts?: { id: string; name: string }[]
}
export function Credentials({ onOpenChange }: CredentialsProps) {
@@ -132,9 +133,10 @@ export function Credentials({ onOpenChange }: CredentialsProps) {
if (connection) {
return {
...service,
isConnected: true,
isConnected: connection.accounts?.length > 0,
scopes: connection.scopes || [],
lastConnected: connection.lastConnected,
accounts: connection.accounts || [],
}
}
@@ -219,17 +221,15 @@ export function Credentials({ onOpenChange }: CredentialsProps) {
const handleConnect = async (service: ServiceInfo) => {
setIsConnecting(service.id)
try {
// Store the current URL to return to after auth
// Store information about the required connection
saveToStorage('auth_return_url', window.location.href)
saveToStorage('pending_service_id', service.id)
// Set a flag to indicate we're in the auth flow
saveToStorage('auth_in_progress', true)
saveToStorage('pending_oauth_provider_id', service.providerId)
// Begin OAuth flow with the appropriate provider
await client.signIn.oauth2({
providerId: service.providerId,
callbackURL: window.location.href, // Return to the current page after auth
callbackURL: window.location.href,
})
} catch (error) {
console.error('OAuth login error:', error)
@@ -237,8 +237,8 @@ export function Credentials({ onOpenChange }: CredentialsProps) {
}
}
const handleDisconnect = async (service: ServiceInfo) => {
setIsConnecting(service.id)
const handleDisconnect = async (service: ServiceInfo, accountId: string) => {
setIsConnecting(`${service.id}-${accountId}`)
try {
// Call your API to disconnect the provider
const response = await fetch('/api/auth/oauth/disconnect', {
@@ -249,15 +249,23 @@ export function Credentials({ onOpenChange }: CredentialsProps) {
body: JSON.stringify({
provider: service.provider,
providerId: service.providerId,
accountId,
}),
})
if (response.ok) {
// Update the local state
// Update the local state by removing the disconnected account
setServices((prev) =>
prev.map((svc) =>
svc.id === service.id ? { ...svc, isConnected: false, scopes: [] } : svc
)
prev.map((svc) => {
if (svc.id === service.id) {
return {
...svc,
accounts: svc.accounts?.filter((acc) => acc.id !== accountId) || [],
isConnected: (svc.accounts?.length || 0) > 1,
}
}
return svc
})
)
}
} catch (error) {
@@ -289,7 +297,7 @@ export function Credentials({ onOpenChange }: CredentialsProps) {
</div>
)}
<div className="space-y-6">
<div className="space-y-4">
{isLoading ? (
<>
<ConnectionSkeleton />
@@ -302,39 +310,91 @@ export function Credentials({ onOpenChange }: CredentialsProps) {
<Card
key={service.id}
className={cn(
'p-5 flex items-center justify-between',
pendingService === service.id && 'border-primary'
'p-6 transition-all hover:shadow-md',
pendingService === service.id && 'border-primary shadow-md'
)}
>
<div className="flex items-center gap-4">
<div className="flex h-10 w-10 items-center justify-center rounded-full bg-muted">
{service.icon}
<div className="flex items-start justify-between gap-4">
<div className="flex items-start gap-4">
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-muted shrink-0">
{service.icon}
</div>
<div className="space-y-1">
<div>
<h4 className="font-medium leading-none">{service.name}</h4>
<p className="text-sm text-muted-foreground mt-1">{service.description}</p>
</div>
{service.accounts && service.accounts.length > 0 && (
<div className="pt-3 space-y-2">
{service.accounts.map((account) => (
<div
key={account.id}
className="flex items-center justify-between gap-2 rounded-md border bg-card/50 p-2"
>
<div className="flex items-center gap-2">
<div className="h-6 w-6 rounded-full bg-green-500/10 flex items-center justify-center">
<Check className="h-3 w-3 text-green-600" />
</div>
<span className="text-sm font-medium">{account.name}</span>
</div>
<Button
variant="ghost"
size="sm"
onClick={() => handleDisconnect(service, account.id)}
disabled={isConnecting === `${service.id}-${account.id}`}
className="h-7 px-2"
>
{isConnecting === `${service.id}-${account.id}` ? (
<RefreshCw className="h-3 w-3 animate-spin" />
) : (
'Disconnect'
)}
</Button>
</div>
))}
<Button
variant="outline"
size="sm"
className="w-full mt-2"
onClick={() => handleConnect(service)}
disabled={isConnecting === service.id}
>
{isConnecting === service.id ? (
<>
<RefreshCw className="h-3 w-3 animate-spin mr-2" />
Connecting...
</>
) : (
<>
<Plus className="h-3 w-3 mr-2" />
Connect Another Account
</>
)}
</Button>
</div>
)}
</div>
</div>
<div>
<h4 className="font-medium">{service.name}</h4>
<p className="text-sm text-muted-foreground">{service.description}</p>
{service.isConnected && (
<p className="text-xs flex items-center gap-1 mt-1 text-green-600">
<Check className="h-3 w-3" />
Connected
</p>
)}
</div>
</div>
<Button
variant={service.isConnected ? 'outline' : 'default'}
size="sm"
onClick={() =>
service.isConnected ? handleDisconnect(service) : handleConnect(service)
}
disabled={isConnecting === service.id}
>
{isConnecting === service.id ? (
<RefreshCw className="h-4 w-4 animate-spin mr-2" />
) : null}
{service.isConnected ? 'Disconnect' : 'Connect'}
</Button>
{!service.accounts?.length && (
<Button
variant="default"
size="sm"
onClick={() => handleConnect(service)}
disabled={isConnecting === service.id}
className="shrink-0"
>
{isConnecting === service.id ? (
<>
<RefreshCw className="h-4 w-4 animate-spin mr-2" />
Connecting...
</>
) : (
'Connect'
)}
</Button>
)}
</div>
</Card>
))
)}
@@ -345,15 +405,17 @@ export function Credentials({ onOpenChange }: CredentialsProps) {
function ConnectionSkeleton() {
return (
<Card className="p-5 flex items-center justify-between">
<div className="flex items-center gap-4">
<Skeleton className="h-10 w-10 rounded-full" />
<div>
<Skeleton className="h-5 w-32 mb-2" />
<Skeleton className="h-4 w-48" />
<Card className="p-6">
<div className="flex items-start justify-between gap-4">
<div className="flex items-start gap-4">
<Skeleton className="h-12 w-12 rounded-lg" />
<div className="space-y-2">
<Skeleton className="h-5 w-32" />
<Skeleton className="h-4 w-48" />
</div>
</div>
<Skeleton className="h-9 w-24 shrink-0" />
</div>
<Skeleton className="h-9 w-24" />
</Card>
)
}

10
package-lock.json generated
View File

@@ -40,6 +40,7 @@
"date-fns": "^3.6.0",
"drizzle-orm": "^0.39.3",
"groq-sdk": "^0.15.0",
"jwt-decode": "^4.0.0",
"lodash.debounce": "^4.0.8",
"lucide-react": "^0.469.0",
"next": "^15.2.0",
@@ -9248,6 +9249,15 @@
"node": ">=6"
}
},
"node_modules/jwt-decode": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/jwt-decode/-/jwt-decode-4.0.0.tgz",
"integrity": "sha512-+KJGIyHgkGuIq3IEBNftfhW/LfWhXUIY6OmyVWjliu5KH1y0fw7VQ8YndE2O4qZdMSd9SqbnC8GOcZEy0Om7sA==",
"license": "MIT",
"engines": {
"node": ">=18"
}
},
"node_modules/keyv": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/keyv/-/keyv-3.1.0.tgz",

View File

@@ -52,6 +52,7 @@
"date-fns": "^3.6.0",
"drizzle-orm": "^0.39.3",
"groq-sdk": "^0.15.0",
"jwt-decode": "^4.0.0",
"lodash.debounce": "^4.0.8",
"lucide-react": "^0.469.0",
"next": "^15.2.0",