mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-09 23:17:59 -05:00
fix(sockets): implement longer token expiration for OTT, preventitive token refresh with retries (#566)
* fix(sockets): implement longer token expiration for OTT, preventitive token refresh with retries * cleanup tests * make websocket first choice transport * fix lint --------- Co-authored-by: Vikhyath Mondreti <vikhyathmondreti@vikhyaths-air.lan>
This commit is contained in:
278
apps/sim/contexts/socket-context.test.tsx
Normal file
278
apps/sim/contexts/socket-context.test.tsx
Normal file
@@ -0,0 +1,278 @@
|
||||
/**
|
||||
* @vitest-environment jsdom
|
||||
*/
|
||||
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { io } from 'socket.io-client'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { SocketProvider, useSocket } from './socket-context'
|
||||
|
||||
vi.mock('socket.io-client')
|
||||
const mockIo = vi.mocked(io)
|
||||
|
||||
global.fetch = vi.fn()
|
||||
const mockFetch = vi.mocked(fetch)
|
||||
|
||||
vi.mock('@/lib/logs/console-logger', () => ({
|
||||
createLogger: () => ({
|
||||
info: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('SocketContext Token Refresh', () => {
|
||||
let mockSocket: any
|
||||
let eventHandlers: Record<string, any>
|
||||
|
||||
beforeEach(() => {
|
||||
eventHandlers = {}
|
||||
mockSocket = {
|
||||
id: 'test-socket-id',
|
||||
connected: true,
|
||||
io: { engine: { transport: { name: 'websocket' } } },
|
||||
auth: { token: 'initial-token' },
|
||||
on: vi.fn((event, handler) => {
|
||||
eventHandlers[event] = handler
|
||||
}),
|
||||
connect: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
emit: vi.fn(),
|
||||
close: vi.fn(),
|
||||
}
|
||||
|
||||
mockIo.mockReturnValue(mockSocket)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ token: 'fresh-token' }),
|
||||
} as Response)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
const renderSocketProvider = async (user = { id: 'test-user', name: 'Test User' }) => {
|
||||
const result = renderHook(() => useSocket(), {
|
||||
wrapper: ({ children }) => <SocketProvider user={user}>{children}</SocketProvider>,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSocket.on).toHaveBeenCalledWith('connect_error', expect.any(Function))
|
||||
})
|
||||
|
||||
vi.clearAllMocks()
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ token: 'fresh-token' }),
|
||||
} as Response)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
describe('Token Refresh on Connection Error', () => {
|
||||
it('should refresh token on authentication failure', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
const error = { message: 'Token validation failed' }
|
||||
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith('/api/auth/socket-token', {
|
||||
method: 'POST',
|
||||
credentials: 'include',
|
||||
})
|
||||
|
||||
// Should update socket auth and reconnect
|
||||
expect(mockSocket.auth.token).toBe('fresh-token')
|
||||
expect(mockSocket.connect).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should limit token refresh attempts to 3', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
const error = { message: 'Token validation failed' }
|
||||
|
||||
for (let i = 0; i < 4; i++) {
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
}
|
||||
|
||||
// Should only call fetch 3 times (max attempts)
|
||||
expect(mockFetch).toHaveBeenCalledTimes(3)
|
||||
expect(mockSocket.connect).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('should prevent concurrent token refresh attempts', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
let resolveTokenFetch!: (value: {
|
||||
ok: boolean
|
||||
json: () => Promise<{ token: string }>
|
||||
}) => void
|
||||
const slowTokenPromise = new Promise((resolve) => {
|
||||
resolveTokenFetch = resolve
|
||||
})
|
||||
|
||||
mockFetch.mockReturnValue(slowTokenPromise as any)
|
||||
|
||||
const error = { message: 'Authentication failed' }
|
||||
|
||||
// Start two concurrent refresh attempts
|
||||
const promise1 = act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
const promise2 = act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
// Resolve the slow fetch
|
||||
resolveTokenFetch({
|
||||
ok: true,
|
||||
json: async () => ({ token: 'fresh-token' }),
|
||||
})
|
||||
|
||||
await Promise.all([promise1, promise2])
|
||||
|
||||
// Should only call fetch once (concurrent protection)
|
||||
expect(mockFetch).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should reset retry counter on successful connection', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
const error = { message: 'Token validation failed' }
|
||||
|
||||
// Use up 2 retry attempts
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledTimes(2)
|
||||
|
||||
// Simulate successful connection (resets counter)
|
||||
await act(async () => {
|
||||
eventHandlers.connect()
|
||||
})
|
||||
|
||||
// Should be able to retry again (counter reset)
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('should handle token refresh failure gracefully', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
// Mock failed token refresh after initialization
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: false,
|
||||
status: 401,
|
||||
} as Response)
|
||||
|
||||
const error = { message: 'Token validation failed' }
|
||||
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
// Should attempt refresh but not update auth or reconnect
|
||||
expect(mockFetch).toHaveBeenCalled()
|
||||
expect(mockSocket.auth.token).toBe('initial-token') // unchanged
|
||||
expect(mockSocket.connect).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle fetch errors gracefully', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
// Mock fetch error after initialization
|
||||
mockFetch.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
const error = { message: 'Authentication failed' }
|
||||
|
||||
// Should not throw error
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
expect(mockFetch).toHaveBeenCalled()
|
||||
expect(mockSocket.connect).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should only refresh token on authentication-related errors', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
// Non-authentication error
|
||||
const networkError = { message: 'Network timeout' }
|
||||
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(networkError)
|
||||
})
|
||||
|
||||
// Should not attempt token refresh
|
||||
expect(mockFetch).not.toHaveBeenCalled()
|
||||
expect(mockSocket.connect).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Interaction with Socket.IO Reconnection', () => {
|
||||
it('should work with Socket.IO built-in reconnection attempts', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
// Simulate Socket.IO reconnection cycle
|
||||
await act(async () => {
|
||||
// Reconnection attempt starts
|
||||
eventHandlers.reconnect_attempt(1)
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
// Fails with auth error
|
||||
await eventHandlers.connect_error({ message: 'Token validation failed' })
|
||||
})
|
||||
|
||||
// Should refresh token and attempt reconnection
|
||||
expect(mockFetch).toHaveBeenCalled()
|
||||
expect(mockSocket.connect).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should reset counters on successful reconnect', async () => {
|
||||
const { result } = await renderSocketProvider()
|
||||
|
||||
// Use up retry attempts
|
||||
const error = { message: 'Authentication failed' }
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledTimes(2)
|
||||
|
||||
// Simulate successful reconnection
|
||||
await act(async () => {
|
||||
eventHandlers.reconnect(1)
|
||||
})
|
||||
|
||||
// Should reset and allow new attempts
|
||||
await act(async () => {
|
||||
await eventHandlers.connect_error(error)
|
||||
})
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -87,6 +87,8 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
|
||||
|
||||
// Connection state tracking
|
||||
const reconnectCount = useRef(0)
|
||||
const tokenRefreshAttempts = useRef(0)
|
||||
const isRefreshingToken = useRef(false)
|
||||
|
||||
// Use refs to store event handlers to avoid stale closures
|
||||
const eventHandlers = useRef<{
|
||||
@@ -138,7 +140,9 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
|
||||
const socketInstance = io(socketUrl, {
|
||||
transports: ['websocket', 'polling'], // Keep polling fallback for reliability
|
||||
withCredentials: true,
|
||||
reconnectionAttempts: 5, // Back to original conservative setting
|
||||
reconnectionAttempts: 5, // Socket.IO handles base reconnection
|
||||
reconnectionDelay: 1000, // Start with 1 second delay
|
||||
reconnectionDelayMax: 5000, // Max 5 second delay
|
||||
timeout: 10000, // Back to original timeout
|
||||
auth: {
|
||||
token, // Send one-time token for authentication
|
||||
@@ -150,6 +154,7 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
|
||||
setIsConnected(true)
|
||||
setIsConnecting(false)
|
||||
reconnectCount.current = 0
|
||||
tokenRefreshAttempts.current = 0
|
||||
|
||||
logger.info('Socket connected successfully', {
|
||||
socketId: socketInstance.id,
|
||||
@@ -172,7 +177,7 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
|
||||
setPresenceUsers([])
|
||||
})
|
||||
|
||||
socketInstance.on('connect_error', (error: any) => {
|
||||
socketInstance.on('connect_error', async (error: any) => {
|
||||
setIsConnecting(false)
|
||||
logger.error('Socket connection error:', {
|
||||
message: error.message,
|
||||
@@ -181,11 +186,54 @@ export function SocketProvider({ children, user }: SocketProviderProps) {
|
||||
type: error.type,
|
||||
transport: error.transport,
|
||||
})
|
||||
|
||||
if (
|
||||
error.message?.includes('Token validation failed') ||
|
||||
error.message?.includes('Authentication failed')
|
||||
) {
|
||||
// Prevent infinite loops - limit refresh attempts
|
||||
if (tokenRefreshAttempts.current >= 3) {
|
||||
logger.warn('Max token refresh attempts reached - user needs to refresh page')
|
||||
return
|
||||
}
|
||||
|
||||
// Prevent concurrent refresh attempts
|
||||
if (isRefreshingToken.current) {
|
||||
logger.info('Token refresh already in progress, skipping...')
|
||||
return
|
||||
}
|
||||
|
||||
isRefreshingToken.current = true
|
||||
tokenRefreshAttempts.current++
|
||||
logger.info(`Token expired, attempting refresh (${tokenRefreshAttempts.current}/3)...`)
|
||||
|
||||
try {
|
||||
const tokenResponse = await fetch('/api/auth/socket-token', {
|
||||
method: 'POST',
|
||||
credentials: 'include',
|
||||
})
|
||||
|
||||
if (tokenResponse.ok) {
|
||||
const { token } = await tokenResponse.json()
|
||||
socketInstance.auth = { ...socketInstance.auth, token }
|
||||
logger.info('Token refreshed successfully, reconnecting...')
|
||||
socketInstance.connect()
|
||||
} else {
|
||||
logger.warn('Failed to refresh token - user may need to refresh page')
|
||||
}
|
||||
} catch (refreshError) {
|
||||
logger.error('Token refresh failed:', refreshError)
|
||||
} finally {
|
||||
isRefreshingToken.current = false
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Add reconnection logging
|
||||
socketInstance.on('reconnect', (attemptNumber) => {
|
||||
reconnectCount.current = attemptNumber
|
||||
// Reset token refresh attempts on successful reconnection
|
||||
tokenRefreshAttempts.current = 0
|
||||
logger.info('Socket reconnected', {
|
||||
attemptNumber,
|
||||
})
|
||||
|
||||
@@ -187,7 +187,7 @@ export const auth = betterAuth({
|
||||
plugins: [
|
||||
nextCookies(),
|
||||
oneTimeToken({
|
||||
expiresIn: 10, // 10 minutes - enough time for socket connection
|
||||
expiresIn: 30, // 30 minutes - covers typical work sessions
|
||||
}),
|
||||
emailOTP({
|
||||
sendVerificationOTP: async (data: {
|
||||
|
||||
@@ -36,7 +36,7 @@ export function createSocketIOServer(httpServer: HttpServer): Server {
|
||||
allowedHeaders: ['Content-Type', 'Authorization', 'Cookie', 'socket.io'],
|
||||
credentials: true, // Enable credentials to accept cookies
|
||||
},
|
||||
transports: ['polling', 'websocket'], // Keep both transports for reliability
|
||||
transports: ['websocket', 'polling'], // WebSocket first, polling as fallback
|
||||
allowEIO3: true, // Keep legacy support for compatibility
|
||||
pingTimeout: 60000, // Back to original conservative setting
|
||||
pingInterval: 25000, // Back to original interval
|
||||
@@ -52,7 +52,7 @@ export function createSocketIOServer(httpServer: HttpServer): Server {
|
||||
|
||||
logger.info('Socket.IO server configured with:', {
|
||||
allowedOrigins: allowedOrigins.length,
|
||||
transports: ['polling', 'websocket'],
|
||||
transports: ['websocket', 'polling'],
|
||||
pingTimeout: 60000,
|
||||
pingInterval: 25000,
|
||||
maxHttpBufferSize: 1e6,
|
||||
|
||||
Reference in New Issue
Block a user