From 4168642d6e92416687d431a29e8516ebf9543dfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Vit=C3=B3ria=20Silva?= Date: Tue, 6 Jan 2026 23:32:31 +0000 Subject: [PATCH] Refactor WebSocket manager and authentication flow Replaces websocket.schema with websocket.manager for managing WebSocket connections, introducing a singleton WebSocketManager and updating all imports and usages accordingly. Adds token-based authentication for WebSocket connections, requiring access_token as a query parameter and validating it server-side. Updates FastAPI WebSocket endpoint to use authenticated user ID, and modifies frontend to connect using the access token. Removes obsolete schema.py and improves error handling and logging for WebSocket events. --- backend/app/activities/activity/crud.py | 4 +- backend/app/activities/activity/router.py | 14 ++-- backend/app/activities/activity/utils.py | 12 +-- backend/app/auth/security.py | 69 ++++++++++++++++- backend/app/core/routes.py | 4 - backend/app/followers/crud.py | 10 +-- backend/app/followers/router.py | 10 +-- backend/app/garmin/activity_utils.py | 8 +- backend/app/garmin/router.py | 10 +-- backend/app/garmin/utils.py | 8 +- backend/app/notifications/utils.py | 20 +++-- backend/app/profile/import_service.py | 4 +- backend/app/profile/router.py | 6 +- backend/app/sign_up_tokens/router.py | 8 +- backend/app/strava/activity_utils.py | 12 +-- backend/app/strava/router.py | 6 +- backend/app/websocket/__init__.py | 10 +++ backend/app/websocket/manager.py | 92 +++++++++++++++++++++++ backend/app/websocket/router.py | 48 ++++++++---- backend/app/websocket/schema.py | 34 --------- backend/app/websocket/utils.py | 40 ++++++---- frontend/app/src/main.js | 2 + frontend/app/src/stores/authStore.js | 14 +++- 23 files changed, 311 insertions(+), 134 deletions(-) create mode 100644 backend/app/websocket/manager.py delete mode 100644 backend/app/websocket/schema.py diff --git a/backend/app/activities/activity/crud.py b/backend/app/activities/activity/crud.py index ebd364aba..3236aea63 100644 --- a/backend/app/activities/activity/crud.py +++ b/backend/app/activities/activity/crud.py @@ -14,7 +14,7 @@ import notifications.utils as notifications_utils import server_settings.utils as server_settings_utils -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager from fastapi import HTTPException, status from pydantic import BaseModel @@ -1190,7 +1190,7 @@ def get_activities_if_contains_name(name: str, user_id: int, db: Session): async def create_activity( activity: activities_schema.Activity, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, create_notification: bool = True, ) -> activities_schema.Activity: diff --git a/backend/app/activities/activity/router.py b/backend/app/activities/activity/router.py index 7923c2878..781fb4b5f 100644 --- a/backend/app/activities/activity/router.py +++ b/backend/app/activities/activity/router.py @@ -20,7 +20,7 @@ import auth.security as auth_security import users.user.dependencies as users_dependencies import garmin.activity_utils as garmin_activity_utils import strava.activity_utils as strava_activity_utils -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager from fastapi import ( APIRouter, Depends, @@ -513,8 +513,8 @@ async def read_activities_user_activities_refresh( Depends(core_database.get_db), ], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], ): # Set the activities to empty list @@ -618,8 +618,8 @@ async def create_activity_with_uploaded_file( Callable, Security(auth_security.check_scopes, scopes=["activities:write"]) ], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], db: Annotated[ Session, @@ -653,8 +653,8 @@ async def create_activity_with_bulk_import( Callable, Security(auth_security.check_scopes, scopes=["activities:write"]) ], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], ): try: diff --git a/backend/app/activities/activity/utils.py b/backend/app/activities/activity/utils.py index 19fc6ffee..77f295592 100644 --- a/backend/app/activities/activity/utils.py +++ b/backend/app/activities/activity/utils.py @@ -37,7 +37,7 @@ import activities.activity_streams.schema as activity_streams_schema import activities.activity_workout_steps.crud as activity_workout_steps_crud -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager import gpx.utils as gpx_utils import tcx.utils as tcx_utils @@ -389,7 +389,7 @@ def handle_gzipped_file( async def parse_and_store_activity_from_file( token_user_id: int, file_path: str, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, from_garmin: bool = False, garminconnect_gear: dict | None = None, @@ -540,7 +540,7 @@ async def parse_and_store_activity_from_file( async def parse_and_store_activity_from_uploaded_file( token_user_id: int, file: UploadFile, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ): # Validate filename exists @@ -742,7 +742,9 @@ def parse_file( async def store_activity( - parsed_info: dict, websocket_manager: websocket_schema.WebSocketManager, db: Session + parsed_info: dict,websocket_manager + websocket_manager: websocket_manager.WebSocketManager, + db: Session, ): # create the activity in the database created_activity = await activities_crud.create_activity( @@ -1193,7 +1195,7 @@ def set_activity_name_based_on_activity_type(activity_type_id: int) -> str: def process_all_files_sync( user_id: int, file_paths: list[str], - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, ): """ Process all files sequentially in single thread. diff --git a/backend/app/auth/security.py b/backend/app/auth/security.py index 49387823c..b2a97f9ed 100644 --- a/backend/app/auth/security.py +++ b/backend/app/auth/security.py @@ -1,5 +1,5 @@ from typing import Annotated, Union -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, Query, status, WebSocket, WebSocketException from fastapi.security import ( OAuth2PasswordBearer, SecurityScopes, @@ -618,3 +618,70 @@ def check_scopes_for_browser_redirect( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during scope validation.", ) from err + + +## WEBSOCKET TOKEN VALIDATION +async def validate_websocket_access_token( + websocket: WebSocket, + access_token: str = Query(..., alias="access_token"), + token_manager: auth_token_manager.TokenManager = Depends( + auth_token_manager.get_token_manager + ), +) -> int: + """ + Validate access token for WebSocket connections. + + WebSocket connections cannot use Authorization headers during + the handshake, so tokens are passed via query parameters. + + Args: + websocket: The WebSocket connection instance. + access_token: Access token from query parameter. + token_manager: Token manager for validation. + + Returns: + The authenticated user ID from the token. + + Raises: + WebSocketException: If token is missing, invalid, or expired. + """ + try: + # Validate token expiration + token_manager.validate_token_expiration(access_token) + + # Get user ID from token + token_user_id = token_manager.get_token_claim(access_token, "sub") + + if token_user_id is None or isinstance(token_user_id, list): + core_logger.print_to_log( + "WebSocket token validation failed: invalid sub claim", + "warning", + ) + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, + reason="Invalid token: missing user ID", + ) + + return int(token_user_id) + except WebSocketException as ws_err: + raise ws_err + except HTTPException as http_err: + core_logger.print_to_log( + f"WebSocket token validation failed: {http_err.detail}", + "warning", + exc=http_err, + ) + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, + reason="Invalid or expired token", + ) from http_err + except Exception as err: + core_logger.print_to_log( + f"Unexpected error during WebSocket token validation: {err}", + "error", + exc=err, + ) + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, + reason="Token validation failed", + ) from err diff --git a/backend/app/core/routes.py b/backend/app/core/routes.py index 85963fd9e..624c51767 100644 --- a/backend/app/core/routes.py +++ b/backend/app/core/routes.py @@ -247,10 +247,6 @@ router.include_router( websocket_router.router, prefix=core_config.ROOT_PATH + "/ws", tags=["websocket"], - # dependencies=[ - # Depends(auth_security.validate_access_token), - # Security(auth_security.check_scopes, scopes=["profile"]), - # ], ) # PUBLIC ROUTES (alphabetical order) diff --git a/backend/app/followers/crud.py b/backend/app/followers/crud.py index 874f458c4..7109d96ae 100644 --- a/backend/app/followers/crud.py +++ b/backend/app/followers/crud.py @@ -9,7 +9,7 @@ import core.logger as core_logger import notifications.utils as notifications_utils -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager def get_all_followers_by_user_id(user_id: int, db: Session): @@ -163,7 +163,7 @@ def get_follower_for_user_id_and_target_user_id( async def create_follower( user_id: int, target_user_id: int, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ): try: @@ -175,7 +175,7 @@ async def create_follower( # Add the new follow relationship to the database db.add(new_follow) db.commit() - + await notifications_utils.create_new_follower_request_notification( user_id, target_user_id, websocket_manager, db ) @@ -199,7 +199,7 @@ async def create_follower( async def accept_follower( user_id: int, target_user_id: int, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ): try: @@ -226,7 +226,7 @@ async def accept_follower( # Commit the transaction db.commit() - + await notifications_utils.create_accepted_follower_request_notification( user_id, target_user_id, websocket_manager, db ) diff --git a/backend/app/followers/router.py b/backend/app/followers/router.py index da472c84a..098f6bb94 100644 --- a/backend/app/followers/router.py +++ b/backend/app/followers/router.py @@ -14,7 +14,7 @@ import auth.security as auth_security import core.database as core_database -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager # Define the API router router = APIRouter() @@ -207,8 +207,8 @@ async def create_follow( Callable, Security(auth_security.check_scopes, scopes=["profile"]) ], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], db: Annotated[ Session, @@ -239,8 +239,8 @@ async def accept_follow( Callable, Security(auth_security.check_scopes, scopes=["profile"]) ], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], db: Annotated[ Session, diff --git a/backend/app/garmin/activity_utils.py b/backend/app/garmin/activity_utils.py index 1c30a9568..86cd2916c 100644 --- a/backend/app/garmin/activity_utils.py +++ b/backend/app/garmin/activity_utils.py @@ -16,7 +16,7 @@ import activities.activity.crud as activities_crud import users.user.crud as users_crud -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager from core.database import SessionLocal @@ -26,7 +26,7 @@ async def fetch_and_process_activities_by_dates( start_date: datetime, end_date: datetime, user_id: int, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ) -> list[activities_schema.Activity] | None: try: @@ -125,7 +125,7 @@ async def fetch_and_process_activities_by_dates( async def retrieve_garminconnect_users_activities_for_days(days: int): - websocket_manager = websocket_schema.get_websocket_manager() + websocket_manager = websocket_manager.get_websocket_manager() # Create a new database session using context manager with SessionLocal() as db: @@ -195,7 +195,7 @@ async def get_user_garminconnect_activities_by_dates( start_date: datetime, end_date: datetime, user_id: int, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ) -> list[activities_schema.Activity] | None: try: diff --git a/backend/app/garmin/router.py b/backend/app/garmin/router.py index 39ee59193..6e99afde5 100644 --- a/backend/app/garmin/router.py +++ b/backend/app/garmin/router.py @@ -13,7 +13,7 @@ import garmin.activity_utils as garmin_activity_utils import garmin.health_utils as garmin_health_utils import garmin.gear_utils as garmin_gear_utils -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager import core.logger as core_logger @@ -37,8 +37,8 @@ async def garminconnect_link( garmin_schema.MFACodeStore, Depends(garmin_schema.get_mfa_store) ], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], ): # Link Garmin Connect account @@ -84,8 +84,8 @@ async def garminconnect_retrieve_activities_days( ], db: Annotated[Session, Depends(core_database.get_db)], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], background_tasks: BackgroundTasks, ): diff --git a/backend/app/garmin/utils.py b/backend/app/garmin/utils.py index b930996d5..166cfea3d 100644 --- a/backend/app/garmin/utils.py +++ b/backend/app/garmin/utils.py @@ -16,7 +16,7 @@ import core.cryptography as core_cryptography import users.user_integrations.schema as user_integrations_schema import users.user_integrations.crud as user_integrations_crud -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager import websocket.utils as websocket_utils import garmin.schema as garmin_schema @@ -27,7 +27,7 @@ import core.logger as core_logger async def get_mfa( user_id: int, mfa_codes: garmin_schema.MFACodeStore, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, ) -> str: # Notify frontend that MFA is required await notify_frontend_mfa_required(user_id, websocket_manager) @@ -42,7 +42,7 @@ async def get_mfa( async def notify_frontend_mfa_required( - user_id: int, websocket_manager: websocket_schema.WebSocketManager + user_id: int, websocket_manager: websocket_manager.WebSocketManager ): try: json_data = {"message": "MFA_REQUIRED", "user_id": user_id} @@ -57,7 +57,7 @@ async def link_garminconnect( password: str, db: Session, mfa_codes: garmin_schema.MFACodeStore, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, ): # Define MFA callback as a coroutine async def async_mfa_callback(): diff --git a/backend/app/notifications/utils.py b/backend/app/notifications/utils.py index 7aeaab35b..f8c5a7633 100644 --- a/backend/app/notifications/utils.py +++ b/backend/app/notifications/utils.py @@ -13,7 +13,7 @@ import users.user.models as users_models import users.user.utils as users_utils import websocket.utils as websocket_utils -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager def serialize_notification(notification: notifications_schema.Notification): @@ -25,7 +25,9 @@ def serialize_notification(notification: notifications_schema.Notification): async def create_new_activity_notification( - user_id: int, activity_id: int, websocket_manager: websocket_schema.WebSocketManager + user_id: int, + activity_id: int, + websocket_manager: websocket_manager.WebSocketManager, ): # Create a new database session using context manager with SessionLocal() as db: @@ -64,7 +66,9 @@ async def create_new_activity_notification( async def create_new_duplicate_start_time_activity_notification( - user_id: int, activity_id: int, websocket_manager: websocket_schema.WebSocketManager + user_id: int, + activity_id: int, + websocket_manager: websocket_manager.WebSocketManager, ): # Create a new database session using context manager with SessionLocal() as db: @@ -105,9 +109,9 @@ async def create_new_duplicate_start_time_activity_notification( async def create_new_follower_request_notification( - user_id: int, + user_id: int,websocket_manager target_user_id: int, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ): try: @@ -160,9 +164,9 @@ async def create_new_follower_request_notification( async def create_accepted_follower_request_notification( - user_id: int, + user_id: int,websocket_manager target_user_id: int, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ): try: @@ -214,7 +218,7 @@ async def create_accepted_follower_request_notification( async def create_admin_new_sign_up_approval_request_notification( user: users_models.User, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ): try: diff --git a/backend/app/profile/import_service.py b/backend/app/profile/import_service.py index 18b1b3e54..2c7cf7ef2 100644 --- a/backend/app/profile/import_service.py +++ b/backend/app/profile/import_service.py @@ -73,7 +73,7 @@ import health.health_weight.schema as health_weight_schema import health.health_targets.crud as health_targets_crud import health.health_targets.schema as health_targets_schema -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager class ImportPerformanceConfig(profile_utils.BasePerformanceConfig): @@ -165,7 +165,7 @@ class ImportService: self, user_id: int, db: Session, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, performance_config: ImportPerformanceConfig | None = None, ): self.user_id = user_id diff --git a/backend/app/profile/router.py b/backend/app/profile/router.py index 8d2693f4c..f9da497d3 100644 --- a/backend/app/profile/router.py +++ b/backend/app/profile/router.py @@ -49,7 +49,7 @@ import core.logger as core_logger import core.rate_limit as core_rate_limit import core.config as core_config -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager # Define the API router router = APIRouter() @@ -458,8 +458,8 @@ async def import_profile_data( ], db: Annotated[Session, Depends(core_database.get_db)], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], ): """ diff --git a/backend/app/sign_up_tokens/router.py b/backend/app/sign_up_tokens/router.py index 024acb49d..5814f899f 100644 --- a/backend/app/sign_up_tokens/router.py +++ b/backend/app/sign_up_tokens/router.py @@ -24,7 +24,7 @@ import server_settings.utils as server_settings_utils import core.database as core_database import core.apprise as core_apprise -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager # Define the API router router = APIRouter() @@ -54,7 +54,7 @@ async def signup( - user (users_schema.UserSignup): The payload containing the user's sign-up information. - email_service (core_apprise.AppriseService): Injected email service used to send verification and admin approval emails. - - websocket_manager (websocket_schema.WebSocketManager): Injected manager used to send + - websocket_manager (websocket_manager.WebSocketManager): Injected manager used to send real-time notifications (e.g., admin approval requests). - password_hasher (auth_password_hasher.PasswordHasher): Injected password hasher used to hash user passwords. - db (Session): Database session/connection used to create the user and related records. @@ -156,8 +156,8 @@ async def verify_email( Depends(core_apprise.get_email_service), ], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], db: Annotated[ Session, diff --git a/backend/app/strava/activity_utils.py b/backend/app/strava/activity_utils.py index 27a7d4eb1..2991515d1 100644 --- a/backend/app/strava/activity_utils.py +++ b/backend/app/strava/activity_utils.py @@ -30,7 +30,7 @@ import gears.gear.crud as gears_crud import strava.utils as strava_utils -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager from core.database import SessionLocal @@ -41,7 +41,7 @@ async def fetch_and_process_activities( end_date: datetime, user_id: int, user_integrations: user_integrations_schema.UsersIntegrations, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, is_startup: bool = False, ) -> int: @@ -369,7 +369,7 @@ async def save_activity_streams_laps( activity: activities_schema.Activity, stream_data: list, laps: dict, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ) -> activities_schema.Activity: # Create the activity and get the ID @@ -411,7 +411,7 @@ async def process_activity( user_privacy_settings: users_privacy_settings_schema.UsersPrivacySettings, strava_client: Client, user_integrations: user_integrations_schema.UsersIntegrations, - websocket_manager: websocket_schema.WebSocketManager, + websocket_manager: websocket_manager.WebSocketManager, db: Session, ): # Get the activity by Strava ID from the user @@ -769,7 +769,7 @@ async def get_user_strava_activities_by_dates( start_date: datetime, end_date: datetime, user_id: int, - websocket_manager: websocket_schema.WebSocketManager = None, + websocket_manager: websocket_manager.WebSocketManager = None, db: Session = None, is_startup: bool = False, ) -> list[activities_schema.Activity] | None: @@ -781,7 +781,7 @@ async def get_user_strava_activities_by_dates( if websocket_manager is None: # Get the websocket manager instance - websocket_manager = websocket_schema.get_websocket_manager() + websocket_manager = websocket_manager.get_websocket_manager() try: # Get the user integrations by user ID diff --git a/backend/app/strava/router.py b/backend/app/strava/router.py index b2162ded4..276e5589c 100644 --- a/backend/app/strava/router.py +++ b/backend/app/strava/router.py @@ -25,7 +25,7 @@ import core.cryptography as core_cryptography import core.logger as core_logger import core.database as core_database -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager # Define the API router router = APIRouter() @@ -121,8 +121,8 @@ async def strava_retrieve_activities_days( Depends(auth_security.get_sub_from_access_token), ], websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), ], # db: Annotated[Session, Depends(core_database.get_db)], background_tasks: BackgroundTasks, diff --git a/backend/app/websocket/__init__.py b/backend/app/websocket/__init__.py index e69de29bb..46a3098c0 100644 --- a/backend/app/websocket/__init__.py +++ b/backend/app/websocket/__init__.py @@ -0,0 +1,10 @@ +"""WebSocket module for real-time notifications and communication.""" + +from websocket.manager import WebSocketManager, get_websocket_manager +from websocket.utils import notify_frontend + +__all__ = [ + "WebSocketManager", + "get_websocket_manager", + "notify_frontend", +] diff --git a/backend/app/websocket/manager.py b/backend/app/websocket/manager.py new file mode 100644 index 000000000..e4b02e3be --- /dev/null +++ b/backend/app/websocket/manager.py @@ -0,0 +1,92 @@ +from functools import lru_cache + +from fastapi import WebSocket + +import core.logger as core_logger + + +class WebSocketManager: + """ + Manage active WebSocket connections per user. + + Maintains a registry of authenticated WebSocket connections + indexed by user ID, enabling message broadcasting and + targeted notifications. + + Attributes: + active_connections: Maps user IDs to their WebSocket. + """ + + def __init__(self) -> None: + """Initialize the WebSocket manager with empty connections.""" + self.active_connections: dict[int, WebSocket] = {} + + async def connect(self, user_id: int, websocket: WebSocket) -> None: + """ + Accept and register a new WebSocket connection. + + Args: + user_id: The user's unique identifier. + websocket: The WebSocket connection to register. + """ + await websocket.accept() + self.active_connections[user_id] = websocket + core_logger.print_to_log(f"WebSocket connected for user {user_id}", "info") + + def disconnect(self, user_id: int) -> None: + """ + Remove a user's WebSocket connection. + + Args: + user_id: The user's unique identifier. + """ + if self.active_connections.pop(user_id, None): + core_logger.print_to_log( + f"WebSocket disconnected for user {user_id}", + "info", + ) + + async def send_message(self, user_id: int, message: dict) -> None: + """ + Send a JSON message to a specific user. + + Args: + user_id: The user's unique identifier. + message: JSON-serializable data to send. + """ + websocket = self.active_connections.get(user_id) + if websocket: + await websocket.send_json(message) + + async def broadcast(self, message: dict) -> None: + """ + Send a JSON message to all connected users. + + Args: + message: JSON-serializable data to broadcast. + """ + for websocket in self.active_connections.values(): + await websocket.send_json(message) + + def get_connection(self, user_id: int) -> WebSocket | None: + """ + Retrieve a user's WebSocket connection. + + Args: + user_id: The user's unique identifier. + + Returns: + The WebSocket connection or None if not found. + """ + return self.active_connections.get(user_id) + + +@lru_cache(maxsize=1) +def get_websocket_manager() -> WebSocketManager: + """ + Get the singleton WebSocket manager instance. + + Returns: + The shared WebSocketManager instance. + """ + return WebSocketManager() diff --git a/backend/app/websocket/router.py b/backend/app/websocket/router.py index 0760b5eb0..67b3b1a9c 100644 --- a/backend/app/websocket/router.py +++ b/backend/app/websocket/router.py @@ -1,28 +1,50 @@ from typing import Annotated from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect -import websocket.schema as websocket_schema + +import auth.security as auth_security +import websocket.manager as websocket_manager + +import core.logger as core_logger # Define the API router router = APIRouter() -@router.websocket("/{user_id}") +@router.websocket("") async def websocket_endpoint( - user_id: int, websocket: WebSocket, - websocket_manager: Annotated[ - websocket_schema.WebSocketManager, - Depends(websocket_schema.get_websocket_manager), + token_user_id: Annotated[ + int, Depends(auth_security.validate_websocket_access_token) ], -): - # Connect and manage WebSocket connections using the manager - await websocket_manager.connect(user_id, websocket) + websocket_manager: Annotated[ + websocket_manager.WebSocketManager, + Depends(websocket_manager.get_websocket_manager), + ], +) -> None: + """ + Handle WebSocket connections for real-time notifications. + + Establishes authenticated WebSocket connection for receiving + real-time notifications, MFA requests, and activity updates. + + Args: + websocket: The WebSocket connection instance. + user_id: Authenticated user ID from access token. + websocket_manager: Manager for WebSocket connections. + """ + await websocket_manager.connect(token_user_id, websocket) try: while True: - # Handle incoming messages if necessary (currently just keeping connection alive) - data = await websocket.receive_json() + try: + # Keep connection alive, handle incoming messages + await websocket.receive_json() + except ValueError: + # Log malformed JSON, keep connection alive + core_logger.print_to_log( + f"Received malformed JSON from user {token_user_id}", + "warning", + ) except WebSocketDisconnect: - # Disconnect using the manager - websocket_manager.disconnect(user_id) + websocket_manager.disconnect(token_user_id) diff --git a/backend/app/websocket/schema.py b/backend/app/websocket/schema.py deleted file mode 100644 index d552d5134..000000000 --- a/backend/app/websocket/schema.py +++ /dev/null @@ -1,34 +0,0 @@ -from fastapi import WebSocket -from typing import Dict - - -class WebSocketManager: - def __init__(self): - self.active_connections: Dict[int, WebSocket] = {} - - async def connect(self, user_id: int, websocket: WebSocket): - await websocket.accept() - self.active_connections[user_id] = websocket - - def disconnect(self, user_id: int): - if user_id in self.active_connections: - del self.active_connections[user_id] - - async def send_message(self, user_id: int, message: str): - if user_id in self.active_connections: - websocket = self.active_connections[user_id] - await websocket.send_json(message) - - async def broadcast(self, message: str): - for websocket in self.active_connections.values(): - await websocket.send_json(message) - - def get_connection(self, user_id: int) -> WebSocket | None: - return self.active_connections.get(user_id) - - -def get_websocket_manager(): - return websocket_manager - - -websocket_manager = WebSocketManager() diff --git a/backend/app/websocket/utils.py b/backend/app/websocket/utils.py index 84429b9a8..d2fd5722b 100644 --- a/backend/app/websocket/utils.py +++ b/backend/app/websocket/utils.py @@ -1,30 +1,38 @@ from fastapi import HTTPException, status -import websocket.schema as websocket_schema +import websocket.manager as websocket_manager async def notify_frontend( - user_id: int, websocket_manager: websocket_schema.WebSocketManager, json_data: dict -): + user_id: int, + websocket_manager: websocket_manager.WebSocketManager, + json_data: dict, +) -> bool: """ - Sends a JSON message to the frontend via an active WebSocket connection for a specific user. + Send a JSON message to a user's WebSocket connection. + + Attempts to send data to the user's WebSocket. For MFA + verification, raises an exception if no connection exists. Args: - user_id (int): The ID of the user to notify. - websocket_manager (websocket_schema.WebSocketManager): The manager handling WebSocket connections. - json_data (dict): The JSON-serializable data to send to the frontend. + user_id: The target user's identifier. + websocket_manager: The WebSocket connection manager. + json_data: JSON-serializable data to send. + + Returns: + True if message was sent, False if no connection. Raises: - HTTPException: If there is no active WebSocket connection for the specified user. + HTTPException: If MFA_REQUIRED but no connection exists. """ - # Check if the user has an active WebSocket connection - websocket = websocket_manager.get_connection(user_id) if websocket: await websocket.send_json(json_data) - else: - if json_data.get("message") == "MFA_REQUIRED": - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"No active WebSocket connection for user {user_id}", - ) + return True + + if json_data.get("message") == "MFA_REQUIRED": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"No WebSocket connection for user {user_id}", + ) + return False diff --git a/frontend/app/src/main.js b/frontend/app/src/main.js index 72b94557e..b6f2b58d4 100644 --- a/frontend/app/src/main.js +++ b/frontend/app/src/main.js @@ -73,6 +73,8 @@ async function initApp() { if (authStore.isAuthenticated && !authStore.getAccessToken()) { try { await authStore.refreshAccessToken() + // Set up WebSocket after token is available + authStore.setUserWebsocket() } catch (error) { // If refresh fails, clear user and redirect to login console.error('Failed to restore session on app init:', error) diff --git a/frontend/app/src/stores/authStore.js b/frontend/app/src/stores/authStore.js index e5ba5f90c..36a500f60 100644 --- a/frontend/app/src/stores/authStore.js +++ b/frontend/app/src/stores/authStore.js @@ -140,7 +140,7 @@ export const useAuthStore = defineStore('auth', { this.user = JSON.parse(storedUser) this.isAuthenticated = true this.setLocale(this.user.preferred_language, locale) - this.setUserWebsocket() + // WebSocket setup deferred until access token is available this.session_id = localStorage.getItem('session_id') } }, @@ -159,11 +159,11 @@ export const useAuthStore = defineStore('auth', { setUserWebsocket() { const urlSplit = API_URL.split('://') const protocol = urlSplit[0] === 'http' ? 'ws' : 'wss' - const websocketURL = `${protocol}://${urlSplit[1]}ws/${this.user.id}` + const websocketURL = `${protocol}://${urlSplit[1]}ws?access_token=${this.accessToken}` try { this.user_websocket = new WebSocket(websocketURL) this.user_websocket.onopen = () => { - console.log(`WebSocket connection established using ${websocketURL}.`) + console.log(`WebSocket connection established for user ID ${this.user.id}`) } this.user_websocket.onerror = (error) => { console.error('WebSocket error:', error) @@ -195,6 +195,14 @@ export const useAuthStore = defineStore('auth', { if (response.csrf_token) { this.csrfToken = response.csrf_token } + // Reconnect WebSocket with new token if currently open + if ( + this.user_websocket && + this.user_websocket.readyState === WebSocket.OPEN + ) { + this.user_websocket.close() + this.setUserWebsocket() + } return response.access_token } throw new Error('No access token in refresh response')