mirror of
https://github.com/joaovitoriasilva/endurain.git
synced 2026-01-10 08:17:59 -05:00
Refactor auth and identity provider modules
Moved authentication and identity provider logic from 'session' and 'identity_providers' modules into a new 'auth' package. Updated all relevant imports and references throughout the backend to use the new structure. Added language-specific coding standards files for Python and JavaScript/TypeScript in .github/instructions/. Minor bugfixes and code style improvements in Alembic migrations and routers. Profile export and import logic now also includes notifications and user idps info
This commit is contained in:
256
.github/copilot-instructions.md
vendored
256
.github/copilot-instructions.md
vendored
@@ -2,18 +2,18 @@
|
||||
|
||||
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
|
||||
|
||||
**Note:** Language-specific coding standards are in separate instruction files:
|
||||
- **Python/Backend:** `.github/instructions/python.instructions.md`
|
||||
- **TypeScript/JavaScript/Vue/Frontend:** `.github/instructions/javatsscript.instructions.md`
|
||||
|
||||
---
|
||||
|
||||
## AI / Copilot Behavior Guidelines
|
||||
|
||||
- Always follow instructions in this file before inferring new patterns.
|
||||
- Do **not** suggest changes to Dockerfiles or CI/CD workflows unless they clearly violate instructions here.
|
||||
- For Vue components:
|
||||
- Use `<script setup lang="ts">` syntax.
|
||||
- Follow the 10-section component structure.
|
||||
- Prefer utilities from `/utils` and constants from `/constants`.
|
||||
- Do not alter port numbers, environment variables, or framework versions unless explicitly instructed.
|
||||
- Prefer using existing validation, constants, and type files rather than creating new ones.
|
||||
- Prefer using existing patterns, utilities, and files rather than creating new ones.
|
||||
- Use the timing benchmarks in this document to evaluate build success or performance anomalies.
|
||||
- **Documentation files:** When creating new development documentation files (e.g., `BACKEND_AUTH_DEVELOPMENT_LOG.md`, `OBSERVABILITY_STRATEGY.md`), store them in the `devdocs/` folder. This folder is gitignored and used for local development documentation that should not be committed to the repository.
|
||||
- **Development/helper scripts:** When creating new development/helper scripts, store them in the `devscripts/` folder. This folder is gitignored and used for local development scripts that should not be committed to the repository.
|
||||
@@ -25,66 +25,74 @@ Always reference these instructions first and fallback to search or bash command
|
||||
|
||||
---
|
||||
|
||||
## Working Effectively
|
||||
## Project Overview
|
||||
|
||||
Endurain is a self-hosted fitness tracking application built with Vue.js frontend, Python FastAPI backend, and PostgreSQL database. The primary development workflow uses Docker, but frontend-only development is supported for faster UI iteration.
|
||||
Endurain is a self-hosted fitness tracking application with:
|
||||
- **Frontend:** Vue.js 3 + TypeScript + Vite + Bootstrap 5
|
||||
- **Backend:** Python 3.13 + FastAPI + SQLAlchemy + Alembic
|
||||
- **Database:** PostgreSQL
|
||||
- **Integrations:** Strava, Garmin Connect
|
||||
- **File Import Support:** .gpx, .tcx, .fit, .gz
|
||||
- **Authentication:** JWT with 15-minute access tokens, 7 days refresh tokens
|
||||
- **Deployment:** Docker multi-stage builds, multi-architecture images (amd64, arm64)
|
||||
|
||||
### Prerequisites and Environment Setup
|
||||
---
|
||||
|
||||
- **Node.js:** v20.19.4 (for frontend)
|
||||
- **Python:** v3.13 (backend)
|
||||
- **Docker:** required for full-stack development and CI/CD builds
|
||||
- **Poetry:** for backend dependency management (when not using Docker)
|
||||
## Development Workflows
|
||||
|
||||
### Quick Start Development Setup
|
||||
### Prerequisites
|
||||
- **Node.js:** v20.19.4 (for frontend development)
|
||||
- **Python:** v3.13 (for backend development)
|
||||
- **Docker:** Required for full-stack development and CI/CD builds
|
||||
- **Poetry:** For backend dependency management (when not using Docker)
|
||||
|
||||
### Quick Start
|
||||
|
||||
1. Clone repository: `git clone https://github.com/joaovitoriasilva/endurain.git`
|
||||
2. Navigate to the project root
|
||||
3. Choose development approach:
|
||||
- **Frontend Only** – see _Frontend Development_
|
||||
- **Full Stack** – use _Docker Development Setup_
|
||||
- **Frontend Only** – see _Frontend Development_ below
|
||||
- **Full Stack** – use _Docker Development Setup_ below
|
||||
|
||||
### Frontend Development (Recommended for UI changes)
|
||||
|
||||
Fast iteration workflow for frontend-only development:
|
||||
|
||||
- Navigate: `cd frontend/app`
|
||||
- Install dependencies: `npm install` (≈20 seconds)
|
||||
- Start dev server: `npm run dev` (port 5173 or 5174 if occupied)
|
||||
- Build frontend: `npm run build` (≈9 seconds)
|
||||
- Format code: `npm run format` (≈5 seconds)
|
||||
- **Note:** ESLint configuration pending migration to flat config format (lint fails currently)
|
||||
- **Note:** Unit tests not yet implemented (`npm run test:unit` exits with “No test files found”)
|
||||
|
||||
**Notes:**
|
||||
- ESLint configuration pending migration to flat config format (lint fails currently)
|
||||
- Unit tests not yet implemented (`npm run test:unit` exits with "No test files found")
|
||||
|
||||
**Pre-commit validation:**
|
||||
- Run `npm run format` before commits
|
||||
- Confirm successful `npm run build`
|
||||
- Ensure `npm run dev` runs without warnings/errors
|
||||
|
||||
### Docker Development (Full Stack)
|
||||
|
||||
Complete environment for frontend + backend + database:
|
||||
|
||||
- Build unified image: `docker build -f docker/Dockerfile -t unified-image .`
|
||||
- **Caution:** Docker builds may take 15–20 minutes. Avoid canceling unless hung for 30+ minutes.
|
||||
- **CI Caveat:** SSL certificate errors can occur during CI builds; document but don’t bypass validation.
|
||||
- Create docker-compose.yml from the provided example.
|
||||
- **CI Caveat:** SSL certificate errors can occur during CI builds; document but don't bypass validation.
|
||||
- Create `docker-compose.yml` from the provided example
|
||||
- Start services: `docker compose up -d`
|
||||
- Stop services: `docker compose down`
|
||||
|
||||
### Backend Development (Advanced)
|
||||
|
||||
- Python 3.13 backend managed by Poetry
|
||||
- Codebase in `backend/` with `pyproject.toml`
|
||||
- Use Docker if system Python < 3.13
|
||||
Python development without Docker (requires Python 3.13):
|
||||
|
||||
- Navigate: `cd backend`
|
||||
- Install Poetry: `pip install poetry`
|
||||
- Install dependencies: `poetry install`
|
||||
|
||||
#### Backend Code Quality Standards
|
||||
|
||||
- **Modern Python Syntax (Python 3.13):**
|
||||
- Use union types with `|` operator: `str | None` instead of `Optional[str]`
|
||||
- Use built-in generics: `list[str]` instead of `List[str]`, `dict[str, int]` instead of `Dict[str, int]`
|
||||
- No imports from `typing` module for `Optional`, `List`, `Dict`, `Tuple`, `Set` – use native syntax
|
||||
- Example: `def get_user(user_id: int) -> User | None:`
|
||||
- Example: `async def get_activities(limit: int = 10) -> list[Activity]:`
|
||||
- Use type hints (`def foo(x: int) -> str:`)
|
||||
- Standard FastAPI project layout: `routers/`, `schemas/`, `services/`, `models/`
|
||||
- Public functions/classes must have docstrings
|
||||
- Core logic should live in `services/`, not routers
|
||||
- Format with `black` and `isort`
|
||||
- Include at least one unit test per router or service
|
||||
- Backend codebase in `backend/app/` with `pyproject.toml`
|
||||
- **Use Docker if system Python < 3.13**
|
||||
|
||||
---
|
||||
|
||||
@@ -102,12 +110,6 @@ Endurain is a self-hosted fitness tracking application built with Vue.js fronten
|
||||
- Document SSL issues but complete functional validation
|
||||
- Ensure built container runs successfully (even if CI SSL fails)
|
||||
|
||||
### Pre-commit Validation
|
||||
|
||||
- Run `npm run format` before commits
|
||||
- Confirm successful `npm run build`
|
||||
- Ensure `npm run dev` runs without warnings/errors
|
||||
|
||||
---
|
||||
|
||||
## Common Tasks
|
||||
@@ -163,172 +165,4 @@ Repository root:
|
||||
- `release.yml`: publishes Docker images to GitHub Container Registry
|
||||
- Use workflow dispatch for manual triggers
|
||||
|
||||
---
|
||||
|
||||
## Code Quality Standards (10/10 Quality)
|
||||
|
||||
### Frontend Component Standards
|
||||
|
||||
#### TypeScript
|
||||
|
||||
- Always use `<script setup lang="ts">`
|
||||
- **Modern Type Inference:**
|
||||
- Use `ref<T>()` with generic parameter: `const user = ref<User | null>(null)`
|
||||
- **Avoid** redundant `Ref<T>` annotations: ~~`const user: Ref<User | null> = ref(null)`~~
|
||||
- Let TypeScript infer types when obvious: `const count = ref(0)` (infers `Ref<number>`)
|
||||
- Let `computed()` infer return types from callback
|
||||
- **Avoid** redundant `ComputedRef<T>` annotations
|
||||
- Explicit typing for function parameters and return types
|
||||
- Type imports for complex types: `Router`, `RouteLocationNormalizedLoaded`
|
||||
- No implicit `any` types
|
||||
- Centralized imports from `/types/index.ts`
|
||||
|
||||
#### Documentation
|
||||
|
||||
- Each component must include a clear, purposeful JSDoc overview
|
||||
- Document complex logic; skip redundant auto-generated comments
|
||||
|
||||
#### Component Structure (10 Sections)
|
||||
|
||||
1. Fileoverview JSDoc
|
||||
2. Imports
|
||||
3. Composables & Stores
|
||||
4. Reactive State
|
||||
5. Computed Properties
|
||||
6. UI Interaction Handlers
|
||||
7. Validation Logic
|
||||
8. Main Logic
|
||||
9. Lifecycle Hooks
|
||||
10. Component Definition
|
||||
|
||||
#### Centralized Architecture
|
||||
|
||||
- **Validation utilities:** `/utils/validationUtils.ts`
|
||||
- `isValidPassword()`, `passwordsMatch()`, `isValidEmail()`, `sanitizeInput()`
|
||||
- Password strength analysis functions
|
||||
- **Constants:** `/constants/httpConstants.ts`
|
||||
- `HTTP_STATUS` enum for status codes
|
||||
- `extractStatusCode()` for error response parsing
|
||||
- `QUERY_PARAM_TRUE` for URL parameters
|
||||
- **Type definitions:** `/types/index.ts`
|
||||
- `ErrorWithResponse`, `NotificationType`, `ActionButtonType`
|
||||
- **Bootstrap modals:** `/composables/useBootstrapModal.ts`
|
||||
- Modal lifecycle management
|
||||
|
||||
#### UI/UX Standards
|
||||
|
||||
- Use Bootstrap 5 `form-floating` classes
|
||||
- Accessibility:
|
||||
- All interactive elements must have `aria-label`
|
||||
- Use `aria-live="polite"` for validation messages
|
||||
- Ensure full keyboard navigation
|
||||
- Responsive across mobile, tablet, desktop
|
||||
- Always include loading states and graceful error handling
|
||||
|
||||
#### Accessibility Testing Checklist
|
||||
|
||||
- Verify tab navigation for all forms
|
||||
- Check color contrast meets WCAG AA
|
||||
- Validate `aria-label` coverage
|
||||
- Confirm focus outlines visible and consistent
|
||||
- Test screen reader compatibility (NVDA/VoiceOver)
|
||||
|
||||
#### Reference Implementations (10/10 Quality)
|
||||
|
||||
Study these files as templates when creating/refactoring components:
|
||||
|
||||
- **LoginView.vue** (437 lines) - Authentication with MFA support
|
||||
- **SignUpView.vue** (611 lines) - Registration with optional fields
|
||||
- **ResetPasswordView.vue** (~320 lines) - Password reset with token validation
|
||||
- **ModalComponentEmailInput.vue** - RFC 5322 email validation
|
||||
|
||||
### Backend Component Standards - Python Code Style Requirements
|
||||
|
||||
#### Modern Python Syntax (Python 3.13+)
|
||||
- Use modern type hint syntax: `int | None`, `list[str]`,
|
||||
`dict[str, Any]`
|
||||
- Do NOT use `typing.Optional`, `typing.List`, `typing.Dict`, etc.
|
||||
- Target Python 3.13+ features and syntax
|
||||
|
||||
#### PEP 8 Line Limits
|
||||
- Code lines: **79 characters maximum**
|
||||
- Comments and docstrings: **72 characters maximum**
|
||||
- Enforce strictly - no exceptions
|
||||
|
||||
#### Docstring Standard (PEP 257)
|
||||
- **Always follow PEP 257** with Args/Returns/Raises sections
|
||||
- **Format**: One-line summary, blank line, then
|
||||
Args/Returns/Raises sections
|
||||
- **Always include Args/Returns/Raises** even when parameters seem
|
||||
obvious
|
||||
- **NO examples** in docstrings - keep in external docs or tests
|
||||
- **NO extended explanations** - one-line summary + sections only
|
||||
- **Keep concise** - describe what, not how
|
||||
|
||||
**Function docstring format:**
|
||||
```python
|
||||
def function(param: str) -> int:
|
||||
"""
|
||||
One-line summary of what this does.
|
||||
|
||||
Args:
|
||||
param: Description of param.
|
||||
|
||||
Returns:
|
||||
Description of return value.
|
||||
|
||||
Raises:
|
||||
ValueError: When param is invalid.
|
||||
"""
|
||||
```
|
||||
|
||||
**Class docstring format:**
|
||||
```python
|
||||
class MyClass:
|
||||
"""
|
||||
One-line summary of the class.
|
||||
|
||||
Attributes:
|
||||
attr: Description of attribute.
|
||||
"""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Code Review Checklist
|
||||
|
||||
- ✅ TypeScript with explicit types
|
||||
- ✅ Meaningful JSDoc documentation
|
||||
- ✅ 10-section component structure
|
||||
- ✅ Centralized use of validation/constants/types
|
||||
- ✅ Bootstrap 5 form-floating usage
|
||||
- ✅ Accessibility attributes verified
|
||||
- ✅ Code formatted successfully (`npm run format`)
|
||||
- ✅ Build passes (`npm run build`)
|
||||
- ✅ Manual browser validation complete
|
||||
- ✅ No console warnings/errors
|
||||
|
||||
---
|
||||
|
||||
## Development Workflow
|
||||
|
||||
1. Edit code in `frontend/app/src/`
|
||||
2. Run `npm run dev` to test locally
|
||||
3. Run `npm run format` before commit
|
||||
4. Build production output with `npm run build`
|
||||
5. For backend changes, use Docker setup
|
||||
6. Validate full UI flow manually
|
||||
|
||||
---
|
||||
|
||||
## Architecture Notes
|
||||
|
||||
- **Frontend:** Vue.js 3, Vite, Bootstrap 5, Chart.js, Leaflet
|
||||
- **Backend:** FastAPI, SQLAlchemy, Alembic
|
||||
- **Database:** PostgreSQL
|
||||
- **Integrations:** Strava, Garmin Connect
|
||||
- **File Imports:** .gpx, .tcx, .fit
|
||||
- **Auth:** JWT with 15-minute access tokens
|
||||
- **Deployment:** Docker multi-stage builds, multi-architecture images
|
||||
|
||||
---
|
||||
130
.github/instructions/javatsscript.instructions.md
vendored
Normal file
130
.github/instructions/javatsscript.instructions.md
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
---
|
||||
applyTo: '**/*.ts,**/*.js,**/*.vue'
|
||||
---
|
||||
# Project Context
|
||||
- **Framework:** Vue.js 3 with TypeScript support
|
||||
- **Build Tool:** Vite
|
||||
- **UI Framework:** Bootstrap 5
|
||||
- **Additional Libraries:** Chart.js, Leaflet
|
||||
- **Project Structure:** All frontend code in `frontend/app/src/`
|
||||
|
||||
# Development Setup
|
||||
- **Navigate:** `cd frontend/app`
|
||||
- **Install dependencies:** `npm install` (≈20 seconds)
|
||||
- **Dev server:** `npm run dev` (port 5173 or 5174 if occupied)
|
||||
- **Build:** `npm run build` (≈9 seconds)
|
||||
- **Format:** `npm run format` (≈5 seconds)
|
||||
|
||||
# TypeScript Standards
|
||||
|
||||
## Modern Type Inference
|
||||
- Always use `<script setup lang="ts">` syntax
|
||||
- Use `ref<T>()` with generic parameter:
|
||||
`const user = ref<User | null>(null)`
|
||||
- **AVOID** redundant `Ref<T>` annotations:
|
||||
~~`const user: Ref<User | null> = ref(null)`~~
|
||||
- Let TypeScript infer types when obvious:
|
||||
`const count = ref(0)` (infers `Ref<number>`)
|
||||
- Let `computed()` infer return types from callback
|
||||
- **AVOID** redundant `ComputedRef<T>` annotations
|
||||
|
||||
## Type Safety
|
||||
- Explicit typing for function parameters and return types
|
||||
- Type imports for complex types: `Router`,
|
||||
`RouteLocationNormalizedLoaded`
|
||||
- No implicit `any` types
|
||||
- Centralized imports from `/types/index.ts`
|
||||
|
||||
# Component Structure (10 Sections)
|
||||
|
||||
All Vue components must follow this exact structure:
|
||||
|
||||
1. **Fileoverview JSDoc** - Clear, purposeful component
|
||||
description
|
||||
2. **Imports** - All dependencies
|
||||
3. **Composables & Stores** - Router, stores, composables
|
||||
4. **Reactive State** - ref() declarations
|
||||
5. **Computed Properties** - computed() declarations
|
||||
6. **UI Interaction Handlers** - Button clicks, form events
|
||||
7. **Validation Logic** - Form validation functions
|
||||
8. **Main Logic** - Core business logic
|
||||
9. **Lifecycle Hooks** - onMounted, onBeforeUnmount, etc.
|
||||
10. **Component Definition** - defineExpose if needed
|
||||
|
||||
# Documentation Standards
|
||||
- Each component must include clear, purposeful JSDoc overview
|
||||
- Document complex logic
|
||||
- Skip redundant auto-generated comments
|
||||
- Focus on "why" not "what" for non-obvious code
|
||||
|
||||
# Centralized Architecture
|
||||
|
||||
## Validation Utilities (`/utils/validationUtils.ts`)
|
||||
- `isValidPassword()` - Password validation
|
||||
- `passwordsMatch()` - Password confirmation
|
||||
- `isValidEmail()` - RFC 5322 email validation
|
||||
- `sanitizeInput()` - Input sanitization
|
||||
- Password strength analysis functions
|
||||
|
||||
## Constants (`/constants/httpConstants.ts`)
|
||||
- `HTTP_STATUS` enum for HTTP status codes
|
||||
- `extractStatusCode()` for error response parsing
|
||||
- `QUERY_PARAM_TRUE` for URL parameters
|
||||
|
||||
## Type Definitions (`/types/index.ts`)
|
||||
- `ErrorWithResponse` - Error handling type
|
||||
- `NotificationType` - Notification types
|
||||
- `ActionButtonType` - Button action types
|
||||
|
||||
## Bootstrap Modals (`/composables/useBootstrapModal.ts`)
|
||||
- Modal lifecycle management
|
||||
- Centralized modal control
|
||||
|
||||
# UI/UX Standards
|
||||
|
||||
## Bootstrap 5
|
||||
- Use `form-floating` classes for all form inputs
|
||||
- Follow Bootstrap 5 component patterns
|
||||
- Maintain consistent spacing and layout
|
||||
|
||||
## Accessibility Requirements
|
||||
- **ARIA labels:** All interactive elements must have
|
||||
`aria-label`
|
||||
- **Live regions:** Use `aria-live="polite"` for validation
|
||||
messages
|
||||
- **Keyboard navigation:** Ensure full keyboard navigation
|
||||
- **Focus management:** Visible and consistent focus outlines
|
||||
- **Screen readers:** Test with NVDA/VoiceOver
|
||||
|
||||
## Responsive Design
|
||||
- Support mobile, tablet, and desktop viewports
|
||||
- Test across different screen sizes
|
||||
- Use Bootstrap responsive utilities
|
||||
|
||||
## User Feedback
|
||||
- Always include loading states for async operations
|
||||
- Graceful error handling with user-friendly messages
|
||||
- Clear validation feedback
|
||||
- Appropriate use of notifications
|
||||
|
||||
# Reference Implementations (10/10 Quality)
|
||||
|
||||
Study these files as templates for new components:
|
||||
|
||||
- **`LoginView.vue`** (437 lines) - Authentication with MFA
|
||||
- **`SignUpView.vue`** (611 lines) - Registration with optional
|
||||
fields
|
||||
- **`ResetPasswordView.vue`** (~320 lines) - Password reset with
|
||||
token validation
|
||||
- **`ModalComponentEmailInput.vue`** - RFC 5322 email validation
|
||||
|
||||
# Pre-commit Checklist
|
||||
- ✅ Run `npm run format` before commits
|
||||
- ✅ Confirm `npm run build` succeeds
|
||||
- ✅ Ensure `npm run dev` runs without warnings/errors
|
||||
- ✅ TypeScript types correct (no implicit any)
|
||||
- ✅ Accessibility attributes verified
|
||||
- ✅ Component follows 10-section structure
|
||||
- ✅ Uses centralized utilities/constants/types
|
||||
- ✅ Bootstrap 5 classes applied correctly
|
||||
- ✅ Manual browser validation complete
|
||||
65
.github/instructions/python.instructions.md
vendored
Normal file
65
.github/instructions/python.instructions.md
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
---
|
||||
applyTo: '**/*.py'
|
||||
---
|
||||
# Project Context
|
||||
- **Python Version:** 3.13+ (required)
|
||||
- **Framework:** FastAPI with SQLAlchemy ORM and Alembic migrations
|
||||
- **Dependency Management:** Poetry (see `backend/pyproject.toml`)
|
||||
- **Project Structure:** All backend code in `backend/app/`
|
||||
|
||||
# Development Setup
|
||||
- **Install Poetry:** `pip install poetry`
|
||||
- **Install dependencies:** `poetry install` (in `backend/`
|
||||
directory)
|
||||
- **Use Docker:** If system Python < 3.13, use Docker for
|
||||
development
|
||||
|
||||
# Modern Python Syntax (Python 3.13+)
|
||||
- Use modern type hint syntax: `int | None`, `list[str]`,
|
||||
`dict[str, Any]`
|
||||
- Do NOT use `typing.Optional`, `typing.List`, `typing.Dict`, etc.
|
||||
- Target Python 3.13+ features and syntax
|
||||
- Always prioritize readability and clarity
|
||||
|
||||
# PEP 8 Line Limits
|
||||
- Code lines: **79 characters maximum**
|
||||
- Comments and docstrings: **72 characters maximum**
|
||||
- Enforce strictly - no exceptions
|
||||
|
||||
# Docstring Standard (PEP 257)
|
||||
- **Always follow PEP 257** with Args/Returns/Raises sections
|
||||
- **Format**: One-line summary, blank line, then
|
||||
Args/Returns/Raises sections
|
||||
- **Always include Args/Returns/Raises** even when parameters seem
|
||||
obvious
|
||||
- **NO examples** in docstrings - keep in external docs or tests
|
||||
- **NO extended explanations** - one-line summary + sections only
|
||||
- **Keep concise** - describe what, not how
|
||||
|
||||
**Function docstring format:**
|
||||
```python
|
||||
def function(param: str) -> int:
|
||||
"""
|
||||
One-line summary of what this does.
|
||||
|
||||
Args:
|
||||
param: Description of param.
|
||||
|
||||
Returns:
|
||||
Description of return value.
|
||||
|
||||
Raises:
|
||||
ValueError: When param is invalid.
|
||||
"""
|
||||
```
|
||||
|
||||
**Class docstring format:**
|
||||
```python
|
||||
class MyClass:
|
||||
"""
|
||||
One-line summary of the class.
|
||||
|
||||
Attributes:
|
||||
attr: Description of attribute.
|
||||
"""
|
||||
```
|
||||
@@ -13,7 +13,7 @@ import core.dependencies as core_dependencies
|
||||
import core.logger as core_logger
|
||||
import core.config as core_config
|
||||
import gears.gear.dependencies as gears_dependencies
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
import users.user.dependencies as users_dependencies
|
||||
import garmin.activity_utils as garmin_activity_utils
|
||||
import strava.activity_utils as strava_activity_utils
|
||||
@@ -48,11 +48,11 @@ async def read_activities_user_activities_week(
|
||||
Callable, Depends(activities_dependencies.validate_week_number)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -94,11 +94,11 @@ async def read_activities_user_activities_this_week_distances(
|
||||
Callable, Depends(users_dependencies.validate_user_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -135,11 +135,11 @@ async def read_activities_user_activities_this_month_distances(
|
||||
Callable, Depends(users_dependencies.validate_user_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -182,11 +182,11 @@ async def read_activities_user_activities_this_month_number(
|
||||
Callable, Depends(users_dependencies.validate_user_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -229,11 +229,11 @@ async def read_activities_gear_activities(
|
||||
Callable, Depends(gears_dependencies.validate_gear_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -256,11 +256,11 @@ async def read_activities_gear_activities_number(
|
||||
Callable, Depends(gears_dependencies.validate_gear_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -291,11 +291,11 @@ async def read_activities_gear_activities_with_pagination(
|
||||
Callable, Depends(core_dependencies.validate_pagination_values)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -314,11 +314,11 @@ async def read_activities_gear_activities_with_pagination(
|
||||
)
|
||||
async def read_activities_user_activities_number(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -358,11 +358,11 @@ async def read_activities_user_activities_number(
|
||||
)
|
||||
async def read_activities_types(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -387,11 +387,11 @@ async def read_activities_user_activities_pagination(
|
||||
Callable, Depends(core_dependencies.validate_pagination_values)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -450,7 +450,7 @@ async def read_activities_followed_user_activities_pagination(
|
||||
Callable, Depends(core_dependencies.validate_pagination_values)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -473,7 +473,7 @@ async def read_activities_followed_user_activities_number(
|
||||
Callable, Depends(users_dependencies.validate_user_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -497,11 +497,11 @@ async def read_activities_followed_user_activities_number(
|
||||
)
|
||||
async def read_activities_user_activities_refresh(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -559,11 +559,11 @@ async def read_activities_activity_from_id(
|
||||
Callable, Depends(activities_dependencies.validate_activity_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -583,11 +583,11 @@ async def read_activities_activity_from_id(
|
||||
async def read_activities_contain_name(
|
||||
name: str,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -606,11 +606,11 @@ async def read_activities_contain_name(
|
||||
async def create_activity_with_uploaded_file(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
file: UploadFile,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
@@ -642,10 +642,10 @@ async def create_activity_with_uploaded_file(
|
||||
async def create_activity_with_bulk_import(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -721,11 +721,11 @@ async def create_activity_with_bulk_import(
|
||||
async def edit_activity(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
activity_attributes: activities_schema.ActivityEdit,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -749,10 +749,10 @@ async def edit_activity_visibility(
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -775,11 +775,11 @@ async def delete_activity(
|
||||
Callable, Depends(activities_dependencies.validate_activity_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
import activities.activity_exercise_titles.schema as activity_exercise_titles_schema
|
||||
import activities.activity_exercise_titles.crud as activity_exercise_titles_crud
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -20,7 +20,7 @@ router = APIRouter()
|
||||
)
|
||||
async def read_activities_exercise_titles_all(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -28,4 +28,4 @@ async def read_activities_exercise_titles_all(
|
||||
],
|
||||
):
|
||||
# Get the exercise titles from the database and return them
|
||||
return activity_exercise_titles_crud.get_activity_exercise_titles(db)
|
||||
return activity_exercise_titles_crud.get_activity_exercise_titles(db)
|
||||
|
||||
@@ -8,7 +8,7 @@ import activities.activity_laps.crud as activity_laps_crud
|
||||
|
||||
import activities.activity.dependencies as activities_dependencies
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -26,11 +26,11 @@ async def read_activities_laps_for_activity_all(
|
||||
Callable, Depends(activities_dependencies.validate_activity_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -38,4 +38,4 @@ async def read_activities_laps_for_activity_all(
|
||||
],
|
||||
):
|
||||
# Get the activity laps from the database and return them
|
||||
return activity_laps_crud.get_activity_laps(activity_id, token_user_id, db)
|
||||
return activity_laps_crud.get_activity_laps(activity_id, token_user_id, db)
|
||||
|
||||
@@ -11,7 +11,7 @@ import activities.activity_media.dependencies as activities_media_dependencies
|
||||
import activities.activity_media.crud as activity_media_crud
|
||||
import activities.activity_media.schema as activity_media_schema
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.config as core_config
|
||||
import core.logger as core_logger
|
||||
@@ -32,11 +32,11 @@ async def read_activities_media_user(
|
||||
Callable, Depends(activities_dependencies.validate_activity_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -58,7 +58,7 @@ async def upload_media(
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["activities:write"]),
|
||||
Security(auth_security.check_scopes, scopes=["activities:write"]),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -88,7 +88,7 @@ async def upload_media(
|
||||
|
||||
# Raise an HTTPException with a 500 Internal Server Error status code
|
||||
raise err
|
||||
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{media_id}",
|
||||
@@ -99,11 +99,11 @@ async def delete_activity_media(
|
||||
Callable, Depends(activities_media_dependencies.validate_media_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -8,7 +8,7 @@ import activities.activity_sets.crud as activity_sets_crud
|
||||
|
||||
import activities.activity.dependencies as activities_dependencies
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -26,11 +26,11 @@ async def read_activities_sets_for_activity_all(
|
||||
Callable, Depends(activities_dependencies.validate_activity_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -38,4 +38,4 @@ async def read_activities_sets_for_activity_all(
|
||||
],
|
||||
):
|
||||
# Get the activity sets from the database and return them
|
||||
return activity_sets_crud.get_activity_sets(activity_id, token_user_id, db)
|
||||
return activity_sets_crud.get_activity_sets(activity_id, token_user_id, db)
|
||||
|
||||
@@ -9,7 +9,7 @@ import activities.activity_streams.dependencies as activity_streams_dependencies
|
||||
|
||||
import activities.activity.dependencies as activities_dependencies
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -27,11 +27,11 @@ async def read_activities_streams_for_activity_all(
|
||||
Callable, Depends(activities_dependencies.validate_activity_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -56,11 +56,11 @@ async def read_activities_streams_for_activity_stream_type(
|
||||
Callable, Depends(activity_streams_dependencies.validate_activity_stream_type)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Callable, Union
|
||||
from datetime import date, datetime, timezone
|
||||
|
||||
import core.database as core_database
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
import activities.activity.dependencies as activities_dependencies
|
||||
import activities.activity_summaries.crud as activities_summary_crud
|
||||
import activities.activity_summaries.schema as activities_summary_schema
|
||||
@@ -28,11 +28,11 @@ async def read_activity_summary(
|
||||
Callable, Depends(activities_summary_dependencies.validate_view_type)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -8,7 +8,7 @@ import activities.activity_workout_steps.crud as activity_workout_steps_crud
|
||||
|
||||
import activities.activity.dependencies as activities_dependencies
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -26,11 +26,11 @@ async def read_activities_workout_steps_for_activity_all(
|
||||
Callable, Depends(activities_dependencies.validate_activity_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["activities:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["activities:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -18,7 +18,7 @@ import gears.gear.models
|
||||
import gears.gear_components.models
|
||||
import health_data.models
|
||||
import health_targets.models
|
||||
import identity_providers.models
|
||||
import auth.identity_providers.models
|
||||
import migrations.models
|
||||
import notifications.models
|
||||
import password_reset_tokens.models
|
||||
@@ -41,7 +41,7 @@ config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.attributes.get('configure_logger', True):
|
||||
if config.attributes.get("configure_logger", True):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
|
||||
@@ -275,22 +275,24 @@ def upgrade() -> None:
|
||||
unique=False,
|
||||
)
|
||||
# Add the new entry to the migrations table
|
||||
op.execute("""
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO migrations (id, name, description, executed) VALUES
|
||||
(6, 'v0.15.0', 'Lowercase user usernames', false);
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the entry from the migrations table
|
||||
op.execute("""
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM migrations
|
||||
WHERE id = 6;
|
||||
""")
|
||||
# Drop sign up tokens table
|
||||
op.drop_index(
|
||||
op.f("ix_sign_up_tokens_user_id"), table_name="sign_up_tokens"
|
||||
"""
|
||||
)
|
||||
# Drop sign up tokens table
|
||||
op.drop_index(op.f("ix_sign_up_tokens_user_id"), table_name="sign_up_tokens")
|
||||
op.drop_table("sign_up_tokens")
|
||||
# Remove columns from gear_components table
|
||||
op.add_column(
|
||||
|
||||
@@ -190,7 +190,7 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"idp_refresh_token",
|
||||
sa.String(length=2000),
|
||||
nullable=False,
|
||||
nullable=True,
|
||||
comment="Encrypted refresh token",
|
||||
),
|
||||
sa.Column(
|
||||
|
||||
0
backend/app/auth/__init__.py
Normal file
0
backend/app/auth/__init__.py
Normal file
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
from typing import Final, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Final
|
||||
|
||||
import core.config as core_config
|
||||
|
||||
@@ -44,27 +43,25 @@ SERVER_SETTINGS_ADMIN_SCOPE: Final[tuple[str, ...]] = (
|
||||
"server_settings:write",
|
||||
)
|
||||
|
||||
SCOPE_DICT: Final[Mapping[str, str]] = MappingProxyType(
|
||||
{
|
||||
"profile": "Privileges over user's own profile",
|
||||
"users:read": "Read privileges over users",
|
||||
"users:write": "Write privileges over users",
|
||||
"sessions:read": "Read privileges over sessions",
|
||||
"sessions:write": "Create/edit/delete privileges over sessions",
|
||||
"gears:read": "Read privileges over gears",
|
||||
"gears:write": "Write privileges over gears",
|
||||
"activities:read": "Read privileges over activities",
|
||||
"activities:write": "Write privileges over activities",
|
||||
"health:read": "Read privileges over health data",
|
||||
"health:write": "Write privileges over health data",
|
||||
"health_targets:read": "Read privileges over health targets data",
|
||||
"health_targets:write": "Write privileges over health targets data",
|
||||
"server_settings:read": "Read privileges over server settings",
|
||||
"server_settings:write": "Write privileges over server settings",
|
||||
"idp:read": "Read privileges over identity providers",
|
||||
"idp:write": "Write privileges over identity providers",
|
||||
}
|
||||
)
|
||||
SCOPE_DICT: Final[dict[str, str]] = {
|
||||
"profile": "Privileges over user's own profile",
|
||||
"users:read": "Read privileges over users",
|
||||
"users:write": "Write privileges over users",
|
||||
"sessions:read": "Read privileges over sessions",
|
||||
"sessions:write": "Create/edit/delete privileges over sessions",
|
||||
"gears:read": "Read privileges over gears",
|
||||
"gears:write": "Write privileges over gears",
|
||||
"activities:read": "Read privileges over activities",
|
||||
"activities:write": "Write privileges over activities",
|
||||
"health:read": "Read privileges over health data",
|
||||
"health:write": "Write privileges over health data",
|
||||
"health_targets:read": "Read privileges over health targets data",
|
||||
"health_targets:write": "Write privileges over health targets data",
|
||||
"server_settings:read": "Read privileges over server settings",
|
||||
"server_settings:write": "Write privileges over server settings",
|
||||
"idp:read": "Read privileges over identity providers",
|
||||
"idp:write": "Write privileges over identity providers",
|
||||
}
|
||||
|
||||
REGULAR_ACCESS_SCOPE: Final[tuple[str, ...]] = (
|
||||
USERS_REGULAR_SCOPE
|
||||
@@ -4,8 +4,8 @@ from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
import identity_providers.models as idp_models
|
||||
import identity_providers.schema as idp_schema
|
||||
import auth.identity_providers.models as idp_models
|
||||
import auth.identity_providers.schema as idp_schema
|
||||
import core.cryptography as core_cryptography
|
||||
import core.logger as core_logger
|
||||
import users.user_identity_providers.crud as user_identity_providers_crud
|
||||
@@ -355,7 +355,11 @@ def delete_identity_provider(idp_id: int, db: Session) -> None:
|
||||
)
|
||||
|
||||
# Check if any users are linked to this provider
|
||||
db_user_idp = user_identity_providers_crud.get_idp_has_user_links(idp_id, db)
|
||||
db_user_idp = (
|
||||
user_identity_providers_crud.check_user_identity_providers_by_idp_id(
|
||||
idp_id, db
|
||||
)
|
||||
)
|
||||
if db_user_idp:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -6,13 +6,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
import core.database as core_database
|
||||
import core.rate_limit as core_rate_limit
|
||||
import session.password_hasher as session_password_hasher
|
||||
import session.token_manager as session_token_manager
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
import auth.token_manager as auth_token_manager
|
||||
import auth.utils as auth_utils
|
||||
import session.utils as session_utils
|
||||
import session.crud as session_crud
|
||||
import identity_providers.crud as idp_crud
|
||||
import identity_providers.schema as idp_schema
|
||||
import identity_providers.service as idp_service
|
||||
import auth.identity_providers.crud as idp_crud
|
||||
import auth.identity_providers.schema as idp_schema
|
||||
import auth.identity_providers.service as idp_service
|
||||
import users.user.schema as users_schema
|
||||
import core.config as core_config
|
||||
import core.logger as core_logger
|
||||
@@ -58,7 +58,7 @@ async def initiate_login(
|
||||
):
|
||||
"""
|
||||
Initiates the login process for a given identity provider using OAuth.
|
||||
|
||||
|
||||
Rate Limit: 10 requests per minute per IP
|
||||
Args:
|
||||
idp_slug (str): The slug identifier for the identity provider.
|
||||
@@ -88,20 +88,20 @@ async def initiate_login(
|
||||
@router.get("/callback/{idp_slug}", status_code=status.HTTP_307_TEMPORARY_REDIRECT)
|
||||
@core_rate_limit.limiter.limit(core_rate_limit.OAUTH_CALLBACK_LIMIT)
|
||||
async def handle_callback(
|
||||
request: Request,
|
||||
response: Response,
|
||||
idp_slug: str,
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
code: str = Query(..., description="Authorization code from IdP"),
|
||||
state: str = Query(..., description="State parameter for CSRF protection"),
|
||||
request: Request = None,
|
||||
response: Response = None,
|
||||
):
|
||||
"""
|
||||
Handle OAuth callback from an identity provider.
|
||||
@@ -109,13 +109,13 @@ async def handle_callback(
|
||||
It supports two modes: login mode (default) and link mode (for linking IdP to existing account).
|
||||
Args:
|
||||
idp_slug (str): The slug identifier of the identity provider.
|
||||
password_hasher (session_password_hasher.PasswordHasher): Password hasher dependency for session management.
|
||||
token_manager (session_token_manager.TokenManager): Token manager dependency for creating session tokens.
|
||||
password_hasher (auth_password_hasher.PasswordHasher): Password hasher dependency for session management.
|
||||
token_manager (auth_token_manager.TokenManager): Token manager dependency for creating session tokens.
|
||||
db (Session): Database session dependency.
|
||||
code (str): Authorization code received from the identity provider.
|
||||
state (str): State parameter used for CSRF protection.
|
||||
request (Request, optional): The incoming HTTP request. Defaults to None.
|
||||
response (Response, optional): The HTTP response object. Defaults to None.
|
||||
request (Request | None): The incoming HTTP request. Defaults to None.
|
||||
response (Response | None): The HTTP response object. Defaults to None.
|
||||
Returns:
|
||||
RedirectResponse: A redirect response to either:
|
||||
- Settings page (link mode): /settings/security with success parameters
|
||||
@@ -149,12 +149,14 @@ async def handle_callback(
|
||||
# Handle link mode differently - redirect to settings without creating new session
|
||||
if is_link_mode:
|
||||
frontend_url = core_config.ENDURAIN_HOST
|
||||
redirect_url = f"{frontend_url}/settings/security?idp_link=success&idp_name={idp.name}"
|
||||
|
||||
redirect_url = (
|
||||
f"{frontend_url}/settings/security?idp_link=success&idp_name={idp.name}"
|
||||
)
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"IdP link successful for user {user.username}, IdP {idp.name}", "info"
|
||||
)
|
||||
|
||||
|
||||
return RedirectResponse(
|
||||
url=redirect_url,
|
||||
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
|
||||
@@ -172,7 +174,7 @@ async def handle_callback(
|
||||
refresh_token_exp,
|
||||
refresh_token,
|
||||
csrf_token,
|
||||
) = session_utils.create_tokens(user_read, token_manager)
|
||||
) = auth_utils.create_tokens(user_read, token_manager)
|
||||
|
||||
# Create the session and store it in the database
|
||||
session_utils.create_session(
|
||||
@@ -180,7 +182,7 @@ async def handle_callback(
|
||||
)
|
||||
|
||||
# Set authentication cookies
|
||||
response = session_utils.create_response_with_tokens(
|
||||
response = auth_utils.create_response_with_tokens(
|
||||
response,
|
||||
access_token,
|
||||
refresh_token,
|
||||
@@ -9,10 +9,10 @@ from fastapi import (
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import core.database as core_database
|
||||
import session.security as session_security
|
||||
import identity_providers.crud as idp_crud
|
||||
import identity_providers.schema as idp_schema
|
||||
import identity_providers.utils as idp_utils
|
||||
import auth.security as auth_security
|
||||
import auth.identity_providers.crud as idp_crud
|
||||
import auth.identity_providers.schema as idp_schema
|
||||
import auth.identity_providers.utils as idp_utils
|
||||
import users.user.schema as users_schema
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ async def list_identity_providers(
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
_check_scopes: Annotated[
|
||||
users_schema.UserRead,
|
||||
Security(session_security.check_scopes, scopes=["identity_providers:read"]),
|
||||
Security(auth_security.check_scopes, scopes=["identity_providers:read"]),
|
||||
],
|
||||
):
|
||||
"""
|
||||
@@ -53,7 +53,7 @@ async def list_identity_providers(
|
||||
async def list_idp_templates(
|
||||
_check_scopes: Annotated[
|
||||
users_schema.UserRead,
|
||||
Security(session_security.check_scopes, scopes=["identity_providers:read"]),
|
||||
Security(auth_security.check_scopes, scopes=["identity_providers:read"]),
|
||||
],
|
||||
):
|
||||
"""
|
||||
@@ -72,7 +72,7 @@ async def create_identity_provider(
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
_check_scopes: Annotated[
|
||||
users_schema.UserRead,
|
||||
Security(session_security.check_scopes, scopes=["identity_providers:write"]),
|
||||
Security(auth_security.check_scopes, scopes=["identity_providers:write"]),
|
||||
],
|
||||
):
|
||||
"""
|
||||
@@ -103,7 +103,7 @@ async def update_identity_provider(
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
_check_scopes: Annotated[
|
||||
users_schema.UserRead,
|
||||
Security(session_security.check_scopes, scopes=["identity_providers:write"]),
|
||||
Security(auth_security.check_scopes, scopes=["identity_providers:write"]),
|
||||
],
|
||||
):
|
||||
"""
|
||||
@@ -129,7 +129,7 @@ async def delete_identity_provider(
|
||||
idp_id: int,
|
||||
_check_scopes: Annotated[
|
||||
users_schema.UserRead,
|
||||
Security(session_security.check_scopes, scopes=["identity_providers:write"]),
|
||||
Security(auth_security.check_scopes, scopes=["identity_providers:write"]),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
@@ -22,15 +22,15 @@ from joserfc.errors import (
|
||||
import core.config as core_config
|
||||
import core.cryptography as core_cryptography
|
||||
import core.logger as core_logger
|
||||
import identity_providers.models as idp_models
|
||||
import identity_providers.crud as idp_crud
|
||||
import auth.identity_providers.models as idp_models
|
||||
import auth.identity_providers.crud as idp_crud
|
||||
import users.user.crud as users_crud
|
||||
import users.user.schema as users_schema
|
||||
import users.user.models as users_models
|
||||
import users.user.utils as users_utils
|
||||
import users.user_identity_providers.crud as user_idp_crud
|
||||
import users.user_identity_providers.models as user_idp_models
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
import server_settings.schema as server_settings_schema
|
||||
|
||||
|
||||
@@ -131,87 +131,78 @@ class IdentityProviderService:
|
||||
cached_data = self._jwks_cache[jwks_uri]
|
||||
cached_at = cached_data.get("cached_at")
|
||||
if cached_at and (now - cached_at) < self._cache_ttl:
|
||||
core_logger.print_to_log(
|
||||
f"Using cached JWKS for {jwks_uri}", "debug"
|
||||
)
|
||||
core_logger.print_to_log(f"Using cached JWKS for {jwks_uri}", "debug")
|
||||
return cached_data["jwks"]
|
||||
|
||||
# Fetch JWKS from IdP
|
||||
try:
|
||||
client = await self._get_http_client()
|
||||
core_logger.print_to_log(
|
||||
f"Fetching JWKS from {jwks_uri}", "debug"
|
||||
)
|
||||
|
||||
core_logger.print_to_log(f"Fetching JWKS from {jwks_uri}", "debug")
|
||||
|
||||
response = await client.get(jwks_uri)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
jwks = response.json()
|
||||
|
||||
|
||||
# Validate JWKS structure
|
||||
if not isinstance(jwks, dict) or "keys" not in jwks:
|
||||
core_logger.print_to_log(
|
||||
f"Invalid JWKS format from {jwks_uri}: missing 'keys' array",
|
||||
"error"
|
||||
"error",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Identity provider returned invalid JWKS format"
|
||||
detail="Identity provider returned invalid JWKS format",
|
||||
)
|
||||
|
||||
|
||||
# Cache the JWKS with timestamp
|
||||
self._jwks_cache[jwks_uri] = {
|
||||
"jwks": jwks,
|
||||
"cached_at": now
|
||||
}
|
||||
|
||||
self._jwks_cache[jwks_uri] = {"jwks": jwks, "cached_at": now}
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Successfully fetched and cached JWKS from {jwks_uri} "
|
||||
f"({len(jwks.get('keys', []))} keys)",
|
||||
"debug"
|
||||
"debug",
|
||||
)
|
||||
|
||||
|
||||
return jwks
|
||||
|
||||
|
||||
except httpx.TimeoutException as err:
|
||||
core_logger.print_to_log(
|
||||
f"Timeout fetching JWKS from {jwks_uri}: {err}",
|
||||
"error",
|
||||
exc=err
|
||||
f"Timeout fetching JWKS from {jwks_uri}: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="Timeout retrieving signing keys from identity provider"
|
||||
detail="Timeout retrieving signing keys from identity provider",
|
||||
)
|
||||
except httpx.HTTPStatusError as err:
|
||||
core_logger.print_to_log(
|
||||
f"HTTP error fetching JWKS from {jwks_uri}: {err.response.status_code}",
|
||||
"error",
|
||||
exc=err
|
||||
exc=err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Identity provider JWKS endpoint returned error: {err.response.status_code}"
|
||||
detail=f"Identity provider JWKS endpoint returned error: {err.response.status_code}",
|
||||
)
|
||||
except json.JSONDecodeError as err:
|
||||
core_logger.print_to_log(
|
||||
f"Invalid JSON in JWKS response from {jwks_uri}: {err}",
|
||||
"error",
|
||||
exc=err
|
||||
exc=err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail="Identity provider returned invalid JWKS JSON"
|
||||
detail="Identity provider returned invalid JWKS JSON",
|
||||
)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Unexpected error fetching JWKS from {jwks_uri}: {err}",
|
||||
"error",
|
||||
exc=err
|
||||
exc=err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve signing keys from identity provider"
|
||||
detail="Failed to retrieve signing keys from identity provider",
|
||||
)
|
||||
|
||||
async def _verify_id_token(
|
||||
@@ -220,7 +211,7 @@ class IdentityProviderService:
|
||||
jwks_uri: str,
|
||||
expected_issuer: str,
|
||||
expected_audience: str,
|
||||
expected_nonce: str | None = None
|
||||
expected_nonce: str | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verifies the ID token's signature and claims using JWKS from the identity provider.
|
||||
@@ -261,12 +252,11 @@ class IdentityProviderService:
|
||||
parts = id_token.split(".")
|
||||
if len(parts) != 3:
|
||||
core_logger.print_to_log(
|
||||
f"Invalid JWT format: expected 3 parts, got {len(parts)}",
|
||||
"warning"
|
||||
f"Invalid JWT format: expected 3 parts, got {len(parts)}", "warning"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid ID token format"
|
||||
detail="Invalid ID token format",
|
||||
)
|
||||
|
||||
# Decode header (first part) to get 'kid' and 'alg'
|
||||
@@ -275,37 +265,32 @@ class IdentityProviderService:
|
||||
padding = 4 - len(header_b64) % 4
|
||||
if padding != 4:
|
||||
header_b64 += "=" * padding
|
||||
|
||||
|
||||
header_bytes = base64.urlsafe_b64decode(header_b64)
|
||||
header = json.loads(header_bytes)
|
||||
|
||||
|
||||
kid = header.get("kid")
|
||||
alg = header.get("alg")
|
||||
|
||||
|
||||
if not kid:
|
||||
core_logger.print_to_log(
|
||||
"ID token header missing 'kid' claim",
|
||||
"warning"
|
||||
"ID token header missing 'kid' claim", "warning"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token missing key identifier"
|
||||
)
|
||||
|
||||
if not alg:
|
||||
core_logger.print_to_log(
|
||||
"ID token header missing 'alg' claim",
|
||||
"warning"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token missing algorithm"
|
||||
detail="ID token missing key identifier",
|
||||
)
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"ID token header: kid={kid}, alg={alg}",
|
||||
"debug"
|
||||
)
|
||||
if not alg:
|
||||
core_logger.print_to_log(
|
||||
"ID token header missing 'alg' claim", "warning"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token missing algorithm",
|
||||
)
|
||||
|
||||
core_logger.print_to_log(f"ID token header: kid={kid}, alg={alg}", "debug")
|
||||
|
||||
# Step 2: Fetch JWKS from IdP
|
||||
jwks = await self._fetch_jwks(jwks_uri)
|
||||
@@ -319,22 +304,21 @@ class IdentityProviderService:
|
||||
|
||||
if not matching_key:
|
||||
core_logger.print_to_log(
|
||||
f"No matching key found in JWKS for kid={kid}",
|
||||
"warning"
|
||||
f"No matching key found in JWKS for kid={kid}", "warning"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token signed with unknown key"
|
||||
detail="ID token signed with unknown key",
|
||||
)
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Found matching key in JWKS: kid={kid}, kty={matching_key.get('kty')}",
|
||||
"debug"
|
||||
"debug",
|
||||
)
|
||||
|
||||
# Step 4: Import the key based on type
|
||||
key_type = matching_key.get("kty")
|
||||
|
||||
|
||||
if key_type == "RSA":
|
||||
key = RSAKey.import_key(matching_key)
|
||||
elif key_type == "EC":
|
||||
@@ -343,12 +327,11 @@ class IdentityProviderService:
|
||||
key = OctKey.import_key(matching_key)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"Unsupported key type in JWKS: {key_type}",
|
||||
"warning"
|
||||
f"Unsupported key type in JWKS: {key_type}", "warning"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Unsupported key type: {key_type}"
|
||||
detail=f"Unsupported key type: {key_type}",
|
||||
)
|
||||
|
||||
# Step 5: Verify signature and decode claims
|
||||
@@ -362,15 +345,15 @@ class IdentityProviderService:
|
||||
iss={"essential": True, "value": expected_issuer},
|
||||
aud={"essential": True, "value": expected_audience},
|
||||
exp={"essential": True},
|
||||
iat={"essential": True}
|
||||
iat={"essential": True},
|
||||
)
|
||||
|
||||
|
||||
# Validate all claims
|
||||
claims_request.validate(claims)
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Successfully verified ID token signature for sub={claims.get('sub')}",
|
||||
"debug"
|
||||
"debug",
|
||||
)
|
||||
|
||||
# Step 6: Validate nonce if provided
|
||||
@@ -379,22 +362,21 @@ class IdentityProviderService:
|
||||
token_nonce = claims.get("nonce")
|
||||
if not token_nonce:
|
||||
core_logger.print_to_log(
|
||||
"ID token missing nonce claim but nonce was expected",
|
||||
"warning"
|
||||
"ID token missing nonce claim but nonce was expected", "warning"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token missing nonce"
|
||||
detail="ID token missing nonce",
|
||||
)
|
||||
|
||||
|
||||
if token_nonce != expected_nonce:
|
||||
core_logger.print_to_log(
|
||||
f"ID token nonce mismatch: expected {expected_nonce}, got {token_nonce}",
|
||||
"warning"
|
||||
"warning",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token nonce mismatch"
|
||||
detail="ID token nonce mismatch",
|
||||
)
|
||||
|
||||
# Return verified claims
|
||||
@@ -402,66 +384,51 @@ class IdentityProviderService:
|
||||
|
||||
except BadSignatureError as err:
|
||||
core_logger.print_to_log(
|
||||
f"ID token signature verification failed: {err}",
|
||||
"warning",
|
||||
exc=err
|
||||
f"ID token signature verification failed: {err}", "warning", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token signature is invalid"
|
||||
detail="ID token signature is invalid",
|
||||
)
|
||||
except ExpiredTokenError as err:
|
||||
core_logger.print_to_log(
|
||||
f"ID token has expired: {err}",
|
||||
"warning",
|
||||
exc=err
|
||||
)
|
||||
core_logger.print_to_log(f"ID token has expired: {err}", "warning", exc=err)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token has expired"
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="ID token has expired"
|
||||
)
|
||||
except InvalidClaimError as err:
|
||||
core_logger.print_to_log(
|
||||
f"ID token claim validation failed: {err}",
|
||||
"warning",
|
||||
exc=err
|
||||
f"ID token claim validation failed: {err}", "warning", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"ID token claim validation failed: {err}"
|
||||
detail=f"ID token claim validation failed: {err}",
|
||||
)
|
||||
except MissingClaimError as err:
|
||||
core_logger.print_to_log(
|
||||
f"ID token missing required claim: {err}",
|
||||
"warning",
|
||||
exc=err
|
||||
f"ID token missing required claim: {err}", "warning", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"ID token missing required claim: {err}"
|
||||
detail=f"ID token missing required claim: {err}",
|
||||
)
|
||||
except InvalidPayloadError as err:
|
||||
core_logger.print_to_log(
|
||||
f"ID token payload is invalid: {err}",
|
||||
"warning",
|
||||
exc=err
|
||||
f"ID token payload is invalid: {err}", "warning", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="ID token payload is invalid"
|
||||
detail="ID token payload is invalid",
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions from _fetch_jwks or our own validations
|
||||
raise
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Unexpected error verifying ID token: {err}",
|
||||
"error",
|
||||
exc=err
|
||||
f"Unexpected error verifying ID token: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to verify ID token"
|
||||
detail="Failed to verify ID token",
|
||||
)
|
||||
|
||||
async def get_oidc_configuration(
|
||||
@@ -491,12 +458,10 @@ class IdentityProviderService:
|
||||
):
|
||||
return self._discovery_cache[idp.id]
|
||||
|
||||
try:
|
||||
# Construct the discovery URL
|
||||
discovery_url = (
|
||||
f"{idp.issuer_url.rstrip('/')}/.well-known/openid-configuration"
|
||||
)
|
||||
# Construct the discovery URL
|
||||
discovery_url = f"{idp.issuer_url.rstrip('/')}/.well-known/openid-configuration"
|
||||
|
||||
try:
|
||||
# Fetch the configuration
|
||||
client = await self._get_http_client()
|
||||
response = await client.get(discovery_url)
|
||||
@@ -585,9 +550,7 @@ class IdentityProviderService:
|
||||
detail=f"Identity provider {idp.name} configuration error. Please contact administrator.",
|
||||
) from err
|
||||
|
||||
async def _resolve_token_endpoint(
|
||||
self, idp: idp_models.IdentityProvider
|
||||
) -> str:
|
||||
async def _resolve_token_endpoint(self, idp: idp_models.IdentityProvider) -> str:
|
||||
"""
|
||||
Resolve the token endpoint URL for an IdP, using OIDC discovery if needed.
|
||||
|
||||
@@ -707,13 +670,11 @@ class IdentityProviderService:
|
||||
state_data = {
|
||||
"random": random_state,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"idp_id": idp.id
|
||||
"idp_id": idp.id,
|
||||
}
|
||||
# Encode state as base64 JSON for URL safety
|
||||
state = base64.urlsafe_b64encode(
|
||||
json.dumps(state_data).encode()
|
||||
).decode()
|
||||
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
|
||||
|
||||
nonce = secrets.token_urlsafe(32)
|
||||
|
||||
# Store in session (using SessionMiddleware)
|
||||
@@ -749,7 +710,11 @@ class IdentityProviderService:
|
||||
) from err
|
||||
|
||||
async def initiate_link(
|
||||
self, idp: idp_models.IdentityProvider, request: Request, user_id: int, db: Session
|
||||
self,
|
||||
idp: idp_models.IdentityProvider,
|
||||
request: Request,
|
||||
user_id: int,
|
||||
db: Session,
|
||||
) -> str:
|
||||
"""
|
||||
Initiates the OAuth/OIDC authorization flow for linking an identity provider to an existing user account.
|
||||
@@ -766,7 +731,7 @@ class IdentityProviderService:
|
||||
Returns:
|
||||
str: The authorization URL to redirect the user to for identity provider authentication.
|
||||
Raises:
|
||||
HTTPException:
|
||||
HTTPException:
|
||||
- 500 status code if the identity provider is not properly configured (missing
|
||||
authorization endpoint).
|
||||
- 500 status code if any unexpected error occurs during the OAuth flow initiation.
|
||||
@@ -806,10 +771,8 @@ class IdentityProviderService:
|
||||
"user_id": user_id, # Ensures callback links to correct user
|
||||
}
|
||||
# Encode state as base64 JSON for URL safety
|
||||
state = base64.urlsafe_b64encode(
|
||||
json.dumps(state_data).encode()
|
||||
).decode()
|
||||
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
|
||||
|
||||
nonce = secrets.token_urlsafe(32)
|
||||
|
||||
# Store in session (using SessionMiddleware)
|
||||
@@ -851,7 +814,7 @@ class IdentityProviderService:
|
||||
code: str,
|
||||
state: str,
|
||||
request: Request,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -866,7 +829,7 @@ class IdentityProviderService:
|
||||
state (str): The state parameter for CSRF protection, containing JSON with
|
||||
timestamp, mode, and optional user_id.
|
||||
request (Request): The FastAPI/Starlette request object containing session data.
|
||||
password_hasher (session_password_hasher.PasswordHasher): Password hasher instance
|
||||
password_hasher (auth_password_hasher.PasswordHasher): Password hasher instance
|
||||
for user authentication operations.
|
||||
db (Session): SQLAlchemy database session.
|
||||
Returns:
|
||||
@@ -899,45 +862,43 @@ class IdentityProviderService:
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid state parameter",
|
||||
)
|
||||
|
||||
|
||||
# Decode and validate state timestamp (10-minute expiry)
|
||||
try:
|
||||
state_json = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
state_data = json.loads(state_json)
|
||||
|
||||
|
||||
# Validate timestamp exists
|
||||
if "timestamp" not in state_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="State parameter missing timestamp",
|
||||
)
|
||||
|
||||
|
||||
# Parse timestamp and check expiry
|
||||
state_timestamp = datetime.fromisoformat(state_data["timestamp"])
|
||||
now = datetime.now(timezone.utc)
|
||||
age = now - state_timestamp
|
||||
|
||||
|
||||
# Reject states older than 10 minutes (CSRF protection)
|
||||
if age > timedelta(minutes=10):
|
||||
core_logger.print_to_log(
|
||||
f"Expired state detected for IdP {idp.name}: age={age.total_seconds():.1f}s",
|
||||
"warning"
|
||||
"warning",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="State parameter expired. Please try logging in again.",
|
||||
)
|
||||
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"State validation successful for IdP {idp.name}: age={age.total_seconds():.1f}s",
|
||||
"debug"
|
||||
"debug",
|
||||
)
|
||||
|
||||
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as err:
|
||||
core_logger.print_to_log(
|
||||
f"Failed to decode state parameter: {err}",
|
||||
"error",
|
||||
exc=err
|
||||
f"Failed to decode state parameter: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -947,27 +908,27 @@ class IdentityProviderService:
|
||||
# Detect link mode from state data
|
||||
is_link_mode = state_data.get("mode") == "link"
|
||||
link_user_id = None
|
||||
|
||||
|
||||
if is_link_mode:
|
||||
# Validate link mode state
|
||||
link_user_id = state_data.get("user_id")
|
||||
session_link_user_id = request.session.get("oauth_link_user_id")
|
||||
|
||||
|
||||
if not link_user_id or not session_link_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid link mode state - missing user ID",
|
||||
)
|
||||
|
||||
|
||||
if link_user_id != session_link_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User ID mismatch - possible session hijacking attempt",
|
||||
)
|
||||
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Link mode detected for IdP {idp.name}, user_id={link_user_id}",
|
||||
"debug"
|
||||
"debug",
|
||||
)
|
||||
|
||||
# Decrypt credentials and resolve endpoints using helper methods
|
||||
@@ -979,7 +940,7 @@ class IdentityProviderService:
|
||||
userinfo_endpoint = idp.userinfo_endpoint
|
||||
jwks_uri = None
|
||||
expected_issuer = None
|
||||
|
||||
|
||||
if idp.issuer_url:
|
||||
try:
|
||||
config = await self.get_oidc_configuration(idp)
|
||||
@@ -987,17 +948,17 @@ class IdentityProviderService:
|
||||
# Get userinfo endpoint if not manually configured
|
||||
if not userinfo_endpoint:
|
||||
userinfo_endpoint = config.get("userinfo_endpoint")
|
||||
|
||||
|
||||
# Get JWKS URI for ID token verification
|
||||
jwks_uri = config.get("jwks_uri")
|
||||
expected_issuer = config.get("issuer")
|
||||
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"OIDC discovery complete for {idp.name}: "
|
||||
f"userinfo={bool(userinfo_endpoint)}, "
|
||||
f"jwks_uri={bool(jwks_uri)}, "
|
||||
f"issuer={bool(expected_issuer)}",
|
||||
"debug"
|
||||
"debug",
|
||||
)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
@@ -1005,7 +966,7 @@ class IdentityProviderService:
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
|
||||
|
||||
# Retrieve nonce from session for ID token verification
|
||||
expected_nonce = request.session.get(f"oauth_nonce_{idp.id}")
|
||||
|
||||
@@ -1045,7 +1006,7 @@ class IdentityProviderService:
|
||||
detail = f"Identity provider {idp.name} rejected the authentication request. Please contact administrator."
|
||||
else:
|
||||
detail = f"Identity provider {idp.name} returned an error. Please try again later."
|
||||
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=detail,
|
||||
@@ -1079,7 +1040,7 @@ class IdentityProviderService:
|
||||
jwks_uri=jwks_uri,
|
||||
expected_issuer=expected_issuer,
|
||||
expected_audience=client_id,
|
||||
expected_nonce=expected_nonce
|
||||
expected_nonce=expected_nonce,
|
||||
)
|
||||
|
||||
# Extract subject (unique user identifier)
|
||||
@@ -1095,9 +1056,9 @@ class IdentityProviderService:
|
||||
)
|
||||
|
||||
# Handle link mode differently from login mode
|
||||
if is_link_mode:
|
||||
if is_link_mode and link_user_id:
|
||||
# LINK MODE: Associate IdP with existing authenticated user
|
||||
|
||||
|
||||
# Verify user exists
|
||||
user = users_crud.get_user_by_id(link_user_id, db)
|
||||
if not user:
|
||||
@@ -1105,9 +1066,13 @@ class IdentityProviderService:
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
|
||||
# Check if this IdP subject is already linked to ANY user
|
||||
existing_link = user_idp_crud.get_user_by_idp(idp.id, subject, db)
|
||||
existing_link = (
|
||||
user_idp_crud.get_user_identity_provider_by_subject_and_idp_id(
|
||||
idp.id, subject, db
|
||||
)
|
||||
)
|
||||
if existing_link:
|
||||
# Check if it's already linked to THIS user
|
||||
if existing_link.user_id == link_user_id:
|
||||
@@ -1121,37 +1086,34 @@ class IdentityProviderService:
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"This {idp.name} account is already linked to another user",
|
||||
)
|
||||
|
||||
|
||||
# Create the link
|
||||
user_idp_crud.create_user_idp_link(
|
||||
user_id=link_user_id,
|
||||
idp_id=idp.id,
|
||||
idp_subject=subject,
|
||||
db=db
|
||||
user_idp_crud.create_user_identity_provider(
|
||||
user_id=link_user_id, idp_id=idp.id, idp_subject=subject, db=db
|
||||
)
|
||||
|
||||
|
||||
# Store IdP tokens for future use
|
||||
await self._store_idp_tokens(link_user_id, idp.id, token_response, db)
|
||||
|
||||
|
||||
# Clean up session data
|
||||
request.session.pop(f"oauth_state_{idp.id}", None)
|
||||
request.session.pop(f"oauth_nonce_{idp.id}", None)
|
||||
request.session.pop("oauth_idp_id", None)
|
||||
request.session.pop("oauth_link_user_id", None)
|
||||
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"User {user.username} (id={link_user_id}) linked IdP {idp.name} (subject={subject})",
|
||||
"info"
|
||||
"info",
|
||||
)
|
||||
|
||||
|
||||
# Return special structure for link mode (no new session created)
|
||||
return {
|
||||
"user": user,
|
||||
"token_data": token_response,
|
||||
"userinfo": userinfo,
|
||||
"mode": "link" # Indicate this was a link operation
|
||||
"mode": "link", # Indicate this was a link operation
|
||||
}
|
||||
|
||||
|
||||
else:
|
||||
# LOGIN MODE: Find or create user and establish session
|
||||
user = await self._find_or_create_user(
|
||||
@@ -1170,7 +1132,11 @@ class IdentityProviderService:
|
||||
f"User {user.username} authenticated via IdP {idp.name}", "info"
|
||||
)
|
||||
|
||||
return {"user": user, "token_data": token_response, "userinfo": userinfo}
|
||||
return {
|
||||
"user": user,
|
||||
"token_data": token_response,
|
||||
"userinfo": userinfo,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions as-is (already have proper status codes and messages)
|
||||
@@ -1248,7 +1214,9 @@ class IdentityProviderService:
|
||||
)
|
||||
except httpx.TimeoutException as err:
|
||||
core_logger.print_to_log(
|
||||
f"Timeout fetching userinfo from endpoint: {err}", "warning", exc=err
|
||||
f"Timeout fetching userinfo from endpoint: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
except httpx.HTTPStatusError as err:
|
||||
core_logger.print_to_log(
|
||||
@@ -1278,14 +1246,14 @@ class IdentityProviderService:
|
||||
jwks_uri=jwks_uri,
|
||||
expected_issuer=expected_issuer,
|
||||
expected_audience=expected_audience,
|
||||
expected_nonce=expected_nonce
|
||||
expected_nonce=expected_nonce,
|
||||
)
|
||||
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Successfully verified ID token for sub={id_token_claims.get('sub')}",
|
||||
"debug"
|
||||
f"Successfully verified ID token for sub={id_token_claims.get('sub')}",
|
||||
"debug",
|
||||
)
|
||||
|
||||
|
||||
# If we got userinfo from endpoint, merge with ID token claims
|
||||
# ID token claims take precedence for standard claims (sub, iss, aud)
|
||||
if userinfo_claims:
|
||||
@@ -1294,38 +1262,36 @@ class IdentityProviderService:
|
||||
merged_claims = {**userinfo_claims, **id_token_claims}
|
||||
core_logger.print_to_log(
|
||||
"Merged userinfo endpoint data with verified ID token claims",
|
||||
"debug"
|
||||
"debug",
|
||||
)
|
||||
return merged_claims
|
||||
else:
|
||||
# Only ID token available, return verified claims
|
||||
return id_token_claims
|
||||
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise verification errors (signature failed, expired, etc.)
|
||||
# These are security-critical and should not be ignored
|
||||
raise
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Unexpected error verifying ID token: {err}",
|
||||
"error",
|
||||
exc=err
|
||||
f"Unexpected error verifying ID token: {err}", "error", exc=err
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to verify ID token"
|
||||
detail="Failed to verify ID token",
|
||||
)
|
||||
|
||||
|
||||
# If we got userinfo from endpoint but no ID token, return userinfo
|
||||
if userinfo_claims:
|
||||
return userinfo_claims
|
||||
|
||||
|
||||
# If ID token exists but we're missing JWKS/issuer info, log warning
|
||||
if id_token and (not jwks_uri or not expected_issuer):
|
||||
core_logger.print_to_log(
|
||||
"ID token present but cannot verify: missing JWKS URI or issuer. "
|
||||
"Configure issuer_url for OIDC discovery.",
|
||||
"warning"
|
||||
"warning",
|
||||
)
|
||||
|
||||
# If we get here, we couldn't retrieve or verify any user information
|
||||
@@ -1405,7 +1371,7 @@ class IdentityProviderService:
|
||||
)
|
||||
|
||||
# Store encrypted token and metadata in database
|
||||
user_idp_crud.store_idp_tokens(
|
||||
user_idp_crud.store_user_identity_provider_tokens(
|
||||
user_id=user_id,
|
||||
idp_id=idp_id,
|
||||
encrypted_refresh_token=encrypted_refresh,
|
||||
@@ -1472,7 +1438,7 @@ class IdentityProviderService:
|
||||
idp: idp_models.IdentityProvider,
|
||||
subject: str,
|
||||
userinfo: Dict[str, Any],
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> users_models.User:
|
||||
"""
|
||||
@@ -1487,7 +1453,7 @@ class IdentityProviderService:
|
||||
idp (idp_models.IdentityProvider): The identity provider instance.
|
||||
subject (str): The unique subject identifier from the IdP.
|
||||
userinfo (Dict[str, Any]): User information/claims from the IdP.
|
||||
password_hasher (session_password_hasher.PasswordHasher): The password hasher instance.
|
||||
password_hasher (auth_password_hasher.PasswordHasher): The password hasher instance.
|
||||
db (Session): Database session.
|
||||
|
||||
Returns:
|
||||
@@ -1497,12 +1463,16 @@ class IdentityProviderService:
|
||||
HTTPException: If user creation is disabled for the identity provider and no existing user is found.
|
||||
"""
|
||||
# Try to find existing user by IdP link
|
||||
link = user_idp_crud.get_user_idp_link_by_subject(idp.id, subject, db)
|
||||
link = user_idp_crud.get_user_identity_provider_by_subject_and_idp_id(
|
||||
idp.id, subject, db
|
||||
)
|
||||
|
||||
if link:
|
||||
user = link.user
|
||||
# Update last login timestamp
|
||||
user_idp_crud.update_user_idp_last_login(link.user_id, idp.id, db)
|
||||
user_idp_crud.update_user_identity_provider_last_login(
|
||||
link.user_id, idp.id, db
|
||||
)
|
||||
|
||||
# Update user info if sync is enabled
|
||||
if idp.sync_user_info:
|
||||
@@ -1517,7 +1487,9 @@ class IdentityProviderService:
|
||||
user = users_crud.get_user_by_email(email, db)
|
||||
if user:
|
||||
# Link existing account to IdP
|
||||
user_idp_crud.create_user_idp_link(user.id, idp.id, subject, db)
|
||||
user_idp_crud.create_user_identity_provider(
|
||||
user.id, idp.id, subject, db
|
||||
)
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Linked existing user {user.username} to IdP {idp.name}", "info"
|
||||
@@ -1543,7 +1515,7 @@ class IdentityProviderService:
|
||||
idp: idp_models.IdentityProvider,
|
||||
subject: str,
|
||||
mapped_data: Dict[str, Any],
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> users_models.User:
|
||||
"""
|
||||
@@ -1558,7 +1530,7 @@ class IdentityProviderService:
|
||||
idp (idp_models.IdentityProvider): The identity provider instance.
|
||||
subject (str): The unique subject identifier from the IdP.
|
||||
mapped_data (Dict[str, Any]): User data mapped from the IdP (e.g., username, email, name).
|
||||
password_hasher (session_password_hasher.PasswordHasher): The password hasher instance.
|
||||
password_hasher (auth_password_hasher.PasswordHasher): The password hasher instance.
|
||||
db (Session): The database session.
|
||||
|
||||
Returns:
|
||||
@@ -1575,7 +1547,6 @@ class IdentityProviderService:
|
||||
username = base_username
|
||||
while users_crud.get_user_by_username(username, db):
|
||||
username = f"{base_username}_{str(random.randint(100000, 999999))}"
|
||||
|
||||
|
||||
# Create user signup schema
|
||||
user_signup = users_schema.UserSignup(
|
||||
@@ -1609,7 +1580,9 @@ class IdentityProviderService:
|
||||
users_utils.create_user_default_data(created_user.id, db)
|
||||
|
||||
# Create the IdP link
|
||||
user_idp_crud.create_user_idp_link(created_user.id, idp.id, subject, db)
|
||||
user_idp_crud.create_user_identity_provider(
|
||||
created_user.id, idp.id, subject, db
|
||||
)
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Created new user {created_user.username} from IdP {idp.name}", "info"
|
||||
@@ -1700,7 +1673,7 @@ class IdentityProviderService:
|
||||
)
|
||||
|
||||
# Get the encrypted refresh token from database
|
||||
encrypted_refresh_token = user_idp_crud.get_idp_refresh_token(
|
||||
encrypted_refresh_token = user_idp_crud.get_user_identity_provider_refresh_token_by_user_id_and_idp_id(
|
||||
user_id, idp_id, db
|
||||
)
|
||||
|
||||
@@ -1726,7 +1699,9 @@ class IdentityProviderService:
|
||||
exc=err,
|
||||
)
|
||||
# Clear corrupted token
|
||||
user_idp_crud.clear_idp_refresh_token(user_id, idp_id, db)
|
||||
user_idp_crud.clear_user_identity_provider_refresh_token_by_user_id_and_idp_id(
|
||||
user_id, idp_id, db
|
||||
)
|
||||
return None
|
||||
|
||||
# Resolve endpoints and credentials using helper methods
|
||||
@@ -1777,7 +1752,9 @@ class IdentityProviderService:
|
||||
exc=err,
|
||||
)
|
||||
# Clear invalid token from database
|
||||
user_idp_crud.clear_idp_refresh_token(user_id, idp_id, db)
|
||||
user_idp_crud.clear_user_identity_provider_refresh_token_by_user_id_and_idp_id(
|
||||
user_id, idp_id, db
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# Other HTTP errors (5xx) - don't clear token
|
||||
@@ -1862,7 +1839,7 @@ class IdentityProviderService:
|
||||
return False
|
||||
|
||||
# Get the encrypted refresh token from database
|
||||
encrypted_refresh_token = user_idp_crud.get_idp_refresh_token(
|
||||
encrypted_refresh_token = user_idp_crud.get_user_identity_provider_refresh_token_by_user_id_and_idp_id(
|
||||
user_id, idp_id, db
|
||||
)
|
||||
|
||||
@@ -2069,12 +2046,12 @@ class IdentityProviderService:
|
||||
- SKIP if token is still valid and not close to expiry
|
||||
|
||||
Example usage:
|
||||
link = user_idp_crud.get_user_idp_link(user_id, idp_id, db)
|
||||
link = user_idp_crud.get_user_identity_provider_by_user_id_and_idp_id(user_id, idp_id, db)
|
||||
action = self._should_refresh_idp_token(link)
|
||||
if action == TokenAction.REFRESH:
|
||||
await self.refresh_idp_session(user_id, idp_id, db)
|
||||
elif action == TokenAction.CLEAR:
|
||||
user_idp_crud.clear_idp_refresh_token(user_id, idp_id, db)
|
||||
user_idp_crud.clear_user_identity_provider_refresh_token_by_user_id_and_idp_id(user_id, idp_id, db)
|
||||
"""
|
||||
# Check if refresh token exists
|
||||
if not link or not link.idp_refresh_token:
|
||||
305
backend/app/auth/identity_providers/utils.py
Normal file
305
backend/app/auth/identity_providers/utils.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Identity Provider utility functions and templates"""
|
||||
|
||||
from typing import Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import auth.identity_providers.schema as idp_schema
|
||||
import auth.identity_providers.service as idp_service
|
||||
|
||||
import users.user_identity_providers.crud as user_idp_crud
|
||||
|
||||
import core.logger as core_logger
|
||||
|
||||
|
||||
# Pre-configured templates for common IdPs
|
||||
IDP_TEMPLATES = {
|
||||
"keycloak": {
|
||||
"name": "Keycloak",
|
||||
"provider_type": "oidc",
|
||||
"issuer_url": "https://{your-keycloak-domain}/realms/{realm}",
|
||||
"scopes": "openid profile email",
|
||||
"icon": "keycloak",
|
||||
"user_mapping": {
|
||||
"username": ["preferred_username", "username", "email"],
|
||||
"email": ["email", "mail"],
|
||||
"name": ["name", "display_name", "full_name"],
|
||||
},
|
||||
"description": "Keycloak - Open Source Identity and Access Management",
|
||||
"configuration_notes": "Replace {your-keycloak-domain} with your Keycloak server domain (e.g., keycloak.example.com) and {realm} with your realm name. Create an OIDC client in Keycloak admin console.",
|
||||
},
|
||||
"authentik": {
|
||||
"name": "Authentik",
|
||||
"provider_type": "oidc",
|
||||
"issuer_url": "https://{your-authentik-domain}/application/o/{slug}/",
|
||||
"scopes": "openid profile email",
|
||||
"icon": "authentik",
|
||||
"user_mapping": {
|
||||
"username": ["preferred_username", "username", "email"],
|
||||
"email": ["email", "mail"],
|
||||
"name": ["name", "display_name"],
|
||||
},
|
||||
"description": "Authentik - Open-source Identity Provider",
|
||||
"configuration_notes": "Replace {your-authentik-domain} with your Authentik server domain (e.g., authentik.example.com) and {slug} with your application slug. Create an OAuth2/OIDC provider in Authentik.",
|
||||
},
|
||||
"authelia": {
|
||||
"name": "Authelia",
|
||||
"provider_type": "oidc",
|
||||
"issuer_url": "https://{your-authelia-domain}",
|
||||
"scopes": "openid profile email",
|
||||
"icon": "authelia",
|
||||
"user_mapping": {
|
||||
"username": ["preferred_username", "username", "email"],
|
||||
"email": ["email"],
|
||||
"name": ["name"],
|
||||
},
|
||||
"description": "Authelia - Open-source authentication and authorization server",
|
||||
"configuration_notes": "Replace {your-authelia-domain} with your Authelia server domain (e.g., auth.example.com). Configure an OIDC client in your Authelia configuration file.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_idp_templates() -> list[idp_schema.IdentityProviderTemplate]:
|
||||
"""
|
||||
Retrieve a list of identity provider templates, excluding specific providers.
|
||||
|
||||
Returns:
|
||||
list[idp_schema.IdentityProviderTemplate]:
|
||||
A list of IdentityProviderTemplate objects for all identity providers.
|
||||
"""
|
||||
templates = []
|
||||
for template_id, template_data in IDP_TEMPLATES.items():
|
||||
templates.append(
|
||||
idp_schema.IdentityProviderTemplate(
|
||||
template_id=template_id, **template_data
|
||||
)
|
||||
)
|
||||
return templates
|
||||
|
||||
|
||||
def get_idp_template(template_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve an identity provider template by its template ID.
|
||||
|
||||
Args:
|
||||
template_id (str): The unique identifier of the identity provider template.
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: The template dictionary if found, otherwise None.
|
||||
"""
|
||||
return IDP_TEMPLATES.get(template_id)
|
||||
|
||||
|
||||
async def refresh_idp_tokens_if_needed(user_id: int, db: Session) -> None:
|
||||
"""
|
||||
Refreshes identity provider (IdP) tokens for a user if needed based on token expiration policies.
|
||||
|
||||
This function retrieves all IdP links associated with a user and evaluates each token's
|
||||
state to determine the appropriate action: refresh if nearing expiry, clear if maximum
|
||||
age is exceeded, or skip if still valid.
|
||||
|
||||
The function is designed to be non-blocking and opportunistic - errors during token
|
||||
refresh or clearing are logged but do not raise exceptions, allowing the application
|
||||
to continue normal operation even if IdP token management fails.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user whose IdP tokens should be checked and refreshed.
|
||||
db (Session): SQLAlchemy database session for performing database operations.
|
||||
|
||||
Returns:
|
||||
None: This function performs side effects (token refresh/clearing) but returns nothing.
|
||||
|
||||
Raises:
|
||||
Does not raise exceptions. All errors are caught, logged, and suppressed to ensure
|
||||
IdP token management does not disrupt normal application flow.
|
||||
|
||||
Notes:
|
||||
- If a user has no IdP links, the function returns early without performing any operations.
|
||||
- Token refresh attempts that fail are logged but the user session remains valid.
|
||||
- Tokens exceeding maximum age are cleared for security, requiring user re-authentication.
|
||||
- Individual IdP operation failures do not prevent checking other IdP links.
|
||||
"""
|
||||
try:
|
||||
# Get all IdP links for this user
|
||||
idp_links = user_idp_crud.get_user_identity_providers_by_user_id(user_id, db)
|
||||
|
||||
if not idp_links:
|
||||
# User has no IdP links - nothing to refresh
|
||||
return
|
||||
|
||||
# Check each IdP link and take appropriate action
|
||||
for link in idp_links:
|
||||
try:
|
||||
# Determine what action to take for this IdP token (policy-based)
|
||||
action = idp_service.idp_service._should_refresh_idp_token(link)
|
||||
|
||||
if action == idp_service.TokenAction.REFRESH:
|
||||
# Token is close to expiry - attempt to refresh
|
||||
core_logger.print_to_log(
|
||||
f"Attempting to refresh IdP token for user {user_id}, idp {link.idp_id}",
|
||||
"debug",
|
||||
)
|
||||
|
||||
# Attempt to refresh the IdP session
|
||||
result = await idp_service.idp_service.refresh_idp_session(
|
||||
user_id, link.idp_id, db
|
||||
)
|
||||
|
||||
if result:
|
||||
core_logger.print_to_log(
|
||||
f"Successfully refreshed IdP token for user {user_id}, idp {link.idp_id}",
|
||||
"debug",
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"IdP token refresh failed for user {user_id}, idp {link.idp_id}. "
|
||||
"User may need to re-authenticate with IdP later.",
|
||||
"debug",
|
||||
)
|
||||
|
||||
elif action == idp_service.TokenAction.CLEAR:
|
||||
# Token has exceeded maximum age - clear it for security
|
||||
core_logger.print_to_log(
|
||||
f"Clearing expired IdP token (max age exceeded) for user {user_id}, idp {link.idp_id}",
|
||||
"info",
|
||||
)
|
||||
|
||||
success = user_idp_crud.clear_user_identity_provider_refresh_token_by_user_id_and_idp_id(
|
||||
user_id, link.idp_id, db
|
||||
)
|
||||
|
||||
if success:
|
||||
core_logger.print_to_log(
|
||||
f"Successfully cleared expired IdP token for user {user_id}, idp {link.idp_id}. "
|
||||
"User will need to re-authenticate with IdP.",
|
||||
"info",
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"Failed to clear expired IdP token for user {user_id}, idp {link.idp_id}",
|
||||
"warning",
|
||||
)
|
||||
|
||||
else: # idp_service.TokenAction.SKIP
|
||||
# Token is still valid and not close to expiry - no action needed
|
||||
pass
|
||||
|
||||
except Exception as err:
|
||||
# Log individual IdP operation failure but continue with other IdPs
|
||||
core_logger.print_to_log(
|
||||
f"Error checking/refreshing IdP token for user {user_id}, idp {link.idp_id}: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
# Continue to next IdP link
|
||||
|
||||
except Exception as err:
|
||||
# Catch-all for unexpected errors (e.g., database query failure)
|
||||
core_logger.print_to_log(
|
||||
f"Error retrieving IdP links for user {user_id}: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
# Don't raise - IdP token refresh is opportunistic and non-blocking
|
||||
|
||||
|
||||
async def clear_all_idp_tokens(
|
||||
user_id: int, db: Session, revoke_at_idp: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Clear all IdP (Identity Provider) refresh tokens for a user.
|
||||
|
||||
This function retrieves all IdP links associated with a user and clears their
|
||||
refresh tokens. It supports optional revocation at the IdP level before clearing
|
||||
tokens locally.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user whose IdP tokens should be cleared.
|
||||
db (Session): The database session to use for queries.
|
||||
revoke_at_idp (bool, optional): If True, attempts to revoke tokens at the
|
||||
IdP provider level (RFC 7009) before clearing locally. Defaults to False.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
This function does not raise exceptions. All errors are logged and handled
|
||||
gracefully to ensure logout processes are not interrupted.
|
||||
|
||||
Notes:
|
||||
- If no IdP links exist for the user, the function returns early.
|
||||
- Token revocation at the IdP is best-effort; local clearing always proceeds
|
||||
regardless of revocation success or failure.
|
||||
- Individual IdP token clearing failures do not prevent clearing tokens for
|
||||
other IdPs.
|
||||
- All errors are logged with appropriate severity levels (debug, info, warning).
|
||||
"""
|
||||
try:
|
||||
# Get all IdP links for this user
|
||||
idp_links = user_idp_crud.get_user_identity_providers_by_user_id(user_id, db)
|
||||
|
||||
if not idp_links:
|
||||
# User has no IdP links - nothing to clear
|
||||
return
|
||||
|
||||
# Clear tokens for each IdP link
|
||||
for link in idp_links:
|
||||
try:
|
||||
# Optionally attempt to revoke token at IdP first (RFC 7009)
|
||||
if revoke_at_idp:
|
||||
try:
|
||||
revoked = await idp_service.idp_service.revoke_idp_token(
|
||||
user_id, link.idp_id, db
|
||||
)
|
||||
if revoked:
|
||||
core_logger.print_to_log(
|
||||
f"Revoked IdP token at provider for user {user_id}, idp {link.idp_id}",
|
||||
"info",
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"IdP token revocation not supported or failed for user {user_id}, idp {link.idp_id}. "
|
||||
"Will clear locally.",
|
||||
"debug",
|
||||
)
|
||||
except Exception as revoke_err:
|
||||
# Log revocation failure but continue with local clearing
|
||||
core_logger.print_to_log(
|
||||
f"Error revoking IdP token for user {user_id}, idp {link.idp_id}: {revoke_err}. "
|
||||
"Will clear locally.",
|
||||
"warning",
|
||||
exc=revoke_err,
|
||||
)
|
||||
|
||||
# Always clear locally regardless of revocation result
|
||||
success = user_idp_crud.clear_user_identity_provider_refresh_token_by_user_id_and_idp_id(
|
||||
user_id, link.idp_id, db
|
||||
)
|
||||
|
||||
if success:
|
||||
core_logger.print_to_log(
|
||||
f"Cleared IdP refresh token for user {user_id}, idp {link.idp_id} on logout",
|
||||
"debug",
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"No IdP refresh token to clear for user {user_id}, idp {link.idp_id}",
|
||||
"debug",
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
# Log individual IdP token clearing failure but continue with other IdPs
|
||||
core_logger.print_to_log(
|
||||
f"Error clearing IdP token for user {user_id}, idp {link.idp_id}: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
# Continue to next IdP link
|
||||
|
||||
except Exception as err:
|
||||
# Catch-all for unexpected errors (e.g., database query failure)
|
||||
core_logger.print_to_log(
|
||||
f"Error retrieving IdP links for user {user_id} during logout: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
# Don't raise - IdP token clearing is a best-effort security measure
|
||||
409
backend/app/auth/router.py
Normal file
409
backend/app/auth/router.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from typing import Annotated, Callable
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
Response,
|
||||
Request,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.utils as session_utils
|
||||
import auth.security as auth_security
|
||||
import auth.utils as auth_utils
|
||||
import session.crud as session_crud
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
import auth.token_manager as auth_token_manager
|
||||
import auth.schema as auth_schema
|
||||
|
||||
import auth.identity_providers.utils as idp_utils
|
||||
|
||||
import users.user.crud as users_crud
|
||||
import users.user.utils as users_utils
|
||||
import profile.utils as profile_utils
|
||||
|
||||
import core.database as core_database
|
||||
import core.rate_limit as core_rate_limit
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
@core_rate_limit.limiter.limit(core_rate_limit.SESSION_LOGIN_LIMIT)
|
||||
async def login_for_access_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
client_type: Annotated[str, Depends(auth_security.header_client_type_scheme)],
|
||||
pending_mfa_store: Annotated[
|
||||
auth_schema.PendingMFALogin, Depends(auth_schema.get_pending_mfa_store)
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
token_manager: Annotated[
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Handles user login and access token generation, including Multi-Factor Authentication (MFA) flow.
|
||||
|
||||
Rate Limit: 5 requests per minute per IP
|
||||
|
||||
This endpoint authenticates a user using provided credentials, checks if the user is active,
|
||||
and determines if MFA is required. If MFA is enabled for the user, it stores the pending login
|
||||
and returns an MFA-required response. Otherwise, it completes the login process and returns
|
||||
the required information.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object
|
||||
request: The HTTP request object
|
||||
form_data: Form data containing username and password
|
||||
client_type: The type of client making the request ("web" or "mobile")
|
||||
pending_mfa_store: Store for pending MFA logins
|
||||
password_hasher: The password hasher instance used for verifying passwords
|
||||
token_manager: The token manager instance used for token operations
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Union[auth_schema.MFARequiredResponse, dict, str]:
|
||||
- If MFA is required, returns an MFA-required response (schema or dict depending on client type)
|
||||
- If MFA is not required, proceeds with normal login via auth_utils.complete_login()
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails or the user is inactive
|
||||
"""
|
||||
user = auth_utils.authenticate_user(
|
||||
form_data.username, form_data.password, password_hasher, db
|
||||
)
|
||||
|
||||
# Check if the user is active
|
||||
users_utils.check_user_is_active(user)
|
||||
|
||||
# Check if MFA is enabled for this user
|
||||
if profile_utils.is_mfa_enabled_for_user(user.id, db):
|
||||
# Store the user for pending MFA verification
|
||||
pending_mfa_store.add_pending_login(form_data.username, user.id)
|
||||
|
||||
# Return MFA required response
|
||||
if client_type == "web":
|
||||
response.status_code = status.HTTP_202_ACCEPTED
|
||||
return auth_schema.MFARequiredResponse(
|
||||
mfa_required=True,
|
||||
username=form_data.username,
|
||||
message="MFA verification required",
|
||||
)
|
||||
if client_type == "mobile":
|
||||
return {
|
||||
"mfa_required": True,
|
||||
"username": form_data.username,
|
||||
"message": "MFA verification required",
|
||||
}
|
||||
|
||||
# If no MFA required, proceed with normal login
|
||||
return auth_utils.complete_login(
|
||||
response, request, user, client_type, password_hasher, token_manager, db
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/verify")
|
||||
async def verify_mfa_and_login(
|
||||
response: Response,
|
||||
request: Request,
|
||||
mfa_request: auth_schema.MFALoginRequest,
|
||||
client_type: Annotated[str, Depends(auth_security.header_client_type_scheme)],
|
||||
pending_mfa_store: Annotated[
|
||||
auth_schema.PendingMFALogin, Depends(auth_schema.get_pending_mfa_store)
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
token_manager: Annotated[
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Verify MFA code and complete login process.
|
||||
|
||||
This endpoint verifies the MFA code for a pending login and completes
|
||||
the authentication process if the code is valid.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object
|
||||
request: The HTTP request object
|
||||
mfa_request: MFA login request containing username and MFA code
|
||||
client_type: The type of client making the request ("web" or "mobile")
|
||||
pending_mfa_store: Store for pending MFA logins
|
||||
password_hasher: The password hasher instance used for verifying passwords
|
||||
token_manager: The token manager instance used for token operations
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Result from auth_utils.complete_login()
|
||||
|
||||
Raises:
|
||||
HTTPException: If no pending login found, MFA code is invalid, or user not found
|
||||
"""
|
||||
# Check if there's a pending MFA login for this username
|
||||
user_id = pending_mfa_store.get_pending_login(mfa_request.username)
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No pending MFA login found for this username",
|
||||
)
|
||||
|
||||
# Verify the MFA code
|
||||
if not profile_utils.verify_user_mfa(user_id, mfa_request.mfa_code, db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid MFA code"
|
||||
)
|
||||
|
||||
# Get the user and complete login
|
||||
user = users_crud.get_user_by_id(user_id, db)
|
||||
if not user:
|
||||
pending_mfa_store.delete_pending_login(mfa_request.username)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# Check if the user is still active
|
||||
users_utils.check_user_is_active(user)
|
||||
|
||||
# Clean up pending login
|
||||
pending_mfa_store.delete_pending_login(mfa_request.username)
|
||||
|
||||
# Complete the login
|
||||
return auth_utils.complete_login(
|
||||
response, request, user, client_type, password_hasher, token_manager, db
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
_validate_refresh_token: Annotated[
|
||||
Callable, Depends(auth_security.validate_refresh_token)
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(auth_security.get_sub_from_refresh_token),
|
||||
],
|
||||
token_session_id: Annotated[
|
||||
str,
|
||||
Depends(auth_security.get_sid_from_refresh_token),
|
||||
],
|
||||
refresh_token_value: Annotated[
|
||||
str,
|
||||
Depends(auth_security.get_and_return_refresh_token),
|
||||
],
|
||||
client_type: Annotated[str, Depends(auth_security.header_client_type_scheme)],
|
||||
password_hasher: Annotated[
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
token_manager: Annotated[
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Handles the refresh token process for user sessions.
|
||||
|
||||
This endpoint validates the provided refresh token, checks session and user status,
|
||||
and issues new access, refresh, and CSRF tokens. The response format depends on the client type.
|
||||
|
||||
Args:
|
||||
response (Response): The HTTP response object.
|
||||
request (Request): The HTTP request object.
|
||||
_validate_refresh_token (Callable): Dependency to validate the refresh token.
|
||||
token_user_id (int): User ID extracted from the refresh token.
|
||||
token_session_id (str): Session ID extracted from the refresh token.
|
||||
refresh_token_value (str): The raw refresh token value.
|
||||
client_type (str): The type of client ("web" or "mobile").
|
||||
password_hasher (PasswordHasher): Utility for verifying token hashes.
|
||||
token_manager (TokenManager): Utility for creating tokens.
|
||||
db (Session): Database session.
|
||||
|
||||
Returns:
|
||||
Union[str, dict]: For "web" clients, returns the session ID.
|
||||
For "mobile" clients, returns a dictionary with new tokens and session ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the session is not found, the refresh token is invalid,
|
||||
the user is inactive, or the client type is invalid.
|
||||
"""
|
||||
# Get the session from the database
|
||||
session = session_crud.get_session_by_id(token_session_id, db)
|
||||
|
||||
# Check if the session was found
|
||||
if session is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
is_valid = password_hasher.verify(refresh_token_value, session.refresh_token)
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# get user
|
||||
user = users_crud.get_user_by_id(token_user_id, db)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Check if the user is active
|
||||
users_utils.check_user_is_active(user)
|
||||
|
||||
# Create the tokens
|
||||
(
|
||||
session_id,
|
||||
new_access_token_exp,
|
||||
new_access_token,
|
||||
_new_refresh_token_exp,
|
||||
new_refresh_token,
|
||||
new_csrf_token,
|
||||
) = auth_utils.create_tokens(user, token_manager, session.id)
|
||||
|
||||
# Edit the session and store it in the database
|
||||
session_utils.edit_session(session, request, new_refresh_token, password_hasher, db)
|
||||
|
||||
# Opportunistically refresh IdP tokens for all linked identity providers
|
||||
await idp_utils.refresh_idp_tokens_if_needed(user.id, db)
|
||||
|
||||
if client_type == "web":
|
||||
response = auth_utils.create_response_with_tokens(
|
||||
response, new_access_token, new_refresh_token, new_csrf_token
|
||||
)
|
||||
|
||||
# Return session ID
|
||||
return {
|
||||
"session_id": session_id,
|
||||
}
|
||||
if client_type == "mobile":
|
||||
# Return the tokens
|
||||
return {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
"session_id": session_id,
|
||||
"token_type": "bearer",
|
||||
"expires_in": int(new_access_token_exp.timestamp()),
|
||||
}
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid client type",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
response: Response,
|
||||
_validate_access_token: Annotated[
|
||||
Callable, Depends(auth_security.validate_access_token)
|
||||
],
|
||||
token_session_id: Annotated[
|
||||
str,
|
||||
Depends(auth_security.get_sid_from_access_token),
|
||||
],
|
||||
refresh_token_value: Annotated[
|
||||
str,
|
||||
Depends(auth_security.get_and_return_refresh_token),
|
||||
],
|
||||
client_type: Annotated[str, Depends(auth_security.header_client_type_scheme)],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(auth_security.get_sub_from_refresh_token),
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Logs out a user by validating and deleting their session, and clearing authentication cookies for web clients.
|
||||
Parameters:
|
||||
response (Response): The response object to modify cookies.
|
||||
_validate_access_token (Callable): Dependency to validate the access token.
|
||||
token_session_id (str): The session ID extracted from the access token.
|
||||
refresh_token_value (str): The refresh token value from the request.
|
||||
client_type (str): The type of client ("web" or "mobile").
|
||||
token_user_id (int): The user ID extracted from the refresh token.
|
||||
password_hasher (PasswordHasher): Utility for verifying the refresh token.
|
||||
db (Session): Database session for CRUD operations.
|
||||
Returns:
|
||||
dict: A message indicating successful logout.
|
||||
Raises:
|
||||
HTTPException: If the refresh token is invalid (401 Unauthorized).
|
||||
HTTPException: If the client type is invalid (403 Forbidden).
|
||||
"""
|
||||
# Get the session from the database
|
||||
session = session_crud.get_session_by_id(token_session_id, db)
|
||||
|
||||
# Check if the session was found
|
||||
if session is not None:
|
||||
# Verify the refresh token
|
||||
is_valid = password_hasher.verify(refresh_token_value, session.refresh_token)
|
||||
|
||||
# If the refresh token is not valid, raise an exception
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Delete the session from the database
|
||||
session_crud.delete_session(session.id, token_user_id, db)
|
||||
|
||||
# Clear all IdP refresh tokens for security
|
||||
await idp_utils.clear_all_idp_tokens(token_user_id, db)
|
||||
|
||||
if client_type == "web":
|
||||
# Clear the cookies by setting their expiration to the past
|
||||
response.delete_cookie(key="endurain_access_token", path="/")
|
||||
response.delete_cookie(key="endurain_refresh_token", path="/")
|
||||
response.delete_cookie(key="endurain_csrf_token", path="/")
|
||||
return {"message": "Logout successful"}
|
||||
if client_type == "mobile":
|
||||
return {"message": "Logout successful"}
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid client type",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
144
backend/app/auth/schema.py
Normal file
144
backend/app/auth/schema.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""
|
||||
Schema for login requests containing username and password.
|
||||
|
||||
Attributes:
|
||||
username (str): The username of the user. Must be between 1 and 250 characters.
|
||||
password (str): The user's password. Must be at least 8 characters long.
|
||||
"""
|
||||
|
||||
username: str = Field(..., min_length=1, max_length=250)
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class MFALoginRequest(BaseModel):
|
||||
"""
|
||||
Schema for Multi-Factor Authentication (MFA) login request.
|
||||
|
||||
Attributes:
|
||||
username (str): The username of the user attempting to log in. Must be between 1 and 250 characters.
|
||||
mfa_code (str): The 6-digit MFA code provided by the user. Must match the pattern: six consecutive digits.
|
||||
"""
|
||||
|
||||
username: str = Field(..., min_length=1, max_length=250)
|
||||
mfa_code: str = Field(..., pattern=r"^\d{6}$")
|
||||
|
||||
|
||||
class MFARequiredResponse(BaseModel):
|
||||
"""
|
||||
Represents a response indicating that Multi-Factor Authentication (MFA) is required.
|
||||
|
||||
Attributes:
|
||||
mfa_required (bool): Indicates whether MFA is required. Defaults to True.
|
||||
username (str): The username for which MFA is required.
|
||||
message (str): A message describing the requirement. Defaults to "MFA verification required".
|
||||
"""
|
||||
|
||||
mfa_required: bool = True
|
||||
username: str
|
||||
message: str = "MFA verification required"
|
||||
|
||||
|
||||
class PendingMFALogin:
|
||||
"""
|
||||
A class to manage pending Multi-Factor Authentication (MFA) login sessions.
|
||||
|
||||
This class provides methods to add, retrieve, delete, and check pending login entries
|
||||
for users who are in the process of MFA authentication. It uses an internal dictionary
|
||||
to store the mapping between usernames and their associated user IDs.
|
||||
|
||||
Attributes:
|
||||
_store (dict): Internal storage mapping usernames to user IDs for pending logins.
|
||||
|
||||
Methods:
|
||||
add_pending_login(username: str, user_id: int):
|
||||
Adds a pending login entry for the specified username and user ID.
|
||||
|
||||
get_pending_login(username: str):
|
||||
Retrieves the user ID associated with the given username's pending login entry.
|
||||
|
||||
delete_pending_login(username: str):
|
||||
Removes the pending login entry for the specified username.
|
||||
|
||||
has_pending_login(username: str):
|
||||
Checks if the specified username has a pending login entry.
|
||||
|
||||
clear_all():
|
||||
Clears all pending login entries from the internal store.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
|
||||
def add_pending_login(self, username: str, user_id: int):
|
||||
"""
|
||||
Adds a pending login entry for a user.
|
||||
|
||||
Stores the provided username and associated user ID in the internal store,
|
||||
marking the user as pending login.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user to add.
|
||||
user_id (int): The unique identifier of the user.
|
||||
|
||||
"""
|
||||
self._store[username] = user_id
|
||||
|
||||
def get_pending_login(self, username: str):
|
||||
"""
|
||||
Retrieve the pending login information for a given username.
|
||||
|
||||
Args:
|
||||
username (str): The username to look up.
|
||||
|
||||
Returns:
|
||||
Any: The pending login information associated with the username, or None if not found.
|
||||
"""
|
||||
return self._store.get(username)
|
||||
|
||||
def delete_pending_login(self, username: str):
|
||||
"""
|
||||
Removes the pending login entry for the specified username from the internal store.
|
||||
|
||||
Args:
|
||||
username (str): The username whose pending login entry should be deleted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if username in self._store:
|
||||
del self._store[username]
|
||||
|
||||
def has_pending_login(self, username: str):
|
||||
"""
|
||||
Checks if the given username has a pending login session.
|
||||
|
||||
Args:
|
||||
username (str): The username to check for a pending login.
|
||||
|
||||
Returns:
|
||||
bool: True if the username has a pending login session, False otherwise.
|
||||
"""
|
||||
return username in self._store
|
||||
|
||||
def clear_all(self):
|
||||
"""
|
||||
Removes all items from the internal store, effectively resetting it to an empty state.
|
||||
"""
|
||||
self._store.clear()
|
||||
|
||||
|
||||
def get_pending_mfa_store():
|
||||
"""
|
||||
Retrieve the current pending MFA (Multi-Factor Authentication) store.
|
||||
|
||||
Returns:
|
||||
dict: The pending MFA store containing MFA-related data.
|
||||
"""
|
||||
return pending_mfa_store
|
||||
|
||||
|
||||
pending_mfa_store = PendingMFALogin()
|
||||
@@ -7,15 +7,15 @@ from fastapi.security import (
|
||||
APIKeyCookie,
|
||||
)
|
||||
|
||||
import session.constants as session_constants
|
||||
import session.token_manager as session_token_manager
|
||||
import auth.constants as auth_constants
|
||||
import auth.token_manager as auth_token_manager
|
||||
|
||||
import core.logger as core_logger
|
||||
|
||||
# Define the OAuth2 scheme for handling bearer tokens
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="token",
|
||||
scopes=session_constants.SCOPE_DICT,
|
||||
scopes=auth_constants.SCOPE_DICT,
|
||||
auto_error=False,
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ def get_token(
|
||||
cookie_token: Union[str, None],
|
||||
client_type: str,
|
||||
token_type: str,
|
||||
) -> str:
|
||||
) -> str | None:
|
||||
"""
|
||||
Retrieves the authentication token based on client type and available sources.
|
||||
|
||||
@@ -78,7 +78,7 @@ def get_access_token(
|
||||
non_cookie_access_token: Annotated[Union[str, None], Depends(oauth2_scheme)],
|
||||
cookie_access_token: Union[str, None] = Depends(cookie_access_token_scheme),
|
||||
client_type: str = Depends(header_client_type_scheme),
|
||||
) -> str:
|
||||
) -> str | None:
|
||||
"""
|
||||
Retrieves the access token from either the Authorization header or a cookie, depending on the client type.
|
||||
|
||||
@@ -102,8 +102,8 @@ def validate_access_token(
|
||||
# access_token: Annotated[str, Depends(get_access_token_from_cookies)]
|
||||
access_token: Annotated[str, Depends(get_access_token)],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
) -> None:
|
||||
"""
|
||||
@@ -115,7 +115,7 @@ def validate_access_token(
|
||||
|
||||
Args:
|
||||
access_token (str): The access token to be validated.
|
||||
token_manager (session_token_manager.TokenManager): The token manager instance used for validation.
|
||||
token_manager (auth_token_manager.TokenManager): The token manager instance used for validation.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the token is expired, invalid, or an unexpected error occurs during validation.
|
||||
@@ -147,8 +147,8 @@ def validate_access_token(
|
||||
def get_sub_from_access_token(
|
||||
access_token: Annotated[str, Depends(get_access_token)],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
) -> int:
|
||||
"""
|
||||
@@ -156,7 +156,7 @@ def get_sub_from_access_token(
|
||||
|
||||
Args:
|
||||
access_token (str): The access token from which to extract the claim.
|
||||
token_manager (session_token_manager.TokenManager): The token manager instance used to decode and validate the token.
|
||||
token_manager (auth_token_manager.TokenManager): The token manager instance used to decode and validate the token.
|
||||
|
||||
Returns:
|
||||
int: The user ID associated with the access token.
|
||||
@@ -165,22 +165,28 @@ def get_sub_from_access_token(
|
||||
Exception: If the token is invalid or the 'sub' claim is missing.
|
||||
"""
|
||||
# Return the user ID associated with the token
|
||||
return token_manager.get_token_claim(access_token, "sub")
|
||||
sub = token_manager.get_token_claim(access_token, "sub")
|
||||
if not isinstance(sub, int):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token: 'sub' claim must be an integer",
|
||||
)
|
||||
return sub
|
||||
|
||||
|
||||
def get_sid_from_access_token(
|
||||
access_token: Annotated[str, Depends(get_access_token)],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
) -> int:
|
||||
) -> str:
|
||||
"""
|
||||
Retrieves the session ID ('sid') from the provided access token.
|
||||
|
||||
Args:
|
||||
access_token (str): The access token from which to extract the session ID.
|
||||
token_manager (session_token_manager.TokenManager): The token manager used to validate and extract claims from the token.
|
||||
token_manager (auth_token_manager.TokenManager): The token manager used to validate and extract claims from the token.
|
||||
|
||||
Returns:
|
||||
int: The session ID ('sid') associated with the access token.
|
||||
@@ -188,8 +194,14 @@ def get_sid_from_access_token(
|
||||
Raises:
|
||||
Exception: If the token is invalid or the 'sid' claim is not present.
|
||||
"""
|
||||
# Return the user ID associated with the token
|
||||
return token_manager.get_token_claim(access_token, "sid")
|
||||
# Return the session ID associated with the token
|
||||
sid = token_manager.get_token_claim(access_token, "sid")
|
||||
if not isinstance(sid, str):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token: 'sid' claim must be a string",
|
||||
)
|
||||
return sid
|
||||
|
||||
|
||||
def get_and_return_access_token(
|
||||
@@ -213,7 +225,7 @@ def get_refresh_token(
|
||||
non_cookie_refresh_token: Annotated[Union[str, None], Depends(oauth2_scheme)],
|
||||
cookie_refresh_token: Union[str, None] = Depends(cookie_refresh_token_scheme),
|
||||
client_type: str = Depends(header_client_type_scheme),
|
||||
) -> str:
|
||||
) -> str | None:
|
||||
"""
|
||||
Retrieves the refresh token from either the Authorization header or a cookie, depending on the client type.
|
||||
|
||||
@@ -237,8 +249,8 @@ def validate_refresh_token(
|
||||
# access_token: Annotated[str, Depends(get_access_token_from_cookies)]
|
||||
refresh_token: Annotated[str, Depends(get_refresh_token)],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
) -> None:
|
||||
"""
|
||||
@@ -246,7 +258,7 @@ def validate_refresh_token(
|
||||
|
||||
Args:
|
||||
refresh_token (str): The refresh token to be validated, extracted via dependency injection.
|
||||
token_manager (session_token_manager.TokenManager): The token manager instance used to validate the token, injected via dependency.
|
||||
token_manager (auth_token_manager.TokenManager): The token manager instance used to validate the token, injected via dependency.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the refresh token is expired or invalid, or if an unexpected error occurs during validation.
|
||||
@@ -281,8 +293,8 @@ def validate_refresh_token(
|
||||
def get_sub_from_refresh_token(
|
||||
refresh_token: Annotated[str, Depends(get_refresh_token)],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
) -> int:
|
||||
"""
|
||||
@@ -290,7 +302,7 @@ def get_sub_from_refresh_token(
|
||||
|
||||
Args:
|
||||
refresh_token (str): The refresh token from which to extract the user ID.
|
||||
token_manager (session_token_manager.TokenManager): The token manager instance used to validate and parse the token.
|
||||
token_manager (auth_token_manager.TokenManager): The token manager instance used to validate and parse the token.
|
||||
|
||||
Returns:
|
||||
int: The user ID associated with the provided refresh token.
|
||||
@@ -299,14 +311,20 @@ def get_sub_from_refresh_token(
|
||||
Exception: If the token is invalid or the 'sub' claim is not found.
|
||||
"""
|
||||
# Return the user ID associated with the token
|
||||
return token_manager.get_token_claim(refresh_token, "sub")
|
||||
sub = token_manager.get_token_claim(refresh_token, "sub")
|
||||
if not isinstance(sub, int):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token: 'sub' claim must be an integer",
|
||||
)
|
||||
return sub
|
||||
|
||||
|
||||
def get_sid_from_refresh_token(
|
||||
refresh_token: Annotated[str, Depends(get_refresh_token)],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
) -> str:
|
||||
"""
|
||||
@@ -314,7 +332,7 @@ def get_sid_from_refresh_token(
|
||||
|
||||
Args:
|
||||
refresh_token (str): The refresh token from which to extract the session ID.
|
||||
token_manager (session_token_manager.TokenManager): The token manager used to validate and extract claims from the token.
|
||||
token_manager (auth_token_manager.TokenManager): The token manager used to validate and extract claims from the token.
|
||||
|
||||
Returns:
|
||||
str: The session ID associated with the provided refresh token.
|
||||
@@ -323,7 +341,13 @@ def get_sid_from_refresh_token(
|
||||
Exception: If the token is invalid or the 'sid' claim is not present.
|
||||
"""
|
||||
# Return the session ID associated with the token
|
||||
return token_manager.get_token_claim(refresh_token, "sid")
|
||||
sid = token_manager.get_token_claim(refresh_token, "sid")
|
||||
if not isinstance(sid, str):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token: 'sid' claim must be a string",
|
||||
)
|
||||
return sid
|
||||
|
||||
|
||||
def get_and_return_refresh_token(
|
||||
@@ -345,8 +369,8 @@ def get_and_return_refresh_token(
|
||||
def check_scopes(
|
||||
access_token: Annotated[str, Depends(get_access_token)],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
auth_token_manager.TokenManager,
|
||||
Depends(auth_token_manager.get_token_manager),
|
||||
],
|
||||
security_scopes: SecurityScopes,
|
||||
) -> None:
|
||||
@@ -355,7 +379,7 @@ def check_scopes(
|
||||
|
||||
Args:
|
||||
access_token (str): The access token extracted from the request.
|
||||
token_manager (session_token_manager.TokenManager): Instance responsible for managing and validating tokens.
|
||||
token_manager (auth_token_manager.TokenManager): Instance responsible for managing and validating tokens.
|
||||
security_scopes (SecurityScopes): Required scopes for the endpoint.
|
||||
|
||||
Raises:
|
||||
@@ -368,6 +392,14 @@ def check_scopes(
|
||||
# Get the scope from the token
|
||||
scope = token_manager.get_token_claim(access_token, "scope")
|
||||
|
||||
# Ensure the scope is a list
|
||||
if not isinstance(scope, list):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Unauthorized Access - Invalid scope format",
|
||||
headers={"WWW-Authenticate": f'Bearer scope="{security_scopes.scopes}"'},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use set operations to find missing scope
|
||||
missing_scopes = set(security_scopes.scopes) - set(scope)
|
||||
@@ -17,7 +17,7 @@ from joserfc.errors import (
|
||||
)
|
||||
from joserfc.jwk import OctKey
|
||||
|
||||
import session.constants as session_constants
|
||||
import auth.constants as auth_constants
|
||||
|
||||
import users.user.schema as users_schema
|
||||
|
||||
@@ -278,16 +278,16 @@ class TokenManager:
|
||||
"""
|
||||
# Check user access level and set scope accordingly
|
||||
if user.access_type == users_schema.UserAccessType.REGULAR:
|
||||
scope = session_constants.REGULAR_ACCESS_SCOPE
|
||||
scope = auth_constants.REGULAR_ACCESS_SCOPE
|
||||
else:
|
||||
scope = session_constants.ADMIN_ACCESS_SCOPE
|
||||
scope = auth_constants.ADMIN_ACCESS_SCOPE
|
||||
|
||||
exp = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=session_constants.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
minutes=auth_constants.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
if token_type == TokenType.REFRESH:
|
||||
exp = datetime.now(timezone.utc) + timedelta(
|
||||
days=session_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
days=auth_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
|
||||
# Set now
|
||||
@@ -338,6 +338,13 @@ def get_token_manager() -> TokenManager:
|
||||
return token_manager
|
||||
|
||||
|
||||
# Validate required configuration before creating token manager
|
||||
if auth_constants.JWT_SECRET_KEY is None:
|
||||
raise ValueError("JWT_SECRET_KEY must be set in environment variables")
|
||||
|
||||
if auth_constants.JWT_ALGORITHM is None:
|
||||
raise ValueError("JWT_ALGORITHM must be set in environment variables")
|
||||
|
||||
token_manager = TokenManager(
|
||||
session_constants.JWT_SECRET_KEY, session_constants.JWT_ALGORITHM
|
||||
auth_constants.JWT_SECRET_KEY, auth_constants.JWT_ALGORITHM
|
||||
)
|
||||
249
backend/app/auth/utils.py
Normal file
249
backend/app/auth/utils.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import os
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
from fastapi import (
|
||||
HTTPException,
|
||||
status,
|
||||
Response,
|
||||
Request,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import auth.constants as auth_constants
|
||||
import session.schema as session_schema
|
||||
import session.crud as session_crud
|
||||
import session.utils as session_utils
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
import auth.token_manager as auth_token_manager
|
||||
|
||||
import users.user.crud as users_crud
|
||||
import users.user.schema as users_schema
|
||||
import users.user_identity_providers.crud as user_idp_crud
|
||||
|
||||
import auth.identity_providers.service as idp_service
|
||||
import core.logger as core_logger
|
||||
|
||||
|
||||
def authenticate_user(
|
||||
username: str,
|
||||
password: str,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> users_schema.UserRead:
|
||||
"""
|
||||
Authenticates a user by verifying the provided username and password.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user attempting to authenticate.
|
||||
password (str): The plaintext password provided by the user.
|
||||
password_hasher (auth_password_hasher.PasswordHasher): An instance of the password hasher for verifying and updating password hashes.
|
||||
db (Session): The database session used for querying and updating user data.
|
||||
|
||||
Returns:
|
||||
users_schema.UserRead: The authenticated user object if authentication is successful.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the username does not exist or the password is invalid.
|
||||
"""
|
||||
# Get the user from the database
|
||||
user = users_crud.authenticate_user(username, db)
|
||||
|
||||
# Check if the user exists and if the password is correct
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Verify password and get updated hash if applicable
|
||||
is_password_valid, updated_hash = password_hasher.verify_and_update(
|
||||
password, user.password
|
||||
)
|
||||
if not is_password_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Update user hash if applicable
|
||||
if updated_hash:
|
||||
users_crud.edit_user_password(
|
||||
user.id, updated_hash, password_hasher, db, is_hashed=True
|
||||
)
|
||||
|
||||
# Return the user if the password is correct
|
||||
return user
|
||||
|
||||
|
||||
def create_tokens(
|
||||
user: users_schema.UserRead,
|
||||
token_manager: auth_token_manager.TokenManager,
|
||||
session_id: str | None = None,
|
||||
) -> Tuple[str, datetime, str, datetime, str, str]:
|
||||
"""
|
||||
Generates session tokens for a user, including access token, refresh token, and CSRF token.
|
||||
|
||||
Args:
|
||||
user (users_schema.UserRead): The user object for whom the tokens are being created.
|
||||
token_manager (auth_token_manager.TokenManager): The token manager responsible for token creation.
|
||||
session_id (str | None, optional): An optional session ID. If not provided, a new unique session ID is generated.
|
||||
|
||||
Returns:
|
||||
Tuple[str, datetime, str, datetime, str, str]:
|
||||
A tuple containing:
|
||||
- session_id (str): The session identifier.
|
||||
- access_token_exp (datetime): Expiration datetime of the access token.
|
||||
- access_token (str): The access token string.
|
||||
- refresh_token_exp (datetime): Expiration datetime of the refresh token.
|
||||
- refresh_token (str): The refresh token string.
|
||||
- csrf_token (str): The CSRF token string.
|
||||
"""
|
||||
if session_id is None:
|
||||
# Generate a unique session ID
|
||||
session_id = str(uuid4())
|
||||
|
||||
# Create the access, refresh tokens and csrf token
|
||||
access_token_exp, access_token = token_manager.create_token(
|
||||
session_id, user, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
refresh_token_exp, refresh_token = token_manager.create_token(
|
||||
session_id, user, auth_token_manager.TokenType.REFRESH
|
||||
)
|
||||
|
||||
csrf_token = token_manager.create_csrf_token()
|
||||
|
||||
return (
|
||||
session_id,
|
||||
access_token_exp,
|
||||
access_token,
|
||||
refresh_token_exp,
|
||||
refresh_token,
|
||||
csrf_token,
|
||||
)
|
||||
|
||||
|
||||
def create_response_with_tokens(
|
||||
response: Response, access_token: str, refresh_token: str, csrf_token: str
|
||||
) -> Response:
|
||||
"""
|
||||
Sets access, refresh, and CSRF tokens as cookies on the given response object.
|
||||
|
||||
Args:
|
||||
response (Response): The response object to set cookies on.
|
||||
access_token (str): The JWT access token to be set as a cookie.
|
||||
refresh_token (str): The JWT refresh token to be set as a cookie.
|
||||
csrf_token (str): The CSRF token to be set as a cookie.
|
||||
|
||||
Returns:
|
||||
Response: The response object with the tokens set as cookies.
|
||||
"""
|
||||
secure = os.environ.get("FRONTEND_PROTOCOL") == "https"
|
||||
|
||||
# Set the cookies with the tokens
|
||||
response.set_cookie(
|
||||
key="endurain_access_token",
|
||||
value=access_token,
|
||||
expires=datetime.now(timezone.utc)
|
||||
+ timedelta(minutes=auth_constants.JWT_ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
httponly=True,
|
||||
path="/",
|
||||
secure=secure,
|
||||
samesite="lax",
|
||||
)
|
||||
response.set_cookie(
|
||||
key="endurain_refresh_token",
|
||||
value=refresh_token,
|
||||
expires=datetime.now(timezone.utc)
|
||||
+ timedelta(days=auth_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
httponly=True,
|
||||
path="/",
|
||||
secure=secure,
|
||||
samesite="lax",
|
||||
)
|
||||
response.set_cookie(
|
||||
key="endurain_csrf_token",
|
||||
value=csrf_token,
|
||||
httponly=False,
|
||||
path="/",
|
||||
secure=secure,
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
# Return the response
|
||||
return response
|
||||
|
||||
|
||||
def complete_login(
|
||||
response: Response,
|
||||
request: Request,
|
||||
user: users_schema.UserRead,
|
||||
client_type: str,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
token_manager: auth_token_manager.TokenManager,
|
||||
db: Session,
|
||||
) -> dict | str:
|
||||
"""
|
||||
Handles the completion of the login process by generating session and authentication tokens,
|
||||
storing the session in the database, and returning appropriate responses based on client type.
|
||||
|
||||
Args:
|
||||
response (Response): The HTTP response object to set cookies for web clients.
|
||||
request (Request): The HTTP request object containing client information.
|
||||
user (users_schema.UserRead): The authenticated user object.
|
||||
client_type (str): The type of client ("web" or "mobile").
|
||||
password_hasher (auth_password_hasher.PasswordHasher): Utility for password hashing.
|
||||
token_manager (auth_token_manager.TokenManager): Utility for token generation and management.
|
||||
db (Session): Database session for storing session information.
|
||||
|
||||
Returns:
|
||||
dict | str: For web clients, returns the session ID as a string.
|
||||
For mobile clients, returns a dictionary containing tokens and session info.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the client type is invalid, raises a 403 Forbidden error.
|
||||
"""
|
||||
# Create the tokens
|
||||
(
|
||||
session_id,
|
||||
access_token_exp,
|
||||
access_token,
|
||||
_refresh_token_exp,
|
||||
refresh_token,
|
||||
csrf_token,
|
||||
) = create_tokens(user, token_manager)
|
||||
|
||||
# Create the session and store it in the database
|
||||
session_utils.create_session(
|
||||
session_id, user, request, refresh_token, password_hasher, db
|
||||
)
|
||||
|
||||
if client_type == "web":
|
||||
# Set response cookies with tokens
|
||||
create_response_with_tokens(response, access_token, refresh_token, csrf_token)
|
||||
|
||||
# Return the session_id
|
||||
return {
|
||||
"session_id": session_id,
|
||||
}
|
||||
if client_type == "mobile":
|
||||
# Return the tokens directly (no cookies for mobile)
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"session_id": session_id,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int(access_token_exp.timestamp()),
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid client type",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
75
backend/app/core/middleware.py
Normal file
75
backend/app/core/middleware.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from fastapi import Request, HTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware for CSRF protection in FastAPI applications.
|
||||
|
||||
This middleware checks for a valid CSRF token in requests from web clients to prevent cross-site request forgery attacks.
|
||||
It exempts specific API paths from CSRF checks and only enforces validation for POST, PUT, DELETE, and PATCH requests.
|
||||
|
||||
Attributes:
|
||||
exempt_paths (list): List of URL paths that are exempt from CSRF protection.
|
||||
|
||||
Methods:
|
||||
dispatch(request, call_next):
|
||||
Processes incoming requests, enforcing CSRF checks for web clients on non-exempt paths and applicable HTTP methods.
|
||||
Raises HTTPException with status code 403 if CSRF token is missing or invalid.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
# Define paths that don't need CSRF protection
|
||||
self.exempt_paths = [
|
||||
"/api/v1/token",
|
||||
"/api/v1/refresh",
|
||||
"/api/v1/mfa/verify",
|
||||
"/api/v1/password-reset/request",
|
||||
"/api/v1/password-reset/confirm",
|
||||
"/api/v1/sign-up/request",
|
||||
"/api/v1/sign-up/confirm",
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Middleware method to enforce CSRF protection for web clients.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
call_next (Callable): The next middleware or endpoint handler.
|
||||
|
||||
Returns:
|
||||
Response: The HTTP response after CSRF validation.
|
||||
|
||||
Behavior:
|
||||
- Skips CSRF checks for non-web clients (determined by "X-Client-Type" header).
|
||||
- Skips CSRF checks for exempt paths.
|
||||
- For web clients and non-exempt paths, validates CSRF token for POST, PUT, DELETE, and PATCH requests:
|
||||
- Requires both "endurain_csrf_token" cookie and "X-CSRF-Token" header.
|
||||
- Raises HTTPException 403 if tokens are missing or do not match.
|
||||
"""
|
||||
# Get client type from header
|
||||
client_type = request.headers.get("X-Client-Type")
|
||||
|
||||
# Skip CSRF checks for not web clients
|
||||
if client_type != "web":
|
||||
return await call_next(request)
|
||||
|
||||
# Skip CSRF check for exempt paths
|
||||
if request.url.path in self.exempt_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Check for CSRF token in POST, PUT, DELETE, and PATCH requests
|
||||
if request.method in ["POST", "PUT", "DELETE", "PATCH"]:
|
||||
csrf_cookie = request.cookies.get("endurain_csrf_token")
|
||||
csrf_header = request.headers.get("X-CSRF-Token")
|
||||
|
||||
if not csrf_cookie or not csrf_header:
|
||||
raise HTTPException(status_code=403, detail="CSRF token missing")
|
||||
|
||||
if csrf_cookie != csrf_header:
|
||||
raise HTTPException(status_code=403, detail="CSRF token invalid")
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
@@ -15,6 +15,7 @@ import activities.activity_streams.public_router as activity_streams_public_rout
|
||||
import activities.activity_summaries.router as activity_summaries_router
|
||||
import activities.activity_workout_steps.router as activity_workout_steps_router
|
||||
import activities.activity_workout_steps.public_router as activity_workout_steps_public_router
|
||||
import auth.router as auth_router
|
||||
import core.config as core_config
|
||||
import core.router as core_router
|
||||
import followers.router as followers_router
|
||||
@@ -23,15 +24,15 @@ import gears.gear.router as gears_router
|
||||
import gears.gear_components.router as gear_components_router
|
||||
import health_data.router as health_data_router
|
||||
import health_targets.router as health_targets_router
|
||||
import identity_providers.router as identity_providers_router
|
||||
import identity_providers.public_router as identity_providers_public_router
|
||||
import auth.identity_providers.router as identity_providers_router
|
||||
import auth.identity_providers.public_router as identity_providers_public_router
|
||||
import notifications.router as notifications_router
|
||||
import password_reset_tokens.router as password_reset_tokens_router
|
||||
import profile.router as profile_router
|
||||
import server_settings.public_router as server_settings_public_router
|
||||
import server_settings.router as server_settings_router
|
||||
import session.router as session_router
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
import sign_up_tokens.router as sign_up_tokens_router
|
||||
import strava.router as strava_router
|
||||
import users.user.router as users_router
|
||||
@@ -49,88 +50,93 @@ router.include_router(
|
||||
activities_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/activities",
|
||||
tags=["activities"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
activity_exercise_titles_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/activities_exercise_titles",
|
||||
tags=["activity_exercise_titles"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
activity_laps_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/activities_laps",
|
||||
tags=["activity_laps"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
activity_media_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/activities_media",
|
||||
tags=["activity_media"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
activity_sets_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/activities_sets",
|
||||
tags=["activity_sets"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
activity_streams_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/activities_streams",
|
||||
tags=["activity_streams"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
activity_workout_steps_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/activities_workout_steps",
|
||||
tags=["activity_workout_steps"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
activity_summaries_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/activities_summaries",
|
||||
tags=["summaries"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
auth_router.router,
|
||||
prefix=core_config.ROOT_PATH,
|
||||
tags=["auth"],
|
||||
)
|
||||
router.include_router(
|
||||
followers_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/followers",
|
||||
tags=["followers"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
garmin_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/garminconnect",
|
||||
tags=["garminconnect"],
|
||||
dependencies=[
|
||||
Depends(session_security.validate_access_token),
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Depends(auth_security.validate_access_token),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
)
|
||||
router.include_router(
|
||||
gear_components_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/gear_components",
|
||||
tags=["gear_components"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
gears_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/gears",
|
||||
tags=["gears"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
health_data_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/health",
|
||||
tags=["health"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
health_targets_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/health_targets",
|
||||
tags=["health_targets"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
identity_providers_router.router,
|
||||
@@ -142,8 +148,8 @@ router.include_router(
|
||||
prefix=core_config.ROOT_PATH + "/notifications",
|
||||
tags=["notifications"],
|
||||
dependencies=[
|
||||
Depends(session_security.validate_access_token),
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Depends(auth_security.validate_access_token),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
)
|
||||
router.include_router(
|
||||
@@ -156,20 +162,21 @@ router.include_router(
|
||||
prefix=core_config.ROOT_PATH + "/profile",
|
||||
tags=["profile"],
|
||||
dependencies=[
|
||||
Depends(session_security.validate_access_token),
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Depends(auth_security.validate_access_token),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
)
|
||||
router.include_router(
|
||||
server_settings_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/server_settings",
|
||||
tags=["server_settings"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
session_router.router,
|
||||
prefix=core_config.ROOT_PATH,
|
||||
prefix=core_config.ROOT_PATH + "/sessions",
|
||||
tags=["sessions"],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
sign_up_tokens_router.router,
|
||||
@@ -186,8 +193,8 @@ router.include_router(
|
||||
prefix=core_config.ROOT_PATH + "/profile/default_gear",
|
||||
tags=["profile"],
|
||||
dependencies=[
|
||||
Depends(session_security.validate_access_token),
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Depends(auth_security.validate_access_token),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
)
|
||||
router.include_router(
|
||||
@@ -195,29 +202,29 @@ router.include_router(
|
||||
prefix=core_config.ROOT_PATH + "/profile/goals",
|
||||
tags=["profile"],
|
||||
dependencies=[
|
||||
Depends(session_security.validate_access_token),
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Depends(auth_security.validate_access_token),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
)
|
||||
router.include_router(
|
||||
user_identity_providers_router.router,
|
||||
prefix=core_config.ROOT_PATH,
|
||||
tags=["user_identity_providers"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
users_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/users",
|
||||
tags=["users"],
|
||||
dependencies=[Depends(session_security.validate_access_token)],
|
||||
dependencies=[Depends(auth_security.validate_access_token)],
|
||||
)
|
||||
router.include_router(
|
||||
websocket_router.router,
|
||||
prefix=core_config.ROOT_PATH + "/ws",
|
||||
tags=["websocket"],
|
||||
# dependencies=[
|
||||
# Depends(session_security.validate_access_token),
|
||||
# Security(session_security.check_scopes, scopes=["profile"]),
|
||||
# Depends(auth_security.validate_access_token),
|
||||
# Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
# ],
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import followers.crud as followers_crud
|
||||
|
||||
import users.user.dependencies as users_dependencies
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -28,7 +28,7 @@ async def get_user_follower_all(
|
||||
user_id: int,
|
||||
validate_user_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -47,7 +47,7 @@ async def get_user_follower_count_all(
|
||||
user_id: int,
|
||||
validate_user_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -73,7 +73,7 @@ async def get_user_follower_count(
|
||||
user_id: int,
|
||||
validate_user_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -99,7 +99,7 @@ async def get_user_following_all(
|
||||
user_id: int,
|
||||
validate_user_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -118,7 +118,7 @@ async def get_user_following_count_all(
|
||||
user_id: int,
|
||||
validate_user_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -144,7 +144,7 @@ async def get_user_following_count(
|
||||
user_id: int,
|
||||
validate_user_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -174,7 +174,7 @@ async def read_followers_user_specific_user(
|
||||
Callable, Depends(users_dependencies.validate_target_user_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -201,10 +201,10 @@ async def create_follow(
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["profile"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["profile"])
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
@@ -233,10 +233,10 @@ async def accept_follow(
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["profile"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["profile"])
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
@@ -268,10 +268,10 @@ async def delete_follower(
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["profile"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["profile"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -297,10 +297,10 @@ async def delete_following(
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["profile"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["profile"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timezone, date
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import users.user_integrations.crud as user_integrations_crud
|
||||
|
||||
@@ -31,7 +31,7 @@ async def garminconnect_link(
|
||||
garmin_user: garmin_schema.GarminLogin,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
mfa_codes: Annotated[
|
||||
@@ -61,7 +61,7 @@ async def garminconnect_mfa_code(
|
||||
mfa_request: garmin_schema.MFARequest,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
mfa_codes: Annotated[
|
||||
garmin_schema.MFACodeStore, Depends(garmin_schema.get_mfa_store)
|
||||
@@ -81,7 +81,7 @@ async def garminconnect_retrieve_activities_days(
|
||||
end_date: date,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
websocket_manager: Annotated[
|
||||
@@ -90,7 +90,9 @@ async def garminconnect_retrieve_activities_days(
|
||||
],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
start_datetime = datetime.combine(start_date, datetime.min.time(), tzinfo=timezone.utc)
|
||||
start_datetime = datetime.combine(
|
||||
start_date, datetime.min.time(), tzinfo=timezone.utc
|
||||
)
|
||||
end_datetime = datetime.combine(end_date, datetime.max.time(), tzinfo=timezone.utc)
|
||||
|
||||
# Process Garmin Connect activities in the background
|
||||
@@ -116,7 +118,7 @@ async def garminconnect_retrieve_activities_days(
|
||||
async def garminconnect_retrieve_gear(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
@@ -144,12 +146,14 @@ async def garminconnect_retrieve_health_days(
|
||||
end_date: date,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
# db: Annotated[Session, Depends(core_database.get_db)],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
start_datetime = datetime.combine(start_date, datetime.min.time(), tzinfo=timezone.utc)
|
||||
start_datetime = datetime.combine(
|
||||
start_date, datetime.min.time(), tzinfo=timezone.utc
|
||||
)
|
||||
end_datetime = datetime.combine(end_date, datetime.max.time(), tzinfo=timezone.utc)
|
||||
|
||||
# Process Garmin Connect activities in the background
|
||||
@@ -173,7 +177,7 @@ async def garminconnect_retrieve_health_days(
|
||||
async def garminconnect_unlink(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Annotated, Callable
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import gears.gear.schema as gears_schema
|
||||
import gears.gear.crud as gears_crud
|
||||
@@ -21,11 +21,9 @@ router = APIRouter()
|
||||
)
|
||||
async def read_gears(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
# Return the gear
|
||||
@@ -40,11 +38,9 @@ async def read_gear_id(
|
||||
gear_id: int,
|
||||
validate_gear_id: Annotated[Callable, Depends(gears_dependencies.validate_gear_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
# Return the gear
|
||||
@@ -59,11 +55,9 @@ async def read_gear_user_pagination(
|
||||
page_number: int,
|
||||
num_records: int,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -81,11 +75,9 @@ async def read_gear_user_pagination(
|
||||
)
|
||||
async def read_gear_user_number(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -109,11 +101,9 @@ async def read_gear_user_number(
|
||||
async def read_gear_user_contains_nickname(
|
||||
nickname: str,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -130,11 +120,9 @@ async def read_gear_user_contains_nickname(
|
||||
async def read_gear_user_by_nickname(
|
||||
nickname: str,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -152,11 +140,9 @@ async def read_gear_user_by_type(
|
||||
gear_type: int,
|
||||
validate_type: Annotated[Callable, Depends(gears_dependencies.validate_gear_type)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -174,11 +160,9 @@ async def read_gear_user_by_type(
|
||||
async def create_gear(
|
||||
gear: gears_schema.Gear,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -194,11 +178,9 @@ async def edit_gear(
|
||||
validate_id: Annotated[Callable, Depends(gears_dependencies.validate_gear_id)],
|
||||
gear: gears_schema.Gear,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -232,11 +214,9 @@ async def delete_gear(
|
||||
gear_id: int,
|
||||
validate_id: Annotated[Callable, Depends(gears_dependencies.validate_gear_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
|
||||
@@ -3,7 +3,7 @@ from urllib.parse import unquote
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
import core.database as core_database
|
||||
|
||||
import gears.gear.models as gears_models
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Annotated, Callable
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import gears.gear_components.schema as gears_components_schema
|
||||
import gears.gear_components.crud as gears_components_crud
|
||||
@@ -22,11 +22,9 @@ router = APIRouter()
|
||||
)
|
||||
async def read_gear_components(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
# Return the gear_components
|
||||
@@ -41,11 +39,9 @@ async def read_gear_components_gear_id(
|
||||
gear_id: int,
|
||||
validate_gear_id: Annotated[Callable, Depends(gears_dependencies.validate_gear_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:read"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
# Return the gear
|
||||
@@ -62,14 +58,12 @@ async def read_gear_components_gear_id(
|
||||
async def create_gear_component(
|
||||
gear_component: gears_components_schema.GearComponents,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
verify_gear_type: Annotated[
|
||||
Callable, Security(gears_components_dependencies.validate_gear_component_type)
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -85,11 +79,9 @@ async def create_gear_component(
|
||||
async def edit_gear_component(
|
||||
gear_component: gears_components_schema.GearComponents,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -139,11 +131,9 @@ async def delete_component_gear(
|
||||
Callable, Depends(gears_components_dependencies.validate_gear_component_id)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
Callable, Security(auth_security.check_scopes, scopes=["gears:write"])
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
|
||||
import health_data.schema as health_data_schema
|
||||
import health_data.crud as health_data_crud
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
import core.dependencies as core_dependencies
|
||||
@@ -23,11 +23,11 @@ router = APIRouter()
|
||||
)
|
||||
async def read_health_data_number(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["health:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["health:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -44,11 +44,11 @@ async def read_health_data_number(
|
||||
)
|
||||
async def read_health_data_all(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["health:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["health:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -67,14 +67,14 @@ async def read_health_data_all_pagination(
|
||||
page_number: int,
|
||||
num_records: int,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["health:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["health:read"])
|
||||
],
|
||||
validate_pagination_values: Annotated[
|
||||
Callable, Depends(core_dependencies.validate_pagination_values)
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -91,11 +91,11 @@ async def read_health_data_all_pagination(
|
||||
async def create_health_data(
|
||||
health_data: health_data_schema.HealthData,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["health:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["health:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -120,11 +120,11 @@ async def create_health_data(
|
||||
async def edit_health_data(
|
||||
health_data: health_data_schema.HealthData,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["health:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["health:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -139,11 +139,11 @@ async def edit_health_data(
|
||||
async def delete_health_data(
|
||||
health_data_id: int,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["health:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["health:write"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
import health_targets.schema as health_targets_schema
|
||||
import health_targets.crud as health_targets_crud
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -20,11 +20,11 @@ router = APIRouter()
|
||||
)
|
||||
async def read_health_data_all_pagination(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["health:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["health:read"])
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Identity Provider utility functions and templates"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import identity_providers.schema as idp_schema
|
||||
|
||||
|
||||
# Pre-configured templates for common IdPs
|
||||
IDP_TEMPLATES = {
|
||||
"keycloak": {
|
||||
"name": "Keycloak",
|
||||
"provider_type": "oidc",
|
||||
"issuer_url": "https://{your-keycloak-domain}/realms/{realm}",
|
||||
"scopes": "openid profile email",
|
||||
"icon": "keycloak",
|
||||
"user_mapping": {
|
||||
"username": ["preferred_username", "username", "email"],
|
||||
"email": ["email", "mail"],
|
||||
"name": ["name", "display_name", "full_name"],
|
||||
},
|
||||
"description": "Keycloak - Open Source Identity and Access Management",
|
||||
"configuration_notes": "Replace {your-keycloak-domain} with your Keycloak server domain (e.g., keycloak.example.com) and {realm} with your realm name. Create an OIDC client in Keycloak admin console.",
|
||||
},
|
||||
"authentik": {
|
||||
"name": "Authentik",
|
||||
"provider_type": "oidc",
|
||||
"issuer_url": "https://{your-authentik-domain}/application/o/{slug}/",
|
||||
"scopes": "openid profile email",
|
||||
"icon": "authentik",
|
||||
"user_mapping": {
|
||||
"username": ["preferred_username", "username", "email"],
|
||||
"email": ["email", "mail"],
|
||||
"name": ["name", "display_name"],
|
||||
},
|
||||
"description": "Authentik - Open-source Identity Provider",
|
||||
"configuration_notes": "Replace {your-authentik-domain} with your Authentik server domain (e.g., authentik.example.com) and {slug} with your application slug. Create an OAuth2/OIDC provider in Authentik.",
|
||||
},
|
||||
"authelia": {
|
||||
"name": "Authelia",
|
||||
"provider_type": "oidc",
|
||||
"issuer_url": "https://{your-authelia-domain}",
|
||||
"scopes": "openid profile email",
|
||||
"icon": "authelia",
|
||||
"user_mapping": {
|
||||
"username": ["preferred_username", "username", "email"],
|
||||
"email": ["email"],
|
||||
"name": ["name"],
|
||||
},
|
||||
"description": "Authelia - Open-source authentication and authorization server",
|
||||
"configuration_notes": "Replace {your-authelia-domain} with your Authelia server domain (e.g., auth.example.com). Configure an OIDC client in your Authelia configuration file.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_idp_templates() -> List[idp_schema.IdentityProviderTemplate]:
|
||||
"""
|
||||
Retrieve a list of identity provider templates, excluding specific providers.
|
||||
|
||||
Returns:
|
||||
List[idp_schema.IdentityProviderTemplate]:
|
||||
A list of IdentityProviderTemplate objects for all identity providers.
|
||||
"""
|
||||
templates = []
|
||||
for template_id, template_data in IDP_TEMPLATES.items():
|
||||
templates.append(
|
||||
idp_schema.IdentityProviderTemplate(
|
||||
template_id=template_id, **template_data
|
||||
)
|
||||
)
|
||||
return templates
|
||||
|
||||
|
||||
def get_idp_template(template_id: str) -> Dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve an identity provider template by its template ID.
|
||||
|
||||
Args:
|
||||
template_id (str): The unique identifier of the identity provider template.
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: The template dictionary if found, otherwise None.
|
||||
"""
|
||||
return IDP_TEMPLATES.get(template_id)
|
||||
@@ -12,14 +12,13 @@ import core.logger as core_logger
|
||||
import core.config as core_config
|
||||
import core.scheduler as core_scheduler
|
||||
import core.tracing as core_tracing
|
||||
import core.middleware as core_middleware
|
||||
import core.migrations as core_migrations
|
||||
import core.rate_limit as core_rate_limit
|
||||
|
||||
import garmin.activity_utils as garmin_activity_utils
|
||||
import garmin.health_utils as garmin_health_utils
|
||||
|
||||
import session.schema as session_schema
|
||||
|
||||
import strava.activity_utils as strava_activity_utils
|
||||
import strava.utils as strava_utils
|
||||
|
||||
@@ -131,13 +130,12 @@ def create_app() -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
fastapi_app.add_middleware(session_schema.CSRFMiddleware)
|
||||
fastapi_app.add_middleware(core_middleware.CSRFMiddleware)
|
||||
|
||||
# Add rate limiting
|
||||
fastapi_app.state.limiter = core_rate_limit.limiter
|
||||
fastapi_app.add_exception_handler(
|
||||
core_rate_limit.RateLimitExceeded,
|
||||
core_rate_limit.rate_limit_exceeded_handler
|
||||
core_rate_limit.RateLimitExceeded, core_rate_limit.rate_limit_exceeded_handler
|
||||
)
|
||||
|
||||
# Router files
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Annotated, Callable
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import notifications.dependencies as notifications_dependencies
|
||||
import notifications.crud as notifications_crud
|
||||
@@ -21,9 +21,7 @@ router = APIRouter()
|
||||
response_model=int,
|
||||
)
|
||||
async def read_notifications_number(
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -54,9 +52,7 @@ async def read_notifications_by_id(
|
||||
validate_notification_id: Annotated[
|
||||
Callable, Depends(notifications_dependencies.validate_notification_id)
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -88,9 +84,7 @@ async def read_notifications_user_pagination(
|
||||
validate_pagination_values: Annotated[
|
||||
Callable, Depends(core_dependencies.validate_pagination_values)
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -122,9 +116,7 @@ async def mark_notification_as_read(
|
||||
validate_notification_id: Annotated[
|
||||
Callable, Depends(notifications_dependencies.validate_notification_id)
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
|
||||
import password_reset_tokens.schema as password_reset_tokens_schema
|
||||
import password_reset_tokens.utils as password_reset_tokens_utils
|
||||
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import core.database as core_database
|
||||
import core.apprise as core_apprise
|
||||
@@ -92,8 +92,8 @@ async def request_password_reset(
|
||||
async def confirm_password_reset(
|
||||
confirm_data: password_reset_tokens_schema.PasswordResetConfirm,
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -104,11 +104,11 @@ async def confirm_password_reset(
|
||||
Confirms a password reset using the provided token and new password.
|
||||
|
||||
Args:
|
||||
confirm_data (password_reset_tokens_schema.PasswordResetConfirm):
|
||||
confirm_data (password_reset_tokens_schema.PasswordResetConfirm):
|
||||
Data containing the password reset token and the new password.
|
||||
password_hasher (session_password_hasher.PasswordHasher):
|
||||
password_hasher (auth_password_hasher.PasswordHasher):
|
||||
An instance of the password hasher to use for hashing the new password.
|
||||
db (Session):
|
||||
db (Session):
|
||||
Database session dependency.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -14,7 +14,7 @@ import password_reset_tokens.crud as password_reset_tokens_crud
|
||||
|
||||
import users.user.crud as users_crud
|
||||
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import core.apprise as core_apprise
|
||||
import core.logger as core_logger
|
||||
@@ -176,7 +176,7 @@ async def send_password_reset_email(
|
||||
def use_password_reset_token(
|
||||
token: str,
|
||||
new_password: str,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
):
|
||||
"""
|
||||
@@ -195,7 +195,7 @@ def use_password_reset_token(
|
||||
function will hash it before database lookup.
|
||||
- new_password (str): The new plain-text password to set for the user. Password
|
||||
validation/hashing is expected to be handled by the underlying users_crud.
|
||||
- password_hasher (session_password_hasher.PasswordHasher): An instance of the
|
||||
- password_hasher (auth_password_hasher.PasswordHasher): An instance of the
|
||||
password hasher to use when updating the user's password.
|
||||
- db (Session): An active SQLAlchemy Session (or equivalent) used for DB operations.
|
||||
Transaction management (commit/rollback) is expected to be handled by the caller
|
||||
|
||||
@@ -28,9 +28,11 @@ import gears.gear.crud as gear_crud
|
||||
import gears.gear_components.crud as gear_components_crud
|
||||
import health_data.crud as health_data_crud
|
||||
import health_targets.crud as health_targets_crud
|
||||
import notifications.crud as notifications_crud
|
||||
import users.user_default_gear.crud as user_default_gear_crud
|
||||
import users.user_integrations.crud as user_integrations_crud
|
||||
import users.user_goals.crud as user_goals_crud
|
||||
import users.user_identity_providers.crud as user_identity_providers_crud
|
||||
import users.user_integrations.crud as user_integrations_crud
|
||||
import users.user_privacy_settings.crud as users_privacy_settings_crud
|
||||
|
||||
|
||||
@@ -721,6 +723,41 @@ class ExportService:
|
||||
f"Failed to collect health data: {err}"
|
||||
) from err
|
||||
|
||||
def collect_notifications_data(self, zipf: zipfile.ZipFile) -> None:
|
||||
try:
|
||||
try:
|
||||
notifications = notifications_crud.get_user_notifications(
|
||||
self.user_id, self.db
|
||||
)
|
||||
if notifications:
|
||||
notifications_dicts = [
|
||||
profile_utils.sqlalchemy_obj_to_dict(n) for n in notifications
|
||||
]
|
||||
profile_utils.write_json_to_zip(
|
||||
zipf,
|
||||
"data/notifications.json",
|
||||
notifications_dicts,
|
||||
self.counts,
|
||||
)
|
||||
else:
|
||||
profile_utils.write_json_to_zip(
|
||||
zipf, "data/notifications.json", [], self.counts
|
||||
)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Failed to collect user notifications: {err}", "warning", exc=err
|
||||
)
|
||||
profile_utils.write_json_to_zip(
|
||||
zipf, "data/notifications.json", [], self.counts
|
||||
)
|
||||
except SQLAlchemyError as err:
|
||||
core_logger.print_to_log(
|
||||
f"Database error collecting user notifications: {err}", "error", exc=err
|
||||
)
|
||||
raise DatabaseConnectionError(
|
||||
f"Failed to collect user notifications: {err}"
|
||||
) from err
|
||||
|
||||
def collect_user_settings_data(self, zipf: zipfile.ZipFile) -> None:
|
||||
"""
|
||||
Collect and write user settings to ZIP.
|
||||
@@ -785,6 +822,38 @@ class ExportService:
|
||||
zipf, "data/user_goals.json", [], self.counts
|
||||
)
|
||||
|
||||
# Collect and write user identity providers
|
||||
try:
|
||||
user_identity_providers = (
|
||||
user_identity_providers_crud.get_user_identity_providers_by_user_id(
|
||||
self.user_id, self.db
|
||||
)
|
||||
)
|
||||
if user_identity_providers:
|
||||
identity_providers_dict = [
|
||||
profile_utils.sqlalchemy_obj_to_dict(uidp)
|
||||
for uidp in user_identity_providers
|
||||
]
|
||||
profile_utils.write_json_to_zip(
|
||||
zipf,
|
||||
"data/user_identity_providers.json",
|
||||
identity_providers_dict,
|
||||
self.counts,
|
||||
)
|
||||
else:
|
||||
profile_utils.write_json_to_zip(
|
||||
zipf, "data/user_identity_providers.json", [], self.counts
|
||||
)
|
||||
except Exception as err:
|
||||
core_logger.print_to_log(
|
||||
f"Failed to collect user identity providers: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
profile_utils.write_json_to_zip(
|
||||
zipf, "data/user_identity_providers.json", [], self.counts
|
||||
)
|
||||
|
||||
# Collect and write user integrations
|
||||
try:
|
||||
user_integrations = (
|
||||
@@ -1168,6 +1237,15 @@ class ExportService:
|
||||
)
|
||||
self.collect_health_data(zipf)
|
||||
|
||||
# Collect and write notifications progressively
|
||||
profile_utils.check_timeout(
|
||||
timeout_seconds, start_time, ExportTimeoutError, "Export"
|
||||
)
|
||||
core_logger.print_to_log(
|
||||
"Collecting and writing notifications data...", "info"
|
||||
)
|
||||
self.collect_notifications_data(zipf)
|
||||
|
||||
# Collect and write settings data progressively
|
||||
profile_utils.check_timeout(
|
||||
timeout_seconds, start_time, ExportTimeoutError, "Export"
|
||||
|
||||
@@ -23,15 +23,18 @@ import profile.utils as profile_utils
|
||||
import users.user.crud as users_crud
|
||||
import users.user.schema as users_schema
|
||||
|
||||
import users.user_integrations.crud as user_integrations_crud
|
||||
import users.user_integrations.schema as users_integrations_schema
|
||||
|
||||
import users.user_default_gear.crud as user_default_gear_crud
|
||||
import users.user_default_gear.schema as user_default_gear_schema
|
||||
|
||||
import users.user_goals.crud as user_goals_crud
|
||||
import users.user_goals.schema as user_goals_schema
|
||||
|
||||
import users.user_identity_providers.crud as user_identity_providers_crud
|
||||
import users.user_identity_providers.schema as user_identity_providers_schema
|
||||
|
||||
import users.user_integrations.crud as user_integrations_crud
|
||||
import users.user_integrations.schema as users_integrations_schema
|
||||
|
||||
import users.user_privacy_settings.crud as users_privacy_settings_crud
|
||||
import users.user_privacy_settings.schema as users_privacy_settings_schema
|
||||
|
||||
@@ -61,6 +64,9 @@ import gears.gear.schema as gear_schema
|
||||
import gears.gear_components.crud as gear_components_crud
|
||||
import gears.gear_components.schema as gear_components_schema
|
||||
|
||||
import notifications.crud as notifications_crud
|
||||
import notifications.schema as notifications_schema
|
||||
|
||||
import health_data.crud as health_data_crud
|
||||
import health_data.schema as health_data_schema
|
||||
|
||||
@@ -252,10 +258,13 @@ class ImportService:
|
||||
user_default_gear_data = self._load_single_json(
|
||||
zipf, "data/user_default_gear.json"
|
||||
)
|
||||
user_goals_data = self._load_single_json(zipf, "data/user_goals.json")
|
||||
user_identity_providers_data = self._load_single_json(
|
||||
zipf, "data/user_identity_providers.json"
|
||||
)
|
||||
user_integrations_data = self._load_single_json(
|
||||
zipf, "data/user_integrations.json"
|
||||
)
|
||||
user_goals_data = self._load_single_json(zipf, "data/user_goals.json")
|
||||
user_privacy_settings_data = self._load_single_json(
|
||||
zipf, "data/user_privacy_settings.json"
|
||||
)
|
||||
@@ -263,8 +272,9 @@ class ImportService:
|
||||
await self.collect_and_import_user_data(
|
||||
user_data,
|
||||
user_default_gear_data,
|
||||
user_integrations_data,
|
||||
user_goals_data,
|
||||
user_identity_providers_data,
|
||||
user_integrations_data,
|
||||
user_privacy_settings_data,
|
||||
gears_id_mapping,
|
||||
)
|
||||
@@ -287,6 +297,17 @@ class ImportService:
|
||||
)
|
||||
)
|
||||
|
||||
# Load and import notifications
|
||||
profile_utils.check_timeout(
|
||||
timeout_seconds, start_time, ImportTimeoutError, "Import"
|
||||
)
|
||||
notifications_data = self._load_single_json(
|
||||
zipf, "data/notifications.json"
|
||||
)
|
||||
|
||||
await self.collect_and_import_notifications_data(notifications_data)
|
||||
del notifications_data
|
||||
|
||||
# Load and import health data
|
||||
profile_utils.check_timeout(
|
||||
timeout_seconds, start_time, ImportTimeoutError, "Import"
|
||||
@@ -432,8 +453,9 @@ class ImportService:
|
||||
self,
|
||||
user_data: list[Any],
|
||||
user_default_gear_data: list[Any],
|
||||
user_integrations_data: list[Any],
|
||||
user_goals_data: list[Any],
|
||||
user_identity_providers_data: list[Any],
|
||||
user_integrations_data: list[Any],
|
||||
user_privacy_settings_data: list[Any],
|
||||
gears_id_mapping: dict[int, int],
|
||||
) -> None:
|
||||
@@ -443,8 +465,9 @@ class ImportService:
|
||||
Args:
|
||||
user_data: User profile data.
|
||||
user_default_gear_data: Default gear settings.
|
||||
user_integrations_data: Integration settings.
|
||||
user_goals_data: User goals data.
|
||||
user_identity_providers_data: Identity providers data.
|
||||
user_integrations_data: Integration settings.
|
||||
user_privacy_settings_data: Privacy settings.
|
||||
gears_id_mapping: Mapping of old to new gear IDs.
|
||||
"""
|
||||
@@ -470,8 +493,11 @@ class ImportService:
|
||||
await self.collect_and_import_user_default_gear(
|
||||
user_default_gear_data, gears_id_mapping
|
||||
)
|
||||
await self.collect_and_import_user_integrations(user_integrations_data)
|
||||
await self.collect_and_import_user_goals(user_goals_data)
|
||||
await self.collect_and_import_user_identity_providers(
|
||||
user_identity_providers_data
|
||||
)
|
||||
await self.collect_and_import_user_integrations(user_integrations_data)
|
||||
await self.collect_and_import_user_privacy_settings(user_privacy_settings_data)
|
||||
|
||||
async def collect_and_import_user_default_gear(
|
||||
@@ -585,6 +611,29 @@ class ImportService:
|
||||
f"Imported {self.counts['user_goals']} user goals", "info"
|
||||
)
|
||||
|
||||
async def collect_and_import_user_identity_providers(
|
||||
self, user_identity_providers_data: list[Any]
|
||||
) -> None:
|
||||
if not user_identity_providers_data:
|
||||
core_logger.print_to_log(
|
||||
"No user identity providers data to import", "info"
|
||||
)
|
||||
return
|
||||
|
||||
for provider_data in user_identity_providers_data:
|
||||
provider_data.pop("id", None)
|
||||
provider_data.pop("user_id", None)
|
||||
|
||||
user_identity_providers_crud.create_user_identity_provider(
|
||||
self.user_id, provider_data.id, provider_data.idp_subject, self.db
|
||||
)
|
||||
self.counts["user_identity_providers"] += 1
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Imported {self.counts['user_identity_providers']} user identity providers",
|
||||
"info",
|
||||
)
|
||||
|
||||
async def collect_and_import_user_privacy_settings(
|
||||
self, user_privacy_settings_data: list[Any]
|
||||
) -> None:
|
||||
@@ -987,6 +1036,25 @@ class ImportService:
|
||||
|
||||
return all_components
|
||||
|
||||
async def collect_and_import_notifications_data(
|
||||
self, notifications_data: list[Any]
|
||||
) -> None:
|
||||
if not notifications_data:
|
||||
core_logger.print_to_log("No notifications data to import", "info")
|
||||
return
|
||||
|
||||
for notification_data in notifications_data:
|
||||
notification_data["user_id"] = self.user_id
|
||||
notification_data.pop("id", None)
|
||||
|
||||
notification = notifications_schema.Notification(**notification_data)
|
||||
notifications_crud.create_notification(notification, self.db)
|
||||
self.counts["notifications"] += 1
|
||||
|
||||
core_logger.print_to_log(
|
||||
f"Imported {self.counts['notifications']} notifications", "info"
|
||||
)
|
||||
|
||||
async def collect_and_import_health_data(
|
||||
self, health_data_data: list[Any], health_targets_data: list[Any]
|
||||
) -> None:
|
||||
|
||||
@@ -11,8 +11,8 @@ import users.user.utils as users_utils
|
||||
import users.user_identity_providers.crud as user_idp_crud
|
||||
import users.user_identity_providers.schema as user_idp_schema
|
||||
|
||||
import identity_providers.crud as idp_crud
|
||||
import identity_providers.service as idp_service
|
||||
import auth.identity_providers.crud as idp_crud
|
||||
import auth.identity_providers.service as idp_service
|
||||
|
||||
import users.user_integrations.crud as user_integrations_crud
|
||||
|
||||
@@ -25,9 +25,9 @@ import profile.export_service as profile_export_service
|
||||
import profile.import_service as profile_import_service
|
||||
import profile.exceptions as profile_exceptions
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
import session.crud as session_crud
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import core.database as core_database
|
||||
import core.logger as core_logger
|
||||
@@ -48,7 +48,7 @@ file_validator = FileValidator()
|
||||
async def read_users_me(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -116,7 +116,7 @@ async def read_users_me(
|
||||
async def read_sessions_me(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -146,7 +146,7 @@ async def upload_profile_image(
|
||||
file: UploadFile,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -184,7 +184,7 @@ async def edit_user(
|
||||
user_attributtes: users_schema.UserRead,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -214,7 +214,7 @@ async def edit_profile_privacy_settings(
|
||||
user_privacy_settings: users_privacy_settings_schema.UsersPrivacySettings,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -246,11 +246,11 @@ async def edit_profile_password(
|
||||
user_attributtes: users_schema.UserEditPassword,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -263,7 +263,7 @@ async def edit_profile_password(
|
||||
Args:
|
||||
user_attributtes (users_schema.UserEditPassword): Schema containing the new password.
|
||||
token_user_id (int): ID of the user extracted from the access token.
|
||||
password_hasher (session_password_hasher.PasswordHasher): Password hasher dependency.
|
||||
password_hasher (auth_password_hasher.PasswordHasher): Password hasher dependency.
|
||||
db (Session): Database session dependency.
|
||||
|
||||
Returns:
|
||||
@@ -282,7 +282,7 @@ async def edit_profile_password(
|
||||
async def delete_profile_photo(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -311,7 +311,7 @@ async def delete_profile_session(
|
||||
session_id: str,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -340,7 +340,7 @@ async def delete_profile_session(
|
||||
async def export_profile_data(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -426,7 +426,7 @@ async def import_profile_data(
|
||||
file: UploadFile,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
websocket_manager: Annotated[
|
||||
@@ -562,7 +562,7 @@ async def import_profile_data(
|
||||
async def get_mfa_status(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
@@ -584,7 +584,7 @@ async def get_mfa_status(
|
||||
async def setup_mfa(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
mfa_secret_store: Annotated[
|
||||
@@ -615,7 +615,7 @@ async def enable_mfa(
|
||||
request: profile_schema.MFASetupRequest,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
mfa_secret_store: Annotated[
|
||||
@@ -661,7 +661,7 @@ async def disable_mfa(
|
||||
request: profile_schema.MFADisableRequest,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
@@ -685,7 +685,7 @@ async def verify_mfa(
|
||||
request: profile_schema.MFARequest,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
@@ -720,7 +720,7 @@ async def verify_mfa(
|
||||
async def get_my_identity_providers(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -755,7 +755,7 @@ async def get_my_identity_providers(
|
||||
HTTPException: May raise authentication/authorization errors via the dependency injection.
|
||||
"""
|
||||
# Get user's IdP links
|
||||
idp_links = user_idp_crud.get_user_idp_links(token_user_id, db)
|
||||
idp_links = user_idp_crud.get_user_identity_providers_by_user_id(token_user_id, db)
|
||||
|
||||
# Enrich with IDP details (reuse logic from admin endpoint)
|
||||
enriched_links = []
|
||||
@@ -792,7 +792,7 @@ async def delete_my_identity_provider(
|
||||
idp_id: int,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -837,7 +837,9 @@ async def delete_my_identity_provider(
|
||||
)
|
||||
|
||||
# Check if link exists for this user
|
||||
link = user_idp_crud.get_user_idp_link(token_user_id, idp_id, db)
|
||||
link = user_idp_crud.get_user_identity_provider_by_user_id_and_idp_id(
|
||||
token_user_id, idp_id, db
|
||||
)
|
||||
if not link:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -849,7 +851,9 @@ async def delete_my_identity_provider(
|
||||
user = users_crud.get_user_by_id(token_user_id, db)
|
||||
|
||||
# Count remaining IdP links after deletion
|
||||
all_idp_links = user_idp_crud.get_user_idp_links(token_user_id, db)
|
||||
all_idp_links = user_idp_crud.get_user_identity_providers_by_user_id(
|
||||
token_user_id, db
|
||||
)
|
||||
remaining_idp_count = len(all_idp_links) - 1
|
||||
|
||||
# User must have either:
|
||||
@@ -864,7 +868,7 @@ async def delete_my_identity_provider(
|
||||
)
|
||||
|
||||
# Proceed with deletion
|
||||
success = user_idp_crud.delete_user_idp_link(token_user_id, idp_id, db)
|
||||
success = user_idp_crud.delete_user_identity_provider(token_user_id, idp_id, db)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@@ -890,7 +894,7 @@ async def link_identity_provider(
|
||||
request: Request,
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -930,7 +934,9 @@ async def link_identity_provider(
|
||||
)
|
||||
|
||||
# Check if already linked
|
||||
existing_link = user_idp_crud.get_user_idp_link(token_user_id, idp_id, db)
|
||||
existing_link = user_idp_crud.get_user_identity_provider_by_user_id_and_idp_id(
|
||||
token_user_id, idp_id, db
|
||||
)
|
||||
if existing_link:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
|
||||
@@ -498,11 +498,13 @@ def initialize_operation_counts(include_user_count: bool = False) -> Dict[str, i
|
||||
"gear_components": 0,
|
||||
"health_data": 0,
|
||||
"health_targets": 0,
|
||||
"notifications": 0,
|
||||
"user_images": 0,
|
||||
"user": 1 if include_user_count else 0,
|
||||
"user_default_gear": 0,
|
||||
"user_integrations": 0,
|
||||
"user_goals": 0,
|
||||
"user_identity_providers": 0,
|
||||
"user_integrations": 0,
|
||||
"user_privacy_settings": 0,
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import server_settings.schema as server_settings_schema
|
||||
import server_settings.crud as server_settings_crud
|
||||
import server_settings.utils as server_settings_utils
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
import core.logger as core_logger
|
||||
@@ -22,7 +22,7 @@ router = APIRouter()
|
||||
async def read_server_settings(
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["server_settings:read"]),
|
||||
Security(auth_security.check_scopes, scopes=["server_settings:read"]),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -38,7 +38,7 @@ async def edit_server_settings(
|
||||
server_settings_attributtes: server_settings_schema.ServerSettingsEdit,
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["server_settings:write"]),
|
||||
Security(auth_security.check_scopes, scopes=["server_settings:write"]),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -57,7 +57,7 @@ async def upload_login_photo(
|
||||
file: UploadFile,
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["server_settings:write"]),
|
||||
Security(auth_security.check_scopes, scopes=["server_settings:write"]),
|
||||
],
|
||||
):
|
||||
try:
|
||||
@@ -88,7 +88,7 @@ async def upload_login_photo(
|
||||
async def delete_login_photo(
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["server_settings:write"]),
|
||||
Security(auth_security.check_scopes, scopes=["server_settings:write"]),
|
||||
],
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -3,411 +3,24 @@ from typing import Annotated, Callable
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
Response,
|
||||
Request,
|
||||
Security,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.utils as session_utils
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
import session.crud as session_crud
|
||||
import session.schema as session_schema
|
||||
import session.password_hasher as session_password_hasher
|
||||
import session.token_manager as session_token_manager
|
||||
|
||||
import users.user.crud as users_crud
|
||||
import users.user.utils as users_utils
|
||||
import profile.utils as profile_utils
|
||||
|
||||
import core.database as core_database
|
||||
import core.rate_limit as core_rate_limit
|
||||
|
||||
# Define the API router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
@core_rate_limit.limiter.limit(core_rate_limit.SESSION_LOGIN_LIMIT)
|
||||
async def login_for_access_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
client_type: Annotated[str, Depends(session_security.header_client_type_scheme)],
|
||||
pending_mfa_store: Annotated[
|
||||
session_schema.PendingMFALogin, Depends(session_schema.get_pending_mfa_store)
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Handles user login and access token generation, including Multi-Factor Authentication (MFA) flow.
|
||||
|
||||
Rate Limit: 5 requests per minute per IP
|
||||
|
||||
This endpoint authenticates a user using provided credentials, checks if the user is active,
|
||||
and determines if MFA is required. If MFA is enabled for the user, it stores the pending login
|
||||
and returns an MFA-required response. Otherwise, it completes the login process and returns
|
||||
the required information.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object
|
||||
request: The HTTP request object
|
||||
form_data: Form data containing username and password
|
||||
client_type: The type of client making the request ("web" or "mobile")
|
||||
pending_mfa_store: Store for pending MFA logins
|
||||
password_hasher: The password hasher instance used for verifying passwords
|
||||
token_manager: The token manager instance used for token operations
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Union[session_schema.MFARequiredResponse, dict, str]:
|
||||
- If MFA is required, returns an MFA-required response (schema or dict depending on client type)
|
||||
- If MFA is not required, proceeds with normal login via session_utils.complete_login()
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails or the user is inactive
|
||||
"""
|
||||
user = session_utils.authenticate_user(
|
||||
form_data.username, form_data.password, password_hasher, db
|
||||
)
|
||||
|
||||
# Check if the user is active
|
||||
users_utils.check_user_is_active(user)
|
||||
|
||||
# Check if MFA is enabled for this user
|
||||
if profile_utils.is_mfa_enabled_for_user(user.id, db):
|
||||
# Store the user for pending MFA verification
|
||||
pending_mfa_store.add_pending_login(form_data.username, user.id)
|
||||
|
||||
# Return MFA required response
|
||||
if client_type == "web":
|
||||
response.status_code = status.HTTP_202_ACCEPTED
|
||||
return session_schema.MFARequiredResponse(
|
||||
mfa_required=True,
|
||||
username=form_data.username,
|
||||
message="MFA verification required",
|
||||
)
|
||||
if client_type == "mobile":
|
||||
return {
|
||||
"mfa_required": True,
|
||||
"username": form_data.username,
|
||||
"message": "MFA verification required",
|
||||
}
|
||||
|
||||
# If no MFA required, proceed with normal login
|
||||
return session_utils.complete_login(
|
||||
response, request, user, client_type, password_hasher, token_manager, db
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/verify")
|
||||
async def verify_mfa_and_login(
|
||||
response: Response,
|
||||
request: Request,
|
||||
mfa_request: session_schema.MFALoginRequest,
|
||||
client_type: Annotated[str, Depends(session_security.header_client_type_scheme)],
|
||||
pending_mfa_store: Annotated[
|
||||
session_schema.PendingMFALogin, Depends(session_schema.get_pending_mfa_store)
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Verify MFA code and complete login process.
|
||||
|
||||
This endpoint verifies the MFA code for a pending login and completes
|
||||
the authentication process if the code is valid.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object
|
||||
request: The HTTP request object
|
||||
mfa_request: MFA login request containing username and MFA code
|
||||
client_type: The type of client making the request ("web" or "mobile")
|
||||
pending_mfa_store: Store for pending MFA logins
|
||||
password_hasher: The password hasher instance used for verifying passwords
|
||||
token_manager: The token manager instance used for token operations
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Result from session_utils.complete_login()
|
||||
|
||||
Raises:
|
||||
HTTPException: If no pending login found, MFA code is invalid, or user not found
|
||||
"""
|
||||
# Check if there's a pending MFA login for this username
|
||||
user_id = pending_mfa_store.get_pending_login(mfa_request.username)
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No pending MFA login found for this username",
|
||||
)
|
||||
|
||||
# Verify the MFA code
|
||||
if not profile_utils.verify_user_mfa(user_id, mfa_request.mfa_code, db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid MFA code"
|
||||
)
|
||||
|
||||
# Get the user and complete login
|
||||
user = users_crud.get_user_by_id(user_id, db)
|
||||
if not user:
|
||||
pending_mfa_store.delete_pending_login(mfa_request.username)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# Check if the user is still active
|
||||
users_utils.check_user_is_active(user)
|
||||
|
||||
# Clean up pending login
|
||||
pending_mfa_store.delete_pending_login(mfa_request.username)
|
||||
|
||||
# Complete the login
|
||||
return session_utils.complete_login(
|
||||
response, request, user, client_type, password_hasher, token_manager, db
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
_validate_refresh_token: Annotated[
|
||||
Callable, Depends(session_security.validate_refresh_token)
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_refresh_token),
|
||||
],
|
||||
token_session_id: Annotated[
|
||||
str,
|
||||
Depends(session_security.get_sid_from_refresh_token),
|
||||
],
|
||||
refresh_token_value: Annotated[
|
||||
str,
|
||||
Depends(session_security.get_and_return_refresh_token),
|
||||
],
|
||||
client_type: Annotated[str, Depends(session_security.header_client_type_scheme)],
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
],
|
||||
token_manager: Annotated[
|
||||
session_token_manager.TokenManager,
|
||||
Depends(session_token_manager.get_token_manager),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Handles the refresh token process for user sessions.
|
||||
|
||||
This endpoint validates the provided refresh token, checks session and user status,
|
||||
and issues new access, refresh, and CSRF tokens. The response format depends on the client type.
|
||||
|
||||
Args:
|
||||
response (Response): The HTTP response object.
|
||||
request (Request): The HTTP request object.
|
||||
_validate_refresh_token (Callable): Dependency to validate the refresh token.
|
||||
token_user_id (int): User ID extracted from the refresh token.
|
||||
token_session_id (str): Session ID extracted from the refresh token.
|
||||
refresh_token_value (str): The raw refresh token value.
|
||||
client_type (str): The type of client ("web" or "mobile").
|
||||
password_hasher (PasswordHasher): Utility for verifying token hashes.
|
||||
token_manager (TokenManager): Utility for creating tokens.
|
||||
db (Session): Database session.
|
||||
|
||||
Returns:
|
||||
Union[str, dict]: For "web" clients, returns the session ID.
|
||||
For "mobile" clients, returns a dictionary with new tokens and session ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the session is not found, the refresh token is invalid,
|
||||
the user is inactive, or the client type is invalid.
|
||||
"""
|
||||
# Get the session from the database
|
||||
session = session_crud.get_session_by_id(token_session_id, db)
|
||||
|
||||
# Check if the session was found
|
||||
if session is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
is_valid = password_hasher.verify(refresh_token_value, session.refresh_token)
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# get user
|
||||
user = users_crud.get_user_by_id(token_user_id, db)
|
||||
|
||||
# Check if the user is active
|
||||
users_utils.check_user_is_active(user)
|
||||
|
||||
# Create the tokens
|
||||
(
|
||||
session_id,
|
||||
new_access_token_exp,
|
||||
new_access_token,
|
||||
_new_refresh_token_exp,
|
||||
new_refresh_token,
|
||||
new_csrf_token,
|
||||
) = session_utils.create_tokens(user, token_manager, session.id)
|
||||
|
||||
# Edit the session and store it in the database
|
||||
session_utils.edit_session(session, request, new_refresh_token, password_hasher, db)
|
||||
|
||||
# Opportunistically refresh IdP tokens for all linked identity providers
|
||||
await session_utils.refresh_idp_tokens_if_needed(user.id, db)
|
||||
|
||||
if client_type == "web":
|
||||
response = session_utils.create_response_with_tokens(
|
||||
response, new_access_token, new_refresh_token, new_csrf_token
|
||||
)
|
||||
|
||||
# Return session ID
|
||||
return {
|
||||
"session_id": session_id,
|
||||
}
|
||||
if client_type == "mobile":
|
||||
# Return the tokens
|
||||
return {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
"session_id": session_id,
|
||||
"token_type": "bearer",
|
||||
"expires_in": int(new_access_token_exp.timestamp()),
|
||||
}
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid client type",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
response: Response,
|
||||
_validate_access_token: Annotated[
|
||||
Callable, Depends(session_security.validate_access_token)
|
||||
],
|
||||
token_session_id: Annotated[
|
||||
str,
|
||||
Depends(session_security.get_sid_from_access_token),
|
||||
],
|
||||
refresh_token_value: Annotated[
|
||||
str,
|
||||
Depends(session_security.get_and_return_refresh_token),
|
||||
],
|
||||
client_type: Annotated[str, Depends(session_security.header_client_type_scheme)],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_refresh_token),
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Logs out a user by validating and deleting their session, and clearing authentication cookies for web clients.
|
||||
Parameters:
|
||||
response (Response): The response object to modify cookies.
|
||||
_validate_access_token (Callable): Dependency to validate the access token.
|
||||
token_session_id (str): The session ID extracted from the access token.
|
||||
refresh_token_value (str): The refresh token value from the request.
|
||||
client_type (str): The type of client ("web" or "mobile").
|
||||
token_user_id (int): The user ID extracted from the refresh token.
|
||||
password_hasher (PasswordHasher): Utility for verifying the refresh token.
|
||||
db (Session): Database session for CRUD operations.
|
||||
Returns:
|
||||
dict: A message indicating successful logout.
|
||||
Raises:
|
||||
HTTPException: If the refresh token is invalid (401 Unauthorized).
|
||||
HTTPException: If the client type is invalid (403 Forbidden).
|
||||
"""
|
||||
# Get the session from the database
|
||||
session = session_crud.get_session_by_id(token_session_id, db)
|
||||
|
||||
# Check if the session was found
|
||||
if session is not None:
|
||||
# Verify the refresh token
|
||||
is_valid = password_hasher.verify(refresh_token_value, session.refresh_token)
|
||||
|
||||
# If the refresh token is not valid, raise an exception
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Delete the session from the database
|
||||
session_crud.delete_session(session.id, token_user_id, db)
|
||||
|
||||
# Clear all IdP refresh tokens for security
|
||||
await session_utils.clear_all_idp_tokens(token_user_id, db)
|
||||
|
||||
if client_type == "web":
|
||||
# Clear the cookies by setting their expiration to the past
|
||||
response.delete_cookie(key="endurain_access_token", path="/")
|
||||
response.delete_cookie(key="endurain_refresh_token", path="/")
|
||||
response.delete_cookie(key="endurain_csrf_token", path="/")
|
||||
return {"message": "Logout successful"}
|
||||
if client_type == "mobile":
|
||||
return {"message": "Logout successful"}
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid client type",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions/user/{user_id}")
|
||||
@router.get("/user/{user_id}")
|
||||
async def read_sessions_user(
|
||||
user_id: int,
|
||||
_validate_access_token: Annotated[
|
||||
Callable, Depends(session_security.validate_access_token)
|
||||
],
|
||||
__check_scope: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["sessions:read"])
|
||||
_check_scope: Annotated[
|
||||
Callable, Security(auth_security.check_scopes, scopes=["sessions:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -430,15 +43,12 @@ async def read_sessions_user(
|
||||
return session_crud.get_user_sessions(user_id, db)
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}/user/{user_id}")
|
||||
@router.delete("/{session_id}/user/{user_id}")
|
||||
async def delete_session_user(
|
||||
session_id: str,
|
||||
user_id: int,
|
||||
_validate_access_token: Annotated[
|
||||
Callable, Depends(session_security.validate_access_token)
|
||||
],
|
||||
__check_scope: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["sessions:write"])
|
||||
_check_scope: Annotated[
|
||||
Callable, Security(auth_security.check_scopes, scopes=["sessions:write"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from datetime import datetime
|
||||
from fastapi import Request, HTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
class UsersSessions(BaseModel):
|
||||
@@ -42,215 +40,3 @@ class UsersSessions(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True, extra="forbid", validate_assignment=True
|
||||
)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""
|
||||
Schema for login requests containing username and password.
|
||||
|
||||
Attributes:
|
||||
username (str): The username of the user. Must be between 1 and 250 characters.
|
||||
password (str): The user's password. Must be at least 8 characters long.
|
||||
"""
|
||||
username: str = Field(..., min_length=1, max_length=250)
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class MFALoginRequest(BaseModel):
|
||||
"""
|
||||
Schema for Multi-Factor Authentication (MFA) login request.
|
||||
|
||||
Attributes:
|
||||
username (str): The username of the user attempting to log in. Must be between 1 and 250 characters.
|
||||
mfa_code (str): The 6-digit MFA code provided by the user. Must match the pattern: six consecutive digits.
|
||||
"""
|
||||
username: str = Field(..., min_length=1, max_length=250)
|
||||
mfa_code: str = Field(..., pattern=r'^\d{6}$')
|
||||
|
||||
|
||||
class MFARequiredResponse(BaseModel):
|
||||
"""
|
||||
Represents a response indicating that Multi-Factor Authentication (MFA) is required.
|
||||
|
||||
Attributes:
|
||||
mfa_required (bool): Indicates whether MFA is required. Defaults to True.
|
||||
username (str): The username for which MFA is required.
|
||||
message (str): A message describing the requirement. Defaults to "MFA verification required".
|
||||
"""
|
||||
mfa_required: bool = True
|
||||
username: str
|
||||
message: str = "MFA verification required"
|
||||
|
||||
|
||||
class PendingMFALogin:
|
||||
"""
|
||||
A class to manage pending Multi-Factor Authentication (MFA) login sessions.
|
||||
|
||||
This class provides methods to add, retrieve, delete, and check pending login entries
|
||||
for users who are in the process of MFA authentication. It uses an internal dictionary
|
||||
to store the mapping between usernames and their associated user IDs.
|
||||
|
||||
Attributes:
|
||||
_store (dict): Internal storage mapping usernames to user IDs for pending logins.
|
||||
|
||||
Methods:
|
||||
add_pending_login(username: str, user_id: int):
|
||||
Adds a pending login entry for the specified username and user ID.
|
||||
|
||||
get_pending_login(username: str):
|
||||
Retrieves the user ID associated with the given username's pending login entry.
|
||||
|
||||
delete_pending_login(username: str):
|
||||
Removes the pending login entry for the specified username.
|
||||
|
||||
has_pending_login(username: str):
|
||||
Checks if the specified username has a pending login entry.
|
||||
|
||||
clear_all():
|
||||
Clears all pending login entries from the internal store.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
|
||||
def add_pending_login(self, username: str, user_id: int):
|
||||
"""
|
||||
Adds a pending login entry for a user.
|
||||
|
||||
Stores the provided username and associated user ID in the internal store,
|
||||
marking the user as pending login.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user to add.
|
||||
user_id (int): The unique identifier of the user.
|
||||
|
||||
"""
|
||||
self._store[username] = user_id
|
||||
|
||||
def get_pending_login(self, username: str):
|
||||
"""
|
||||
Retrieve the pending login information for a given username.
|
||||
|
||||
Args:
|
||||
username (str): The username to look up.
|
||||
|
||||
Returns:
|
||||
Any: The pending login information associated with the username, or None if not found.
|
||||
"""
|
||||
return self._store.get(username)
|
||||
|
||||
def delete_pending_login(self, username: str):
|
||||
"""
|
||||
Removes the pending login entry for the specified username from the internal store.
|
||||
|
||||
Args:
|
||||
username (str): The username whose pending login entry should be deleted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if username in self._store:
|
||||
del self._store[username]
|
||||
|
||||
def has_pending_login(self, username: str):
|
||||
"""
|
||||
Checks if the given username has a pending login session.
|
||||
|
||||
Args:
|
||||
username (str): The username to check for a pending login.
|
||||
|
||||
Returns:
|
||||
bool: True if the username has a pending login session, False otherwise.
|
||||
"""
|
||||
return username in self._store
|
||||
|
||||
def clear_all(self):
|
||||
"""
|
||||
Removes all items from the internal store, effectively resetting it to an empty state.
|
||||
"""
|
||||
self._store.clear()
|
||||
|
||||
|
||||
def get_pending_mfa_store():
|
||||
"""
|
||||
Retrieve the current pending MFA (Multi-Factor Authentication) store.
|
||||
|
||||
Returns:
|
||||
dict: The pending MFA store containing MFA-related data.
|
||||
"""
|
||||
return pending_mfa_store
|
||||
|
||||
|
||||
pending_mfa_store = PendingMFALogin()
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware for CSRF protection in FastAPI applications.
|
||||
|
||||
This middleware checks for a valid CSRF token in requests from web clients to prevent cross-site request forgery attacks.
|
||||
It exempts specific API paths from CSRF checks and only enforces validation for POST, PUT, DELETE, and PATCH requests.
|
||||
|
||||
Attributes:
|
||||
exempt_paths (list): List of URL paths that are exempt from CSRF protection.
|
||||
|
||||
Methods:
|
||||
dispatch(request, call_next):
|
||||
Processes incoming requests, enforcing CSRF checks for web clients on non-exempt paths and applicable HTTP methods.
|
||||
Raises HTTPException with status code 403 if CSRF token is missing or invalid.
|
||||
"""
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
# Define paths that don't need CSRF protection
|
||||
self.exempt_paths = [
|
||||
"/api/v1/token",
|
||||
"/api/v1/refresh",
|
||||
"/api/v1/mfa/verify",
|
||||
"/api/v1/password-reset/request",
|
||||
"/api/v1/password-reset/confirm",
|
||||
"/api/v1/sign-up/request",
|
||||
"/api/v1/sign-up/confirm"
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Middleware method to enforce CSRF protection for web clients.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
call_next (Callable): The next middleware or endpoint handler.
|
||||
|
||||
Returns:
|
||||
Response: The HTTP response after CSRF validation.
|
||||
|
||||
Behavior:
|
||||
- Skips CSRF checks for non-web clients (determined by "X-Client-Type" header).
|
||||
- Skips CSRF checks for exempt paths.
|
||||
- For web clients and non-exempt paths, validates CSRF token for POST, PUT, DELETE, and PATCH requests:
|
||||
- Requires both "endurain_csrf_token" cookie and "X-CSRF-Token" header.
|
||||
- Raises HTTPException 403 if tokens are missing or do not match.
|
||||
"""
|
||||
# Get client type from header
|
||||
client_type = request.headers.get("X-Client-Type")
|
||||
|
||||
# Skip CSRF checks for not web clients
|
||||
if client_type != "web":
|
||||
return await call_next(request)
|
||||
|
||||
# Skip CSRF check for exempt paths
|
||||
if request.url.path in self.exempt_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Check for CSRF token in POST, PUT, DELETE, and PATCH requests
|
||||
if request.method in ["POST", "PUT", "DELETE", "PATCH"]:
|
||||
csrf_cookie = request.cookies.get("endurain_csrf_token")
|
||||
csrf_header = request.headers.get("X-CSRF-Token")
|
||||
|
||||
if not csrf_cookie or not csrf_header:
|
||||
raise HTTPException(status_code=403, detail="CSRF token missing")
|
||||
|
||||
if csrf_cookie != csrf_header:
|
||||
raise HTTPException(status_code=403, detail="CSRF token invalid")
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
@@ -1,33 +1,21 @@
|
||||
import os
|
||||
"""Session utility functions and classes"""
|
||||
|
||||
from enum import Enum
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
from fastapi import (
|
||||
HTTPException,
|
||||
status,
|
||||
Response,
|
||||
Request,
|
||||
)
|
||||
from user_agents import parse
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import session.constants as session_constants
|
||||
import auth.constants as auth_constants
|
||||
import session.schema as session_schema
|
||||
import session.crud as session_crud
|
||||
import session.password_hasher as session_password_hasher
|
||||
import session.token_manager as session_token_manager
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import users.user.crud as users_crud
|
||||
import users.user.schema as users_schema
|
||||
import users.user_identity_providers.crud as user_idp_crud
|
||||
|
||||
import identity_providers.service as idp_service
|
||||
from identity_providers.service import TokenAction
|
||||
import core.logger as core_logger
|
||||
|
||||
|
||||
class DeviceType(Enum):
|
||||
@@ -93,7 +81,7 @@ def create_session_object(
|
||||
user_id=user.id,
|
||||
refresh_token=hashed_refresh_token,
|
||||
ip_address=get_ip_address(request),
|
||||
device_type=device_info.device_type,
|
||||
device_type=device_info.device_type.value,
|
||||
operating_system=device_info.operating_system,
|
||||
operating_system_version=device_info.operating_system_version,
|
||||
browser=device_info.browser,
|
||||
@@ -129,7 +117,7 @@ def edit_session_object(
|
||||
user_id=session.user_id,
|
||||
refresh_token=hashed_refresh_token,
|
||||
ip_address=get_ip_address(request),
|
||||
device_type=device_info.device_type,
|
||||
device_type=device_info.device_type.value,
|
||||
operating_system=device_info.operating_system,
|
||||
operating_system_version=device_info.operating_system_version,
|
||||
browser=device_info.browser,
|
||||
@@ -139,164 +127,12 @@ def edit_session_object(
|
||||
)
|
||||
|
||||
|
||||
def authenticate_user(
|
||||
username: str,
|
||||
password: str,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> users_schema.UserRead:
|
||||
"""
|
||||
Authenticates a user by verifying the provided username and password.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user attempting to authenticate.
|
||||
password (str): The plaintext password provided by the user.
|
||||
password_hasher (session_password_hasher.PasswordHasher): An instance of the password hasher for verifying and updating password hashes.
|
||||
db (Session): The database session used for querying and updating user data.
|
||||
|
||||
Returns:
|
||||
users_schema.UserRead: The authenticated user object if authentication is successful.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the username does not exist or the password is invalid.
|
||||
"""
|
||||
# Get the user from the database
|
||||
user = users_crud.authenticate_user(username, db)
|
||||
|
||||
# Check if the user exists and if the password is correct
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Verify password and get updated hash if applicable
|
||||
is_password_valid, updated_hash = password_hasher.verify_and_update(
|
||||
password, user.password
|
||||
)
|
||||
if not is_password_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Update user hash if applicable
|
||||
if updated_hash:
|
||||
users_crud.edit_user_password(
|
||||
user.id, updated_hash, password_hasher, db, is_hashed=True
|
||||
)
|
||||
|
||||
# Return the user if the password is correct
|
||||
return user
|
||||
|
||||
|
||||
def create_tokens(
|
||||
user: users_schema.UserRead,
|
||||
token_manager: session_token_manager.TokenManager,
|
||||
session_id: str | None = None,
|
||||
) -> Tuple[str, datetime, str, datetime, str, str]:
|
||||
"""
|
||||
Generates session tokens for a user, including access token, refresh token, and CSRF token.
|
||||
|
||||
Args:
|
||||
user (users_schema.UserRead): The user object for whom the tokens are being created.
|
||||
token_manager (session_token_manager.TokenManager): The token manager responsible for token creation.
|
||||
session_id (str | None, optional): An optional session ID. If not provided, a new unique session ID is generated.
|
||||
|
||||
Returns:
|
||||
Tuple[str, datetime, str, datetime, str, str]:
|
||||
A tuple containing:
|
||||
- session_id (str): The session identifier.
|
||||
- access_token_exp (datetime): Expiration datetime of the access token.
|
||||
- access_token (str): The access token string.
|
||||
- refresh_token_exp (datetime): Expiration datetime of the refresh token.
|
||||
- refresh_token (str): The refresh token string.
|
||||
- csrf_token (str): The CSRF token string.
|
||||
"""
|
||||
if session_id is None:
|
||||
# Generate a unique session ID
|
||||
session_id = str(uuid4())
|
||||
|
||||
# Create the access, refresh tokens and csrf token
|
||||
access_token_exp, access_token = token_manager.create_token(
|
||||
session_id, user, session_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
refresh_token_exp, refresh_token = token_manager.create_token(
|
||||
session_id, user, session_token_manager.TokenType.REFRESH
|
||||
)
|
||||
|
||||
csrf_token = token_manager.create_csrf_token()
|
||||
|
||||
return (
|
||||
session_id,
|
||||
access_token_exp,
|
||||
access_token,
|
||||
refresh_token_exp,
|
||||
refresh_token,
|
||||
csrf_token,
|
||||
)
|
||||
|
||||
|
||||
def create_response_with_tokens(
|
||||
response: Response, access_token: str, refresh_token: str, csrf_token: str
|
||||
) -> Response:
|
||||
"""
|
||||
Sets access, refresh, and CSRF tokens as cookies on the given response object.
|
||||
|
||||
Args:
|
||||
response (Response): The response object to set cookies on.
|
||||
access_token (str): The JWT access token to be set as a cookie.
|
||||
refresh_token (str): The JWT refresh token to be set as a cookie.
|
||||
csrf_token (str): The CSRF token to be set as a cookie.
|
||||
|
||||
Returns:
|
||||
Response: The response object with the tokens set as cookies.
|
||||
"""
|
||||
secure = os.environ.get("FRONTEND_PROTOCOL") == "https"
|
||||
|
||||
# Set the cookies with the tokens
|
||||
response.set_cookie(
|
||||
key="endurain_access_token",
|
||||
value=access_token,
|
||||
expires=datetime.now(timezone.utc)
|
||||
+ timedelta(minutes=session_constants.JWT_ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
httponly=True,
|
||||
path="/",
|
||||
secure=secure,
|
||||
samesite="Lax",
|
||||
)
|
||||
response.set_cookie(
|
||||
key="endurain_refresh_token",
|
||||
value=refresh_token,
|
||||
expires=datetime.now(timezone.utc)
|
||||
+ timedelta(days=session_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
httponly=True,
|
||||
path="/",
|
||||
secure=secure,
|
||||
samesite="Lax",
|
||||
)
|
||||
response.set_cookie(
|
||||
key="endurain_csrf_token",
|
||||
value=csrf_token,
|
||||
httponly=False,
|
||||
path="/",
|
||||
secure=secure,
|
||||
samesite="Lax",
|
||||
)
|
||||
|
||||
# Return the response
|
||||
return response
|
||||
|
||||
|
||||
def create_session(
|
||||
session_id: str,
|
||||
user: users_schema.UserRead,
|
||||
request: Request,
|
||||
refresh_token: str,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -307,7 +143,7 @@ def create_session(
|
||||
user (users_schema.UserRead): The user for whom the session is being created.
|
||||
request (Request): The incoming HTTP request object.
|
||||
refresh_token (str): The refresh token to be associated with the session.
|
||||
password_hasher (session_password_hasher.PasswordHasher): Utility to hash the refresh token.
|
||||
password_hasher (auth_password_hasher.PasswordHasher): Utility to hash the refresh token.
|
||||
db (Session): Database session for storing the session.
|
||||
|
||||
Returns:
|
||||
@@ -315,7 +151,7 @@ def create_session(
|
||||
"""
|
||||
# Calculate the refresh token expiration date
|
||||
exp = datetime.now(timezone.utc) + timedelta(
|
||||
days=session_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
days=auth_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
|
||||
# Create a new session
|
||||
@@ -335,7 +171,7 @@ def edit_session(
|
||||
session: session_schema.UsersSessions,
|
||||
request: Request,
|
||||
new_refresh_token: str,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -345,7 +181,7 @@ def edit_session(
|
||||
session (session_schema.UsersSessions): The current user session object to be edited.
|
||||
request (Request): The incoming request object containing session context.
|
||||
new_refresh_token (str): The new refresh token to be set for the session.
|
||||
password_hasher (session_password_hasher.PasswordHasher): Utility for hashing the refresh token.
|
||||
password_hasher (auth_password_hasher.PasswordHasher): Utility for hashing the refresh token.
|
||||
db (Session): Database session for committing changes.
|
||||
|
||||
Returns:
|
||||
@@ -353,7 +189,7 @@ def edit_session(
|
||||
"""
|
||||
# Calculate the refresh token expiration date
|
||||
exp = datetime.now(timezone.utc) + timedelta(
|
||||
days=session_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
days=auth_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
|
||||
# Update the session
|
||||
@@ -368,73 +204,6 @@ def edit_session(
|
||||
session_crud.edit_session(updated_session, db)
|
||||
|
||||
|
||||
def complete_login(
|
||||
response: Response,
|
||||
request: Request,
|
||||
user: users_schema.UserRead,
|
||||
client_type: str,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
token_manager: session_token_manager.TokenManager,
|
||||
db: Session,
|
||||
) -> dict | str:
|
||||
"""
|
||||
Handles the completion of the login process by generating session and authentication tokens,
|
||||
storing the session in the database, and returning appropriate responses based on client type.
|
||||
|
||||
Args:
|
||||
response (Response): The HTTP response object to set cookies for web clients.
|
||||
request (Request): The HTTP request object containing client information.
|
||||
user (users_schema.UserRead): The authenticated user object.
|
||||
client_type (str): The type of client ("web" or "mobile").
|
||||
password_hasher (session_password_hasher.PasswordHasher): Utility for password hashing.
|
||||
token_manager (session_token_manager.TokenManager): Utility for token generation and management.
|
||||
db (Session): Database session for storing session information.
|
||||
|
||||
Returns:
|
||||
dict | str: For web clients, returns the session ID as a string.
|
||||
For mobile clients, returns a dictionary containing tokens and session info.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the client type is invalid, raises a 403 Forbidden error.
|
||||
"""
|
||||
# Create the tokens
|
||||
(
|
||||
session_id,
|
||||
access_token_exp,
|
||||
access_token,
|
||||
_refresh_token_exp,
|
||||
refresh_token,
|
||||
csrf_token,
|
||||
) = create_tokens(user, token_manager)
|
||||
|
||||
# Create the session and store it in the database
|
||||
create_session(session_id, user, request, refresh_token, password_hasher, db)
|
||||
|
||||
if client_type == "web":
|
||||
# Set response cookies with tokens
|
||||
create_response_with_tokens(response, access_token, refresh_token, csrf_token)
|
||||
|
||||
# Return the session_id
|
||||
return {
|
||||
"session_id": session_id,
|
||||
}
|
||||
if client_type == "mobile":
|
||||
# Return the tokens directly (no cookies for mobile)
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"session_id": session_id,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int(access_token_exp.timestamp()),
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid client type",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def get_user_agent(request: Request) -> str:
|
||||
"""
|
||||
Extracts the 'User-Agent' string from the request headers.
|
||||
@@ -511,219 +280,3 @@ def parse_user_agent(user_agent: str) -> DeviceInfo:
|
||||
browser=ua.browser.family or "Unknown",
|
||||
browser_version=ua.browser.version_string or "Unknown",
|
||||
)
|
||||
|
||||
|
||||
async def refresh_idp_tokens_if_needed(user_id: int, db: Session) -> None:
|
||||
"""
|
||||
Refreshes identity provider (IdP) tokens for a user if needed based on token expiration policies.
|
||||
|
||||
This function retrieves all IdP links associated with a user and evaluates each token's
|
||||
state to determine the appropriate action: refresh if nearing expiry, clear if maximum
|
||||
age is exceeded, or skip if still valid.
|
||||
|
||||
The function is designed to be non-blocking and opportunistic - errors during token
|
||||
refresh or clearing are logged but do not raise exceptions, allowing the application
|
||||
to continue normal operation even if IdP token management fails.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user whose IdP tokens should be checked and refreshed.
|
||||
db (Session): SQLAlchemy database session for performing database operations.
|
||||
|
||||
Returns:
|
||||
None: This function performs side effects (token refresh/clearing) but returns nothing.
|
||||
|
||||
Raises:
|
||||
Does not raise exceptions. All errors are caught, logged, and suppressed to ensure
|
||||
IdP token management does not disrupt normal application flow.
|
||||
|
||||
Notes:
|
||||
- If a user has no IdP links, the function returns early without performing any operations.
|
||||
- Token refresh attempts that fail are logged but the user session remains valid.
|
||||
- Tokens exceeding maximum age are cleared for security, requiring user re-authentication.
|
||||
- Individual IdP operation failures do not prevent checking other IdP links.
|
||||
"""
|
||||
try:
|
||||
# Get all IdP links for this user
|
||||
idp_links = user_idp_crud.get_user_idp_links(user_id, db)
|
||||
|
||||
if not idp_links:
|
||||
# User has no IdP links - nothing to refresh
|
||||
return
|
||||
|
||||
# Check each IdP link and take appropriate action
|
||||
for link in idp_links:
|
||||
try:
|
||||
# Determine what action to take for this IdP token (policy-based)
|
||||
action = idp_service.idp_service._should_refresh_idp_token(link)
|
||||
|
||||
if action == TokenAction.REFRESH:
|
||||
# Token is close to expiry - attempt to refresh
|
||||
core_logger.print_to_log(
|
||||
f"Attempting to refresh IdP token for user {user_id}, idp {link.idp_id}",
|
||||
"debug",
|
||||
)
|
||||
|
||||
# Attempt to refresh the IdP session
|
||||
result = await idp_service.idp_service.refresh_idp_session(
|
||||
user_id, link.idp_id, db
|
||||
)
|
||||
|
||||
if result:
|
||||
core_logger.print_to_log(
|
||||
f"Successfully refreshed IdP token for user {user_id}, idp {link.idp_id}",
|
||||
"debug",
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"IdP token refresh failed for user {user_id}, idp {link.idp_id}. "
|
||||
"User may need to re-authenticate with IdP later.",
|
||||
"debug",
|
||||
)
|
||||
|
||||
elif action == TokenAction.CLEAR:
|
||||
# Token has exceeded maximum age - clear it for security
|
||||
core_logger.print_to_log(
|
||||
f"Clearing expired IdP token (max age exceeded) for user {user_id}, idp {link.idp_id}",
|
||||
"info",
|
||||
)
|
||||
|
||||
success = user_idp_crud.clear_idp_refresh_token(
|
||||
user_id, link.idp_id, db
|
||||
)
|
||||
|
||||
if success:
|
||||
core_logger.print_to_log(
|
||||
f"Successfully cleared expired IdP token for user {user_id}, idp {link.idp_id}. "
|
||||
"User will need to re-authenticate with IdP.",
|
||||
"info",
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"Failed to clear expired IdP token for user {user_id}, idp {link.idp_id}",
|
||||
"warning",
|
||||
)
|
||||
|
||||
else: # TokenAction.SKIP
|
||||
# Token is still valid and not close to expiry - no action needed
|
||||
pass
|
||||
|
||||
except Exception as err:
|
||||
# Log individual IdP operation failure but continue with other IdPs
|
||||
core_logger.print_to_log(
|
||||
f"Error checking/refreshing IdP token for user {user_id}, idp {link.idp_id}: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
# Continue to next IdP link
|
||||
|
||||
except Exception as err:
|
||||
# Catch-all for unexpected errors (e.g., database query failure)
|
||||
core_logger.print_to_log(
|
||||
f"Error retrieving IdP links for user {user_id}: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
# Don't raise - IdP token refresh is opportunistic and non-blocking
|
||||
|
||||
|
||||
async def clear_all_idp_tokens(
|
||||
user_id: int, db: Session, revoke_at_idp: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Clear all IdP (Identity Provider) refresh tokens for a user.
|
||||
|
||||
This function retrieves all IdP links associated with a user and clears their
|
||||
refresh tokens. It supports optional revocation at the IdP level before clearing
|
||||
tokens locally.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user whose IdP tokens should be cleared.
|
||||
db (Session): The database session to use for queries.
|
||||
revoke_at_idp (bool, optional): If True, attempts to revoke tokens at the
|
||||
IdP provider level (RFC 7009) before clearing locally. Defaults to False.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
This function does not raise exceptions. All errors are logged and handled
|
||||
gracefully to ensure logout processes are not interrupted.
|
||||
|
||||
Notes:
|
||||
- If no IdP links exist for the user, the function returns early.
|
||||
- Token revocation at the IdP is best-effort; local clearing always proceeds
|
||||
regardless of revocation success or failure.
|
||||
- Individual IdP token clearing failures do not prevent clearing tokens for
|
||||
other IdPs.
|
||||
- All errors are logged with appropriate severity levels (debug, info, warning).
|
||||
"""
|
||||
try:
|
||||
# Get all IdP links for this user
|
||||
idp_links = user_idp_crud.get_user_idp_links(user_id, db)
|
||||
|
||||
if not idp_links:
|
||||
# User has no IdP links - nothing to clear
|
||||
return
|
||||
|
||||
# Clear tokens for each IdP link
|
||||
for link in idp_links:
|
||||
try:
|
||||
# Optionally attempt to revoke token at IdP first (RFC 7009)
|
||||
if revoke_at_idp:
|
||||
try:
|
||||
revoked = await idp_service.idp_service.revoke_idp_token(
|
||||
user_id, link.idp_id, db
|
||||
)
|
||||
if revoked:
|
||||
core_logger.print_to_log(
|
||||
f"Revoked IdP token at provider for user {user_id}, idp {link.idp_id}",
|
||||
"info",
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"IdP token revocation not supported or failed for user {user_id}, idp {link.idp_id}. "
|
||||
"Will clear locally.",
|
||||
"debug",
|
||||
)
|
||||
except Exception as revoke_err:
|
||||
# Log revocation failure but continue with local clearing
|
||||
core_logger.print_to_log(
|
||||
f"Error revoking IdP token for user {user_id}, idp {link.idp_id}: {revoke_err}. "
|
||||
"Will clear locally.",
|
||||
"warning",
|
||||
exc=revoke_err,
|
||||
)
|
||||
|
||||
# Always clear locally regardless of revocation result
|
||||
success = user_idp_crud.clear_idp_refresh_token(
|
||||
user_id, link.idp_id, db
|
||||
)
|
||||
|
||||
if success:
|
||||
core_logger.print_to_log(
|
||||
f"Cleared IdP refresh token for user {user_id}, idp {link.idp_id} on logout",
|
||||
"debug",
|
||||
)
|
||||
else:
|
||||
core_logger.print_to_log(
|
||||
f"No IdP refresh token to clear for user {user_id}, idp {link.idp_id}",
|
||||
"debug",
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
# Log individual IdP token clearing failure but continue with other IdPs
|
||||
core_logger.print_to_log(
|
||||
f"Error clearing IdP token for user {user_id}, idp {link.idp_id}: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
# Continue to next IdP link
|
||||
|
||||
except Exception as err:
|
||||
# Catch-all for unexpected errors (e.g., database query failure)
|
||||
core_logger.print_to_log(
|
||||
f"Error retrieving IdP links for user {user_id} during logout: {err}",
|
||||
"warning",
|
||||
exc=err,
|
||||
)
|
||||
# Don't raise - IdP token clearing is a best-effort security measure
|
||||
|
||||
@@ -19,10 +19,10 @@ class SignUpToken(Base):
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="User ID that the sign-up token belongs to",
|
||||
comment="User ID that the sign up token belongs to",
|
||||
)
|
||||
token_hash = Column(
|
||||
String(length=128), nullable=False, comment="Hashed sign-up token"
|
||||
String(length=128), nullable=False, comment="Hashed sign up token"
|
||||
)
|
||||
created_at = Column(
|
||||
DateTime, nullable=False, comment="Token creation date (datetime)"
|
||||
|
||||
@@ -17,7 +17,7 @@ import notifications.utils as notifications_utils
|
||||
import sign_up_tokens.utils as sign_up_tokens_utils
|
||||
import sign_up_tokens.schema as sign_up_tokens_schema
|
||||
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import server_settings.utils as server_settings_utils
|
||||
|
||||
@@ -38,8 +38,8 @@ async def signup(
|
||||
Depends(core_apprise.get_email_service),
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -56,7 +56,7 @@ async def signup(
|
||||
verification and admin approval emails.
|
||||
- websocket_manager (websocket_schema.WebSocketManager): Injected manager used to send
|
||||
real-time notifications (e.g., admin approval requests).
|
||||
- password_hasher (session_password_hasher.PasswordHasher): Injected password hasher used to hash user passwords.
|
||||
- password_hasher (auth_password_hasher.PasswordHasher): Injected password hasher used to hash user passwords.
|
||||
- db (Session): Database session/connection used to create the user and related records.
|
||||
|
||||
Behavior and side effects
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from stravalib.exc import AccessUnauthorized
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import users.user_integrations.crud as user_integrations_crud
|
||||
|
||||
@@ -109,15 +109,15 @@ async def strava_retrieve_activities_days(
|
||||
days: int,
|
||||
validate_access_token: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.validate_access_token),
|
||||
Depends(auth_security.validate_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
websocket_manager: Annotated[
|
||||
websocket_schema.WebSocketManager,
|
||||
@@ -149,15 +149,15 @@ async def strava_retrieve_activities_days(
|
||||
async def strava_retrieve_gear(
|
||||
validate_access_token: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.validate_access_token),
|
||||
Depends(auth_security.validate_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
@@ -180,7 +180,7 @@ async def strava_retrieve_gear(
|
||||
async def import_bikes_from_strava_export(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -236,7 +236,7 @@ async def import_bikes_from_strava_export(
|
||||
async def import_shoes_from_strava_export(
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sid_from_access_token),
|
||||
Depends(auth_security.get_sid_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -293,15 +293,15 @@ async def strava_set_user_client(
|
||||
client: strava_schema.StravaClient,
|
||||
validate_access_token: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.validate_access_token),
|
||||
Depends(auth_security.validate_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
@@ -321,15 +321,15 @@ async def strava_set_user_unique_state(
|
||||
state: str | None,
|
||||
validate_access_token: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.validate_access_token),
|
||||
Depends(auth_security.validate_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
@@ -344,15 +344,15 @@ async def strava_set_user_unique_state(
|
||||
async def strava_unlink(
|
||||
validate_access_token: Annotated[
|
||||
Callable,
|
||||
Depends(session_security.validate_access_token),
|
||||
Depends(auth_security.validate_access_token),
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable,
|
||||
Security(session_security.check_scopes, scopes=["profile"]),
|
||||
Security(auth_security.check_scopes, scopes=["profile"]),
|
||||
],
|
||||
token_user_id: Annotated[
|
||||
int,
|
||||
Depends(session_security.get_sub_from_access_token),
|
||||
Depends(auth_security.get_sub_from_access_token),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -4,8 +4,8 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from urllib.parse import unquote
|
||||
|
||||
import session.security as session_security
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.security as auth_security
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import users.user.schema as users_schema
|
||||
import users.user.utils as users_utils
|
||||
@@ -86,7 +86,9 @@ def get_users_with_pagination(db: Session, page_number: int = 1, num_records: in
|
||||
|
||||
# Enrich users with IDP count
|
||||
for user in users:
|
||||
idp_links = user_idp_crud.get_user_idp_links(user.id, db)
|
||||
idp_links = user_idp_crud.get_user_identity_providers_by_user_id(
|
||||
user.id, db
|
||||
)
|
||||
user.external_auth_count = len(idp_links)
|
||||
|
||||
# Return the users
|
||||
@@ -288,7 +290,7 @@ def get_users_admin(db: Session):
|
||||
|
||||
def create_user(
|
||||
user: users_schema.UserCreate,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
):
|
||||
try:
|
||||
@@ -566,7 +568,7 @@ def verify_user_email(
|
||||
def edit_user_password(
|
||||
user_id: int,
|
||||
password: str,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
is_hashed: bool = False,
|
||||
):
|
||||
@@ -805,7 +807,7 @@ def disable_user_mfa(user_id: int, db: Session):
|
||||
def create_signup_user(
|
||||
user: users_schema.UserSignup,
|
||||
server_settings: server_settings_schema.ServerSettingsRead,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
db: Session,
|
||||
):
|
||||
"""
|
||||
@@ -814,7 +816,7 @@ def create_signup_user(
|
||||
Args:
|
||||
user (users_schema.UserSignup): The user signup data containing user details.
|
||||
server_settings (server_settings_schema.ServerSettingsRead): Server settings used to determine if email verification or admin approval is required.
|
||||
password_hasher (session_password_hasher.PasswordHasher): Password hasher used to hash the user's password.
|
||||
password_hasher (auth_password_hasher.PasswordHasher): Password hasher used to hash the user's password.
|
||||
db (Session): SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -17,8 +17,8 @@ import users.user_privacy_settings.crud as users_privacy_settings_crud
|
||||
import health_targets.crud as health_targets_crud
|
||||
|
||||
import sign_up_tokens.utils as sign_up_tokens_utils
|
||||
import session.security as session_security
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.security as auth_security
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import core.apprise as core_apprise
|
||||
import core.database as core_database
|
||||
@@ -31,7 +31,7 @@ router = APIRouter()
|
||||
@router.get("/number", response_model=int)
|
||||
async def read_users_number(
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -52,7 +52,7 @@ async def read_users_all_pagination(
|
||||
Callable, Depends(core_dependencies.validate_pagination_values)
|
||||
],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -70,7 +70,7 @@ async def read_users_all_pagination(
|
||||
async def read_users_contain_username(
|
||||
username: str,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -88,7 +88,7 @@ async def read_users_contain_username(
|
||||
async def read_users_username(
|
||||
username: str,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -106,7 +106,7 @@ async def read_users_username(
|
||||
async def read_users_email(
|
||||
email: str,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -122,7 +122,7 @@ async def read_users_id(
|
||||
user_id: int,
|
||||
validate_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:read"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:read"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -136,12 +136,12 @@ async def read_users_id(
|
||||
@router.post("", response_model=users_schema.UserRead, status_code=201)
|
||||
async def create_user(
|
||||
user: users_schema.UserCreate,
|
||||
__check_scope: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:write"])
|
||||
_check_scope: Annotated[
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:write"])
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -168,7 +168,7 @@ async def upload_user_image(
|
||||
validate_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
file: UploadFile,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:write"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -184,7 +184,7 @@ async def edit_user(
|
||||
validate_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
user_attributtes: users_schema.UserRead,
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:write"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -203,7 +203,7 @@ async def approve_user(
|
||||
user_id: int,
|
||||
validate_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:write"])
|
||||
],
|
||||
email_service: Annotated[
|
||||
core_apprise.AppriseService,
|
||||
@@ -229,12 +229,12 @@ async def edit_user_password(
|
||||
user_id: int,
|
||||
_validate_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
user_attributes: users_schema.UserEditPassword,
|
||||
__check_scope: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:write"])
|
||||
_check_scope: Annotated[
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:write"])
|
||||
],
|
||||
password_hasher: Annotated[
|
||||
session_password_hasher.PasswordHasher,
|
||||
Depends(session_password_hasher.get_password_hasher),
|
||||
auth_password_hasher.PasswordHasher,
|
||||
Depends(auth_password_hasher.get_password_hasher),
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -255,7 +255,7 @@ async def delete_user_photo(
|
||||
user_id: int,
|
||||
validate_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:write"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
@@ -274,7 +274,7 @@ async def delete_user(
|
||||
user_id: int,
|
||||
validate_id: Annotated[Callable, Depends(users_dependencies.validate_user_id)],
|
||||
_check_scopes: Annotated[
|
||||
Callable, Security(session_security.check_scopes, scopes=["users:write"])
|
||||
Callable, Security(auth_security.check_scopes, scopes=["users:write"])
|
||||
],
|
||||
db: Annotated[
|
||||
Session,
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
import shutil
|
||||
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
|
||||
import users.user.crud as users_crud
|
||||
import users.user.schema as users_schema
|
||||
@@ -36,13 +36,13 @@ def create_user_default_data(user_id: int, db: Session) -> None:
|
||||
|
||||
def check_password_and_hash(
|
||||
password: str,
|
||||
password_hasher: session_password_hasher.PasswordHasher,
|
||||
password_hasher: auth_password_hasher.PasswordHasher,
|
||||
min_length: int = 8,
|
||||
) -> str:
|
||||
# Check if password meets requirements
|
||||
try:
|
||||
password_hasher.validate_password(password, min_length)
|
||||
except session_password_hasher.PasswordPolicyError as err:
|
||||
except auth_password_hasher.PasswordPolicyError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(err),
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
import users.user_default_gear.schema as user_default_gear_schema
|
||||
import users.user_default_gear.crud as user_default_gear_crud
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -16,9 +16,7 @@ router = APIRouter()
|
||||
|
||||
@router.get("", response_model=user_default_gear_schema.UserDefaultGear)
|
||||
async def read_user_default_gear(
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
@@ -31,9 +29,7 @@ async def read_user_default_gear(
|
||||
@router.put("", response_model=user_default_gear_schema.UserDefaultGear)
|
||||
async def edit_user_default_gear(
|
||||
user_default_gear: user_default_gear_schema.UserDefaultGear,
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[
|
||||
Session,
|
||||
Depends(core_database.get_db),
|
||||
|
||||
@@ -9,7 +9,7 @@ import users.user_goals.schema as user_goals_schema
|
||||
import users.user_goals.crud as user_goals_crud
|
||||
import users.user_goals.utils as user_goals_utils
|
||||
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
|
||||
import core.database as core_database
|
||||
|
||||
@@ -19,9 +19,7 @@ router = APIRouter()
|
||||
|
||||
@router.get("", response_model=List[user_goals_schema.UserGoalRead] | None)
|
||||
async def get_user_goals(
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
"""
|
||||
@@ -39,9 +37,7 @@ async def get_user_goals(
|
||||
|
||||
@router.get("/results", response_model=List[user_goals_schema.UserGoalProgress] | None)
|
||||
async def get_user_goals_results(
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
"""
|
||||
@@ -60,9 +56,7 @@ async def get_user_goals_results(
|
||||
@router.post("", response_model=user_goals_schema.UserGoalRead, status_code=201)
|
||||
async def create_user_goal(
|
||||
user_goal: user_goals_schema.UserGoalCreate,
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
"""
|
||||
@@ -87,9 +81,7 @@ async def update_user_goal(
|
||||
goal_id: int,
|
||||
validate_id: Annotated[Callable, Depends(user_goals_dependencies.validate_goal_id)],
|
||||
user_goal: user_goals_schema.UserGoalEdit,
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
"""
|
||||
@@ -115,9 +107,7 @@ async def update_user_goal(
|
||||
async def delete_user_goal(
|
||||
goal_id: int,
|
||||
validate_id: Annotated[Callable, Depends(user_goals_dependencies.validate_goal_id)],
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy.sql import func
|
||||
from users.user_identity_providers import models as user_idp_models
|
||||
|
||||
|
||||
def get_idp_has_user_links(idp_id: int, db: Session) -> bool:
|
||||
def check_user_identity_providers_by_idp_id(idp_id: int, db: Session) -> bool:
|
||||
"""
|
||||
Checks if there are any user links associated with a given identity provider ID.
|
||||
|
||||
@@ -21,7 +21,7 @@ def get_idp_has_user_links(idp_id: int, db: Session) -> bool:
|
||||
).scalar()
|
||||
|
||||
|
||||
def get_user_idp_link(
|
||||
def get_user_identity_provider_by_user_id_and_idp_id(
|
||||
user_id: int, idp_id: int, db: Session
|
||||
) -> user_idp_models.UserIdentityProvider | None:
|
||||
"""
|
||||
@@ -45,7 +45,7 @@ def get_user_idp_link(
|
||||
)
|
||||
|
||||
|
||||
def get_user_idp_link_by_subject(
|
||||
def get_user_identity_provider_by_subject_and_idp_id(
|
||||
idp_id: int, idp_subject: str, db: Session
|
||||
) -> user_idp_models.UserIdentityProvider | None:
|
||||
"""
|
||||
@@ -69,7 +69,7 @@ def get_user_idp_link_by_subject(
|
||||
)
|
||||
|
||||
|
||||
def get_user_idp_links(
|
||||
def get_user_identity_providers_by_user_id(
|
||||
user_id: int, db: Session
|
||||
) -> list[user_idp_models.UserIdentityProvider]:
|
||||
"""
|
||||
@@ -89,7 +89,35 @@ def get_user_idp_links(
|
||||
)
|
||||
|
||||
|
||||
def create_user_idp_link(
|
||||
def get_user_identity_provider_refresh_token_by_user_id_and_idp_id(
|
||||
user_id: int, idp_id: int, db: Session
|
||||
) -> str | None:
|
||||
"""
|
||||
Get the encrypted refresh token for a user-IdP link.
|
||||
|
||||
This function retrieves the encrypted refresh token. The caller is responsible
|
||||
for decrypting it using Fernet before use.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
idp_id (int): The ID of the identity provider.
|
||||
db (Session): The SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
str | None: The encrypted refresh token string if found, otherwise None.
|
||||
|
||||
Security Note:
|
||||
- Returns the encrypted token (not plaintext)
|
||||
- Caller must decrypt using Fernet
|
||||
- Returns None if link doesn't exist or token is not set
|
||||
"""
|
||||
db_link = get_user_identity_provider_by_user_id_and_idp_id(user_id, idp_id, db)
|
||||
if db_link:
|
||||
return db_link.idp_refresh_token
|
||||
return None
|
||||
|
||||
|
||||
def create_user_identity_provider(
|
||||
user_id: int, idp_id: int, idp_subject: str, db: Session
|
||||
) -> user_idp_models.UserIdentityProvider:
|
||||
"""
|
||||
@@ -113,7 +141,7 @@ def create_user_idp_link(
|
||||
return db_link
|
||||
|
||||
|
||||
def update_user_idp_last_login(
|
||||
def update_user_identity_provider_last_login(
|
||||
user_id: int, idp_id: int, db: Session
|
||||
) -> user_idp_models.UserIdentityProvider | None:
|
||||
"""
|
||||
@@ -127,7 +155,7 @@ def update_user_idp_last_login(
|
||||
Returns:
|
||||
user_idp_models.UserIdentityProvider | None: The updated UserIdentityProvider link if found, otherwise None.
|
||||
"""
|
||||
db_link = get_user_idp_link(user_id, idp_id, db)
|
||||
db_link = get_user_identity_provider_by_user_id_and_idp_id(user_id, idp_id, db)
|
||||
if db_link:
|
||||
db_link.last_login = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
@@ -135,7 +163,7 @@ def update_user_idp_last_login(
|
||||
return db_link
|
||||
|
||||
|
||||
def store_idp_tokens(
|
||||
def store_user_identity_provider_tokens(
|
||||
user_id: int,
|
||||
idp_id: int,
|
||||
encrypted_refresh_token: str,
|
||||
@@ -162,7 +190,7 @@ def store_idp_tokens(
|
||||
The refresh_token parameter must be pre-encrypted with Fernet before calling this function.
|
||||
Never pass plaintext tokens to this function.
|
||||
"""
|
||||
db_link = get_user_idp_link(user_id, idp_id, db)
|
||||
db_link = get_user_identity_provider_by_user_id_and_idp_id(user_id, idp_id, db)
|
||||
if db_link:
|
||||
db_link.idp_refresh_token = encrypted_refresh_token
|
||||
db_link.idp_access_token_expires_at = access_token_expires_at
|
||||
@@ -172,35 +200,9 @@ def store_idp_tokens(
|
||||
return db_link
|
||||
|
||||
|
||||
def get_idp_refresh_token(
|
||||
def clear_user_identity_provider_refresh_token_by_user_id_and_idp_id(
|
||||
user_id: int, idp_id: int, db: Session
|
||||
) -> str | None:
|
||||
"""
|
||||
Get the encrypted refresh token for a user-IdP link.
|
||||
|
||||
This function retrieves the encrypted refresh token. The caller is responsible
|
||||
for decrypting it using Fernet before use.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
idp_id (int): The ID of the identity provider.
|
||||
db (Session): The SQLAlchemy database session.
|
||||
|
||||
Returns:
|
||||
str | None: The encrypted refresh token string if found, otherwise None.
|
||||
|
||||
Security Note:
|
||||
- Returns the encrypted token (not plaintext)
|
||||
- Caller must decrypt using Fernet
|
||||
- Returns None if link doesn't exist or token is not set
|
||||
"""
|
||||
db_link = get_user_idp_link(user_id, idp_id, db)
|
||||
if db_link:
|
||||
return db_link.idp_refresh_token
|
||||
return None
|
||||
|
||||
|
||||
def clear_idp_refresh_token(user_id: int, idp_id: int, db: Session) -> bool:
|
||||
) -> bool:
|
||||
"""
|
||||
Clear the IdP refresh token and related metadata.
|
||||
|
||||
@@ -218,7 +220,7 @@ def clear_idp_refresh_token(user_id: int, idp_id: int, db: Session) -> bool:
|
||||
Returns:
|
||||
bool: True if the token was cleared, False if the link was not found.
|
||||
"""
|
||||
db_link = get_user_idp_link(user_id, idp_id, db)
|
||||
db_link = get_user_identity_provider_by_user_id_and_idp_id(user_id, idp_id, db)
|
||||
if db_link:
|
||||
db_link.idp_refresh_token = None
|
||||
db_link.idp_access_token_expires_at = None
|
||||
@@ -228,7 +230,7 @@ def clear_idp_refresh_token(user_id: int, idp_id: int, db: Session) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def delete_user_idp_link(user_id: int, idp_id: int, db: Session) -> bool:
|
||||
def delete_user_identity_provider(user_id: int, idp_id: int, db: Session) -> bool:
|
||||
"""
|
||||
Deletes the link between a user and an identity provider (IDP) from the database.
|
||||
|
||||
@@ -246,7 +248,7 @@ def delete_user_idp_link(user_id: int, idp_id: int, db: Session) -> bool:
|
||||
Security Note:
|
||||
Sensitive token data is explicitly cleared before deletion as a defense-in-depth measure.
|
||||
"""
|
||||
db_link = get_user_idp_link(user_id, idp_id, db)
|
||||
db_link = get_user_identity_provider_by_user_id_and_idp_id(user_id, idp_id, db)
|
||||
if db_link:
|
||||
# Clear sensitive data first (defense in depth)
|
||||
db_link.idp_refresh_token = None
|
||||
|
||||
@@ -10,18 +10,19 @@ Security:
|
||||
- Does NOT expose refresh tokens (security)
|
||||
- Audit logging handled by CRUD layer
|
||||
"""
|
||||
|
||||
from typing import Annotated, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import core.database as core_database
|
||||
import core.logger as core_logger
|
||||
import session.security as session_security
|
||||
import auth.security as auth_security
|
||||
import users.user_identity_providers.crud as user_idp_crud
|
||||
import users.user_identity_providers.schema as user_idp_schema
|
||||
import users.user.schema as users_schema
|
||||
import users.user.crud as users_crud
|
||||
import identity_providers.crud as idp_crud
|
||||
import auth.identity_providers.crud as idp_crud
|
||||
|
||||
|
||||
# Define the API router
|
||||
@@ -38,7 +39,7 @@ async def get_user_identity_providers(
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
_check_scopes: Annotated[
|
||||
users_schema.UserRead,
|
||||
Security(session_security.check_scopes, scopes=["sessions:read"]),
|
||||
Security(auth_security.check_scopes, scopes=["sessions:read"]),
|
||||
],
|
||||
):
|
||||
"""
|
||||
@@ -88,12 +89,12 @@ async def get_user_identity_providers(
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User with id {user_id} not found"
|
||||
detail=f"User with id {user_id} not found",
|
||||
)
|
||||
|
||||
|
||||
# Get user's identity provider links
|
||||
idp_links = user_idp_crud.get_user_idp_links(user_id, db)
|
||||
|
||||
idp_links = user_idp_crud.get_user_identity_providers_by_user_id(user_id, db)
|
||||
|
||||
# Enrich with IDP details for frontend display
|
||||
enriched_links = []
|
||||
for link in idp_links:
|
||||
@@ -108,7 +109,7 @@ async def get_user_identity_providers(
|
||||
"idp_access_token_expires_at": link.idp_access_token_expires_at,
|
||||
"idp_refresh_token_updated_at": link.idp_refresh_token_updated_at,
|
||||
}
|
||||
|
||||
|
||||
# Fetch IDP details for display
|
||||
idp = idp_crud.get_identity_provider(link.idp_id, db)
|
||||
if idp:
|
||||
@@ -116,9 +117,9 @@ async def get_user_identity_providers(
|
||||
link_dict["idp_slug"] = idp.slug
|
||||
link_dict["idp_icon"] = idp.icon
|
||||
link_dict["idp_provider_type"] = idp.provider_type
|
||||
|
||||
|
||||
enriched_links.append(link_dict)
|
||||
|
||||
|
||||
return enriched_links
|
||||
|
||||
|
||||
@@ -129,12 +130,10 @@ async def get_user_identity_providers(
|
||||
async def delete_user_identity_provider(
|
||||
user_id: int,
|
||||
idp_id: int,
|
||||
token_user_id: Annotated[
|
||||
int, Depends(session_security.get_sub_from_access_token)
|
||||
],
|
||||
token_user_id: Annotated[int, Depends(auth_security.get_sub_from_access_token)],
|
||||
_check_scopes: Annotated[
|
||||
users_schema.UserRead,
|
||||
Security(session_security.check_scopes, scopes=["sessions:write"]),
|
||||
Security(auth_security.check_scopes, scopes=["sessions:write"]),
|
||||
],
|
||||
db: Annotated[Session, Depends(core_database.get_db)],
|
||||
):
|
||||
@@ -164,31 +163,31 @@ async def delete_user_identity_provider(
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User with id {user_id} not found"
|
||||
detail=f"User with id {user_id} not found",
|
||||
)
|
||||
|
||||
|
||||
# Validate IDP exists
|
||||
idp = idp_crud.get_identity_provider(idp_id, db)
|
||||
if idp is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Identity provider with id {idp_id} not found"
|
||||
detail=f"Identity provider with id {idp_id} not found",
|
||||
)
|
||||
|
||||
|
||||
# Attempt to delete the link
|
||||
success = user_idp_crud.delete_user_idp_link(user_id, idp_id, db)
|
||||
|
||||
success = user_idp_crud.delete_user_identity_provider(user_id, idp_id, db)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Identity provider link not found for user {user_id} and IDP {idp_id}"
|
||||
detail=f"Identity provider link not found for user {user_id} and IDP {idp_id}",
|
||||
)
|
||||
|
||||
|
||||
# Audit logging
|
||||
core_logger.print_to_log(
|
||||
f"Admin user {token_user_id} deleted IDP link: "
|
||||
f"user_id={user_id}, idp_id={idp_id} ({idp.name})"
|
||||
)
|
||||
|
||||
|
||||
# Return 204 No Content (successful deletion)
|
||||
return None
|
||||
|
||||
@@ -5,7 +5,7 @@ from pwdlib import PasswordHash
|
||||
from pwdlib.hashers.argon2 import Argon2Hasher
|
||||
from pwdlib.hashers.bcrypt import BcryptHasher
|
||||
|
||||
from session.password_hasher import PasswordHasher, PasswordPolicyError
|
||||
from auth.password_hasher import PasswordHasher, PasswordPolicyError
|
||||
|
||||
|
||||
class TestPasswordHasherSecurity:
|
||||
@@ -3,7 +3,7 @@ from datetime import datetime, timezone
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import session.token_manager as session_token_manager
|
||||
import authtoken_manager as auth_token_manager
|
||||
|
||||
|
||||
class TestTokenManagerSecurity:
|
||||
@@ -36,7 +36,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-123"
|
||||
_, token = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
sub_claim = token_manager.get_token_claim(token, "sub")
|
||||
@@ -54,7 +54,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
_, token = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -78,7 +78,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
_, token = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
payload = token_manager.decode_token(token)
|
||||
@@ -109,7 +109,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
_, token = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
payload = token_manager.decode_token(token)
|
||||
@@ -151,15 +151,15 @@ class TestTokenManagerSecurity:
|
||||
This test creates a token using one instance of TokenManager with a specific secret key, then attempts to decode the token using another TokenManager instance with a different secret key. It asserts that an HTTPException with status code 401 is raised, indicating unauthorized access due to the wrong secret.
|
||||
"""
|
||||
# Create token with one manager
|
||||
manager1 = session_token_manager.TokenManager(
|
||||
manager1 = auth_token_manager.TokenManager(
|
||||
secret_key="secret-key-one-min-32-characters-long"
|
||||
)
|
||||
_, token = manager1.create_token(
|
||||
"session-id", sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
# Try to decode with different manager
|
||||
manager2 = session_token_manager.TokenManager(
|
||||
manager2 = auth_token_manager.TokenManager(
|
||||
secret_key="secret-key-two-min-32-characters-long"
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -178,7 +178,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
_, token = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
# Should not raise an exception
|
||||
@@ -223,7 +223,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
exp_time, token = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
assert token is not None, "Token should not be None"
|
||||
@@ -254,7 +254,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
exp_time, token = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.REFRESH
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.REFRESH
|
||||
)
|
||||
|
||||
assert token is not None, "Token should not be None"
|
||||
@@ -276,10 +276,10 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
access_exp, _ = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
refresh_exp, _ = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.REFRESH
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.REFRESH
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -335,10 +335,10 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
_, token1 = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
_, token2 = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
assert token1 != token2, "Tokens should be unique even for the same user"
|
||||
@@ -361,7 +361,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
_, token = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
# Tamper with the token
|
||||
@@ -418,10 +418,10 @@ class TestTokenManagerSecurity:
|
||||
The tokens generated for different session IDs are not equal.
|
||||
"""
|
||||
_, token1 = token_manager.create_token(
|
||||
"session-id-1", sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
"session-id-1", sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
_, token2 = token_manager.create_token(
|
||||
"session-id-2", sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
"session-id-2", sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
assert token1 != token2, "Different session IDs should produce different tokens"
|
||||
@@ -435,7 +435,7 @@ class TestTokenManagerSecurity:
|
||||
"""
|
||||
session_id = "test-session-id"
|
||||
exp_time, _ = token_manager.create_token(
|
||||
session_id, sample_user_read, session_token_manager.TokenType.ACCESS
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -18,8 +18,8 @@ load_dotenv(dotenv_path=env_test_path)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "app"))
|
||||
|
||||
import session.router as session_router
|
||||
import session.password_hasher as session_password_hasher
|
||||
import session.token_manager as session_token_manager
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
import auth.token_manager as auth_token_manager
|
||||
import users.user.schema as user_schema
|
||||
|
||||
# Variables and constants
|
||||
@@ -29,25 +29,25 @@ DEFAULT_ROUTER_MODULES = [
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def password_hasher() -> session_password_hasher.PasswordHasher:
|
||||
def password_hasher() -> auth_password_hasher.PasswordHasher:
|
||||
"""
|
||||
Creates and returns an instance of session_password_hasher.PasswordHasher using the get_password_hasher function.
|
||||
Creates and returns an instance of auth_password_hasher.PasswordHasher using the get_password_hasher function.
|
||||
|
||||
Returns:
|
||||
session_password_hasher.PasswordHasher: An instance of the password hasher utility.
|
||||
auth_password_hasher.PasswordHasher: An instance of the password hasher utility.
|
||||
"""
|
||||
return session_password_hasher.get_password_hasher()
|
||||
return auth_password_hasher.get_password_hasher()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_manager() -> session_token_manager.TokenManager:
|
||||
def token_manager() -> auth_token_manager.TokenManager:
|
||||
"""
|
||||
Creates and returns a session_token_manager.TokenManager instance configured with a test secret key.
|
||||
Creates and returns a auth_token_manager.TokenManager instance configured with a test secret key.
|
||||
|
||||
Returns:
|
||||
session_token_manager.TokenManager: An instance of session_token_manager.TokenManager initialized with a test secret key for use in testing.
|
||||
auth_token_manager.TokenManager: An instance of auth_token_manager.TokenManager initialized with a test secret key for use in testing.
|
||||
"""
|
||||
return session_token_manager.TokenManager(
|
||||
return auth_token_manager.TokenManager(
|
||||
secret_key="test-secret-key-for-testing-only-min-32-chars"
|
||||
)
|
||||
|
||||
@@ -329,64 +329,64 @@ def fast_api_app(password_hasher, token_manager, mock_db) -> FastAPI:
|
||||
|
||||
try:
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.header_client_type_scheme
|
||||
session_router.auth_security.header_client_type_scheme
|
||||
] = _client_type_override
|
||||
app.dependency_overrides[
|
||||
session_router.session_schema.get_pending_mfa_store
|
||||
] = lambda: fake_store
|
||||
|
||||
# Override security dependencies for authenticated endpoint testing
|
||||
app.dependency_overrides[session_router.auth_security.validate_access_token] = (
|
||||
_mock_validate_access_token
|
||||
)
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.validate_access_token
|
||||
] = _mock_validate_access_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.validate_refresh_token
|
||||
session_router.auth_security.validate_refresh_token
|
||||
] = _mock_validate_refresh_token
|
||||
app.dependency_overrides[session_router.auth_security.get_access_token] = (
|
||||
_mock_get_access_token
|
||||
)
|
||||
app.dependency_overrides[session_router.auth_security.get_refresh_token] = (
|
||||
_mock_get_refresh_token
|
||||
)
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.get_access_token
|
||||
] = _mock_get_access_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.get_refresh_token
|
||||
] = _mock_get_refresh_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.get_sub_from_access_token
|
||||
session_router.auth_security.get_sub_from_access_token
|
||||
] = _mock_get_sub_from_access_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.get_sid_from_access_token
|
||||
session_router.auth_security.get_sid_from_access_token
|
||||
] = _mock_get_sid_from_access_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.get_sub_from_refresh_token
|
||||
session_router.auth_security.get_sub_from_refresh_token
|
||||
] = _mock_get_sub_from_refresh_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.get_sid_from_refresh_token
|
||||
session_router.auth_security.get_sid_from_refresh_token
|
||||
] = _mock_get_sid_from_refresh_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.get_and_return_access_token
|
||||
session_router.auth_security.get_and_return_access_token
|
||||
] = _mock_get_and_return_access_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.get_and_return_refresh_token
|
||||
session_router.auth_security.get_and_return_refresh_token
|
||||
] = _mock_get_and_return_refresh_token
|
||||
app.dependency_overrides[
|
||||
session_router.session_security.check_scopes
|
||||
] = _mock_check_scopes
|
||||
app.dependency_overrides[session_router.auth_security.check_scopes] = (
|
||||
_mock_check_scopes
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Generic overrides
|
||||
_override_if_exists(
|
||||
app, "session.password_hasher", "get_password_hasher", lambda: password_hasher
|
||||
app, "auth.password_hasher", "get_password_hasher", lambda: password_hasher
|
||||
)
|
||||
_override_if_exists(
|
||||
app,
|
||||
"session.session_password_hasher",
|
||||
"session.auth_password_hasher",
|
||||
"get_password_hasher",
|
||||
lambda: password_hasher,
|
||||
)
|
||||
_override_if_exists(
|
||||
app, "session.token_manager", "get_token_manager", lambda: token_manager
|
||||
app, "auth.token_manager", "get_token_manager", lambda: token_manager
|
||||
)
|
||||
_override_if_exists(
|
||||
app, "session.session_token_manager", "get_token_manager", lambda: token_manager
|
||||
app, "session.auth_token_manager", "get_token_manager", lambda: token_manager
|
||||
)
|
||||
_override_if_exists(
|
||||
app, "core.database", "get_db", lambda: mock_db
|
||||
|
||||
@@ -52,12 +52,12 @@ class TestLoginEndpointSecurity:
|
||||
returns_tokens: Boolean indicating if the endpoint should return tokens or just a session ID.
|
||||
"""
|
||||
fast_api_app.state._client_type = client_type
|
||||
with patch(
|
||||
"session.router.session_utils.authenticate_user"
|
||||
) as mock_auth, patch("session.router.users_utils.check_user_is_active"), patch(
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.profile_utils.is_mfa_enabled_for_user"
|
||||
) as mock_mfa, patch(
|
||||
"session.router.session_utils.complete_login"
|
||||
"session.router.auth_utils.complete_login"
|
||||
) as mock_complete:
|
||||
mock_auth.return_value = sample_user_read
|
||||
mock_mfa.return_value = False
|
||||
@@ -127,11 +127,9 @@ class TestLoginEndpointSecurity:
|
||||
- The fake_store in the app state records the correct call.
|
||||
"""
|
||||
fast_api_app.state._client_type = client_type
|
||||
with patch(
|
||||
"session.router.session_utils.authenticate_user"
|
||||
) as mock_auth, patch("session.router.users_utils.check_user_is_active"), patch(
|
||||
"session.router.profile_utils.is_mfa_enabled_for_user"
|
||||
) as mock_mfa:
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch("session.router.profile_utils.is_mfa_enabled_for_user") as mock_mfa:
|
||||
mock_auth.return_value = sample_user_read
|
||||
mock_mfa.return_value = True
|
||||
|
||||
@@ -166,12 +164,12 @@ class TestLoginEndpointSecurity:
|
||||
sample_user_read: A sample user object returned by the authentication mock.
|
||||
"""
|
||||
fast_api_app.state._client_type = "desktop"
|
||||
with patch(
|
||||
"session.router.session_utils.authenticate_user"
|
||||
) as mock_auth, patch("session.router.users_utils.check_user_is_active"), patch(
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.profile_utils.is_mfa_enabled_for_user"
|
||||
) as mock_mfa, patch(
|
||||
"session.router.session_utils.create_tokens"
|
||||
"session.router.auth_utils.create_tokens"
|
||||
) as mock_create_tokens, patch(
|
||||
"session.router.session_utils.create_session"
|
||||
) as mock_create_session:
|
||||
@@ -204,7 +202,7 @@ class TestLoginEndpointSecurity:
|
||||
to simulate authentication failure and verifies that the exception is raised
|
||||
with the correct status code and detail.
|
||||
"""
|
||||
with patch("session.router.session_utils.authenticate_user") as mock_auth:
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth:
|
||||
mock_auth.side_effect = HTTPException(
|
||||
status_code=401, detail="Invalid username"
|
||||
)
|
||||
@@ -223,7 +221,7 @@ class TestLoginEndpointSecurity:
|
||||
the scenario where a user is found but is inactive. It asserts that the correct
|
||||
exception is raised with the expected status code.
|
||||
"""
|
||||
with patch("session.router.session_utils.authenticate_user") as mock_auth:
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth:
|
||||
with patch("session.router.users_utils.check_user_is_active") as mock_check:
|
||||
mock_auth.return_value = sample_inactive_user
|
||||
mock_check.side_effect = HTTPException(
|
||||
@@ -283,7 +281,7 @@ class TestMFAVerifyEndpoint:
|
||||
returns_tokens: Boolean indicating if tokens should be returned.
|
||||
"""
|
||||
fast_api_app.state._client_type = client_type
|
||||
|
||||
|
||||
# Setup pending MFA login
|
||||
pending_store = fast_api_app.state.fake_store
|
||||
pending_store._store = {"testuser": sample_user_read.id}
|
||||
@@ -295,7 +293,7 @@ class TestMFAVerifyEndpoint:
|
||||
) as mock_get_user, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.session_utils.complete_login"
|
||||
"session.router.auth_utils.complete_login"
|
||||
) as mock_complete:
|
||||
mock_verify_mfa.return_value = True
|
||||
mock_get_user.return_value = sample_user_read
|
||||
@@ -488,7 +486,9 @@ class TestRefreshTokenEndpoint:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password("refresh_token_value")
|
||||
mock_session.refresh_token = password_hasher.hash_password(
|
||||
"refresh_token_value"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
@@ -497,11 +497,11 @@ class TestRefreshTokenEndpoint:
|
||||
), patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.session_utils.create_tokens"
|
||||
"session.router.auth_utils.create_tokens"
|
||||
) as mock_create_tokens, patch(
|
||||
"session.router.session_utils.edit_session"
|
||||
), patch(
|
||||
"session.router.session_utils.create_response_with_tokens",
|
||||
"session.router.auth_utils.create_response_with_tokens",
|
||||
side_effect=lambda r, a, rf, c: r,
|
||||
):
|
||||
# Set up proper mock for create_tokens with timestamp
|
||||
@@ -516,10 +516,10 @@ class TestRefreshTokenEndpoint:
|
||||
"new_refresh_token",
|
||||
"new_csrf_token",
|
||||
)
|
||||
|
||||
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/refresh",
|
||||
headers={"X-Client-Type": client_type},
|
||||
@@ -547,12 +547,10 @@ class TestRefreshTokenEndpoint:
|
||||
fast_api_app.state.mock_session_id = "nonexistent-session"
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=None
|
||||
):
|
||||
with patch("session.router.session_crud.get_session_by_id", return_value=None):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/refresh",
|
||||
headers={"X-Client-Type": "web"},
|
||||
@@ -584,7 +582,7 @@ class TestRefreshTokenEndpoint:
|
||||
):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "wrong_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/refresh",
|
||||
headers={"X-Client-Type": "web"},
|
||||
@@ -609,7 +607,9 @@ class TestRefreshTokenEndpoint:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password("refresh_token_value")
|
||||
mock_session.refresh_token = password_hasher.hash_password(
|
||||
"refresh_token_value"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
@@ -624,7 +624,7 @@ class TestRefreshTokenEndpoint:
|
||||
):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/refresh",
|
||||
headers={"X-Client-Type": "web"},
|
||||
@@ -648,7 +648,9 @@ class TestRefreshTokenEndpoint:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password("refresh_token_value")
|
||||
mock_session.refresh_token = password_hasher.hash_password(
|
||||
"refresh_token_value"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
@@ -657,7 +659,7 @@ class TestRefreshTokenEndpoint:
|
||||
), patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.session_utils.create_tokens"
|
||||
"session.router.auth_utils.create_tokens"
|
||||
) as mock_create_tokens, patch(
|
||||
"session.router.session_utils.edit_session"
|
||||
):
|
||||
@@ -673,10 +675,10 @@ class TestRefreshTokenEndpoint:
|
||||
"new_refresh_token",
|
||||
"new_csrf_token",
|
||||
)
|
||||
|
||||
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/refresh",
|
||||
headers={"X-Client-Type": "desktop"},
|
||||
@@ -735,17 +737,17 @@ class TestLogoutEndpoint:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password("refresh_token_value")
|
||||
mock_session.refresh_token = password_hasher.hash_password(
|
||||
"refresh_token_value"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
), patch(
|
||||
"session.router.session_crud.delete_session"
|
||||
) as mock_delete:
|
||||
), patch("session.router.session_crud.delete_session") as mock_delete:
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_access_token", "access_token")
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/logout",
|
||||
headers={"X-Client-Type": client_type},
|
||||
@@ -784,7 +786,7 @@ class TestLogoutEndpoint:
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_access_token", "access_token")
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "wrong_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/logout",
|
||||
headers={"X-Client-Type": "web"},
|
||||
@@ -807,13 +809,11 @@ class TestLogoutEndpoint:
|
||||
fast_api_app.state.mock_user_id = 1
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=None
|
||||
):
|
||||
with patch("session.router.session_crud.get_session_by_id", return_value=None):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_access_token", "access_token")
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/logout",
|
||||
headers={"X-Client-Type": "web"},
|
||||
@@ -844,7 +844,7 @@ class TestLogoutEndpoint:
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_access_token", "access_token")
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/logout",
|
||||
headers={"X-Client-Type": "desktop"},
|
||||
@@ -921,9 +921,7 @@ class TestSessionsEndpoints:
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_user_sessions", return_value=[]
|
||||
):
|
||||
with patch("session.router.session_crud.get_user_sessions", return_value=[]):
|
||||
resp = fast_api_client.get(
|
||||
f"/sessions/user/{sample_user_read.id}",
|
||||
headers={
|
||||
|
||||
@@ -7,7 +7,8 @@ from fastapi import HTTPException, Response
|
||||
from pwdlib.hashers.bcrypt import BcryptHasher
|
||||
from pwdlib import PasswordHash
|
||||
|
||||
import session.password_hasher as session_password_hasher
|
||||
import auth.password_hasher as auth_password_hasher
|
||||
import auth.utils as auth_utils
|
||||
import session.utils as session_utils
|
||||
|
||||
|
||||
@@ -144,17 +145,16 @@ class TestAuthenticationSecurity:
|
||||
), "Operating system should be set"
|
||||
assert updated_session.browser is not None, "Browser should be set"
|
||||
|
||||
|
||||
def test_authenticate_user_with_valid_credentials(
|
||||
self, password_hasher, mock_db, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test that the `session_utils.authenticate_user` function successfully authenticates a user with valid credentials.
|
||||
Test that the `auth_utils.authenticate_user` function successfully authenticates a user with valid credentials.
|
||||
This test:
|
||||
- Hashes a sample password using the provided password hasher.
|
||||
- Mocks a user ORM object with the hashed password and sample user data.
|
||||
- Patches the `session_utils.authenticate_user` function in the users CRUD utility to return the mocked user.
|
||||
- Calls the actual `session_utils.authenticate_user` function with valid credentials.
|
||||
- Patches the `auth_utils.authenticate_user` function in the users CRUD utility to return the mocked user.
|
||||
- Calls the actual `auth_utils.authenticate_user` function with valid credentials.
|
||||
- Asserts that authentication succeeds and the returned user matches the expected sample user.
|
||||
Args:
|
||||
password_hasher: Fixture or mock for password hashing utilities.
|
||||
@@ -177,7 +177,7 @@ class TestAuthenticationSecurity:
|
||||
with patch("session.utils.users_crud.authenticate_user") as mock_auth:
|
||||
mock_auth.return_value = mock_user_orm
|
||||
|
||||
result = session_utils.authenticate_user(
|
||||
result = auth_utils.authenticate_user(
|
||||
"testuser", password, password_hasher, mock_db
|
||||
)
|
||||
|
||||
@@ -187,7 +187,7 @@ class TestAuthenticationSecurity:
|
||||
|
||||
def test_authenticate_user_with_invalid_username(self, password_hasher, mock_db):
|
||||
"""
|
||||
Test that the `session_utils.authenticate_user` function raises an HTTPException with status code 401
|
||||
Test that the `auth_utils.authenticate_user` function raises an HTTPException with status code 401
|
||||
when provided with an invalid (nonexistent) username. Ensures that the exception detail
|
||||
contains information about the username.
|
||||
"""
|
||||
@@ -195,7 +195,7 @@ class TestAuthenticationSecurity:
|
||||
mock_auth.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
session_utils.authenticate_user(
|
||||
auth_utils.authenticate_user(
|
||||
"nonexistent", "password", password_hasher, mock_db
|
||||
)
|
||||
|
||||
@@ -206,7 +206,7 @@ class TestAuthenticationSecurity:
|
||||
self, password_hasher, mock_db, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test that the `session_utils.authenticate_user` function raises an HTTPException with status code 401
|
||||
Test that the `auth_utils.authenticate_user` function raises an HTTPException with status code 401
|
||||
when an incorrect password is provided for an existing user.
|
||||
This test mocks the user retrieval and password hashing process, simulating a scenario
|
||||
where the user exists but the provided password does not match the stored hash.
|
||||
@@ -225,7 +225,7 @@ class TestAuthenticationSecurity:
|
||||
mock_auth.return_value = mock_user_orm
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
session_utils.authenticate_user(
|
||||
auth_utils.authenticate_user(
|
||||
"testuser", wrong_password, password_hasher, mock_db
|
||||
)
|
||||
|
||||
@@ -236,7 +236,7 @@ class TestAuthenticationSecurity:
|
||||
self, password_hasher, mock_db, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test that the `session_utils.authenticate_user` function updates the user's password hash if the current hash is outdated.
|
||||
Test that the `auth_utils.authenticate_user` function updates the user's password hash if the current hash is outdated.
|
||||
This test simulates a scenario where a user's password is hashed with an old hasher. It mocks the authentication and password update functions to verify that authentication succeeds and that the system is prepared to update the password hash if necessary.
|
||||
Args:
|
||||
self: The test case instance.
|
||||
@@ -248,9 +248,7 @@ class TestAuthenticationSecurity:
|
||||
"""
|
||||
password = "TestPassword123!"
|
||||
|
||||
old_hasher = session_password_hasher.PasswordHasher(
|
||||
PasswordHash([BcryptHasher()])
|
||||
)
|
||||
old_hasher = auth_password_hasher.PasswordHasher(PasswordHash([BcryptHasher()]))
|
||||
|
||||
# Create ORM-like object with password attribute
|
||||
mock_user_orm = MagicMock()
|
||||
@@ -262,7 +260,7 @@ class TestAuthenticationSecurity:
|
||||
with patch("session.utils.users_crud.edit_user_password") as _mock_edit:
|
||||
mock_auth.return_value = mock_user_orm
|
||||
|
||||
result = session_utils.authenticate_user(
|
||||
result = auth_utils.authenticate_user(
|
||||
"testuser", password, password_hasher, mock_db
|
||||
)
|
||||
|
||||
@@ -297,7 +295,7 @@ class TestAuthenticationSecurity:
|
||||
|
||||
for username in malicious_usernames:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
session_utils.authenticate_user(
|
||||
auth_utils.authenticate_user(
|
||||
username, "password", password_hasher, mock_db
|
||||
)
|
||||
assert exc_info.value.status_code == 401
|
||||
@@ -323,15 +321,11 @@ class TestAuthenticationSecurity:
|
||||
mock_auth.return_value = mock_user_orm
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
session_utils.authenticate_user(
|
||||
"testuser", "", password_hasher, mock_db
|
||||
)
|
||||
auth_utils.authenticate_user("testuser", "", password_hasher, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_empty_username_authentication(
|
||||
self, password_hasher, mock_db
|
||||
):
|
||||
def test_empty_username_authentication(self, password_hasher, mock_db):
|
||||
"""
|
||||
Test that authentication fails with an empty username.
|
||||
|
||||
@@ -345,7 +339,7 @@ class TestAuthenticationSecurity:
|
||||
mock_auth.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
session_utils.authenticate_user(
|
||||
auth_utils.authenticate_user(
|
||||
"", "RealPassword123!", password_hasher, mock_db
|
||||
)
|
||||
|
||||
@@ -379,7 +373,7 @@ class TestAuthenticationSecurity:
|
||||
mock_auth.return_value = mock_user_orm
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
session_utils.authenticate_user(
|
||||
auth_utils.authenticate_user(
|
||||
"testuser", " ", password_hasher, mock_db
|
||||
)
|
||||
|
||||
@@ -389,7 +383,7 @@ class TestAuthenticationSecurity:
|
||||
self, token_manager, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test that the `session_utils.create_tokens` function generates all required tokens and their expirations.
|
||||
Test that the `auth_utils.create_tokens` function generates all required tokens and their expirations.
|
||||
|
||||
This test verifies that:
|
||||
- A session ID is generated.
|
||||
@@ -412,7 +406,7 @@ class TestAuthenticationSecurity:
|
||||
refresh_token_exp,
|
||||
refresh_token,
|
||||
csrf_token,
|
||||
) = session_utils.create_tokens(sample_user_read, token_manager)
|
||||
) = auth_utils.create_tokens(sample_user_read, token_manager)
|
||||
|
||||
assert session_id is not None, "Session ID should be generated"
|
||||
assert access_token is not None, "Access token should be generated"
|
||||
@@ -429,14 +423,14 @@ class TestAuthenticationSecurity:
|
||||
self, token_manager, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test that the `session_utils.create_tokens` function uses the provided session ID when one is supplied.
|
||||
Test that the `auth_utils.create_tokens` function uses the provided session ID when one is supplied.
|
||||
|
||||
Args:
|
||||
token_manager: The token manager fixture or mock used to generate tokens.
|
||||
sample_user_read: A sample user object used for token creation.
|
||||
|
||||
Asserts:
|
||||
The returned session ID from `session_utils.create_tokens` matches the provided session ID.
|
||||
The returned session ID from `auth_utils.create_tokens` matches the provided session ID.
|
||||
"""
|
||||
provided_session_id = "custom-session-id-123"
|
||||
|
||||
@@ -447,7 +441,7 @@ class TestAuthenticationSecurity:
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = session_utils.create_tokens(
|
||||
) = auth_utils.create_tokens(
|
||||
sample_user_read, token_manager, session_id=provided_session_id
|
||||
)
|
||||
|
||||
@@ -457,9 +451,9 @@ class TestAuthenticationSecurity:
|
||||
self, token_manager, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test that the `session_utils.create_tokens` function generates unique session IDs for each invocation.
|
||||
Test that the `auth_utils.create_tokens` function generates unique session IDs for each invocation.
|
||||
|
||||
This test calls `session_utils.create_tokens` multiple times with the same user and token manager,
|
||||
This test calls `auth_utils.create_tokens` multiple times with the same user and token manager,
|
||||
collects the returned session IDs, and asserts that all session IDs are unique.
|
||||
|
||||
Args:
|
||||
@@ -472,7 +466,7 @@ class TestAuthenticationSecurity:
|
||||
"""
|
||||
session_ids = set()
|
||||
for _ in range(10):
|
||||
session_id, _, _, _, _, _ = session_utils.create_tokens(
|
||||
session_id, _, _, _, _, _ = auth_utils.create_tokens(
|
||||
sample_user_read, token_manager
|
||||
)
|
||||
session_ids.add(session_id)
|
||||
@@ -496,7 +490,7 @@ class TestAuthenticationSecurity:
|
||||
mock_request.headers["X-Client-Type"] = "web"
|
||||
|
||||
with patch("session.utils.create_session") as mock_create_session:
|
||||
result = session_utils.complete_login(
|
||||
result = auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
@@ -532,7 +526,7 @@ class TestAuthenticationSecurity:
|
||||
mock_request.headers["X-Client-Type"] = "mobile"
|
||||
|
||||
with patch("session.utils.create_session") as mock_create_session:
|
||||
result = session_utils.complete_login(
|
||||
result = auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
@@ -573,7 +567,7 @@ class TestAuthenticationSecurity:
|
||||
|
||||
with patch("session.utils.create_session"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
session_utils.complete_login(
|
||||
auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
@@ -900,7 +894,7 @@ class TestAuthenticationSecurity:
|
||||
refresh_token = "test-refresh-token"
|
||||
|
||||
with patch("session.utils.session_crud.create_session") as mock_create:
|
||||
with patch("session.utils.session_constants") as mock_constants:
|
||||
with patch("session.utils.auth_constants") as mock_constants:
|
||||
# Set the expiration days
|
||||
mock_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||
|
||||
@@ -917,9 +911,7 @@ class TestAuthenticationSecurity:
|
||||
session_obj = mock_create.call_args[0][0]
|
||||
|
||||
# Verify expiration is set and is in the future
|
||||
assert (
|
||||
session_obj.expires_at is not None
|
||||
), "Expiration should be set"
|
||||
assert session_obj.expires_at is not None, "Expiration should be set"
|
||||
assert isinstance(
|
||||
session_obj.expires_at, datetime
|
||||
), "Expiration should be a datetime"
|
||||
@@ -1040,8 +1032,8 @@ class TestAuthenticationSecurity:
|
||||
assert (
|
||||
updated_session.expires_at != old_expiration
|
||||
), "Expiration should be updated"
|
||||
assert (
|
||||
updated_session.expires_at > datetime.now(timezone.utc)
|
||||
assert updated_session.expires_at > datetime.now(
|
||||
timezone.utc
|
||||
), "New expiration should be in the future"
|
||||
assert isinstance(
|
||||
updated_session.expires_at, datetime
|
||||
@@ -1140,15 +1132,9 @@ class TestAuthenticationSecurity:
|
||||
updated_session = mock_edit.call_args[0][0]
|
||||
|
||||
# Verify device information is set (updated from mock_request)
|
||||
assert (
|
||||
updated_session.ip_address is not None
|
||||
), "IP address should be set"
|
||||
assert (
|
||||
updated_session.device_type is not None
|
||||
), "Device type should be set"
|
||||
assert updated_session.ip_address is not None, "IP address should be set"
|
||||
assert updated_session.device_type is not None, "Device type should be set"
|
||||
assert (
|
||||
updated_session.operating_system is not None
|
||||
), "Operating system should be set"
|
||||
assert updated_session.browser is not None, "Browser should be set"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user