Refactor OAuth state cleanup after token exchange

Updates session token exchange logic to clear the oauth_state_id reference and immediately delete the associated OAuth state per OAuth 2.1 best practices. Adjusts database migration and model to set foreign key ondelete to SET NULL, ensuring referential integrity. Also updates frontend components to use formatHrZoneLabel utility for heart rate zone chart labels.
This commit is contained in:
João Vitória Silva
2025-12-22 16:29:18 +00:00
parent 7c8d4ccccc
commit 0bc739904c
6 changed files with 54 additions and 22 deletions

View File

@@ -363,13 +363,20 @@ def transform_activity_streams(activity_stream, activity, db):
def transform_activity_streams_hr(activity_stream, activity, db):
"""
Transforms an activity stream by calculating the percentage of time spent in each heart rate zone based on user details.
Transforms an activity stream by calculating the percentage of time spent
in each heart rate zone based on user details.
Args:
activity_stream: The activity stream object containing waypoints with heart rate data.
activity: The activity object associated with the stream, used to retrieve the user ID.
activity_stream: The activity stream object containing waypoints with
heart rate data.
activity: The activity object associated with the stream, used to
retrieve the user ID.
db: The database session or connection used to fetch user details.
Returns:
The activity stream object with an added 'hr_zone_percentages' attribute, which contains the percentage of time spent in each heart rate zone and their respective HR boundaries. If waypoi[...]
The activity stream object with an added 'hr_zone_percentages'
attribute, which contains the percentage of time spent in each heart
rate zone and their respective HR boundaries.
If waypoints or user details are missing, returns the original activity
stream unchanged.
Notes:
- Heart rate zones are calculated using the formula: max_heart_rate = 220 - age.
- The function expects waypoints to be a list of dicts with an "hr" key.
@@ -423,20 +430,13 @@ def transform_activity_streams_hr(activity_stream, activity, db):
np.sum((hr_values >= zone_3) & (hr_values < zone_4)),
np.sum(hr_values >= zone_4),
]
zone_percentages = [
round((count / total) * 100, 2) for count in zone_counts
]
zone_percentages = [round((count / total) * 100, 2) for count in zone_counts]
# Calculate time in seconds for each zone using the percentage
# of total_timer_time
has_timer_time = (
hasattr(activity, "total_timer_time")
and activity.total_timer_time
)
# Calculate time in seconds for each zone using the percentage of total_timer_time
has_timer_time = hasattr(activity, "total_timer_time") and activity.total_timer_time
if has_timer_time:
total_time_seconds = activity.total_timer_time
zone_time_seconds = [
int((percent / 100) * total_time_seconds)
int((percent / 100) * float(activity.total_timer_time))
for percent in zone_percentages
]
else:

View File

@@ -195,7 +195,12 @@ def upgrade() -> None:
unique=True,
)
op.create_foreign_key(
None, "users_sessions", "oauth_states", ["oauth_state_id"], ["id"]
"users_sessions_oauth_state_id_fkey",
"users_sessions",
"oauth_states",
["oauth_state_id"],
["id"],
ondelete="SET NULL",
)
# Create rotated_refresh_tokens table
@@ -308,7 +313,9 @@ def downgrade() -> None:
op.drop_index(op.f("ix_mfa_backup_codes_used"), table_name="mfa_backup_codes")
op.drop_index("idx_user_unused_codes", table_name="mfa_backup_codes")
op.drop_table("mfa_backup_codes")
op.drop_constraint(None, "users_sessions", type_="foreignkey")
op.drop_constraint(
"users_sessions_oauth_state_id_fkey", "users_sessions", type_="foreignkey"
)
op.drop_index(
op.f("ix_users_sessions_token_family_id"), table_name="users_sessions"
)

View File

@@ -210,10 +210,15 @@ def mark_tokens_exchanged(session_id: str, db: Session) -> None:
"""
Atomically mark tokens as exchanged for a session to prevent duplicate mobile token exchanges.
This function sets the tokens_exchanged flag to True for a specific session.
This function sets the tokens_exchanged flag to True for a specific session,
clears the oauth_state_id reference, and deletes the associated OAuth state.
Prevents replay attacks where multiple token exchange requests could be made
for the same session.
Per OAuth 2.1 best practices, the OAuth state parameter is ephemeral and should
be deleted immediately after successful token exchange. The session maintains
its own security mechanisms (refresh tokens, CSRF tokens) independently.
Args:
session_id (str): The unique identifier of the session.
db (Session): The SQLAlchemy database session.
@@ -234,9 +239,28 @@ def mark_tokens_exchanged(session_id: str, db: Session) -> None:
if not db_session:
raise SessionNotFoundError(f"Session {session_id} not found")
# Mark tokens as exchanged
# Store oauth_state_id before clearing (for cleanup)
oauth_state_id_to_delete = db_session.oauth_state_id
# Mark tokens as exchanged and clear OAuth state reference
# Per OAuth 2.1: state is ephemeral, only needed during authorization flow
db_session.tokens_exchanged = True
db_session.oauth_state_id = None
db.commit()
# Delete the OAuth state now that tokens are exchanged
# The state has served its CSRF protection purpose
if oauth_state_id_to_delete:
try:
oauth_state_crud.delete_oauth_state(oauth_state_id_to_delete, db)
except Exception as err:
# Log but don't fail - cleanup job will handle orphaned states
core_logger.print_to_log(
f"Failed to delete OAuth state {oauth_state_id_to_delete[:8]}... "
f"after token exchange: {err}",
"warning",
exc=err,
)
except SessionNotFoundError as err:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(err)

View File

@@ -68,7 +68,7 @@ class UsersSessions(Base):
)
oauth_state_id = Column(
String(64),
ForeignKey("oauth_states.id"),
ForeignKey("oauth_states.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="Link to OAuth state for PKCE validation",

View File

@@ -286,7 +286,7 @@ import BarChartComponent from '@/components/GeneralComponents/BarChartComponent.
// Import Notivue push
import { push } from 'notivue'
// Import the utils
import { getHrBarChartData } from '@/utils/chartUtils'
import { getHrBarChartData, formatHrZoneLabel } from '@/utils/chartUtils'
import {
formatPaceMetric,
formatPaceImperial,

View File

@@ -125,7 +125,8 @@
:barColors="hrChartData.barColors"
:timeSeconds="hrChartData.timeSeconds"
:datalabelsFormatter="
(value, context) => formatHrZoneLabel(value, hrChartData.timeSeconds[context.dataIndex])
(value, context) =>
formatHrZoneLabel(value, hrChartData.timeSeconds[context.dataIndex])
"
:title="$t('activityMandAbovePillsComponent.labelHRZones')"
/>