diff --git a/backend/app/alembic/versions/v0_16_4_migration.py b/backend/app/alembic/versions/v0_16_4_migration.py index 513641557..ca29b998a 100644 --- a/backend/app/alembic/versions/v0_16_4_migration.py +++ b/backend/app/alembic/versions/v0_16_4_migration.py @@ -94,10 +94,12 @@ def upgrade() -> None: sa.ForeignKeyConstraint( ["idp_id"], ["identity_providers.id"], + ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["user_id"], ["users.id"], + ondelete="CASCADE", ), sa.PrimaryKeyConstraint("id"), ) @@ -158,7 +160,13 @@ def upgrade() -> None: ) # Add last_activity_at column with default value = created_at op.add_column( - "users_sessions", sa.Column("last_activity_at", sa.DateTime(), nullable=True) + "users_sessions", + sa.Column( + "last_activity_at", + sa.DateTime(), + nullable=True, + comment="Last activity timestamp for idle timeout", + ), ) # Backfill existing sessions: set last_activity_at = created_at @@ -167,7 +175,13 @@ def upgrade() -> None: ) # Make column non-nullable after backfill - op.alter_column("users_sessions", "last_activity_at", nullable=False) + op.alter_column( + "users_sessions", + "last_activity_at", + nullable=False, + comment="Last activity timestamp for idle timeout", + existing_type=sa.DateTime(), + ) # ### end Alembic commands ### diff --git a/backend/app/auth/identity_providers/models.py b/backend/app/auth/identity_providers/models.py index aae49910e..0f740cbea 100644 --- a/backend/app/auth/identity_providers/models.py +++ b/backend/app/auth/identity_providers/models.py @@ -29,115 +29,103 @@ class IdentityProvider(Base): created_at (datetime): Timestamp when the provider was created. updated_at (datetime): Timestamp when the provider was last updated. user_identity_providers (list[UserIdentityProvider]): Relationship to user identity providers (many-to-many). + oauth_states (list[OAuthState]): Relationship to OAuth states. """ + __tablename__ = "identity_providers" id = Column(Integer, primary_key=True, index=True) - name = Column( - String(length=100), - nullable=False, - comment="Display name of the IdP" - ) + name = Column(String(length=100), nullable=False, comment="Display name of the IdP") slug = Column( String(length=50), nullable=False, unique=True, index=True, - comment="URL-safe identifier" + comment="URL-safe identifier", ) provider_type = Column( String(length=50), nullable=False, default="oidc", - comment="Type: oidc, oauth2, saml" + comment="Type: oidc, oauth2, saml", ) enabled = Column( Boolean, nullable=False, default=False, index=True, - comment="Whether this provider is enabled" + comment="Whether this provider is enabled", ) client_id = Column( - String(length=512), - nullable=True, - comment="OAuth2/OIDC client ID (encrypted)" + String(length=512), nullable=True, comment="OAuth2/OIDC client ID (encrypted)" ) client_secret = Column( String(length=512), nullable=True, - comment="OAuth2/OIDC client secret (encrypted)" + comment="OAuth2/OIDC client secret (encrypted)", ) issuer_url = Column( - String(length=500), - nullable=True, - comment="OIDC issuer/discovery URL" + String(length=500), nullable=True, comment="OIDC issuer/discovery URL" ) authorization_endpoint = Column( - String(length=500), - nullable=True, - comment="OAuth2/OIDC authorization endpoint" + String(length=500), nullable=True, comment="OAuth2/OIDC authorization endpoint" ) token_endpoint = Column( - String(length=500), - nullable=True, - comment="OAuth2/OIDC token endpoint" + String(length=500), nullable=True, comment="OAuth2/OIDC token endpoint" ) userinfo_endpoint = Column( - String(length=500), - nullable=True, - comment="OIDC userinfo endpoint" + String(length=500), nullable=True, comment="OIDC userinfo endpoint" ) jwks_uri = Column( String(length=500), nullable=True, - comment="OIDC JWKS URI for token verification" + comment="OIDC JWKS URI for token verification", ) scopes = Column( String(length=500), nullable=True, default="openid profile email", - comment="OAuth2/OIDC scopes to request" + comment="OAuth2/OIDC scopes to request", ) icon = Column( String(length=100), nullable=True, - comment="Icon name (FontAwesome) or custom URL" + comment="Icon name (FontAwesome) or custom URL", ) auto_create_users = Column( Boolean, nullable=False, default=True, - comment="Automatically create users on first login" + comment="Automatically create users on first login", ) sync_user_info = Column( - Boolean, - nullable=False, - default=True, - comment="Sync user info on each login" + Boolean, nullable=False, default=True, comment="Sync user info on each login" ) user_mapping = Column( - JSON, - nullable=True, - comment="JSON mapping of IdP claims to user fields" + JSON, nullable=True, comment="JSON mapping of IdP claims to user fields" ) created_at = Column( DateTime, nullable=False, server_default=func.now(), - comment="When this provider was created" + comment="When this provider was created", ) updated_at = Column( DateTime, nullable=False, server_default=func.now(), onupdate=func.now(), - comment="When this provider was last updated" + comment="When this provider was last updated", ) # Relationship to user identity providers (many-to-many through junction table) user_identity_providers = relationship( "UserIdentityProvider", back_populates="identity_providers", - cascade="all, delete-orphan" + cascade="all, delete-orphan", + ) + + # Relationship to OAuth states + oauth_states = relationship( + "OAuthState", back_populates="identity_provider", cascade="all, delete-orphan" ) diff --git a/backend/app/auth/oauth_state/models.py b/backend/app/auth/oauth_state/models.py index d60c449cc..aacdbb718 100644 --- a/backend/app/auth/oauth_state/models.py +++ b/backend/app/auth/oauth_state/models.py @@ -26,6 +26,9 @@ class OAuthState(Base): created_at: Timestamp for expiry calculation. expires_at: Hard expiry at 10 minutes. used: Prevents replay attacks. + identity_provider: Relationship to IdentityProvider model. + user: Relationship to User model (nullable). + users_sessions: Relationship to UsersSessions model. """ __tablename__ = "oauth_states" @@ -99,5 +102,7 @@ class OAuthState(Base): comment="True when state is consumed (prevents replay)", ) - # Relationship to UsersSessions for reverse lookup + # Relationships + identity_provider = relationship("IdentityProvider", back_populates="oauth_states") + user = relationship("User", back_populates="oauth_states") users_sessions = relationship("UsersSessions", back_populates="oauth_state") diff --git a/backend/app/users/user/models.py b/backend/app/users/user/models.py index 92c62b4b8..ff5c946db 100644 --- a/backend/app/users/user/models.py +++ b/backend/app/users/user/models.py @@ -51,6 +51,7 @@ class User(Base): notifications: List of notifications for the user. goals: List of user goals. user_identity_providers: List of identity providers linked to the user. + oauth_states: List of OAuth states for the user (link mode). """ __tablename__ = "users" @@ -266,3 +267,10 @@ class User(Base): back_populates="user", cascade="all, delete-orphan", ) + + # Establish a one-to-many relationship with oauth_states + oauth_states = relationship( + "OAuthState", + back_populates="user", + cascade="all, delete-orphan", + )