diff --git a/apps/sim/contexts/socket-context.test.tsx b/apps/sim/contexts/socket-context.test.tsx new file mode 100644 index 000000000..5f1c8df74 --- /dev/null +++ b/apps/sim/contexts/socket-context.test.tsx @@ -0,0 +1,278 @@ +/** + * @vitest-environment jsdom + */ + +import { act, renderHook, waitFor } from '@testing-library/react' +import { io } from 'socket.io-client' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { SocketProvider, useSocket } from './socket-context' + +vi.mock('socket.io-client') +const mockIo = vi.mocked(io) + +global.fetch = vi.fn() +const mockFetch = vi.mocked(fetch) + +vi.mock('@/lib/logs/console-logger', () => ({ + createLogger: () => ({ + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + debug: vi.fn(), + }), +})) + +describe('SocketContext Token Refresh', () => { + let mockSocket: any + let eventHandlers: Record + + beforeEach(() => { + eventHandlers = {} + mockSocket = { + id: 'test-socket-id', + connected: true, + io: { engine: { transport: { name: 'websocket' } } }, + auth: { token: 'initial-token' }, + on: vi.fn((event, handler) => { + eventHandlers[event] = handler + }), + connect: vi.fn(), + disconnect: vi.fn(), + emit: vi.fn(), + close: vi.fn(), + } + + mockIo.mockReturnValue(mockSocket) + + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ token: 'fresh-token' }), + } as Response) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + const renderSocketProvider = async (user = { id: 'test-user', name: 'Test User' }) => { + const result = renderHook(() => useSocket(), { + wrapper: ({ children }) => {children}, + }) + + await waitFor(() => { + expect(mockSocket.on).toHaveBeenCalledWith('connect_error', expect.any(Function)) + }) + + vi.clearAllMocks() + + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ token: 'fresh-token' }), + } as Response) + + return result + } + + describe('Token Refresh on Connection Error', () => { + it('should refresh token on authentication failure', async () => { + const { result } = await renderSocketProvider() + + const error = { message: 'Token validation failed' } + + await act(async () => { + await eventHandlers.connect_error(error) + }) + + expect(mockFetch).toHaveBeenCalledWith('/api/auth/socket-token', { + method: 'POST', + credentials: 'include', + }) + + // Should update socket auth and reconnect + expect(mockSocket.auth.token).toBe('fresh-token') + expect(mockSocket.connect).toHaveBeenCalled() + }) + + it('should limit token refresh attempts to 3', async () => { + const { result } = await renderSocketProvider() + + const error = { message: 'Token validation failed' } + + for (let i = 0; i < 4; i++) { + await act(async () => { + await eventHandlers.connect_error(error) + }) + } + + // Should only call fetch 3 times (max attempts) + expect(mockFetch).toHaveBeenCalledTimes(3) + expect(mockSocket.connect).toHaveBeenCalledTimes(3) + }) + + it('should prevent concurrent token refresh attempts', async () => { + const { result } = await renderSocketProvider() + + let resolveTokenFetch!: (value: { + ok: boolean + json: () => Promise<{ token: string }> + }) => void + const slowTokenPromise = new Promise((resolve) => { + resolveTokenFetch = resolve + }) + + mockFetch.mockReturnValue(slowTokenPromise as any) + + const error = { message: 'Authentication failed' } + + // Start two concurrent refresh attempts + const promise1 = act(async () => { + await eventHandlers.connect_error(error) + }) + + const promise2 = act(async () => { + await eventHandlers.connect_error(error) + }) + + // Resolve the slow fetch + resolveTokenFetch({ + ok: true, + json: async () => ({ token: 'fresh-token' }), + }) + + await Promise.all([promise1, promise2]) + + // Should only call fetch once (concurrent protection) + expect(mockFetch).toHaveBeenCalledTimes(1) + }) + + it('should reset retry counter on successful connection', async () => { + const { result } = await renderSocketProvider() + + const error = { message: 'Token validation failed' } + + // Use up 2 retry attempts + await act(async () => { + await eventHandlers.connect_error(error) + }) + await act(async () => { + await eventHandlers.connect_error(error) + }) + + expect(mockFetch).toHaveBeenCalledTimes(2) + + // Simulate successful connection (resets counter) + await act(async () => { + eventHandlers.connect() + }) + + // Should be able to retry again (counter reset) + await act(async () => { + await eventHandlers.connect_error(error) + }) + + expect(mockFetch).toHaveBeenCalledTimes(3) + }) + + it('should handle token refresh failure gracefully', async () => { + const { result } = await renderSocketProvider() + + // Mock failed token refresh after initialization + mockFetch.mockResolvedValue({ + ok: false, + status: 401, + } as Response) + + const error = { message: 'Token validation failed' } + + await act(async () => { + await eventHandlers.connect_error(error) + }) + + // Should attempt refresh but not update auth or reconnect + expect(mockFetch).toHaveBeenCalled() + expect(mockSocket.auth.token).toBe('initial-token') // unchanged + expect(mockSocket.connect).not.toHaveBeenCalled() + }) + + it('should handle fetch errors gracefully', async () => { + const { result } = await renderSocketProvider() + + // Mock fetch error after initialization + mockFetch.mockRejectedValue(new Error('Network error')) + + const error = { message: 'Authentication failed' } + + // Should not throw error + await act(async () => { + await eventHandlers.connect_error(error) + }) + + expect(mockFetch).toHaveBeenCalled() + expect(mockSocket.connect).not.toHaveBeenCalled() + }) + + it('should only refresh token on authentication-related errors', async () => { + const { result } = await renderSocketProvider() + + // Non-authentication error + const networkError = { message: 'Network timeout' } + + await act(async () => { + await eventHandlers.connect_error(networkError) + }) + + // Should not attempt token refresh + expect(mockFetch).not.toHaveBeenCalled() + expect(mockSocket.connect).not.toHaveBeenCalled() + }) + }) + + describe('Interaction with Socket.IO Reconnection', () => { + it('should work with Socket.IO built-in reconnection attempts', async () => { + const { result } = await renderSocketProvider() + + // Simulate Socket.IO reconnection cycle + await act(async () => { + // Reconnection attempt starts + eventHandlers.reconnect_attempt(1) + }) + + await act(async () => { + // Fails with auth error + await eventHandlers.connect_error({ message: 'Token validation failed' }) + }) + + // Should refresh token and attempt reconnection + expect(mockFetch).toHaveBeenCalled() + expect(mockSocket.connect).toHaveBeenCalled() + }) + + it('should reset counters on successful reconnect', async () => { + const { result } = await renderSocketProvider() + + // Use up retry attempts + const error = { message: 'Authentication failed' } + await act(async () => { + await eventHandlers.connect_error(error) + }) + + await act(async () => { + await eventHandlers.connect_error(error) + }) + + expect(mockFetch).toHaveBeenCalledTimes(2) + + // Simulate successful reconnection + await act(async () => { + eventHandlers.reconnect(1) + }) + + // Should reset and allow new attempts + await act(async () => { + await eventHandlers.connect_error(error) + }) + + expect(mockFetch).toHaveBeenCalledTimes(3) + }) + }) +}) diff --git a/apps/sim/contexts/socket-context.tsx b/apps/sim/contexts/socket-context.tsx index a8045e196..cccb7c20f 100644 --- a/apps/sim/contexts/socket-context.tsx +++ b/apps/sim/contexts/socket-context.tsx @@ -87,6 +87,8 @@ export function SocketProvider({ children, user }: SocketProviderProps) { // Connection state tracking const reconnectCount = useRef(0) + const tokenRefreshAttempts = useRef(0) + const isRefreshingToken = useRef(false) // Use refs to store event handlers to avoid stale closures const eventHandlers = useRef<{ @@ -138,7 +140,9 @@ export function SocketProvider({ children, user }: SocketProviderProps) { const socketInstance = io(socketUrl, { transports: ['websocket', 'polling'], // Keep polling fallback for reliability withCredentials: true, - reconnectionAttempts: 5, // Back to original conservative setting + reconnectionAttempts: 5, // Socket.IO handles base reconnection + reconnectionDelay: 1000, // Start with 1 second delay + reconnectionDelayMax: 5000, // Max 5 second delay timeout: 10000, // Back to original timeout auth: { token, // Send one-time token for authentication @@ -150,6 +154,7 @@ export function SocketProvider({ children, user }: SocketProviderProps) { setIsConnected(true) setIsConnecting(false) reconnectCount.current = 0 + tokenRefreshAttempts.current = 0 logger.info('Socket connected successfully', { socketId: socketInstance.id, @@ -172,7 +177,7 @@ export function SocketProvider({ children, user }: SocketProviderProps) { setPresenceUsers([]) }) - socketInstance.on('connect_error', (error: any) => { + socketInstance.on('connect_error', async (error: any) => { setIsConnecting(false) logger.error('Socket connection error:', { message: error.message, @@ -181,11 +186,54 @@ export function SocketProvider({ children, user }: SocketProviderProps) { type: error.type, transport: error.transport, }) + + if ( + error.message?.includes('Token validation failed') || + error.message?.includes('Authentication failed') + ) { + // Prevent infinite loops - limit refresh attempts + if (tokenRefreshAttempts.current >= 3) { + logger.warn('Max token refresh attempts reached - user needs to refresh page') + return + } + + // Prevent concurrent refresh attempts + if (isRefreshingToken.current) { + logger.info('Token refresh already in progress, skipping...') + return + } + + isRefreshingToken.current = true + tokenRefreshAttempts.current++ + logger.info(`Token expired, attempting refresh (${tokenRefreshAttempts.current}/3)...`) + + try { + const tokenResponse = await fetch('/api/auth/socket-token', { + method: 'POST', + credentials: 'include', + }) + + if (tokenResponse.ok) { + const { token } = await tokenResponse.json() + socketInstance.auth = { ...socketInstance.auth, token } + logger.info('Token refreshed successfully, reconnecting...') + socketInstance.connect() + } else { + logger.warn('Failed to refresh token - user may need to refresh page') + } + } catch (refreshError) { + logger.error('Token refresh failed:', refreshError) + } finally { + isRefreshingToken.current = false + } + } }) // Add reconnection logging socketInstance.on('reconnect', (attemptNumber) => { reconnectCount.current = attemptNumber + // Reset token refresh attempts on successful reconnection + tokenRefreshAttempts.current = 0 logger.info('Socket reconnected', { attemptNumber, }) diff --git a/apps/sim/lib/auth.ts b/apps/sim/lib/auth.ts index 68a3125ce..8608da260 100644 --- a/apps/sim/lib/auth.ts +++ b/apps/sim/lib/auth.ts @@ -187,7 +187,7 @@ export const auth = betterAuth({ plugins: [ nextCookies(), oneTimeToken({ - expiresIn: 10, // 10 minutes - enough time for socket connection + expiresIn: 30, // 30 minutes - covers typical work sessions }), emailOTP({ sendVerificationOTP: async (data: { diff --git a/apps/sim/socket-server/config/socket.ts b/apps/sim/socket-server/config/socket.ts index 0ec9f3261..681c3fda9 100644 --- a/apps/sim/socket-server/config/socket.ts +++ b/apps/sim/socket-server/config/socket.ts @@ -36,7 +36,7 @@ export function createSocketIOServer(httpServer: HttpServer): Server { allowedHeaders: ['Content-Type', 'Authorization', 'Cookie', 'socket.io'], credentials: true, // Enable credentials to accept cookies }, - transports: ['polling', 'websocket'], // Keep both transports for reliability + transports: ['websocket', 'polling'], // WebSocket first, polling as fallback allowEIO3: true, // Keep legacy support for compatibility pingTimeout: 60000, // Back to original conservative setting pingInterval: 25000, // Back to original interval @@ -52,7 +52,7 @@ export function createSocketIOServer(httpServer: HttpServer): Server { logger.info('Socket.IO server configured with:', { allowedOrigins: allowedOrigins.length, - transports: ['polling', 'websocket'], + transports: ['websocket', 'polling'], pingTimeout: 60000, pingInterval: 25000, maxHttpBufferSize: 1e6,