mirror of
https://github.com/joaovitoriasilva/endurain.git
synced 2026-01-07 23:13:57 -05:00
Refactor WebSocket manager and authentication flow
Replaces websocket.schema with websocket.manager for managing WebSocket connections, introducing a singleton WebSocketManager and updating all imports and usages accordingly. Adds token-based authentication for WebSocket connections, requiring access_token as a query parameter and validating it server-side. Updates FastAPI WebSocket endpoint to use authenticated user ID, and modifies frontend to connect using the access token. Removes obsolete schema.py and improves error handling and logging for WebSocket events.
This commit is contained in:
@@ -14,7 +14,7 @@ import notifications.utils as notifications_utils
|
||||
|
||||
import server_settings.utils as server_settings_utils
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
@@ -1190,7 +1190,7 @@ def get_activities_if_contains_name(name: str, user_id: int, db: Session):
|
||||
|
||||
async def create_activity(
|
||||
activity: activities_schema.Activity,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
create_notification: bool = True,
|
||||
) -> activities_schema.Activity:
|
||||
|
||||
@@ -20,7 +20,7 @@ import auth.security as auth_security
|
||||
import users.user.dependencies as users_dependencies
|
||||
import garmin.activity_utils as garmin_activity_utils
|
||||
import strava.activity_utils as strava_activity_utils
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -513,8 +513,8 @@ async def read_activities_user_activities_refresh(
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
):
|
||||
# Set the activities to empty list
|
||||
@@ -618,8 +618,8 @@ async def create_activity_with_uploaded_file(
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -653,8 +653,8 @@ async def create_activity_with_bulk_import(
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -37,7 +37,7 @@ import activities.activity_streams.schema as activity_streams_schema
|
||||
|
||||
import activities.activity_workout_steps.crud as activity_workout_steps_crud
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
import gpx.utils as gpx_utils
|
||||
import tcx.utils as tcx_utils
|
||||
@@ -389,7 +389,7 @@ def handle_gzipped_file(
|
||||
async def parse_and_store_activity_from_file(
|
||||
token_user_id: int,
|
||||
file_path: str,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
from_garmin: bool = False,
|
||||
garminconnect_gear: dict | None = None,
|
||||
@@ -540,7 +540,7 @@ async def parse_and_store_activity_from_file(
|
||||
async def parse_and_store_activity_from_uploaded_file(
|
||||
token_user_id: int,
|
||||
file: UploadFile,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
):
|
||||
# Validate filename exists
|
||||
@@ -742,7 +742,9 @@ def parse_file(
|
||||
|
||||
|
||||
async def store_activity(
|
||||
parsed_info: dict, websocket_manager: websocket_schema.WebSocketManager, db: Session
|
||||
parsed_info: dict,websocket_manager
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
):
|
||||
# create the activity in the database
|
||||
created_activity = await activities_crud.create_activity(
|
||||
@@ -1193,7 +1195,7 @@ def set_activity_name_based_on_activity_type(activity_type_id: int) -> str:
|
||||
def process_all_files_sync(
|
||||
user_id: int,
|
||||
file_paths: list[str],
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
):
|
||||
"""
|
||||
Process all files sequentially in single thread.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Annotated, Union
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, Query, status, WebSocket, WebSocketException
|
||||
from fastapi.security import (
|
||||
OAuth2PasswordBearer,
|
||||
SecurityScopes,
|
||||
@@ -618,3 +618,70 @@ def check_scopes_for_browser_redirect(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred during scope validation.",
|
||||
) from err
|
||||
|
||||
|
||||
## WEBSOCKET TOKEN VALIDATION
|
||||
async def validate_websocket_access_token(
|
||||
websocket: WebSocket,
|
||||
access_token: str = Query(..., alias="access_token"),
|
||||
token_manager: auth_token_manager.TokenManager = Depends(
|
||||
auth_token_manager.get_token_manager
|
||||
),
|
||||
) -> int:
|
||||
"""
|
||||
Validate access token for WebSocket connections.
|
||||
|
||||
WebSocket connections cannot use Authorization headers during
|
||||
the handshake, so tokens are passed via query parameters.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection instance.
|
||||
access_token: Access token from query parameter.
|
||||
token_manager: Token manager for validation.
|
||||
|
||||
Returns:
|
||||
The authenticated user ID from the token.
|
||||
|
||||
Raises:
|
||||
WebSocketException: If token is missing, invalid, or expired.
|
||||
"""
|
||||
try:
|
||||
# Validate token expiration
|
||||
token_manager.validate_token_expiration(access_token)
|
||||
|
||||
# Get user ID from token
|
||||
token_user_id = token_manager.get_token_claim(access_token, "sub")
|
||||
|
||||
if token_user_id is None or isinstance(token_user_id, list):
|
||||
core_logger.print_to_log(
|
||||
"WebSocket token validation failed: invalid sub claim",
|
||||
"warning",
|
||||
)
|
||||
raise WebSocketException(
|
||||
code=status.WS_1008_POLICY_VIOLATION,
|
||||
reason="Invalid token: missing user ID",
|
||||
)
|
||||
|
||||
return int(token_user_id)
|
||||
except WebSocketException as ws_err:
|
||||
raise ws_err
|
||||
except HTTPException as http_err:
|
||||
core_logger.print_to_log(
|
||||
f"WebSocket token validation failed: {http_err.detail}",
|
||||
"warning",
|
||||
exc=http_err,
|
||||
)
|
||||
raise WebSocketException(
|
||||
code=status.WS_1008_POLICY_VIOLATION,
|
||||
reason="Invalid or expired token",
|
||||
) from http_err
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Unexpected error during WebSocket token validation: {err}",
|
||||
"error",
|
||||
exc=err,
|
||||
)
|
||||
raise WebSocketException(
|
||||
code=status.WS_1008_POLICY_VIOLATION,
|
||||
reason="Token validation failed",
|
||||
) from err
|
||||
|
||||
@@ -247,10 +247,6 @@ router.include_router(
|
||||
websocket_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/ws",
|
||||
tags=["websocket"],
|
||||
# dependencies=[
|
||||
# Depends(auth_security.validate_access_token),
|
||||
# Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
# ],
|
||||
)
|
||||
|
||||
# PUBLIC ROUTES (alphabetical order)
|
||||
|
||||
@@ -9,7 +9,7 @@ import core.logger as core_logger
|
||||
|
||||
import notifications.utils as notifications_utils
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
|
||||
def get_all_followers_by_user_id(user_id: int, db: Session):
|
||||
@@ -163,7 +163,7 @@ def get_follower_for_user_id_and_target_user_id(
|
||||
async def create_follower(
|
||||
user_id: int,
|
||||
target_user_id: int,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
):
|
||||
try:
|
||||
@@ -175,7 +175,7 @@ async def create_follower(
|
||||
# Add the new follow relationship to the database
|
||||
db.add(new_follow)
|
||||
db.commit()
|
||||
|
||||
|
||||
await notifications_utils.create_new_follower_request_notification(
|
||||
user_id, target_user_id, websocket_manager, db
|
||||
)
|
||||
@@ -199,7 +199,7 @@ async def create_follower(
|
||||
async def accept_follower(
|
||||
user_id: int,
|
||||
target_user_id: int,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
):
|
||||
try:
|
||||
@@ -226,7 +226,7 @@ async def accept_follower(
|
||||
|
||||
# Commit the transaction
|
||||
db.commit()
|
||||
|
||||
|
||||
await notifications_utils.create_accepted_follower_request_notification(
|
||||
user_id, target_user_id, websocket_manager, db
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
@@ -207,8 +207,8 @@ async def create_follow(
|
||||
Callable, Security(auth_security.check_scopes, scopes=["profile"])
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -239,8 +239,8 @@ async def accept_follow(
|
||||
Callable, Security(auth_security.check_scopes, scopes=["profile"])
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -16,7 +16,7 @@ import activities.activity.crud as activities_crud
|
||||
|
||||
import users.user.crud as users_crud
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
from core.database import SessionLocal
|
||||
|
||||
@@ -26,7 +26,7 @@ async def fetch_and_process_activities_by_dates(
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
user_id: int,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
) -> list[activities_schema.Activity] | None:
|
||||
try:
|
||||
@@ -125,7 +125,7 @@ async def fetch_and_process_activities_by_dates(
|
||||
|
||||
|
||||
async def retrieve_garminconnect_users_activities_for_days(days: int):
|
||||
websocket_manager = websocket_schema.get_websocket_manager()
|
||||
websocket_manager = websocket_manager.get_websocket_manager()
|
||||
|
||||
# Create a new database session using context manager
|
||||
with SessionLocal() as db:
|
||||
@@ -195,7 +195,7 @@ async def get_user_garminconnect_activities_by_dates(
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
user_id: int,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
) -> list[activities_schema.Activity] | None:
|
||||
try:
|
||||
|
||||
@@ -13,7 +13,7 @@ import garmin.activity_utils as garmin_activity_utils
|
||||
import garmin.health_utils as garmin_health_utils
|
||||
import garmin.gear_utils as garmin_gear_utils
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
import core.logger as core_logger
|
||||
|
||||
@@ -37,8 +37,8 @@ async def garminconnect_link(
|
||||
garmin_schema.MFACodeStore, Depends(garmin_schema.get_mfa_store)
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
):
|
||||
# Link Garmin Connect account
|
||||
@@ -84,8 +84,8 @@ async def garminconnect_retrieve_activities_days(
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
|
||||
@@ -16,7 +16,7 @@ import core.cryptography as core_cryptography
|
||||
import users.user_integrations.schema as user_integrations_schema
|
||||
import users.user_integrations.crud as user_integrations_crud
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
import websocket.utils as websocket_utils
|
||||
|
||||
import garmin.schema as garmin_schema
|
||||
@@ -27,7 +27,7 @@ import core.logger as core_logger
|
||||
async def get_mfa(
|
||||
user_id: int,
|
||||
mfa_codes: garmin_schema.MFACodeStore,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
) -> str:
|
||||
# Notify frontend that MFA is required
|
||||
await notify_frontend_mfa_required(user_id, websocket_manager)
|
||||
@@ -42,7 +42,7 @@ async def get_mfa(
|
||||
|
||||
|
||||
async def notify_frontend_mfa_required(
|
||||
user_id: int, websocket_manager: websocket_schema.WebSocketManager
|
||||
user_id: int, websocket_manager: websocket_manager.WebSocketManager
|
||||
):
|
||||
try:
|
||||
json_data = {"message": "MFA_REQUIRED", "user_id": user_id}
|
||||
@@ -57,7 +57,7 @@ async def link_garminconnect(
|
||||
password: str,
|
||||
db: Session,
|
||||
mfa_codes: garmin_schema.MFACodeStore,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
):
|
||||
# Define MFA callback as a coroutine
|
||||
async def async_mfa_callback():
|
||||
|
||||
@@ -13,7 +13,7 @@ import users.user.models as users_models
|
||||
import users.user.utils as users_utils
|
||||
|
||||
import websocket.utils as websocket_utils
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
|
||||
def serialize_notification(notification: notifications_schema.Notification):
|
||||
@@ -25,7 +25,9 @@ def serialize_notification(notification: notifications_schema.Notification):
|
||||
|
||||
|
||||
async def create_new_activity_notification(
|
||||
user_id: int, activity_id: int, websocket_manager: websocket_schema.WebSocketManager
|
||||
user_id: int,
|
||||
activity_id: int,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
):
|
||||
# Create a new database session using context manager
|
||||
with SessionLocal() as db:
|
||||
@@ -64,7 +66,9 @@ async def create_new_activity_notification(
|
||||
|
||||
|
||||
async def create_new_duplicate_start_time_activity_notification(
|
||||
user_id: int, activity_id: int, websocket_manager: websocket_schema.WebSocketManager
|
||||
user_id: int,
|
||||
activity_id: int,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
):
|
||||
# Create a new database session using context manager
|
||||
with SessionLocal() as db:
|
||||
@@ -105,9 +109,9 @@ async def create_new_duplicate_start_time_activity_notification(
|
||||
|
||||
|
||||
async def create_new_follower_request_notification(
|
||||
user_id: int,
|
||||
user_id: int,websocket_manager
|
||||
target_user_id: int,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
):
|
||||
try:
|
||||
@@ -160,9 +164,9 @@ async def create_new_follower_request_notification(
|
||||
|
||||
|
||||
async def create_accepted_follower_request_notification(
|
||||
user_id: int,
|
||||
user_id: int,websocket_manager
|
||||
target_user_id: int,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
):
|
||||
try:
|
||||
@@ -214,7 +218,7 @@ async def create_accepted_follower_request_notification(
|
||||
|
||||
async def create_admin_new_sign_up_approval_request_notification(
|
||||
user: users_models.User,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -73,7 +73,7 @@ import health.health_weight.schema as health_weight_schema
|
||||
import health.health_targets.crud as health_targets_crud
|
||||
import health.health_targets.schema as health_targets_schema
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
|
||||
class ImportPerformanceConfig(profile_utils.BasePerformanceConfig):
|
||||
@@ -165,7 +165,7 @@ class ImportService:
|
||||
self,
|
||||
user_id: int,
|
||||
db: Session,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
performance_config: ImportPerformanceConfig | None = None,
|
||||
):
|
||||
self.user_id = user_id
|
||||
|
||||
@@ -49,7 +49,7 @@ import core.logger as core_logger
|
||||
import core.rate_limit as core_rate_limit
|
||||
import core.config as core_config
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
@@ -458,8 +458,8 @@ async def import_profile_data(
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -24,7 +24,7 @@ import server_settings.utils as server_settings_utils
|
||||
import core.database as core_database
|
||||
import core.apprise as core_apprise
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
@@ -54,7 +54,7 @@ async def signup(
|
||||
- user (users_schema.UserSignup): The payload containing the user's sign-up information.
|
||||
- email_service (core_apprise.AppriseService): Injected email service used to send
|
||||
verification and admin approval emails.
|
||||
- websocket_manager (websocket_schema.WebSocketManager): Injected manager used to send
|
||||
- websocket_manager (websocket_manager.WebSocketManager): Injected manager used to send
|
||||
real-time notifications (e.g., admin approval requests).
|
||||
- password_hasher (auth_password_hasher.PasswordHasher): Injected password hasher used to hash user passwords.
|
||||
- db (Session): Database session/connection used to create the user and related records.
|
||||
@@ -156,8 +156,8 @@ async def verify_email(
|
||||
Depends(core_apprise.get_email_service),
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -30,7 +30,7 @@ import gears.gear.crud as gears_crud
|
||||
|
||||
import strava.utils as strava_utils
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
from core.database import SessionLocal
|
||||
|
||||
@@ -41,7 +41,7 @@ async def fetch_and_process_activities(
|
||||
end_date: datetime,
|
||||
user_id: int,
|
||||
user_integrations: user_integrations_schema.UsersIntegrations,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
is_startup: bool = False,
|
||||
) -> int:
|
||||
@@ -369,7 +369,7 @@ async def save_activity_streams_laps(
|
||||
activity: activities_schema.Activity,
|
||||
stream_data: list,
|
||||
laps: dict,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
) -> activities_schema.Activity:
|
||||
# Create the activity and get the ID
|
||||
@@ -411,7 +411,7 @@ async def process_activity(
|
||||
user_privacy_settings: users_privacy_settings_schema.UsersPrivacySettings,
|
||||
strava_client: Client,
|
||||
user_integrations: user_integrations_schema.UsersIntegrations,
|
||||
websocket_manager: websocket_schema.WebSocketManager,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
db: Session,
|
||||
):
|
||||
# Get the activity by Strava ID from the user
|
||||
@@ -769,7 +769,7 @@ async def get_user_strava_activities_by_dates(
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
user_id: int,
|
||||
websocket_manager: websocket_schema.WebSocketManager = None,
|
||||
websocket_manager: websocket_manager.WebSocketManager = None,
|
||||
db: Session = None,
|
||||
is_startup: bool = False,
|
||||
) -> list[activities_schema.Activity] | None:
|
||||
@@ -781,7 +781,7 @@ async def get_user_strava_activities_by_dates(
|
||||
|
||||
if websocket_manager is None:
|
||||
# Get the websocket manager instance
|
||||
websocket_manager = websocket_schema.get_websocket_manager()
|
||||
websocket_manager = websocket_manager.get_websocket_manager()
|
||||
|
||||
try:
|
||||
# Get the user integrations by user ID
|
||||
|
||||
@@ -25,7 +25,7 @@ import core.cryptography as core_cryptography
|
||||
import core.logger as core_logger
|
||||
import core.database as core_database
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
@@ -121,8 +121,8 @@ async def strava_retrieve_activities_days(
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
# db: Annotated[Session, Depends(core_database.get_db)],
|
||||
background_tasks: BackgroundTasks,
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
"""WebSocket module for real-time notifications and communication."""
|
||||
|
||||
from websocket.manager import WebSocketManager, get_websocket_manager
|
||||
from websocket.utils import notify_frontend
|
||||
|
||||
__all__ = [
|
||||
"WebSocketManager",
|
||||
"get_websocket_manager",
|
||||
"notify_frontend",
|
||||
]
|
||||
|
||||
92
backend/app/websocket/manager.py
Normal file
92
backend/app/websocket/manager.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
import core.logger as core_logger
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""
|
||||
Manage active WebSocket connections per user.
|
||||
|
||||
Maintains a registry of authenticated WebSocket connections
|
||||
indexed by user ID, enabling message broadcasting and
|
||||
targeted notifications.
|
||||
|
||||
Attributes:
|
||||
active_connections: Maps user IDs to their WebSocket.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the WebSocket manager with empty connections."""
|
||||
self.active_connections: dict[int, WebSocket] = {}
|
||||
|
||||
async def connect(self, user_id: int, websocket: WebSocket) -> None:
|
||||
"""
|
||||
Accept and register a new WebSocket connection.
|
||||
|
||||
Args:
|
||||
user_id: The user's unique identifier.
|
||||
websocket: The WebSocket connection to register.
|
||||
"""
|
||||
await websocket.accept()
|
||||
self.active_connections[user_id] = websocket
|
||||
core_logger.print_to_log(f"WebSocket connected for user {user_id}", "info")
|
||||
|
||||
def disconnect(self, user_id: int) -> None:
|
||||
"""
|
||||
Remove a user's WebSocket connection.
|
||||
|
||||
Args:
|
||||
user_id: The user's unique identifier.
|
||||
"""
|
||||
if self.active_connections.pop(user_id, None):
|
||||
core_logger.print_to_log(
|
||||
f"WebSocket disconnected for user {user_id}",
|
||||
"info",
|
||||
)
|
||||
|
||||
async def send_message(self, user_id: int, message: dict) -> None:
|
||||
"""
|
||||
Send a JSON message to a specific user.
|
||||
|
||||
Args:
|
||||
user_id: The user's unique identifier.
|
||||
message: JSON-serializable data to send.
|
||||
"""
|
||||
websocket = self.active_connections.get(user_id)
|
||||
if websocket:
|
||||
await websocket.send_json(message)
|
||||
|
||||
async def broadcast(self, message: dict) -> None:
|
||||
"""
|
||||
Send a JSON message to all connected users.
|
||||
|
||||
Args:
|
||||
message: JSON-serializable data to broadcast.
|
||||
"""
|
||||
for websocket in self.active_connections.values():
|
||||
await websocket.send_json(message)
|
||||
|
||||
def get_connection(self, user_id: int) -> WebSocket | None:
|
||||
"""
|
||||
Retrieve a user's WebSocket connection.
|
||||
|
||||
Args:
|
||||
user_id: The user's unique identifier.
|
||||
|
||||
Returns:
|
||||
The WebSocket connection or None if not found.
|
||||
"""
|
||||
return self.active_connections.get(user_id)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_websocket_manager() -> WebSocketManager:
|
||||
"""
|
||||
Get the singleton WebSocket manager instance.
|
||||
|
||||
Returns:
|
||||
The shared WebSocketManager instance.
|
||||
"""
|
||||
return WebSocketManager()
|
||||
@@ -1,28 +1,50 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
import websocket.schema as websocket_schema
|
||||
|
||||
import auth.security as auth_security
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
import core.logger as core_logger
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket("/{user_id}")
|
||||
@router.websocket("")
|
||||
async def websocket_endpoint(
|
||||
user_id: int,
|
||||
websocket: WebSocket,
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
Depends(websocket_schema.get_websocket_manager),
|
||||
token_user_id: Annotated[
|
||||
int, Depends(auth_security.validate_websocket_access_token)
|
||||
],
|
||||
):
|
||||
# Connect and manage WebSocket connections using the manager
|
||||
await websocket_manager.connect(user_id, websocket)
|
||||
websocket_manager: Annotated[
|
||||
websocket_manager.WebSocketManager,
|
||||
Depends(websocket_manager.get_websocket_manager),
|
||||
],
|
||||
) -> None:
|
||||
"""
|
||||
Handle WebSocket connections for real-time notifications.
|
||||
|
||||
Establishes authenticated WebSocket connection for receiving
|
||||
real-time notifications, MFA requests, and activity updates.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection instance.
|
||||
user_id: Authenticated user ID from access token.
|
||||
websocket_manager: Manager for WebSocket connections.
|
||||
"""
|
||||
await websocket_manager.connect(token_user_id, websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Handle incoming messages if necessary (currently just keeping connection alive)
|
||||
data = await websocket.receive_json()
|
||||
try:
|
||||
# Keep connection alive, handle incoming messages
|
||||
await websocket.receive_json()
|
||||
except ValueError:
|
||||
# Log malformed JSON, keep connection alive
|
||||
core_logger.print_to_log(
|
||||
f"Received malformed JSON from user {token_user_id}",
|
||||
"warning",
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
# Disconnect using the manager
|
||||
websocket_manager.disconnect(user_id)
|
||||
websocket_manager.disconnect(token_user_id)
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
from fastapi import WebSocket
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[int, WebSocket] = {}
|
||||
|
||||
async def connect(self, user_id: int, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections[user_id] = websocket
|
||||
|
||||
def disconnect(self, user_id: int):
|
||||
if user_id in self.active_connections:
|
||||
del self.active_connections[user_id]
|
||||
|
||||
async def send_message(self, user_id: int, message: str):
|
||||
if user_id in self.active_connections:
|
||||
websocket = self.active_connections[user_id]
|
||||
await websocket.send_json(message)
|
||||
|
||||
async def broadcast(self, message: str):
|
||||
for websocket in self.active_connections.values():
|
||||
await websocket.send_json(message)
|
||||
|
||||
def get_connection(self, user_id: int) -> WebSocket | None:
|
||||
return self.active_connections.get(user_id)
|
||||
|
||||
|
||||
def get_websocket_manager():
|
||||
return websocket_manager
|
||||
|
||||
|
||||
websocket_manager = WebSocketManager()
|
||||
@@ -1,30 +1,38 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
import websocket.schema as websocket_schema
|
||||
import websocket.manager as websocket_manager
|
||||
|
||||
|
||||
async def notify_frontend(
|
||||
user_id: int, websocket_manager: websocket_schema.WebSocketManager, json_data: dict
|
||||
):
|
||||
user_id: int,
|
||||
websocket_manager: websocket_manager.WebSocketManager,
|
||||
json_data: dict,
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a JSON message to the frontend via an active WebSocket connection for a specific user.
|
||||
Send a JSON message to a user's WebSocket connection.
|
||||
|
||||
Attempts to send data to the user's WebSocket. For MFA
|
||||
verification, raises an exception if no connection exists.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user to notify.
|
||||
websocket_manager (websocket_schema.WebSocketManager): The manager handling WebSocket connections.
|
||||
json_data (dict): The JSON-serializable data to send to the frontend.
|
||||
user_id: The target user's identifier.
|
||||
websocket_manager: The WebSocket connection manager.
|
||||
json_data: JSON-serializable data to send.
|
||||
|
||||
Returns:
|
||||
True if message was sent, False if no connection.
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is no active WebSocket connection for the specified user.
|
||||
HTTPException: If MFA_REQUIRED but no connection exists.
|
||||
"""
|
||||
# Check if the user has an active WebSocket connection
|
||||
|
||||
websocket = websocket_manager.get_connection(user_id)
|
||||
if websocket:
|
||||
await websocket.send_json(json_data)
|
||||
else:
|
||||
if json_data.get("message") == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No active WebSocket connection for user {user_id}",
|
||||
)
|
||||
return True
|
||||
|
||||
if json_data.get("message") == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No WebSocket connection for user {user_id}",
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -73,6 +73,8 @@ async function initApp() {
|
||||
if (authStore.isAuthenticated && !authStore.getAccessToken()) {
|
||||
try {
|
||||
await authStore.refreshAccessToken()
|
||||
// Set up WebSocket after token is available
|
||||
authStore.setUserWebsocket()
|
||||
} catch (error) {
|
||||
// If refresh fails, clear user and redirect to login
|
||||
console.error('Failed to restore session on app init:', error)
|
||||
|
||||
@@ -140,7 +140,7 @@ export const useAuthStore = defineStore('auth', {
|
||||
this.user = JSON.parse(storedUser)
|
||||
this.isAuthenticated = true
|
||||
this.setLocale(this.user.preferred_language, locale)
|
||||
this.setUserWebsocket()
|
||||
// WebSocket setup deferred until access token is available
|
||||
this.session_id = localStorage.getItem('session_id')
|
||||
}
|
||||
},
|
||||
@@ -159,11 +159,11 @@ export const useAuthStore = defineStore('auth', {
|
||||
setUserWebsocket() {
|
||||
const urlSplit = API_URL.split('://')
|
||||
const protocol = urlSplit[0] === 'http' ? 'ws' : 'wss'
|
||||
const websocketURL = `${protocol}://${urlSplit[1]}ws/${this.user.id}`
|
||||
const websocketURL = `${protocol}://${urlSplit[1]}ws?access_token=${this.accessToken}`
|
||||
try {
|
||||
this.user_websocket = new WebSocket(websocketURL)
|
||||
this.user_websocket.onopen = () => {
|
||||
console.log(`WebSocket connection established using ${websocketURL}.`)
|
||||
console.log(`WebSocket connection established for user ID ${this.user.id}`)
|
||||
}
|
||||
this.user_websocket.onerror = (error) => {
|
||||
console.error('WebSocket error:', error)
|
||||
@@ -195,6 +195,14 @@ export const useAuthStore = defineStore('auth', {
|
||||
if (response.csrf_token) {
|
||||
this.csrfToken = response.csrf_token
|
||||
}
|
||||
// Reconnect WebSocket with new token if currently open
|
||||
if (
|
||||
this.user_websocket &&
|
||||
this.user_websocket.readyState === WebSocket.OPEN
|
||||
) {
|
||||
this.user_websocket.close()
|
||||
this.setUserWebsocket()
|
||||
}
|
||||
return response.access_token
|
||||
}
|
||||
throw new Error('No access token in refresh response')
|
||||
|
||||
Reference in New Issue
Block a user