Refactor WebSocket manager and authentication flow

Replaces websocket.schema with websocket.manager for managing WebSocket connections, introducing a singleton WebSocketManager and updating all imports and usages accordingly. Adds token-based authentication for WebSocket connections, requiring access_token as a query parameter and validating it server-side. Updates FastAPI WebSocket endpoint to use authenticated user ID, and modifies frontend to connect using the access token. Removes obsolete schema.py and improves error handling and logging for WebSocket events.
This commit is contained in:
João Vitória Silva
2026-01-06 23:32:31 +00:00
parent 5d32f8c649
commit 4168642d6e
23 changed files with 311 additions and 134 deletions

View File

@@ -14,7 +14,7 @@ import notifications.utils as notifications_utils
import server_settings.utils as server_settings_utils
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
from fastapi import HTTPException, status
from pydantic import BaseModel
@@ -1190,7 +1190,7 @@ def get_activities_if_contains_name(name: str, user_id: int, db: Session):
async def create_activity(
activity: activities_schema.Activity,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
create_notification: bool = True,
) -> activities_schema.Activity:

View File

@@ -20,7 +20,7 @@ import auth.security as auth_security
import users.user.dependencies as users_dependencies
import garmin.activity_utils as garmin_activity_utils
import strava.activity_utils as strava_activity_utils
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
from fastapi import (
APIRouter,
Depends,
@@ -513,8 +513,8 @@ async def read_activities_user_activities_refresh(
Depends(core_database.get_db),
],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
):
# Set the activities to empty list
@@ -618,8 +618,8 @@ async def create_activity_with_uploaded_file(
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
db: Annotated[
Session,
@@ -653,8 +653,8 @@ async def create_activity_with_bulk_import(
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
):
try:

View File

@@ -37,7 +37,7 @@ import activities.activity_streams.schema as activity_streams_schema
import activities.activity_workout_steps.crud as activity_workout_steps_crud
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
import gpx.utils as gpx_utils
import tcx.utils as tcx_utils
@@ -389,7 +389,7 @@ def handle_gzipped_file(
async def parse_and_store_activity_from_file(
token_user_id: int,
file_path: str,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
from_garmin: bool = False,
garminconnect_gear: dict | None = None,
@@ -540,7 +540,7 @@ async def parse_and_store_activity_from_file(
async def parse_and_store_activity_from_uploaded_file(
token_user_id: int,
file: UploadFile,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
):
# Validate filename exists
@@ -742,7 +742,9 @@ def parse_file(
async def store_activity(
parsed_info: dict, websocket_manager: websocket_schema.WebSocketManager, db: Session
parsed_info: dict,websocket_manager
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
):
# create the activity in the database
created_activity = await activities_crud.create_activity(
@@ -1193,7 +1195,7 @@ def set_activity_name_based_on_activity_type(activity_type_id: int) -> str:
def process_all_files_sync(
user_id: int,
file_paths: list[str],
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
):
"""
Process all files sequentially in single thread.

View File

@@ -1,5 +1,5 @@
from typing import Annotated, Union
from fastapi import Depends, HTTPException, status
from fastapi import Depends, HTTPException, Query, status, WebSocket, WebSocketException
from fastapi.security import (
OAuth2PasswordBearer,
SecurityScopes,
@@ -618,3 +618,70 @@ def check_scopes_for_browser_redirect(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred during scope validation.",
) from err
## WEBSOCKET TOKEN VALIDATION
async def validate_websocket_access_token(
websocket: WebSocket,
access_token: str = Query(..., alias="access_token"),
token_manager: auth_token_manager.TokenManager = Depends(
auth_token_manager.get_token_manager
),
) -> int:
"""
Validate access token for WebSocket connections.
WebSocket connections cannot use Authorization headers during
the handshake, so tokens are passed via query parameters.
Args:
websocket: The WebSocket connection instance.
access_token: Access token from query parameter.
token_manager: Token manager for validation.
Returns:
The authenticated user ID from the token.
Raises:
WebSocketException: If token is missing, invalid, or expired.
"""
try:
# Validate token expiration
token_manager.validate_token_expiration(access_token)
# Get user ID from token
token_user_id = token_manager.get_token_claim(access_token, "sub")
if token_user_id is None or isinstance(token_user_id, list):
core_logger.print_to_log(
"WebSocket token validation failed: invalid sub claim",
"warning",
)
raise WebSocketException(
code=status.WS_1008_POLICY_VIOLATION,
reason="Invalid token: missing user ID",
)
return int(token_user_id)
except WebSocketException as ws_err:
raise ws_err
except HTTPException as http_err:
core_logger.print_to_log(
f"WebSocket token validation failed: {http_err.detail}",
"warning",
exc=http_err,
)
raise WebSocketException(
code=status.WS_1008_POLICY_VIOLATION,
reason="Invalid or expired token",
) from http_err
except Exception as err:
core_logger.print_to_log(
f"Unexpected error during WebSocket token validation: {err}",
"error",
exc=err,
)
raise WebSocketException(
code=status.WS_1008_POLICY_VIOLATION,
reason="Token validation failed",
) from err

View File

@@ -247,10 +247,6 @@ router.include_router(
websocket_router.router,
prefix=core_config.ROOT_PATH + "/ws",
tags=["websocket"],
# dependencies=[
# Depends(auth_security.validate_access_token),
# Security(auth_security.check_scopes, scopes=["profile"]),
# ],
)
# PUBLIC ROUTES (alphabetical order)

View File

@@ -9,7 +9,7 @@ import core.logger as core_logger
import notifications.utils as notifications_utils
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
def get_all_followers_by_user_id(user_id: int, db: Session):
@@ -163,7 +163,7 @@ def get_follower_for_user_id_and_target_user_id(
async def create_follower(
user_id: int,
target_user_id: int,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
):
try:
@@ -175,7 +175,7 @@ async def create_follower(
# Add the new follow relationship to the database
db.add(new_follow)
db.commit()
await notifications_utils.create_new_follower_request_notification(
user_id, target_user_id, websocket_manager, db
)
@@ -199,7 +199,7 @@ async def create_follower(
async def accept_follower(
user_id: int,
target_user_id: int,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
):
try:
@@ -226,7 +226,7 @@ async def accept_follower(
# Commit the transaction
db.commit()
await notifications_utils.create_accepted_follower_request_notification(
user_id, target_user_id, websocket_manager, db
)

View File

@@ -14,7 +14,7 @@ import auth.security as auth_security
import core.database as core_database
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
# Define the API router
router = APIRouter()
@@ -207,8 +207,8 @@ async def create_follow(
Callable, Security(auth_security.check_scopes, scopes=["profile"])
],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
db: Annotated[
Session,
@@ -239,8 +239,8 @@ async def accept_follow(
Callable, Security(auth_security.check_scopes, scopes=["profile"])
],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
db: Annotated[
Session,

View File

@@ -16,7 +16,7 @@ import activities.activity.crud as activities_crud
import users.user.crud as users_crud
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
from core.database import SessionLocal
@@ -26,7 +26,7 @@ async def fetch_and_process_activities_by_dates(
start_date: datetime,
end_date: datetime,
user_id: int,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
) -> list[activities_schema.Activity] | None:
try:
@@ -125,7 +125,7 @@ async def fetch_and_process_activities_by_dates(
async def retrieve_garminconnect_users_activities_for_days(days: int):
websocket_manager = websocket_schema.get_websocket_manager()
websocket_manager = websocket_manager.get_websocket_manager()
# Create a new database session using context manager
with SessionLocal() as db:
@@ -195,7 +195,7 @@ async def get_user_garminconnect_activities_by_dates(
start_date: datetime,
end_date: datetime,
user_id: int,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
) -> list[activities_schema.Activity] | None:
try:

View File

@@ -13,7 +13,7 @@ import garmin.activity_utils as garmin_activity_utils
import garmin.health_utils as garmin_health_utils
import garmin.gear_utils as garmin_gear_utils
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
import core.logger as core_logger
@@ -37,8 +37,8 @@ async def garminconnect_link(
garmin_schema.MFACodeStore, Depends(garmin_schema.get_mfa_store)
],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
):
# Link Garmin Connect account
@@ -84,8 +84,8 @@ async def garminconnect_retrieve_activities_days(
],
db: Annotated[Session, Depends(core_database.get_db)],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
background_tasks: BackgroundTasks,
):

View File

@@ -16,7 +16,7 @@ import core.cryptography as core_cryptography
import users.user_integrations.schema as user_integrations_schema
import users.user_integrations.crud as user_integrations_crud
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
import websocket.utils as websocket_utils
import garmin.schema as garmin_schema
@@ -27,7 +27,7 @@ import core.logger as core_logger
async def get_mfa(
user_id: int,
mfa_codes: garmin_schema.MFACodeStore,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
) -> str:
# Notify frontend that MFA is required
await notify_frontend_mfa_required(user_id, websocket_manager)
@@ -42,7 +42,7 @@ async def get_mfa(
async def notify_frontend_mfa_required(
user_id: int, websocket_manager: websocket_schema.WebSocketManager
user_id: int, websocket_manager: websocket_manager.WebSocketManager
):
try:
json_data = {"message": "MFA_REQUIRED", "user_id": user_id}
@@ -57,7 +57,7 @@ async def link_garminconnect(
password: str,
db: Session,
mfa_codes: garmin_schema.MFACodeStore,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
):
# Define MFA callback as a coroutine
async def async_mfa_callback():

View File

@@ -13,7 +13,7 @@ import users.user.models as users_models
import users.user.utils as users_utils
import websocket.utils as websocket_utils
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
def serialize_notification(notification: notifications_schema.Notification):
@@ -25,7 +25,9 @@ def serialize_notification(notification: notifications_schema.Notification):
async def create_new_activity_notification(
user_id: int, activity_id: int, websocket_manager: websocket_schema.WebSocketManager
user_id: int,
activity_id: int,
websocket_manager: websocket_manager.WebSocketManager,
):
# Create a new database session using context manager
with SessionLocal() as db:
@@ -64,7 +66,9 @@ async def create_new_activity_notification(
async def create_new_duplicate_start_time_activity_notification(
user_id: int, activity_id: int, websocket_manager: websocket_schema.WebSocketManager
user_id: int,
activity_id: int,
websocket_manager: websocket_manager.WebSocketManager,
):
# Create a new database session using context manager
with SessionLocal() as db:
@@ -105,9 +109,9 @@ async def create_new_duplicate_start_time_activity_notification(
async def create_new_follower_request_notification(
user_id: int,
user_id: int,websocket_manager
target_user_id: int,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
):
try:
@@ -160,9 +164,9 @@ async def create_new_follower_request_notification(
async def create_accepted_follower_request_notification(
user_id: int,
user_id: int,websocket_manager
target_user_id: int,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
):
try:
@@ -214,7 +218,7 @@ async def create_accepted_follower_request_notification(
async def create_admin_new_sign_up_approval_request_notification(
user: users_models.User,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
):
try:

View File

@@ -73,7 +73,7 @@ import health.health_weight.schema as health_weight_schema
import health.health_targets.crud as health_targets_crud
import health.health_targets.schema as health_targets_schema
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
class ImportPerformanceConfig(profile_utils.BasePerformanceConfig):
@@ -165,7 +165,7 @@ class ImportService:
self,
user_id: int,
db: Session,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
performance_config: ImportPerformanceConfig | None = None,
):
self.user_id = user_id

View File

@@ -49,7 +49,7 @@ import core.logger as core_logger
import core.rate_limit as core_rate_limit
import core.config as core_config
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
# Define the API router
router = APIRouter()
@@ -458,8 +458,8 @@ async def import_profile_data(
],
db: Annotated[Session, Depends(core_database.get_db)],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
):
"""

View File

@@ -24,7 +24,7 @@ import server_settings.utils as server_settings_utils
import core.database as core_database
import core.apprise as core_apprise
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
# Define the API router
router = APIRouter()
@@ -54,7 +54,7 @@ async def signup(
- user (users_schema.UserSignup): The payload containing the user's sign-up information.
- email_service (core_apprise.AppriseService): Injected email service used to send
verification and admin approval emails.
- websocket_manager (websocket_schema.WebSocketManager): Injected manager used to send
- websocket_manager (websocket_manager.WebSocketManager): Injected manager used to send
real-time notifications (e.g., admin approval requests).
- password_hasher (auth_password_hasher.PasswordHasher): Injected password hasher used to hash user passwords.
- db (Session): Database session/connection used to create the user and related records.
@@ -156,8 +156,8 @@ async def verify_email(
Depends(core_apprise.get_email_service),
],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
db: Annotated[
Session,

View File

@@ -30,7 +30,7 @@ import gears.gear.crud as gears_crud
import strava.utils as strava_utils
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
from core.database import SessionLocal
@@ -41,7 +41,7 @@ async def fetch_and_process_activities(
end_date: datetime,
user_id: int,
user_integrations: user_integrations_schema.UsersIntegrations,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
is_startup: bool = False,
) -> int:
@@ -369,7 +369,7 @@ async def save_activity_streams_laps(
activity: activities_schema.Activity,
stream_data: list,
laps: dict,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
) -> activities_schema.Activity:
# Create the activity and get the ID
@@ -411,7 +411,7 @@ async def process_activity(
user_privacy_settings: users_privacy_settings_schema.UsersPrivacySettings,
strava_client: Client,
user_integrations: user_integrations_schema.UsersIntegrations,
websocket_manager: websocket_schema.WebSocketManager,
websocket_manager: websocket_manager.WebSocketManager,
db: Session,
):
# Get the activity by Strava ID from the user
@@ -769,7 +769,7 @@ async def get_user_strava_activities_by_dates(
start_date: datetime,
end_date: datetime,
user_id: int,
websocket_manager: websocket_schema.WebSocketManager = None,
websocket_manager: websocket_manager.WebSocketManager = None,
db: Session = None,
is_startup: bool = False,
) -> list[activities_schema.Activity] | None:
@@ -781,7 +781,7 @@ async def get_user_strava_activities_by_dates(
if websocket_manager is None:
# Get the websocket manager instance
websocket_manager = websocket_schema.get_websocket_manager()
websocket_manager = websocket_manager.get_websocket_manager()
try:
# Get the user integrations by user ID

View File

@@ -25,7 +25,7 @@ import core.cryptography as core_cryptography
import core.logger as core_logger
import core.database as core_database
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
# Define the API router
router = APIRouter()
@@ -121,8 +121,8 @@ async def strava_retrieve_activities_days(
Depends(auth_security.get_sub_from_access_token),
],
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
# db: Annotated[Session, Depends(core_database.get_db)],
background_tasks: BackgroundTasks,

View File

@@ -0,0 +1,10 @@
"""WebSocket module for real-time notifications and communication."""
from websocket.manager import WebSocketManager, get_websocket_manager
from websocket.utils import notify_frontend
__all__ = [
"WebSocketManager",
"get_websocket_manager",
"notify_frontend",
]

View File

@@ -0,0 +1,92 @@
from functools import lru_cache
from fastapi import WebSocket
import core.logger as core_logger
class WebSocketManager:
"""
Manage active WebSocket connections per user.
Maintains a registry of authenticated WebSocket connections
indexed by user ID, enabling message broadcasting and
targeted notifications.
Attributes:
active_connections: Maps user IDs to their WebSocket.
"""
def __init__(self) -> None:
"""Initialize the WebSocket manager with empty connections."""
self.active_connections: dict[int, WebSocket] = {}
async def connect(self, user_id: int, websocket: WebSocket) -> None:
"""
Accept and register a new WebSocket connection.
Args:
user_id: The user's unique identifier.
websocket: The WebSocket connection to register.
"""
await websocket.accept()
self.active_connections[user_id] = websocket
core_logger.print_to_log(f"WebSocket connected for user {user_id}", "info")
def disconnect(self, user_id: int) -> None:
"""
Remove a user's WebSocket connection.
Args:
user_id: The user's unique identifier.
"""
if self.active_connections.pop(user_id, None):
core_logger.print_to_log(
f"WebSocket disconnected for user {user_id}",
"info",
)
async def send_message(self, user_id: int, message: dict) -> None:
"""
Send a JSON message to a specific user.
Args:
user_id: The user's unique identifier.
message: JSON-serializable data to send.
"""
websocket = self.active_connections.get(user_id)
if websocket:
await websocket.send_json(message)
async def broadcast(self, message: dict) -> None:
"""
Send a JSON message to all connected users.
Args:
message: JSON-serializable data to broadcast.
"""
for websocket in self.active_connections.values():
await websocket.send_json(message)
def get_connection(self, user_id: int) -> WebSocket | None:
"""
Retrieve a user's WebSocket connection.
Args:
user_id: The user's unique identifier.
Returns:
The WebSocket connection or None if not found.
"""
return self.active_connections.get(user_id)
@lru_cache(maxsize=1)
def get_websocket_manager() -> WebSocketManager:
"""
Get the singleton WebSocket manager instance.
Returns:
The shared WebSocketManager instance.
"""
return WebSocketManager()

View File

@@ -1,28 +1,50 @@
from typing import Annotated
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
import websocket.schema as websocket_schema
import auth.security as auth_security
import websocket.manager as websocket_manager
import core.logger as core_logger
# Define the API router
router = APIRouter()
@router.websocket("/{user_id}")
@router.websocket("")
async def websocket_endpoint(
user_id: int,
websocket: WebSocket,
websocket_manager: Annotated[
websocket_schema.WebSocketManager,
Depends(websocket_schema.get_websocket_manager),
token_user_id: Annotated[
int, Depends(auth_security.validate_websocket_access_token)
],
):
# Connect and manage WebSocket connections using the manager
await websocket_manager.connect(user_id, websocket)
websocket_manager: Annotated[
websocket_manager.WebSocketManager,
Depends(websocket_manager.get_websocket_manager),
],
) -> None:
"""
Handle WebSocket connections for real-time notifications.
Establishes authenticated WebSocket connection for receiving
real-time notifications, MFA requests, and activity updates.
Args:
websocket: The WebSocket connection instance.
user_id: Authenticated user ID from access token.
websocket_manager: Manager for WebSocket connections.
"""
await websocket_manager.connect(token_user_id, websocket)
try:
while True:
# Handle incoming messages if necessary (currently just keeping connection alive)
data = await websocket.receive_json()
try:
# Keep connection alive, handle incoming messages
await websocket.receive_json()
except ValueError:
# Log malformed JSON, keep connection alive
core_logger.print_to_log(
f"Received malformed JSON from user {token_user_id}",
"warning",
)
except WebSocketDisconnect:
# Disconnect using the manager
websocket_manager.disconnect(user_id)
websocket_manager.disconnect(token_user_id)

View File

@@ -1,34 +0,0 @@
from fastapi import WebSocket
from typing import Dict
class WebSocketManager:
def __init__(self):
self.active_connections: Dict[int, WebSocket] = {}
async def connect(self, user_id: int, websocket: WebSocket):
await websocket.accept()
self.active_connections[user_id] = websocket
def disconnect(self, user_id: int):
if user_id in self.active_connections:
del self.active_connections[user_id]
async def send_message(self, user_id: int, message: str):
if user_id in self.active_connections:
websocket = self.active_connections[user_id]
await websocket.send_json(message)
async def broadcast(self, message: str):
for websocket in self.active_connections.values():
await websocket.send_json(message)
def get_connection(self, user_id: int) -> WebSocket | None:
return self.active_connections.get(user_id)
def get_websocket_manager():
return websocket_manager
websocket_manager = WebSocketManager()

View File

@@ -1,30 +1,38 @@
from fastapi import HTTPException, status
import websocket.schema as websocket_schema
import websocket.manager as websocket_manager
async def notify_frontend(
user_id: int, websocket_manager: websocket_schema.WebSocketManager, json_data: dict
):
user_id: int,
websocket_manager: websocket_manager.WebSocketManager,
json_data: dict,
) -> bool:
"""
Sends a JSON message to the frontend via an active WebSocket connection for a specific user.
Send a JSON message to a user's WebSocket connection.
Attempts to send data to the user's WebSocket. For MFA
verification, raises an exception if no connection exists.
Args:
user_id (int): The ID of the user to notify.
websocket_manager (websocket_schema.WebSocketManager): The manager handling WebSocket connections.
json_data (dict): The JSON-serializable data to send to the frontend.
user_id: The target user's identifier.
websocket_manager: The WebSocket connection manager.
json_data: JSON-serializable data to send.
Returns:
True if message was sent, False if no connection.
Raises:
HTTPException: If there is no active WebSocket connection for the specified user.
HTTPException: If MFA_REQUIRED but no connection exists.
"""
# Check if the user has an active WebSocket connection
websocket = websocket_manager.get_connection(user_id)
if websocket:
await websocket.send_json(json_data)
else:
if json_data.get("message") == "MFA_REQUIRED":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No active WebSocket connection for user {user_id}",
)
return True
if json_data.get("message") == "MFA_REQUIRED":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No WebSocket connection for user {user_id}",
)
return False

View File

@@ -73,6 +73,8 @@ async function initApp() {
if (authStore.isAuthenticated && !authStore.getAccessToken()) {
try {
await authStore.refreshAccessToken()
// Set up WebSocket after token is available
authStore.setUserWebsocket()
} catch (error) {
// If refresh fails, clear user and redirect to login
console.error('Failed to restore session on app init:', error)

View File

@@ -140,7 +140,7 @@ export const useAuthStore = defineStore('auth', {
this.user = JSON.parse(storedUser)
this.isAuthenticated = true
this.setLocale(this.user.preferred_language, locale)
this.setUserWebsocket()
// WebSocket setup deferred until access token is available
this.session_id = localStorage.getItem('session_id')
}
},
@@ -159,11 +159,11 @@ export const useAuthStore = defineStore('auth', {
setUserWebsocket() {
const urlSplit = API_URL.split('://')
const protocol = urlSplit[0] === 'http' ? 'ws' : 'wss'
const websocketURL = `${protocol}://${urlSplit[1]}ws/${this.user.id}`
const websocketURL = `${protocol}://${urlSplit[1]}ws?access_token=${this.accessToken}`
try {
this.user_websocket = new WebSocket(websocketURL)
this.user_websocket.onopen = () => {
console.log(`WebSocket connection established using ${websocketURL}.`)
console.log(`WebSocket connection established for user ID ${this.user.id}`)
}
this.user_websocket.onerror = (error) => {
console.error('WebSocket error:', error)
@@ -195,6 +195,14 @@ export const useAuthStore = defineStore('auth', {
if (response.csrf_token) {
this.csrfToken = response.csrf_token
}
// Reconnect WebSocket with new token if currently open
if (
this.user_websocket &&
this.user_websocket.readyState === WebSocket.OPEN
) {
this.user_websocket.close()
this.setUserWebsocket()
}
return response.access_token
}
throw new Error('No access token in refresh response')