mirror of
https://github.com/joaovitoriasilva/endurain.git
synced 2026-01-09 15:57:59 -05:00
Implements optional session idle and absolute timeout logic, including new environment variables for configuration. Adds last_activity_at to sessions, enforces timeouts on token refresh, and introduces a scheduler job to clean up idle sessions. Also introduces progressive lockout for failed logins and updates documentation and examples accordingly.
408 lines
14 KiB
Python
408 lines
14 KiB
Python
"""CRUD operations for user sessions."""
|
|
|
|
from datetime import datetime, timezone
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
import auth.oauth_state.models as oauth_state_models
|
|
import session.models as session_models
|
|
import session.schema as session_schema
|
|
|
|
import core.logger as core_logger
|
|
|
|
|
|
class SessionNotFoundError(Exception):
|
|
"""
|
|
Exception raised when a requested session cannot be found.
|
|
|
|
This error is typically used to indicate that an operation requiring a session
|
|
failed because the session does not exist in the data store.
|
|
|
|
Attributes:
|
|
message (str): Optional explanation of the error.
|
|
"""
|
|
|
|
|
|
def get_user_sessions(
|
|
user_id: int, db: Session
|
|
) -> list[session_models.UsersSessions] | None:
|
|
"""
|
|
Retrieve all session records for a given user, ordered by creation date descending.
|
|
|
|
Args:
|
|
user_id (int): The ID of the user whose sessions are to be retrieved.
|
|
db (Session): SQLAlchemy database session.
|
|
|
|
Returns:
|
|
list[session_models.UsersSessions]: List of session objects for the user, ordered by most recent.
|
|
None: If no sessions are found for the user.
|
|
|
|
Raises:
|
|
HTTPException: If an error occurs during retrieval, raises a 500 Internal Server Error.
|
|
"""
|
|
try:
|
|
db_sessions = (
|
|
db.query(session_models.UsersSessions)
|
|
.filter(session_models.UsersSessions.user_id == user_id)
|
|
.order_by(session_models.UsersSessions.created_at.desc())
|
|
.all()
|
|
)
|
|
|
|
# If the no session was found, return None
|
|
if db_sessions is None:
|
|
return None
|
|
|
|
# Return the sessions
|
|
return db_sessions
|
|
except Exception as err:
|
|
# Log the exception
|
|
core_logger.print_to_log(f"Error in get_user_sessions: {err}", "error", exc=err)
|
|
|
|
# Raise an HTTPException with a 500 Internal Server Error status code
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to retrieve sessions",
|
|
) from err
|
|
|
|
|
|
def get_session_by_id(
|
|
session_id: str, db: Session
|
|
) -> session_models.UsersSessions | None:
|
|
"""
|
|
Retrieve a user session from the database using a refresh token, ensuring the session is not expired.
|
|
|
|
Args:
|
|
hashed_refresh_token (str): The hashed refresh token associated with the user session.
|
|
db (Session): The SQLAlchemy database session.
|
|
|
|
Returns:
|
|
UsersSessions | None: The user session object if found and not expired, otherwise None.
|
|
|
|
Raises:
|
|
HTTPException: If an error occurs during retrieval, raises a 500 Internal Server Error.
|
|
"""
|
|
try:
|
|
# Get the session from the database, ensure it's not expired
|
|
db_session = (
|
|
db.query(session_models.UsersSessions)
|
|
.filter(session_models.UsersSessions.id == session_id)
|
|
.filter(
|
|
session_models.UsersSessions.expires_at > datetime.now(timezone.utc)
|
|
)
|
|
.first()
|
|
)
|
|
|
|
# Return the session
|
|
return db_session
|
|
except Exception as err:
|
|
# Log the exception
|
|
core_logger.print_to_log(f"Error in get_session_by_id: {err}", "error", exc=err)
|
|
|
|
# Raise an HTTPException with a 500 Internal Server Error status code
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to retrieve session",
|
|
) from err
|
|
|
|
|
|
def get_session_with_oauth_state(
|
|
session_id: str, db: Session
|
|
) -> tuple[session_models.UsersSessions, oauth_state_models.OAuthState | None] | None:
|
|
"""
|
|
Retrieve a session with its associated OAuthState for token exchange validation.
|
|
|
|
This function performs a query to retrieve a session along with its
|
|
linked OAuth state record (if any). Used during mobile token exchange to
|
|
validate PKCE and ensure the session is valid.
|
|
|
|
Args:
|
|
session_id (str): The unique identifier of the session.
|
|
db (Session): The SQLAlchemy database session.
|
|
|
|
Returns:
|
|
tuple[UsersSessions, OAuthState | None] | None: A tuple of (session, oauth_state)
|
|
where oauth_state may be None if not linked. Returns None if session not found.
|
|
|
|
Raises:
|
|
HTTPException: If an error occurs during retrieval (500).
|
|
"""
|
|
try:
|
|
# Query session
|
|
db_session = (
|
|
db.query(session_models.UsersSessions)
|
|
.filter(session_models.UsersSessions.id == session_id)
|
|
.filter(
|
|
session_models.UsersSessions.expires_at > datetime.now(timezone.utc)
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not db_session:
|
|
return None
|
|
|
|
# Get OAuth state if linked
|
|
oauth_state = None
|
|
if db_session.oauth_state_id:
|
|
oauth_state = (
|
|
db.query(oauth_state_models.OAuthState)
|
|
.filter(oauth_state_models.OAuthState.id == db_session.oauth_state_id)
|
|
.first()
|
|
)
|
|
|
|
return (db_session, oauth_state)
|
|
except Exception as err:
|
|
# Log the exception
|
|
core_logger.print_to_log(
|
|
f"Error in get_session_with_oauth_state: {err}", "error", exc=err
|
|
)
|
|
|
|
# Raise an HTTPException with a 500 Internal Server Error status code
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to retrieve session with OAuth state",
|
|
) from err
|
|
|
|
|
|
def create_session(
|
|
session: session_schema.UsersSessions, db: Session
|
|
) -> session_models.UsersSessions:
|
|
"""
|
|
Creates a new user session in the database.
|
|
|
|
Args:
|
|
session (session_schema.UsersSessions): The session data to be created.
|
|
db (Session): The SQLAlchemy database session.
|
|
|
|
Returns:
|
|
session_models.UsersSessions: The newly created session object.
|
|
|
|
Raises:
|
|
HTTPException: If an error occurs during session creation, raises a 500 Internal Server Error.
|
|
"""
|
|
try:
|
|
# Create a new session using model_dump
|
|
db_session = session_models.UsersSessions(**session.model_dump())
|
|
|
|
# Add the session to the database
|
|
db.add(db_session)
|
|
db.commit()
|
|
db.refresh(db_session)
|
|
|
|
# Return the session
|
|
return db_session
|
|
except Exception as err:
|
|
# Rollback the transaction
|
|
db.rollback()
|
|
|
|
# Log the exception
|
|
core_logger.print_to_log(f"Error in create_session: {err}", "error", exc=err)
|
|
|
|
# Raise an HTTPException with a 500 Internal Server Error status code
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to create session",
|
|
) from err
|
|
|
|
|
|
def mark_tokens_exchanged(session_id: str, db: Session) -> None:
|
|
"""
|
|
Atomically mark tokens as exchanged for a session to prevent duplicate mobile token exchanges.
|
|
|
|
This function sets the tokens_exchanged flag to True for a specific session.
|
|
Prevents replay attacks where multiple token exchange requests could be made
|
|
for the same session.
|
|
|
|
Args:
|
|
session_id (str): The unique identifier of the session.
|
|
db (Session): The SQLAlchemy database session.
|
|
|
|
Raises:
|
|
SessionNotFoundError: If the session does not exist.
|
|
HTTPException: If an error occurs during the update (500).
|
|
"""
|
|
try:
|
|
# Get the session from the database
|
|
db_session = (
|
|
db.query(session_models.UsersSessions)
|
|
.filter(session_models.UsersSessions.id == session_id)
|
|
.first()
|
|
)
|
|
|
|
# Check if the session exists
|
|
if not db_session:
|
|
raise SessionNotFoundError(f"Session {session_id} not found")
|
|
|
|
# Mark tokens as exchanged
|
|
db_session.tokens_exchanged = True
|
|
db.commit()
|
|
except SessionNotFoundError as err:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(err)
|
|
) from err
|
|
except Exception as err:
|
|
# Rollback the transaction
|
|
db.rollback()
|
|
|
|
# Log the exception
|
|
core_logger.print_to_log(
|
|
f"Error in mark_tokens_exchanged: {err}", "error", exc=err
|
|
)
|
|
|
|
# Raise an HTTPException with a 500 Internal Server Error status code
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to mark tokens as exchanged",
|
|
) from err
|
|
|
|
|
|
def edit_session(session: session_schema.UsersSessions, db: Session) -> None:
|
|
"""
|
|
Edits an existing user session in the database.
|
|
|
|
This function retrieves a session by its ID, updates its fields with the provided values,
|
|
and commits the changes to the database. If the session does not exist, it raises a 404 error.
|
|
If any other exception occurs, it rolls back the transaction, logs the error, and raises a 500 error.
|
|
|
|
Args:
|
|
session (session_schema.UsersSessions): The session data containing updated fields.
|
|
db (Session): The SQLAlchemy database session.
|
|
|
|
Raises:
|
|
HTTPException: If the session is not found (404) or if an error occurs during update (500).
|
|
"""
|
|
try:
|
|
# Get the session from the database
|
|
db_session = (
|
|
db.query(session_models.UsersSessions)
|
|
.filter(session_models.UsersSessions.id == session.id)
|
|
.first()
|
|
)
|
|
|
|
# Check if the session exists, if not raises exception
|
|
if not db_session:
|
|
raise SessionNotFoundError(f"Session {session.id} not found")
|
|
|
|
# Dictionary of the fields to update if they are not None
|
|
session_data = session.model_dump(exclude_unset=True)
|
|
# Iterate over the fields and update the db_session dynamically
|
|
for key, value in session_data.items():
|
|
setattr(db_session, key, value)
|
|
|
|
# Commit the transaction
|
|
db.commit()
|
|
except SessionNotFoundError as err:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(err)
|
|
) from err
|
|
except Exception as err:
|
|
# Rollback the transaction
|
|
db.rollback()
|
|
|
|
# Log the exception
|
|
core_logger.print_to_log(f"Error in edit_session: {err}", "error", exc=err)
|
|
|
|
# Raise an HTTPException with a 500 Internal Server Error status code
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to update session",
|
|
) from err
|
|
|
|
|
|
def delete_session(session_id: str, user_id: int, db: Session) -> None:
|
|
"""
|
|
Deletes a user session from the database.
|
|
|
|
Args:
|
|
session_id (str): The unique identifier of the session to delete.
|
|
user_id (int): The ID of the user associated with the session.
|
|
db (Session): The SQLAlchemy database session.
|
|
|
|
Raises:
|
|
HTTPException: If the session is not found (404) or if an error occurs during deletion (500).
|
|
|
|
Notes:
|
|
- Rolls back the transaction and logs the error if an unexpected exception occurs.
|
|
- Commits the transaction if the session is successfully deleted.
|
|
"""
|
|
try:
|
|
# Delete the session
|
|
num_deleted = (
|
|
db.query(session_models.UsersSessions)
|
|
.filter(
|
|
session_models.UsersSessions.id == session_id,
|
|
session_models.UsersSessions.user_id == user_id,
|
|
)
|
|
.delete()
|
|
)
|
|
|
|
# Check if the session was found and deleted
|
|
if num_deleted == 0:
|
|
raise SessionNotFoundError(
|
|
f"Session {session_id} not found for user {user_id}"
|
|
)
|
|
|
|
# Commit the transaction
|
|
db.commit()
|
|
except SessionNotFoundError as err:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail=str(err)
|
|
) from err
|
|
except Exception as err:
|
|
# Rollback the transaction
|
|
db.rollback()
|
|
|
|
# Log the exception
|
|
core_logger.print_to_log(f"Error in delete_session: {err}", "error", exc=err)
|
|
|
|
# Raise an HTTPException with a 500 Internal Server Error status code
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to delete session",
|
|
) from err
|
|
|
|
|
|
def delete_idle_sessions(cutoff_time: datetime, db: Session) -> int:
|
|
"""
|
|
Delete sessions that have exceeded the idle timeout threshold.
|
|
|
|
This function removes all sessions where the last_activity_at timestamp
|
|
is older than the provided cutoff time. Used by the cleanup scheduler
|
|
to periodically remove inactive sessions.
|
|
|
|
Args:
|
|
cutoff_time (datetime): Sessions with last_activity_at before this time will be deleted.
|
|
db (Session): The SQLAlchemy database session.
|
|
|
|
Returns:
|
|
int: The number of sessions deleted.
|
|
|
|
Raises:
|
|
HTTPException: If an error occurs during deletion (500).
|
|
"""
|
|
try:
|
|
# Delete sessions with last_activity_at older than cutoff_time
|
|
num_deleted = (
|
|
db.query(session_models.UsersSessions)
|
|
.filter(session_models.UsersSessions.last_activity_at < cutoff_time)
|
|
.delete()
|
|
)
|
|
|
|
# Commit the transaction
|
|
db.commit()
|
|
|
|
return num_deleted
|
|
except Exception as err:
|
|
# Rollback the transaction
|
|
db.rollback()
|
|
|
|
# Log the exception
|
|
core_logger.print_to_log(
|
|
f"Error in delete_idle_sessions: {err}", "error", exc=err
|
|
)
|
|
|
|
# Raise an HTTPException with a 500 Internal Server Error status code
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to delete idle sessions",
|
|
) from err
|