feat(registration): allow self-hosted users to disable registration altogether (#2365)

* feat(registration): allow self-hosted users to disable registration altogether

* updated tests

* fix build
This commit is contained in:
Waleed
2025-12-13 17:34:53 -08:00
committed by GitHub
parent 746ff68a2e
commit 95b9ca4670
66 changed files with 332 additions and 154 deletions

View File

@@ -4,10 +4,13 @@ DATABASE_URL="postgresql://postgres:password@localhost:5432/postgres"
# PostgreSQL Port (Optional) - defaults to 5432 if not specified # PostgreSQL Port (Optional) - defaults to 5432 if not specified
# POSTGRES_PORT=5432 # POSTGRES_PORT=5432
# Authentication (Required) # Authentication (Required unless DISABLE_AUTH=true)
BETTER_AUTH_SECRET=your_secret_key # Use `openssl rand -hex 32` to generate, or visit https://www.better-auth.com/docs/installation BETTER_AUTH_SECRET=your_secret_key # Use `openssl rand -hex 32` to generate, or visit https://www.better-auth.com/docs/installation
BETTER_AUTH_URL=http://localhost:3000 BETTER_AUTH_URL=http://localhost:3000
# Authentication Bypass (Optional - for self-hosted deployments behind private networks)
# DISABLE_AUTH=true # Uncomment to bypass authentication entirely. Creates an anonymous session for all requests.
# NextJS (Required) # NextJS (Required)
NEXT_PUBLIC_APP_URL=http://localhost:3000 NEXT_PUBLIC_APP_URL=http://localhost:3000

View File

@@ -1,7 +1,7 @@
'use server' 'use server'
import { env } from '@/lib/core/config/env' import { env } from '@/lib/core/config/env'
import { isProd } from '@/lib/core/config/environment' import { isProd } from '@/lib/core/config/feature-flags'
export async function getOAuthProviderStatus() { export async function getOAuthProviderStatus() {
const githubAvailable = !!(env.GITHUB_CLIENT_ID && env.GITHUB_CLIENT_SECRET) const githubAvailable = !!(env.GITHUB_CLIENT_ID && env.GITHUB_CLIENT_SECRET)

View File

@@ -1,7 +1,6 @@
import { getOAuthProviderStatus } from '@/app/(auth)/components/oauth-provider-checker' import { getOAuthProviderStatus } from '@/app/(auth)/components/oauth-provider-checker'
import LoginForm from '@/app/(auth)/login/login-form' import LoginForm from '@/app/(auth)/login/login-form'
// Force dynamic rendering to avoid prerender errors with search params
export const dynamic = 'force-dynamic' export const dynamic = 'force-dynamic'
export default async function LoginPage() { export default async function LoginPage() {

View File

@@ -1,16 +1,16 @@
import { env, isTruthy } from '@/lib/core/config/env' import { isRegistrationDisabled } from '@/lib/core/config/feature-flags'
import { getOAuthProviderStatus } from '@/app/(auth)/components/oauth-provider-checker' import { getOAuthProviderStatus } from '@/app/(auth)/components/oauth-provider-checker'
import SignupForm from '@/app/(auth)/signup/signup-form' import SignupForm from '@/app/(auth)/signup/signup-form'
export const dynamic = 'force-dynamic' export const dynamic = 'force-dynamic'
export default async function SignupPage() { export default async function SignupPage() {
const { githubAvailable, googleAvailable, isProduction } = await getOAuthProviderStatus() if (isRegistrationDisabled) {
if (isTruthy(env.DISABLE_REGISTRATION)) {
return <div>Registration is disabled, please contact your admin.</div> return <div>Registration is disabled, please contact your admin.</div>
} }
const { githubAvailable, googleAvailable, isProduction } = await getOAuthProviderStatus()
return ( return (
<SignupForm <SignupForm
githubAvailable={githubAvailable} githubAvailable={githubAvailable}

View File

@@ -1,4 +1,4 @@
import { isEmailVerificationEnabled, isProd } from '@/lib/core/config/environment' import { isEmailVerificationEnabled, isProd } from '@/lib/core/config/feature-flags'
import { hasEmailService } from '@/lib/messaging/email/mailer' import { hasEmailService } from '@/lib/messaging/email/mailer'
import { VerifyContent } from '@/app/(auth)/verify/verify-content' import { VerifyContent } from '@/app/(auth)/verify/verify-content'

View File

@@ -13,7 +13,7 @@ import {
SelectValue, SelectValue,
} from '@/components/ui/select' } from '@/components/ui/select'
import { Textarea } from '@/components/ui/textarea' import { Textarea } from '@/components/ui/textarea'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { cn } from '@/lib/core/utils/cn' import { cn } from '@/lib/core/utils/cn'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { quickValidateEmail } from '@/lib/messaging/email/validation' import { quickValidateEmail } from '@/lib/messaging/email/validation'

View File

@@ -1,6 +1,6 @@
'use client' 'use client'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { soehne } from '@/app/_styles/fonts/soehne/soehne' import { soehne } from '@/app/_styles/fonts/soehne/soehne'
import Footer from '@/app/(landing)/components/footer/footer' import Footer from '@/app/(landing)/components/footer/footer'
import Nav from '@/app/(landing)/components/nav/nav' import Nav from '@/app/(landing)/components/nav/nav'

View File

@@ -7,7 +7,7 @@ import Link from 'next/link'
import { useRouter } from 'next/navigation' import { useRouter } from 'next/navigation'
import { GithubIcon } from '@/components/icons' import { GithubIcon } from '@/components/icons'
import { useBrandConfig } from '@/lib/branding/branding' import { useBrandConfig } from '@/lib/branding/branding'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { soehne } from '@/app/_styles/fonts/soehne/soehne' import { soehne } from '@/app/_styles/fonts/soehne/soehne'
import { getFormattedGitHubStars } from '@/app/(landing)/actions/github' import { getFormattedGitHubStars } from '@/app/(landing)/actions/github'

View File

@@ -1,6 +1,23 @@
import { toNextJsHandler } from 'better-auth/next-js' import { toNextJsHandler } from 'better-auth/next-js'
import { type NextRequest, NextResponse } from 'next/server'
import { auth } from '@/lib/auth' import { auth } from '@/lib/auth'
import { createAnonymousSession, ensureAnonymousUserExists } from '@/lib/auth/anonymous'
import { isAuthDisabled } from '@/lib/core/config/feature-flags'
export const dynamic = 'force-dynamic' export const dynamic = 'force-dynamic'
export const { GET, POST } = toNextJsHandler(auth.handler) const { GET: betterAuthGET, POST: betterAuthPOST } = toNextJsHandler(auth.handler)
export async function GET(request: NextRequest) {
const url = new URL(request.url)
const path = url.pathname.replace('/api/auth/', '')
if (path === 'get-session' && isAuthDisabled) {
await ensureAnonymousUserExists()
return NextResponse.json(createAnonymousSession())
}
return betterAuthGET(request)
}
export const POST = betterAuthPOST

View File

@@ -1,9 +1,14 @@
import { headers } from 'next/headers' import { headers } from 'next/headers'
import { NextResponse } from 'next/server' import { NextResponse } from 'next/server'
import { auth } from '@/lib/auth' import { auth } from '@/lib/auth'
import { isAuthDisabled } from '@/lib/core/config/feature-flags'
export async function POST() { export async function POST() {
try { try {
if (isAuthDisabled) {
return NextResponse.json({ token: 'anonymous-socket-token' })
}
const hdrs = await headers() const hdrs = await headers()
const response = await auth.api.generateOneTimeToken({ const response = await auth.api.generateOneTimeToken({
headers: hdrs, headers: hdrs,

View File

@@ -1,14 +1,14 @@
import { db, ssoProvider } from '@sim/db' import { db, ssoProvider } from '@sim/db'
import { eq } from 'drizzle-orm' import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server' import { NextResponse } from 'next/server'
import { auth } from '@/lib/auth' import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('SSO-Providers') const logger = createLogger('SSO-Providers')
export async function GET(req: NextRequest) { export async function GET() {
try { try {
const session = await auth.api.getSession({ headers: req.headers }) const session = await getSession()
let providers let providers
if (session?.user?.id) { if (session?.user?.id) {
@@ -38,8 +38,6 @@ export async function GET(req: NextRequest) {
: ('oidc' as 'oidc' | 'saml'), : ('oidc' as 'oidc' | 'saml'),
})) }))
} else { } else {
// Unauthenticated users can only see basic info (domain only)
// This is needed for SSO login flow to check if a domain has SSO enabled
const results = await db const results = await db
.select({ .select({
domain: ssoProvider.domain, domain: ssoProvider.domain,

View File

@@ -5,7 +5,7 @@ import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod' import { z } from 'zod'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing' import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { checkInternalApiKey } from '@/lib/copilot/utils' import { checkInternalApiKey } from '@/lib/copilot/utils'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { generateRequestId } from '@/lib/core/utils/request' import { generateRequestId } from '@/lib/core/utils/request'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'

View File

@@ -6,6 +6,12 @@ import { NextRequest } from 'next/server'
*/ */
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('@/lib/core/config/feature-flags', () => ({
isDev: true,
isHosted: false,
isProd: false,
}))
describe('Chat Edit API Route', () => { describe('Chat Edit API Route', () => {
const mockSelect = vi.fn() const mockSelect = vi.fn()
const mockFrom = vi.fn() const mockFrom = vi.fn()
@@ -24,7 +30,6 @@ describe('Chat Edit API Route', () => {
beforeEach(() => { beforeEach(() => {
vi.resetModules() vi.resetModules()
// Set default return values
mockLimit.mockResolvedValue([]) mockLimit.mockResolvedValue([])
mockSelect.mockReturnValue({ from: mockFrom }) mockSelect.mockReturnValue({ from: mockFrom })
mockFrom.mockReturnValue({ where: mockWhere }) mockFrom.mockReturnValue({ where: mockWhere })
@@ -77,10 +82,6 @@ describe('Chat Edit API Route', () => {
getEmailDomain: vi.fn().mockReturnValue('localhost:3000'), getEmailDomain: vi.fn().mockReturnValue('localhost:3000'),
})) }))
vi.doMock('@/lib/core/config/environment', () => ({
isDev: true,
}))
vi.doMock('@/app/api/chat/utils', () => ({ vi.doMock('@/app/api/chat/utils', () => ({
checkChatAccess: mockCheckChatAccess, checkChatAccess: mockCheckChatAccess,
})) }))
@@ -254,7 +255,6 @@ describe('Chat Edit API Route', () => {
mockCheckChatAccess.mockResolvedValue({ hasAccess: true, chat: mockChat }) mockCheckChatAccess.mockResolvedValue({ hasAccess: true, chat: mockChat })
// Reset and reconfigure mockLimit to return the conflict
mockLimit.mockReset() mockLimit.mockReset()
mockLimit.mockResolvedValue([{ id: 'other-chat-id', identifier: 'new-identifier' }]) mockLimit.mockResolvedValue([{ id: 'other-chat-id', identifier: 'new-identifier' }])
mockWhere.mockReturnValue({ limit: mockLimit }) mockWhere.mockReturnValue({ limit: mockLimit })
@@ -291,7 +291,7 @@ describe('Chat Edit API Route', () => {
const req = new NextRequest('http://localhost:3000/api/chat/manage/chat-123', { const req = new NextRequest('http://localhost:3000/api/chat/manage/chat-123', {
method: 'PATCH', method: 'PATCH',
body: JSON.stringify({ authType: 'password' }), // No password provided body: JSON.stringify({ authType: 'password' }),
}) })
const { PATCH } = await import('@/app/api/chat/manage/[id]/route') const { PATCH } = await import('@/app/api/chat/manage/[id]/route')
const response = await PATCH(req, { params: Promise.resolve({ id: 'chat-123' }) }) const response = await PATCH(req, { params: Promise.resolve({ id: 'chat-123' }) })
@@ -316,9 +316,8 @@ describe('Chat Edit API Route', () => {
workflowId: 'workflow-123', workflowId: 'workflow-123',
} }
// User doesn't own chat but has workspace admin access
mockCheckChatAccess.mockResolvedValue({ hasAccess: true, chat: mockChat }) mockCheckChatAccess.mockResolvedValue({ hasAccess: true, chat: mockChat })
mockLimit.mockResolvedValueOnce([]) // No identifier conflict mockLimit.mockResolvedValueOnce([])
const req = new NextRequest('http://localhost:3000/api/chat/manage/chat-123', { const req = new NextRequest('http://localhost:3000/api/chat/manage/chat-123', {
method: 'PATCH', method: 'PATCH',
@@ -399,7 +398,6 @@ describe('Chat Edit API Route', () => {
}), }),
})) }))
// User doesn't own chat but has workspace admin access
mockCheckChatAccess.mockResolvedValue({ hasAccess: true }) mockCheckChatAccess.mockResolvedValue({ hasAccess: true })
mockWhere.mockResolvedValue(undefined) mockWhere.mockResolvedValue(undefined)

View File

@@ -4,7 +4,7 @@ import { eq } from 'drizzle-orm'
import type { NextRequest } from 'next/server' import type { NextRequest } from 'next/server'
import { z } from 'zod' import { z } from 'zod'
import { getSession } from '@/lib/auth' import { getSession } from '@/lib/auth'
import { isDev } from '@/lib/core/config/environment' import { isDev } from '@/lib/core/config/feature-flags'
import { encryptSecret } from '@/lib/core/security/encryption' import { encryptSecret } from '@/lib/core/security/encryption'
import { getEmailDomain } from '@/lib/core/utils/urls' import { getEmailDomain } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'

View File

@@ -5,7 +5,7 @@ import type { NextRequest } from 'next/server'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
import { z } from 'zod' import { z } from 'zod'
import { getSession } from '@/lib/auth' import { getSession } from '@/lib/auth'
import { isDev } from '@/lib/core/config/environment' import { isDev } from '@/lib/core/config/feature-flags'
import { encryptSecret } from '@/lib/core/security/encryption' import { encryptSecret } from '@/lib/core/security/encryption'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'

View File

@@ -44,6 +44,12 @@ vi.mock('@/lib/core/utils/request', () => ({
generateRequestId: vi.fn(), generateRequestId: vi.fn(),
})) }))
vi.mock('@/lib/core/config/feature-flags', () => ({
isDev: true,
isHosted: false,
isProd: false,
}))
describe('Chat API Utils', () => { describe('Chat API Utils', () => {
beforeEach(() => { beforeEach(() => {
vi.doMock('@/lib/logs/console/logger', () => ({ vi.doMock('@/lib/logs/console/logger', () => ({
@@ -62,11 +68,6 @@ describe('Chat API Utils', () => {
NODE_ENV: 'development', NODE_ENV: 'development',
}, },
}) })
vi.doMock('@/lib/core/config/environment', () => ({
isDev: true,
isHosted: false,
}))
}) })
afterEach(() => { afterEach(() => {

View File

@@ -3,7 +3,7 @@ import { db } from '@sim/db'
import { chat, workflow } from '@sim/db/schema' import { chat, workflow } from '@sim/db/schema'
import { eq } from 'drizzle-orm' import { eq } from 'drizzle-orm'
import type { NextRequest, NextResponse } from 'next/server' import type { NextRequest, NextResponse } from 'next/server'
import { isDev } from '@/lib/core/config/environment' import { isDev } from '@/lib/core/config/feature-flags'
import { decryptSecret } from '@/lib/core/security/encryption' import { decryptSecret } from '@/lib/core/security/encryption'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { hasAdminPermission } from '@/lib/workspaces/permissions/utils' import { hasAdminPermission } from '@/lib/workspaces/permissions/utils'
@@ -282,8 +282,8 @@ export async function validateChatAuth(
return { authorized: false, error: 'Email not authorized for SSO access' } return { authorized: false, error: 'Email not authorized for SSO access' }
} }
const { auth } = await import('@/lib/auth') const { getSession } = await import('@/lib/auth')
const session = await auth.api.getSession({ headers: request.headers }) const session = await getSession()
if (!session || !session.user) { if (!session || !session.user) {
return { authorized: false, error: 'auth_required_sso' } return { authorized: false, error: 'auth_required_sso' }

View File

@@ -2,7 +2,7 @@ import { db } from '@sim/db'
import { settings } from '@sim/db/schema' import { settings } from '@sim/db/schema'
import { eq } from 'drizzle-orm' import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server' import { type NextRequest, NextResponse } from 'next/server'
import { auth } from '@/lib/auth' import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('CopilotAutoAllowedToolsAPI') const logger = createLogger('CopilotAutoAllowedToolsAPI')
@@ -10,9 +10,9 @@ const logger = createLogger('CopilotAutoAllowedToolsAPI')
/** /**
* GET - Fetch user's auto-allowed integration tools * GET - Fetch user's auto-allowed integration tools
*/ */
export async function GET(request: NextRequest) { export async function GET() {
try { try {
const session = await auth.api.getSession({ headers: request.headers }) const session = await getSession()
if (!session?.user?.id) { if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
@@ -31,7 +31,6 @@ export async function GET(request: NextRequest) {
return NextResponse.json({ autoAllowedTools }) return NextResponse.json({ autoAllowedTools })
} }
// If no settings record exists, create one with empty array
await db.insert(settings).values({ await db.insert(settings).values({
id: userId, id: userId,
userId, userId,
@@ -50,7 +49,7 @@ export async function GET(request: NextRequest) {
*/ */
export async function POST(request: NextRequest) { export async function POST(request: NextRequest) {
try { try {
const session = await auth.api.getSession({ headers: request.headers }) const session = await getSession()
if (!session?.user?.id) { if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
@@ -65,13 +64,11 @@ export async function POST(request: NextRequest) {
const toolId = body.toolId const toolId = body.toolId
// Get existing settings
const [existing] = await db.select().from(settings).where(eq(settings.userId, userId)).limit(1) const [existing] = await db.select().from(settings).where(eq(settings.userId, userId)).limit(1)
if (existing) { if (existing) {
const currentTools = (existing.copilotAutoAllowedTools as string[]) || [] const currentTools = (existing.copilotAutoAllowedTools as string[]) || []
// Add tool if not already present
if (!currentTools.includes(toolId)) { if (!currentTools.includes(toolId)) {
const updatedTools = [...currentTools, toolId] const updatedTools = [...currentTools, toolId]
await db await db
@@ -89,7 +86,6 @@ export async function POST(request: NextRequest) {
return NextResponse.json({ success: true, autoAllowedTools: currentTools }) return NextResponse.json({ success: true, autoAllowedTools: currentTools })
} }
// Create new settings record with the tool
await db.insert(settings).values({ await db.insert(settings).values({
id: userId, id: userId,
userId, userId,
@@ -109,7 +105,7 @@ export async function POST(request: NextRequest) {
*/ */
export async function DELETE(request: NextRequest) { export async function DELETE(request: NextRequest) {
try { try {
const session = await auth.api.getSession({ headers: request.headers }) const session = await getSession()
if (!session?.user?.id) { if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
@@ -123,7 +119,6 @@ export async function DELETE(request: NextRequest) {
return NextResponse.json({ error: 'toolId query parameter is required' }, { status: 400 }) return NextResponse.json({ error: 'toolId query parameter is required' }, { status: 400 })
} }
// Get existing settings
const [existing] = await db.select().from(settings).where(eq(settings.userId, userId)).limit(1) const [existing] = await db.select().from(settings).where(eq(settings.userId, userId)).limit(1)
if (existing) { if (existing) {

View File

@@ -1,6 +1,6 @@
import { eq } from 'drizzle-orm' import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server' import { type NextRequest, NextResponse } from 'next/server'
import { auth } from '@/lib/auth' import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/../../packages/db' import { db } from '@/../../packages/db'
import { settings } from '@/../../packages/db/schema' import { settings } from '@/../../packages/db/schema'
@@ -32,7 +32,7 @@ const DEFAULT_ENABLED_MODELS: Record<string, boolean> = {
// GET - Fetch user's enabled models // GET - Fetch user's enabled models
export async function GET(request: NextRequest) { export async function GET(request: NextRequest) {
try { try {
const session = await auth.api.getSession({ headers: request.headers }) const session = await getSession()
if (!session?.user?.id) { if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
@@ -40,7 +40,6 @@ export async function GET(request: NextRequest) {
const userId = session.user.id const userId = session.user.id
// Try to fetch existing settings record
const [userSettings] = await db const [userSettings] = await db
.select() .select()
.from(settings) .from(settings)
@@ -50,13 +49,11 @@ export async function GET(request: NextRequest) {
if (userSettings) { if (userSettings) {
const userModelsMap = (userSettings.copilotEnabledModels as Record<string, boolean>) || {} const userModelsMap = (userSettings.copilotEnabledModels as Record<string, boolean>) || {}
// Merge: start with defaults, then override with user's existing preferences
const mergedModels = { ...DEFAULT_ENABLED_MODELS } const mergedModels = { ...DEFAULT_ENABLED_MODELS }
for (const [modelId, enabled] of Object.entries(userModelsMap)) { for (const [modelId, enabled] of Object.entries(userModelsMap)) {
mergedModels[modelId] = enabled mergedModels[modelId] = enabled
} }
// If we added any new models, update the database
const hasNewModels = Object.keys(DEFAULT_ENABLED_MODELS).some( const hasNewModels = Object.keys(DEFAULT_ENABLED_MODELS).some(
(key) => !(key in userModelsMap) (key) => !(key in userModelsMap)
) )
@@ -76,7 +73,6 @@ export async function GET(request: NextRequest) {
}) })
} }
// If no settings record exists, create one with defaults
await db.insert(settings).values({ await db.insert(settings).values({
id: userId, id: userId,
userId, userId,
@@ -97,7 +93,7 @@ export async function GET(request: NextRequest) {
// PUT - Update user's enabled models // PUT - Update user's enabled models
export async function PUT(request: NextRequest) { export async function PUT(request: NextRequest) {
try { try {
const session = await auth.api.getSession({ headers: request.headers }) const session = await getSession()
if (!session?.user?.id) { if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
@@ -110,11 +106,9 @@ export async function PUT(request: NextRequest) {
return NextResponse.json({ error: 'enabledModels must be an object' }, { status: 400 }) return NextResponse.json({ error: 'enabledModels must be an object' }, { status: 400 })
} }
// Check if settings record exists
const [existing] = await db.select().from(settings).where(eq(settings.userId, userId)).limit(1) const [existing] = await db.select().from(settings).where(eq(settings.userId, userId)).limit(1)
if (existing) { if (existing) {
// Update existing record
await db await db
.update(settings) .update(settings)
.set({ .set({
@@ -123,7 +117,6 @@ export async function PUT(request: NextRequest) {
}) })
.where(eq(settings.userId, userId)) .where(eq(settings.userId, userId))
} else { } else {
// Create new settings record
await db.insert(settings).values({ await db.insert(settings).values({
id: userId, id: userId,
userId, userId,

View File

@@ -1,6 +1,6 @@
import { createContext, Script } from 'vm' import { createContext, Script } from 'vm'
import { type NextRequest, NextResponse } from 'next/server' import { type NextRequest, NextResponse } from 'next/server'
import { env, isTruthy } from '@/lib/core/config/env' import { isE2bEnabled } from '@/lib/core/config/feature-flags'
import { validateProxyUrl } from '@/lib/core/security/input-validation' import { validateProxyUrl } from '@/lib/core/security/input-validation'
import { generateRequestId } from '@/lib/core/utils/request' import { generateRequestId } from '@/lib/core/utils/request'
import { executeInE2B } from '@/lib/execution/e2b' import { executeInE2B } from '@/lib/execution/e2b'
@@ -701,7 +701,6 @@ export async function POST(req: NextRequest) {
resolvedCode = codeResolution.resolvedCode resolvedCode = codeResolution.resolvedCode
const contextVariables = codeResolution.contextVariables const contextVariables = codeResolution.contextVariables
const e2bEnabled = isTruthy(env.E2B_ENABLED)
const lang = isValidCodeLanguage(language) ? language : DEFAULT_CODE_LANGUAGE const lang = isValidCodeLanguage(language) ? language : DEFAULT_CODE_LANGUAGE
// Extract imports once for JavaScript code (reuse later to avoid double extraction) // Extract imports once for JavaScript code (reuse later to avoid double extraction)
@@ -722,14 +721,14 @@ export async function POST(req: NextRequest) {
} }
// Python always requires E2B // Python always requires E2B
if (lang === CodeLanguage.Python && !e2bEnabled) { if (lang === CodeLanguage.Python && !isE2bEnabled) {
throw new Error( throw new Error(
'Python execution requires E2B to be enabled. Please contact your administrator to enable E2B, or use JavaScript instead.' 'Python execution requires E2B to be enabled. Please contact your administrator to enable E2B, or use JavaScript instead.'
) )
} }
// JavaScript with imports requires E2B // JavaScript with imports requires E2B
if (lang === CodeLanguage.JavaScript && hasImports && !e2bEnabled) { if (lang === CodeLanguage.JavaScript && hasImports && !isE2bEnabled) {
throw new Error( throw new Error(
'JavaScript code with import statements requires E2B to be enabled. Please remove the import statements, or contact your administrator to enable E2B.' 'JavaScript code with import statements requires E2B to be enabled. Please remove the import statements, or contact your administrator to enable E2B.'
) )
@@ -740,13 +739,13 @@ export async function POST(req: NextRequest) {
// - Not a custom tool AND // - Not a custom tool AND
// - (Python OR JavaScript with imports) // - (Python OR JavaScript with imports)
const useE2B = const useE2B =
e2bEnabled && isE2bEnabled &&
!isCustomTool && !isCustomTool &&
(lang === CodeLanguage.Python || (lang === CodeLanguage.JavaScript && hasImports)) (lang === CodeLanguage.Python || (lang === CodeLanguage.JavaScript && hasImports))
if (useE2B) { if (useE2B) {
logger.info(`[${requestId}] E2B status`, { logger.info(`[${requestId}] E2B status`, {
enabled: e2bEnabled, enabled: isE2bEnabled,
hasApiKey: Boolean(process.env.E2B_API_KEY), hasApiKey: Boolean(process.env.E2B_API_KEY),
language: lang, language: lang,
}) })

View File

@@ -6,7 +6,7 @@ import { z } from 'zod'
import { getSession } from '@/lib/auth' import { getSession } from '@/lib/auth'
import { getPlanPricing } from '@/lib/billing/core/billing' import { getPlanPricing } from '@/lib/billing/core/billing'
import { requireStripeClient } from '@/lib/billing/stripe-client' import { requireStripeClient } from '@/lib/billing/stripe-client'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('OrganizationSeatsAPI') const logger = createLogger('OrganizationSeatsAPI')

View File

@@ -3,7 +3,7 @@ import { NextResponse } from 'next/server'
import { z } from 'zod' import { z } from 'zod'
import { checkHybridAuth } from '@/lib/auth/hybrid' import { checkHybridAuth } from '@/lib/auth/hybrid'
import { generateInternalToken } from '@/lib/auth/internal' import { generateInternalToken } from '@/lib/auth/internal'
import { isDev } from '@/lib/core/config/environment' import { isDev } from '@/lib/core/config/feature-flags'
import { createPinnedUrl, validateUrlWithDNS } from '@/lib/core/security/input-validation' import { createPinnedUrl, validateUrlWithDNS } from '@/lib/core/security/input-validation'
import { generateRequestId } from '@/lib/core/utils/request' import { generateRequestId } from '@/lib/core/utils/request'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'

View File

@@ -42,11 +42,11 @@ describe('Scheduled Workflow Execution API Route', () => {
executeScheduleJob: mockExecuteScheduleJob, executeScheduleJob: mockExecuteScheduleJob,
})) }))
vi.doMock('@/lib/core/config/env', () => ({ vi.doMock('@/lib/core/config/feature-flags', () => ({
env: { isTriggerDevEnabled: false,
TRIGGER_DEV_ENABLED: false, isHosted: false,
}, isProd: false,
isTruthy: vi.fn(() => false), isDev: true,
})) }))
vi.doMock('drizzle-orm', () => ({ vi.doMock('drizzle-orm', () => ({
@@ -119,11 +119,11 @@ describe('Scheduled Workflow Execution API Route', () => {
}, },
})) }))
vi.doMock('@/lib/core/config/env', () => ({ vi.doMock('@/lib/core/config/feature-flags', () => ({
env: { isTriggerDevEnabled: true,
TRIGGER_DEV_ENABLED: true, isHosted: false,
}, isProd: false,
isTruthy: vi.fn(() => true), isDev: true,
})) }))
vi.doMock('drizzle-orm', () => ({ vi.doMock('drizzle-orm', () => ({
@@ -191,11 +191,11 @@ describe('Scheduled Workflow Execution API Route', () => {
executeScheduleJob: vi.fn().mockResolvedValue(undefined), executeScheduleJob: vi.fn().mockResolvedValue(undefined),
})) }))
vi.doMock('@/lib/core/config/env', () => ({ vi.doMock('@/lib/core/config/feature-flags', () => ({
env: { isTriggerDevEnabled: false,
TRIGGER_DEV_ENABLED: false, isHosted: false,
}, isProd: false,
isTruthy: vi.fn(() => false), isDev: true,
})) }))
vi.doMock('drizzle-orm', () => ({ vi.doMock('drizzle-orm', () => ({
@@ -250,11 +250,11 @@ describe('Scheduled Workflow Execution API Route', () => {
executeScheduleJob: vi.fn().mockResolvedValue(undefined), executeScheduleJob: vi.fn().mockResolvedValue(undefined),
})) }))
vi.doMock('@/lib/core/config/env', () => ({ vi.doMock('@/lib/core/config/feature-flags', () => ({
env: { isTriggerDevEnabled: false,
TRIGGER_DEV_ENABLED: false, isHosted: false,
}, isProd: false,
isTruthy: vi.fn(() => false), isDev: true,
})) }))
vi.doMock('drizzle-orm', () => ({ vi.doMock('drizzle-orm', () => ({

View File

@@ -3,7 +3,7 @@ import { tasks } from '@trigger.dev/sdk'
import { and, eq, isNull, lt, lte, not, or } from 'drizzle-orm' import { and, eq, isNull, lt, lte, not, or } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server' import { type NextRequest, NextResponse } from 'next/server'
import { verifyCronAuth } from '@/lib/auth/internal' import { verifyCronAuth } from '@/lib/auth/internal'
import { env, isTruthy } from '@/lib/core/config/env' import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { generateRequestId } from '@/lib/core/utils/request' import { generateRequestId } from '@/lib/core/utils/request'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { executeScheduleJob } from '@/background/schedule-execution' import { executeScheduleJob } from '@/background/schedule-execution'
@@ -54,9 +54,7 @@ export async function GET(request: NextRequest) {
logger.debug(`[${requestId}] Successfully queried schedules: ${dueSchedules.length} found`) logger.debug(`[${requestId}] Successfully queried schedules: ${dueSchedules.length} found`)
logger.info(`[${requestId}] Processing ${dueSchedules.length} due scheduled workflows`) logger.info(`[${requestId}] Processing ${dueSchedules.length} due scheduled workflows`)
const useTrigger = isTruthy(env.TRIGGER_DEV_ENABLED) if (isTriggerDevEnabled) {
if (useTrigger) {
const triggerPromises = dueSchedules.map(async (schedule) => { const triggerPromises = dueSchedules.map(async (schedule) => {
const queueTime = schedule.lastQueuedAt ?? queuedAt const queueTime = schedule.lastQueuedAt ?? queuedAt

View File

@@ -1,6 +1,6 @@
import { type NextRequest, NextResponse } from 'next/server' import { type NextRequest, NextResponse } from 'next/server'
import { env } from '@/lib/core/config/env' import { env } from '@/lib/core/config/env'
import { isProd } from '@/lib/core/config/environment' import { isProd } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('TelemetryAPI') const logger = createLogger('TelemetryAPI')

View File

@@ -1,5 +1,7 @@
import type { NextRequest } from 'next/server' import type { NextRequest } from 'next/server'
import { authenticateApiKeyFromHeader, updateApiKeyLastUsed } from '@/lib/api-key/service' import { authenticateApiKeyFromHeader, updateApiKeyLastUsed } from '@/lib/api-key/service'
import { ANONYMOUS_USER_ID } from '@/lib/auth/constants'
import { isAuthDisabled } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('V1Auth') const logger = createLogger('V1Auth')
@@ -13,6 +15,14 @@ export interface AuthResult {
} }
export async function authenticateV1Request(request: NextRequest): Promise<AuthResult> { export async function authenticateV1Request(request: NextRequest): Promise<AuthResult> {
if (isAuthDisabled) {
return {
authenticated: true,
userId: ANONYMOUS_USER_ID,
keyType: 'personal',
}
}
const apiKey = request.headers.get('x-api-key') const apiKey = request.headers.get('x-api-key')
if (!apiKey) { if (!apiKey) {

View File

@@ -5,7 +5,7 @@ import { type NextRequest, NextResponse } from 'next/server'
import OpenAI, { AzureOpenAI } from 'openai' import OpenAI, { AzureOpenAI } from 'openai'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing' import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { env } from '@/lib/core/config/env' import { env } from '@/lib/core/config/env'
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/environment' import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
import { generateRequestId } from '@/lib/core/utils/request' import { generateRequestId } from '@/lib/core/utils/request'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { getModelPricing } from '@/providers/utils' import { getModelPricing } from '@/providers/utils'

View File

@@ -3,7 +3,7 @@ import { type NextRequest, NextResponse } from 'next/server'
import { validate as uuidValidate, v4 as uuidv4 } from 'uuid' import { validate as uuidValidate, v4 as uuidv4 } from 'uuid'
import { z } from 'zod' import { z } from 'zod'
import { checkHybridAuth } from '@/lib/auth/hybrid' import { checkHybridAuth } from '@/lib/auth/hybrid'
import { env, isTruthy } from '@/lib/core/config/env' import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { generateRequestId } from '@/lib/core/utils/request' import { generateRequestId } from '@/lib/core/utils/request'
import { SSE_HEADERS } from '@/lib/core/utils/sse' import { SSE_HEADERS } from '@/lib/core/utils/sse'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
@@ -236,9 +236,8 @@ type AsyncExecutionParams = {
*/ */
async function handleAsyncExecution(params: AsyncExecutionParams): Promise<NextResponse> { async function handleAsyncExecution(params: AsyncExecutionParams): Promise<NextResponse> {
const { requestId, workflowId, userId, input, triggerType } = params const { requestId, workflowId, userId, input, triggerType } = params
const useTrigger = isTruthy(env.TRIGGER_DEV_ENABLED)
if (!useTrigger) { if (!isTriggerDevEnabled) {
logger.warn(`[${requestId}] Async mode requested but TRIGGER_DEV_ENABLED is false`) logger.warn(`[${requestId}] Async mode requested but TRIGGER_DEV_ENABLED is false`)
return NextResponse.json( return NextResponse.json(
{ error: 'Async execution is not enabled. Set TRIGGER_DEV_ENABLED=true to use async mode.' }, { error: 'Async execution is not enabled. Set TRIGGER_DEV_ENABLED=true to use async mode.' },

View File

@@ -16,6 +16,7 @@ import {
} from '@/components/emcn' } from '@/components/emcn'
import { Input, Skeleton } from '@/components/ui' import { Input, Skeleton } from '@/components/ui'
import { signOut, useSession } from '@/lib/auth/auth-client' import { signOut, useSession } from '@/lib/auth/auth-client'
import { ANONYMOUS_USER_ID } from '@/lib/auth/constants'
import { useBrandConfig } from '@/lib/branding/branding' import { useBrandConfig } from '@/lib/branding/branding'
import { getEnv, isTruthy } from '@/lib/core/config/env' import { getEnv, isTruthy } from '@/lib/core/config/env'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
@@ -59,6 +60,7 @@ export function General({ onOpenChange }: GeneralProps) {
const isLoading = isProfileLoading || isSettingsLoading const isLoading = isProfileLoading || isSettingsLoading
const isTrainingEnabled = isTruthy(getEnv('NEXT_PUBLIC_COPILOT_TRAINING_ENABLED')) const isTrainingEnabled = isTruthy(getEnv('NEXT_PUBLIC_COPILOT_TRAINING_ENABLED'))
const isAuthDisabled = session?.user?.id === ANONYMOUS_USER_ID
const [isSuperUser, setIsSuperUser] = useState(false) const [isSuperUser, setIsSuperUser] = useState(false)
const [loadingSuperUser, setLoadingSuperUser] = useState(true) const [loadingSuperUser, setLoadingSuperUser] = useState(true)
@@ -461,10 +463,12 @@ export function General({ onOpenChange }: GeneralProps) {
</div> </div>
)} )}
<div className='mt-auto flex items-center gap-[8px]'> {!isAuthDisabled && (
<Button onClick={handleSignOut}>Sign out</Button> <div className='mt-auto flex items-center gap-[8px]'>
<Button onClick={() => setShowResetPasswordModal(true)}>Reset password</Button> <Button onClick={handleSignOut}>Sign out</Button>
</div> <Button onClick={() => setShowResetPasswordModal(true)}>Reset password</Button>
</div>
)}
{/* Password Reset Confirmation Modal */} {/* Password Reset Confirmation Modal */}
<Modal open={showResetPasswordModal} onOpenChange={setShowResetPasswordModal}> <Modal open={showResetPasswordModal} onOpenChange={setShowResetPasswordModal}>

View File

@@ -6,7 +6,7 @@ import { Button, Combobox, Input, Switch, Textarea } from '@/components/emcn'
import { Skeleton } from '@/components/ui' import { Skeleton } from '@/components/ui'
import { useSession } from '@/lib/auth/auth-client' import { useSession } from '@/lib/auth/auth-client'
import { getSubscriptionStatus } from '@/lib/billing/client/utils' import { getSubscriptionStatus } from '@/lib/billing/client/utils'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { cn } from '@/lib/core/utils/cn' import { cn } from '@/lib/core/utils/cn'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'

View File

@@ -26,7 +26,7 @@ import { McpIcon } from '@/components/icons'
import { useSession } from '@/lib/auth/auth-client' import { useSession } from '@/lib/auth/auth-client'
import { getSubscriptionStatus } from '@/lib/billing/client' import { getSubscriptionStatus } from '@/lib/billing/client'
import { getEnv, isTruthy } from '@/lib/core/config/env' import { getEnv, isTruthy } from '@/lib/core/config/env'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { getUserRole } from '@/lib/workspaces/organization' import { getUserRole } from '@/lib/workspaces/organization'
import { import {
ApiKeys, ApiKeys,

View File

@@ -1,5 +1,5 @@
import { AgentIcon } from '@/components/icons' import { AgentIcon } from '@/components/icons'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import type { BlockConfig } from '@/blocks/types' import type { BlockConfig } from '@/blocks/types'
import { AuthMode } from '@/blocks/types' import { AuthMode } from '@/blocks/types'

View File

@@ -1,5 +1,5 @@
import { ChartBarIcon } from '@/components/icons' import { ChartBarIcon } from '@/components/icons'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import type { BlockConfig, ParamType } from '@/blocks/types' import type { BlockConfig, ParamType } from '@/blocks/types'
import type { ProviderId } from '@/providers/types' import type { ProviderId } from '@/providers/types'

View File

@@ -1,5 +1,5 @@
import { ShieldCheckIcon } from '@/components/icons' import { ShieldCheckIcon } from '@/components/icons'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import type { BlockConfig } from '@/blocks/types' import type { BlockConfig } from '@/blocks/types'
import { getHostedModels, getProviderIcon } from '@/providers/utils' import { getHostedModels, getProviderIcon } from '@/providers/utils'
import { useProvidersStore } from '@/stores/providers/store' import { useProvidersStore } from '@/stores/providers/store'

View File

@@ -1,5 +1,5 @@
import { ConnectIcon } from '@/components/icons' import { ConnectIcon } from '@/components/icons'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { AuthMode, type BlockConfig } from '@/blocks/types' import { AuthMode, type BlockConfig } from '@/blocks/types'
import type { ProviderId } from '@/providers/types' import type { ProviderId } from '@/providers/types'
import { import {

View File

@@ -1,5 +1,5 @@
import { TranslateIcon } from '@/components/icons' import { TranslateIcon } from '@/components/icons'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { AuthMode, type BlockConfig } from '@/blocks/types' import { AuthMode, type BlockConfig } from '@/blocks/types'
import { getHostedModels, getProviderIcon, providers } from '@/providers/utils' import { getHostedModels, getProviderIcon, providers } from '@/providers/utils'
import { useProvidersStore } from '@/stores/providers/store' import { useProvidersStore } from '@/stores/providers/store'

View File

@@ -1,6 +1,6 @@
import { Container, Img, Link, Section, Text } from '@react-email/components' import { Container, Img, Link, Section, Text } from '@react-email/components'
import { getBrandConfig } from '@/lib/branding/branding' import { getBrandConfig } from '@/lib/branding/branding'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
interface UnsubscribeOptions { interface UnsubscribeOptions {

View File

@@ -1,5 +1,4 @@
import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest' import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
import { isHosted } from '@/lib/core/config/environment'
import { getAllBlocks } from '@/blocks' import { getAllBlocks } from '@/blocks'
import { BlockType } from '@/executor/constants' import { BlockType } from '@/executor/constants'
import { AgentBlockHandler } from '@/executor/handlers/agent/agent-handler' import { AgentBlockHandler } from '@/executor/handlers/agent/agent-handler'
@@ -11,11 +10,11 @@ import { executeTool } from '@/tools'
process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000' process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000'
vi.mock('@/lib/core/config/environment', () => ({ vi.mock('@/lib/core/config/feature-flags', () => ({
isHosted: vi.fn().mockReturnValue(false), isHosted: false,
isProd: vi.fn().mockReturnValue(false), isProd: false,
isDev: vi.fn().mockReturnValue(true), isDev: true,
isTest: vi.fn().mockReturnValue(false), isTest: false,
getCostMultiplier: vi.fn().mockReturnValue(1), getCostMultiplier: vi.fn().mockReturnValue(1),
isEmailVerificationEnabled: false, isEmailVerificationEnabled: false,
isBillingEnabled: false, isBillingEnabled: false,
@@ -65,7 +64,6 @@ global.fetch = Object.assign(vi.fn(), { preconnect: vi.fn() }) as typeof fetch
const mockGetAllBlocks = getAllBlocks as Mock const mockGetAllBlocks = getAllBlocks as Mock
const mockExecuteTool = executeTool as Mock const mockExecuteTool = executeTool as Mock
const mockIsHosted = isHosted as unknown as Mock
const mockGetProviderFromModel = getProviderFromModel as Mock const mockGetProviderFromModel = getProviderFromModel as Mock
const mockTransformBlockTool = transformBlockTool as Mock const mockTransformBlockTool = transformBlockTool as Mock
const mockFetch = global.fetch as unknown as Mock const mockFetch = global.fetch as unknown as Mock
@@ -120,7 +118,6 @@ describe('AgentBlockHandler', () => {
loops: {}, loops: {},
} as SerializedWorkflow, } as SerializedWorkflow,
} }
mockIsHosted.mockReturnValue(false)
mockGetProviderFromModel.mockReturnValue('mock-provider') mockGetProviderFromModel.mockReturnValue('mock-provider')
mockFetch.mockImplementation(() => { mockFetch.mockImplementation(() => {
@@ -552,8 +549,6 @@ describe('AgentBlockHandler', () => {
}) })
it('should not require API key for gpt-4o on hosted version', async () => { it('should not require API key for gpt-4o on hosted version', async () => {
mockIsHosted.mockReturnValue(true)
const inputs = { const inputs = {
model: 'gpt-4o', model: 'gpt-4o',
systemPrompt: 'You are a helpful assistant.', systemPrompt: 'You are a helpful assistant.',

View File

@@ -1,5 +1,5 @@
import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query' import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('CopilotKeysQuery') const logger = createLogger('CopilotKeysQuery')

View File

@@ -0,0 +1,104 @@
import { db } from '@sim/db'
import * as schema from '@sim/db/schema'
import { eq } from 'drizzle-orm'
import { createLogger } from '@/lib/logs/console/logger'
import { ANONYMOUS_USER, ANONYMOUS_USER_ID } from './constants'
const logger = createLogger('AnonymousAuth')
let anonymousUserEnsured = false
/**
* Ensures the anonymous user and their stats record exist in the database.
* Called when DISABLE_AUTH is enabled to ensure DB operations work.
*/
export async function ensureAnonymousUserExists(): Promise<void> {
if (anonymousUserEnsured) return
try {
const existingUser = await db.query.user.findFirst({
where: eq(schema.user.id, ANONYMOUS_USER_ID),
})
if (!existingUser) {
const now = new Date()
await db.insert(schema.user).values({
...ANONYMOUS_USER,
createdAt: now,
updatedAt: now,
})
logger.info('Created anonymous user for DISABLE_AUTH mode')
}
const existingStats = await db.query.userStats.findFirst({
where: eq(schema.userStats.userId, ANONYMOUS_USER_ID),
})
if (!existingStats) {
await db.insert(schema.userStats).values({
id: crypto.randomUUID(),
userId: ANONYMOUS_USER_ID,
currentUsageLimit: '10000000000',
})
logger.info('Created anonymous user stats for DISABLE_AUTH mode')
}
anonymousUserEnsured = true
} catch (error) {
if (
error instanceof Error &&
(error.message.includes('unique') || error.message.includes('duplicate'))
) {
anonymousUserEnsured = true
return
}
logger.error('Failed to ensure anonymous user exists', { error })
throw error
}
}
export interface AnonymousSession {
user: {
id: string
name: string
email: string
emailVerified: boolean
image: null
createdAt: Date
updatedAt: Date
}
session: {
id: string
userId: string
expiresAt: Date
createdAt: Date
updatedAt: Date
token: string
ipAddress: null
userAgent: null
}
}
/**
* Creates an anonymous session for when auth is disabled.
*/
export function createAnonymousSession(): AnonymousSession {
const now = new Date()
return {
user: {
...ANONYMOUS_USER,
createdAt: now,
updatedAt: now,
},
session: {
id: 'anonymous-session',
userId: ANONYMOUS_USER_ID,
expiresAt: new Date(Date.now() + 365 * 24 * 60 * 60 * 1000), // 1 year
createdAt: now,
updatedAt: now,
token: 'anonymous-token',
ipAddress: null,
userAgent: null,
},
}
}

View File

@@ -10,7 +10,7 @@ import {
import { createAuthClient } from 'better-auth/react' import { createAuthClient } from 'better-auth/react'
import type { auth } from '@/lib/auth' import type { auth } from '@/lib/auth'
import { env } from '@/lib/core/config/env' import { env } from '@/lib/core/config/env'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { SessionContext, type SessionHookResult } from '@/app/_shell/providers/session-provider' import { SessionContext, type SessionHookResult } from '@/app/_shell/providers/session-provider'

View File

@@ -38,13 +38,19 @@ import {
handleSubscriptionCreated, handleSubscriptionCreated,
handleSubscriptionDeleted, handleSubscriptionDeleted,
} from '@/lib/billing/webhooks/subscription' } from '@/lib/billing/webhooks/subscription'
import { env, isTruthy } from '@/lib/core/config/env' import { env } from '@/lib/core/config/env'
import { isBillingEnabled, isEmailVerificationEnabled } from '@/lib/core/config/environment' import {
isAuthDisabled,
isBillingEnabled,
isEmailVerificationEnabled,
isRegistrationDisabled,
} from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { sendEmail } from '@/lib/messaging/email/mailer' import { sendEmail } from '@/lib/messaging/email/mailer'
import { getFromEmailAddress } from '@/lib/messaging/email/utils' import { getFromEmailAddress } from '@/lib/messaging/email/utils'
import { quickValidateEmail } from '@/lib/messaging/email/validation' import { quickValidateEmail } from '@/lib/messaging/email/validation'
import { createAnonymousSession, ensureAnonymousUserExists } from './anonymous'
import { SSO_TRUSTED_PROVIDERS } from './sso/constants' import { SSO_TRUSTED_PROVIDERS } from './sso/constants'
const logger = createLogger('Auth') const logger = createLogger('Auth')
@@ -270,7 +276,7 @@ export const auth = betterAuth({
}, },
hooks: { hooks: {
before: createAuthMiddleware(async (ctx) => { before: createAuthMiddleware(async (ctx) => {
if (ctx.path.startsWith('/sign-up') && isTruthy(env.DISABLE_REGISTRATION)) if (ctx.path.startsWith('/sign-up') && isRegistrationDisabled)
throw new Error('Registration is disabled, please contact your admin.') throw new Error('Registration is disabled, please contact your admin.')
if ( if (
@@ -2185,6 +2191,11 @@ export const auth = betterAuth({
}) })
export async function getSession() { export async function getSession() {
if (isAuthDisabled) {
await ensureAnonymousUserExists()
return createAnonymousSession()
}
const hdrs = await headers() const hdrs = await headers()
return await auth.api.getSession({ return await auth.api.getSession({
headers: hdrs, headers: hdrs,

View File

@@ -0,0 +1,10 @@
/** Anonymous user ID used when DISABLE_AUTH is enabled */
export const ANONYMOUS_USER_ID = '00000000-0000-0000-0000-000000000000'
export const ANONYMOUS_USER = {
id: ANONYMOUS_USER_ID,
name: 'Anonymous',
email: 'anonymous@localhost',
emailVerified: true,
image: null,
} as const

View File

@@ -1 +1,4 @@
export type { AnonymousSession } from './anonymous'
export { createAnonymousSession, ensureAnonymousUserExists } from './anonymous'
export { auth, getSession, signIn, signUp } from './auth' export { auth, getSession, signIn, signUp } from './auth'
export { ANONYMOUS_USER, ANONYMOUS_USER_ID } from './constants'

View File

@@ -2,7 +2,7 @@ import { db } from '@sim/db'
import { member, organization, userStats } from '@sim/db/schema' import { member, organization, userStats } from '@sim/db/schema'
import { and, eq, inArray } from 'drizzle-orm' import { and, eq, inArray } from 'drizzle-orm'
import { getUserUsageLimit } from '@/lib/billing/core/usage' import { getUserUsageLimit } from '@/lib/billing/core/usage'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('UsageMonitor') const logger = createLogger('UsageMonitor')

View File

@@ -9,7 +9,7 @@ import {
getPerUserMinimumLimit, getPerUserMinimumLimit,
} from '@/lib/billing/subscriptions/utils' } from '@/lib/billing/subscriptions/utils'
import type { UserSubscriptionState } from '@/lib/billing/types' import type { UserSubscriptionState } from '@/lib/billing/types'
import { isProd } from '@/lib/core/config/environment' import { isProd } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'

View File

@@ -14,7 +14,7 @@ import {
getPlanPricing, getPlanPricing,
} from '@/lib/billing/subscriptions/utils' } from '@/lib/billing/subscriptions/utils'
import type { BillingData, UsageData, UsageLimitInfo } from '@/lib/billing/types' import type { BillingData, UsageData, UsageLimitInfo } from '@/lib/billing/types'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { sendEmail } from '@/lib/messaging/email/mailer' import { sendEmail } from '@/lib/messaging/email/mailer'

View File

@@ -13,7 +13,7 @@ import {
import { organization, subscription, userStats } from '@sim/db/schema' import { organization, subscription, userStats } from '@sim/db/schema'
import { eq } from 'drizzle-orm' import { eq } from 'drizzle-orm'
import { getEnv } from '@/lib/core/config/env' import { getEnv } from '@/lib/core/config/env'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('StorageLimits') const logger = createLogger('StorageLimits')

View File

@@ -7,7 +7,7 @@
import { db } from '@sim/db' import { db } from '@sim/db'
import { organization, userStats } from '@sim/db/schema' import { organization, userStats } from '@sim/db/schema'
import { eq, sql } from 'drizzle-orm' import { eq, sql } from 'drizzle-orm'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('StorageTracking') const logger = createLogger('StorageTracking')

View File

@@ -20,6 +20,7 @@ export const env = createEnv({
BETTER_AUTH_URL: z.string().url(), // Base URL for Better Auth service BETTER_AUTH_URL: z.string().url(), // Base URL for Better Auth service
BETTER_AUTH_SECRET: z.string().min(32), // Secret key for Better Auth JWT signing BETTER_AUTH_SECRET: z.string().min(32), // Secret key for Better Auth JWT signing
DISABLE_REGISTRATION: z.boolean().optional(), // Flag to disable new user registration DISABLE_REGISTRATION: z.boolean().optional(), // Flag to disable new user registration
DISABLE_AUTH: z.boolean().optional(), // Bypass authentication entirely (self-hosted only, creates anonymous session)
ALLOWED_LOGIN_EMAILS: z.string().optional(), // Comma-separated list of allowed email addresses for login ALLOWED_LOGIN_EMAILS: z.string().optional(), // Comma-separated list of allowed email addresses for login
ALLOWED_LOGIN_DOMAINS: z.string().optional(), // Comma-separated list of allowed email domains for login ALLOWED_LOGIN_DOMAINS: z.string().optional(), // Comma-separated list of allowed email domains for login
ENCRYPTION_KEY: z.string().min(32), // Key for encrypting sensitive data ENCRYPTION_KEY: z.string().min(32), // Key for encrypting sensitive data

View File

@@ -35,6 +35,31 @@ export const isBillingEnabled = isTruthy(env.BILLING_ENABLED)
*/ */
export const isEmailVerificationEnabled = isTruthy(env.EMAIL_VERIFICATION_ENABLED) export const isEmailVerificationEnabled = isTruthy(env.EMAIL_VERIFICATION_ENABLED)
/**
* Is authentication disabled (for self-hosted deployments behind private networks)
*/
export const isAuthDisabled = isTruthy(env.DISABLE_AUTH)
/**
* Is user registration disabled
*/
export const isRegistrationDisabled = isTruthy(env.DISABLE_REGISTRATION)
/**
* Is Trigger.dev enabled for async job processing
*/
export const isTriggerDevEnabled = isTruthy(env.TRIGGER_DEV_ENABLED)
/**
* Is SSO enabled for enterprise authentication
*/
export const isSsoEnabled = isTruthy(env.SSO_ENABLED)
/**
* Is E2B enabled for remote code execution
*/
export const isE2bEnabled = isTruthy(env.E2B_ENABLED)
/** /**
* Get cost multiplier based on environment * Get cost multiplier based on environment
*/ */

View File

@@ -1,5 +1,5 @@
import { getEnv } from '@/lib/core/config/env' import { getEnv } from '@/lib/core/config/env'
import { isProd } from '@/lib/core/config/environment' import { isProd } from '@/lib/core/config/feature-flags'
/** /**
* Returns the base URL of the application from NEXT_PUBLIC_APP_URL * Returns the base URL of the application from NEXT_PUBLIC_APP_URL

View File

@@ -6,7 +6,7 @@ import {
} from '@sim/db/schema' } from '@sim/db/schema'
import { and, eq, or, sql } from 'drizzle-orm' import { and, eq, or, sql } from 'drizzle-orm'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
import { env, isTruthy } from '@/lib/core/config/env' import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import type { WorkflowExecutionLog } from '@/lib/logs/types' import type { WorkflowExecutionLog } from '@/lib/logs/types'
import { import {
@@ -140,9 +140,7 @@ export async function emitWorkflowExecutionCompleted(log: WorkflowExecutionLog):
alertConfig: alertConfig || undefined, alertConfig: alertConfig || undefined,
} }
const useTrigger = isTruthy(env.TRIGGER_DEV_ENABLED) if (isTriggerDevEnabled) {
if (useTrigger) {
await workspaceNotificationDeliveryTask.trigger(payload) await workspaceNotificationDeliveryTask.trigger(payload)
logger.info( logger.info(
`Enqueued ${subscription.notificationType} notification ${deliveryId} via Trigger.dev` `Enqueued ${subscription.notificationType} notification ${deliveryId} via Trigger.dev`

View File

@@ -15,7 +15,7 @@ import {
maybeSendUsageThresholdEmail, maybeSendUsageThresholdEmail,
} from '@/lib/billing/core/usage' } from '@/lib/billing/core/usage'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing' import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { isBillingEnabled } from '@/lib/core/config/environment' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { redactApiKeys } from '@/lib/core/security/redaction' import { redactApiKeys } from '@/lib/core/security/redaction'
import { filterForDisplay } from '@/lib/core/utils/display-filters' import { filterForDisplay } from '@/lib/core/utils/display-filters'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'

View File

@@ -5,7 +5,7 @@
import { db } from '@sim/db' import { db } from '@sim/db'
import { mcpServers } from '@sim/db/schema' import { mcpServers } from '@sim/db/schema'
import { and, eq, isNull } from 'drizzle-orm' import { and, eq, isNull } from 'drizzle-orm'
import { isTest } from '@/lib/core/config/environment' import { isTest } from '@/lib/core/config/feature-flags'
import { generateRequestId } from '@/lib/core/utils/request' import { generateRequestId } from '@/lib/core/utils/request'
import { getEffectiveDecryptedEnv } from '@/lib/environment/utils' import { getEffectiveDecryptedEnv } from '@/lib/environment/utils'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'

View File

@@ -7,7 +7,7 @@ import {
} from '@sim/db/schema' } from '@sim/db/schema'
import { and, eq, gte, sql } from 'drizzle-orm' import { and, eq, gte, sql } from 'drizzle-orm'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
import { env, isTruthy } from '@/lib/core/config/env' import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { import {
executeNotificationDelivery, executeNotificationDelivery,
@@ -118,9 +118,7 @@ async function checkWorkflowInactivity(
alertConfig, alertConfig,
} }
const useTrigger = isTruthy(env.TRIGGER_DEV_ENABLED) if (isTriggerDevEnabled) {
if (useTrigger) {
await workspaceNotificationDeliveryTask.trigger(payload) await workspaceNotificationDeliveryTask.trigger(payload)
} else { } else {
void executeNotificationDelivery(payload).catch((error) => { void executeNotificationDelivery(payload).catch((error) => {

View File

@@ -3,7 +3,7 @@ import { tasks } from '@trigger.dev/sdk'
import { and, eq } from 'drizzle-orm' import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server' import { type NextRequest, NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
import { env, isTruthy } from '@/lib/core/config/env' import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { preprocessExecution } from '@/lib/execution/preprocessing' import { preprocessExecution } from '@/lib/execution/preprocessing'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { convertSquareBracketsToTwiML } from '@/lib/webhooks/utils' import { convertSquareBracketsToTwiML } from '@/lib/webhooks/utils'
@@ -707,9 +707,7 @@ export async function queueWebhookExecution(
...(credentialId ? { credentialId } : {}), ...(credentialId ? { credentialId } : {}),
} }
const useTrigger = isTruthy(env.TRIGGER_DEV_ENABLED) if (isTriggerDevEnabled) {
if (useTrigger) {
const handle = await tasks.trigger('webhook-execution', payload) const handle = await tasks.trigger('webhook-execution', payload)
logger.info( logger.info(
`[${options.requestId}] Queued ${options.testMode ? 'TEST ' : ''}webhook execution task ${ `[${options.requestId}] Queued ${options.testMode ? 'TEST ' : ''}webhook execution task ${

View File

@@ -1,6 +1,6 @@
import type { NextConfig } from 'next' import type { NextConfig } from 'next'
import { env, getEnv, isTruthy } from './lib/core/config/env' import { env, getEnv, isTruthy } from './lib/core/config/env'
import { isDev, isHosted } from './lib/core/config/environment' import { isDev, isHosted } from './lib/core/config/feature-flags'
import { getMainCSPPolicy, getWorkflowExecutionCSPPolicy } from './lib/core/security/csp' import { getMainCSPPolicy, getWorkflowExecutionCSPPolicy } from './lib/core/security/csp'
const nextConfig: NextConfig = { const nextConfig: NextConfig = {

View File

@@ -1,4 +1,4 @@
import { getCostMultiplier } from '@/lib/core/config/environment' import { getCostMultiplier } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types' import type { StreamingExecution } from '@/executor/types'
import type { ProviderRequest, ProviderResponse } from '@/providers/types' import type { ProviderRequest, ProviderResponse } from '@/providers/types'

View File

@@ -1,5 +1,5 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import * as environmentModule from '@/lib/core/config/environment' import * as environmentModule from '@/lib/core/config/feature-flags'
import { import {
calculateCost, calculateCost,
extractAndParseJSON, extractAndParseJSON,

View File

@@ -1,5 +1,5 @@
import { getEnv, isTruthy } from '@/lib/core/config/env' import { getEnv, isTruthy } from '@/lib/core/config/env'
import { isHosted } from '@/lib/core/config/environment' import { isHosted } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import { anthropicProvider } from '@/providers/anthropic' import { anthropicProvider } from '@/providers/anthropic'
import { azureOpenAIProvider } from '@/providers/azure-openai' import { azureOpenAIProvider } from '@/providers/azure-openai'

View File

@@ -1,6 +1,6 @@
import { getSessionCookie } from 'better-auth/cookies' import { getSessionCookie } from 'better-auth/cookies'
import { type NextRequest, NextResponse } from 'next/server' import { type NextRequest, NextResponse } from 'next/server'
import { isHosted } from './lib/core/config/environment' import { isAuthDisabled, isHosted } from './lib/core/config/feature-flags'
import { generateRuntimeCSP } from './lib/core/security/csp' import { generateRuntimeCSP } from './lib/core/security/csp'
import { createLogger } from './lib/logs/console/logger' import { createLogger } from './lib/logs/console/logger'
@@ -135,7 +135,7 @@ export async function proxy(request: NextRequest) {
const url = request.nextUrl const url = request.nextUrl
const sessionCookie = getSessionCookie(request) const sessionCookie = getSessionCookie(request)
const hasActiveSession = !!sessionCookie const hasActiveSession = isAuthDisabled || !!sessionCookie
const redirect = handleRootPathRedirects(request, hasActiveSession) const redirect = handleRootPathRedirects(request, hasActiveSession)
if (redirect) return redirect if (redirect) return redirect

View File

@@ -5,7 +5,7 @@ import { db } from '@sim/db'
import { docsEmbeddings } from '@sim/db/schema' import { docsEmbeddings } from '@sim/db/schema'
import { sql } from 'drizzle-orm' import { sql } from 'drizzle-orm'
import { type DocChunk, DocsChunker } from '@/lib/chunkers' import { type DocChunk, DocsChunker } from '@/lib/chunkers'
import { isDev } from '@/lib/core/config/environment' import { isDev } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('ProcessDocs') const logger = createLogger('ProcessDocs')

View File

@@ -1,7 +1,7 @@
import type { Server as HttpServer } from 'http' import type { Server as HttpServer } from 'http'
import { Server } from 'socket.io' import { Server } from 'socket.io'
import { env } from '@/lib/core/config/env' import { env } from '@/lib/core/config/env'
import { isProd } from '@/lib/core/config/environment' import { isProd } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'

View File

@@ -1,10 +1,14 @@
import type { Socket } from 'socket.io' import type { Socket } from 'socket.io'
import { auth } from '@/lib/auth' import { auth } from '@/lib/auth'
import { ANONYMOUS_USER, ANONYMOUS_USER_ID } from '@/lib/auth/constants'
import { isAuthDisabled } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('SocketAuth') const logger = createLogger('SocketAuth')
// Extend Socket interface to include user data /**
* Authenticated socket with user data attached.
*/
export interface AuthenticatedSocket extends Socket { export interface AuthenticatedSocket extends Socket {
userId?: string userId?: string
userName?: string userName?: string
@@ -13,9 +17,21 @@ export interface AuthenticatedSocket extends Socket {
userImage?: string | null userImage?: string | null
} }
// Enhanced authentication middleware /**
* Socket.IO authentication middleware.
* Handles both anonymous mode (DISABLE_AUTH=true) and normal token-based auth.
*/
export async function authenticateSocket(socket: AuthenticatedSocket, next: any) { export async function authenticateSocket(socket: AuthenticatedSocket, next: any) {
try { try {
if (isAuthDisabled) {
socket.userId = ANONYMOUS_USER_ID
socket.userName = ANONYMOUS_USER.name
socket.userEmail = ANONYMOUS_USER.email
socket.userImage = ANONYMOUS_USER.image
logger.debug(`Socket ${socket.id} authenticated as anonymous`)
return next()
}
// Extract authentication data from socket handshake // Extract authentication data from socket handshake
const token = socket.handshake.auth?.token const token = socket.handshake.auth?.token
const origin = socket.handshake.headers.origin const origin = socket.handshake.headers.origin

View File

@@ -1,4 +1,4 @@
import { isTest } from '@/lib/core/config/environment' import { isTest } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger' import { createLogger } from '@/lib/logs/console/logger'
import type { TableRow } from '@/tools/types' import type { TableRow } from '@/tools/types'