mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'main' into feat/model-manager-queue-redesign
This commit is contained in:
17
Makefile
17
Makefile
@@ -16,20 +16,20 @@ help:
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "frontend-prettier Format the frontend using lint:prettier"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||
@echo "docs Serve the mkdocs site with live reload"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
cd invokeai && uv tool run ruff@0.11.2 format
|
||||
|
||||
# Runs ruff, fixing all errors it can fix and formatting
|
||||
ruff-unsafe:
|
||||
ruff check . --fix --unsafe-fixes
|
||||
ruff format .
|
||||
ruff format
|
||||
|
||||
# Runs mypy, using the config in pyproject.toml
|
||||
mypy:
|
||||
@@ -64,6 +64,13 @@ frontend-dev:
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
frontend-lint:
|
||||
cd invokeai/frontend/web/src && \
|
||||
pnpm lint:tsc && \
|
||||
pnpm lint:dpdm && \
|
||||
pnpm lint:eslint --fix && \
|
||||
pnpm lint:prettier --write
|
||||
|
||||
# Tag the release
|
||||
wheel:
|
||||
cd scripts && ./build_wheel.sh
|
||||
@@ -79,4 +86,4 @@ openapi:
|
||||
# Serve the mkdocs site w/ live reload
|
||||
.PHONY: docs
|
||||
docs:
|
||||
mkdocs serve
|
||||
mkdocs serve
|
||||
|
||||
@@ -52,7 +52,7 @@ The Unified Canvas is a fully integrated canvas implementation with support for
|
||||
|
||||
### Workflows & Nodes
|
||||
|
||||
Invoke offers a fully featured workflow management solution, enabling users to combine the power of node-based workflows with the easy of a UI. This allows for customizable generation pipelines to be developed and shared by users looking to create specific workflows to support their production use-cases.
|
||||
Invoke offers a fully featured workflow management solution, enabling users to combine the power of node-based workflows with the ease of a UI. This allows for customizable generation pipelines to be developed and shared by users looking to create specific workflows to support their production use-cases.
|
||||
|
||||
### Board & Gallery Management
|
||||
|
||||
|
||||
169
USER_ISOLATION_IMPLEMENTATION.md
Normal file
169
USER_ISOLATION_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,169 @@
|
||||
# User Isolation Implementation Summary
|
||||
|
||||
This document describes the implementation of user isolation features in the InvokeAI session queue and processing system to address issues identified in the enhancement request.
|
||||
|
||||
## Issues Addressed
|
||||
|
||||
### 1. Cross-User Image/Preview Visibility
|
||||
**Problem:** When two users are logged in simultaneously and one initiates a generation, the generation preview shows up in both users' browsers and the generated image gets saved to both users' image boards.
|
||||
|
||||
**Solution:** Implemented socket-level event filtering based on user authentication:
|
||||
|
||||
#### Backend Changes (`invokeai/app/api/sockets.py`):
|
||||
- Added socket authentication middleware in `_handle_connect()` method
|
||||
- Extracts JWT token from socket auth data or HTTP headers
|
||||
- Verifies token using existing `verify_token()` function
|
||||
- Stores `user_id` and `is_admin` in socket session for later use
|
||||
- Modified `_handle_queue_event()` to filter events by user:
|
||||
- For `QueueItemEventBase` events, only emit to:
|
||||
- The user who owns the queue item (`user_id` matches)
|
||||
- Admin users (`is_admin` is True)
|
||||
- For general queue events, emit to all subscribers
|
||||
|
||||
#### Event System Changes (`invokeai/app/services/events/events_common.py`):
|
||||
- Added `user_id` field to `QueueItemEventBase` class
|
||||
- Updated all event builders to include `user_id` from queue items:
|
||||
- `InvocationStartedEvent.build()`
|
||||
- `InvocationProgressEvent.build()`
|
||||
- `InvocationCompleteEvent.build()`
|
||||
- `InvocationErrorEvent.build()`
|
||||
- `QueueItemStatusChangedEvent.build()`
|
||||
|
||||
### 2. Batch Field Values Privacy
|
||||
**Problem:** Users can see batch field values from generation processes launched by other users.
|
||||
|
||||
**Solution:** Implemented field value sanitization at the API level:
|
||||
|
||||
#### API Router Changes (`invokeai/app/api/routers/session_queue.py`):
|
||||
- Created `sanitize_queue_item_for_user()` helper function
|
||||
- Clears `field_values` for non-admin users viewing other users' items
|
||||
- Admins and item owners can see all field values
|
||||
- Updated endpoints to require authentication and sanitize responses:
|
||||
- `list_all_queue_items()` - Added `CurrentUser` dependency
|
||||
- `get_queue_items_by_item_ids()` - Added `CurrentUser` dependency
|
||||
- `get_queue_item()` - Added `CurrentUser` dependency
|
||||
|
||||
### 3. Queue Updates Across Browser Windows
|
||||
**Problem:** When the job queue tab is open in multiple browsers and a generation is begun in one browser window, the queue does not update in the other window.
|
||||
|
||||
**Status:** This issue is likely resolved by the socket authentication and event filtering changes. The existing socket subscription mechanism (`subscribe_queue` event) already supports multiple connections per user. Testing is required to confirm this works correctly with the new authentication flow.
|
||||
|
||||
### 4. User Information Display
|
||||
**Problem:** Queue table lacks user identification, making it difficult to know who launched which job.
|
||||
|
||||
**Solution:** Added user information to queue items and UI:
|
||||
|
||||
#### Database Layer (`invokeai/app/services/session_queue/session_queue_sqlite.py`):
|
||||
- Updated SQL queries to JOIN with `users` table
|
||||
- Modified methods to fetch user information:
|
||||
- `get_queue_item()` - Now selects `display_name` and `email` from users table
|
||||
- `dequeue()` - Includes user info
|
||||
- `get_next()` - Includes user info
|
||||
- `get_current()` - Includes user info
|
||||
- `list_all_queue_items()` - Includes user info
|
||||
|
||||
#### Data Model Changes (`invokeai/app/services/session_queue/session_queue_common.py`):
|
||||
- Added optional fields to `SessionQueueItem`:
|
||||
- `user_display_name: Optional[str]` - Display name from users table
|
||||
- `user_email: Optional[str]` - Email from users table
|
||||
- Note: `user_id` field already existed from Migration 25
|
||||
|
||||
#### Frontend UI Changes:
|
||||
- **Constants** (`constants.ts`): Added `user: '8rem'` column width
|
||||
- **Header** (`QueueListHeader.tsx`): Added "User" column header
|
||||
- **Item Component** (`QueueItemComponent.tsx`):
|
||||
- Added logic to display user information (display_name → email → user_id)
|
||||
- Added user column to queue item row
|
||||
- Added tooltip with full username on hover
|
||||
- Added "Hidden for privacy" message when field_values are null for non-owned items
|
||||
- **Localization** (`en.json`): Added translations:
|
||||
- `"user": "User"`
|
||||
- `"fieldValuesHidden": "Hidden for privacy"`
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Token Verification
|
||||
- Tokens are verified using the existing `verify_token()` function from `invokeai.app.services.auth.token_service`
|
||||
- Invalid or missing tokens default to "system" user with non-admin privileges
|
||||
- Socket connections without valid tokens are still accepted for backward compatibility but have limited access
|
||||
|
||||
### Data Privacy
|
||||
- Field values are only visible to:
|
||||
- The user who created the queue item
|
||||
- Admin users
|
||||
- Non-admin users viewing other users' queue items see "Hidden for privacy" instead of field values
|
||||
|
||||
### Admin Privileges
|
||||
- Admin users can see all queue events and field values across all users
|
||||
- Admin status is determined from the JWT token's `is_admin` field
|
||||
|
||||
## Migration Notes
|
||||
|
||||
No database migration is required. The changes leverage:
|
||||
- Existing `user_id` column in `session_queue` table (added in Migration 25)
|
||||
- Existing `users` table (added in Migration 25)
|
||||
- SQL LEFT JOINs to fetch user information (gracefully handles missing user records)
|
||||
|
||||
## Testing Requirements
|
||||
|
||||
### Backend Testing
|
||||
1. **Socket Authentication:**
|
||||
- Verify valid tokens are accepted and user context is stored
|
||||
- Verify invalid tokens default to system user
|
||||
- Verify expired tokens are rejected
|
||||
|
||||
2. **Event Filtering:**
|
||||
- User A should only receive events for their own queue items
|
||||
- Admin users should receive all events
|
||||
- Non-admin users should not receive events from other users
|
||||
|
||||
3. **Field Value Sanitization:**
|
||||
- Non-admin users should see null field_values for other users' items
|
||||
- Admins should see all field values
|
||||
- Users should see their own field values
|
||||
|
||||
### Frontend Testing
|
||||
1. **UI Display:**
|
||||
- User column should display in queue list
|
||||
- Display name should be shown when available
|
||||
- Email should be shown as fallback when display name is missing
|
||||
- User ID should be shown when both display name and email are missing
|
||||
- Tooltip should show full username on hover
|
||||
|
||||
2. **Field Values Display:**
|
||||
- "Hidden for privacy" message should appear when viewing other users' items
|
||||
- Own items should show field values normally
|
||||
|
||||
3. **Multi-Browser Testing:**
|
||||
- Open queue tab in two browsers with different users
|
||||
- Start generation in one browser
|
||||
- Verify other browser doesn't see the preview/progress
|
||||
- Verify admin user can see all generations
|
||||
|
||||
### Integration Testing
|
||||
1. Multi-user scenarios with simultaneous generations
|
||||
2. Queue updates across multiple browser windows
|
||||
3. Admin vs. non-admin privilege differentiation
|
||||
4. Socket reconnection handling
|
||||
|
||||
## Known Limitations
|
||||
|
||||
1. **TypeScript Types:**
|
||||
- The OpenAPI schema needs to be regenerated to include new fields
|
||||
- Run: `cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen`
|
||||
|
||||
2. **Backward Compatibility:**
|
||||
- System user ("system") entries will not have display name or email
|
||||
- Existing queue items from before Migration 25 will have user_id="system"
|
||||
|
||||
3. **Socket.IO Session Storage:**
|
||||
- Socket.IO's in-memory session storage may not persist across server restarts
|
||||
- Consider implementing persistent session storage if needed for production
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. Add user filtering to queue list (show only my items vs. all items)
|
||||
2. Add permission system for queue management operations (cancel, retry, delete)
|
||||
3. Implement queue item ownership transfer for administrative purposes
|
||||
4. Add audit logging for queue operations with user attribution
|
||||
5. Consider implementing user-specific queue limits or quotas
|
||||
BIN
docs/assets/multiuser/admin-add-user-1.png
Normal file
BIN
docs/assets/multiuser/admin-add-user-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 13 KiB |
BIN
docs/assets/multiuser/admin-add-user-2.png
Normal file
BIN
docs/assets/multiuser/admin-add-user-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
BIN
docs/assets/multiuser/admin-add-user-3.png
Normal file
BIN
docs/assets/multiuser/admin-add-user-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
BIN
docs/assets/multiuser/admin-setup.png
Normal file
BIN
docs/assets/multiuser/admin-setup.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
BIN
docs/assets/multiuser/user-login-1.png
Normal file
BIN
docs/assets/multiuser/user-login-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
@@ -18,7 +18,7 @@ If you’d like to add a Node, please see our [nodes contribution guide](../node
|
||||
|
||||
Helping support other users in [Discord](https://discord.gg/ZmtBAhwWhy) and on Github are valuable forms of contribution that we greatly appreciate.
|
||||
|
||||
We receive many issues and requests for help from users. We're limited in bandwidth relative to our the user base, so providing answers to questions or helping identify causes of issues is very helpful. By doing this, you enable us to spend time on the highest priority work.
|
||||
We receive many issues and requests for help from users. We're limited in bandwidth relative to our user base, so providing answers to questions or helping identify causes of issues is very helpful. By doing this, you enable us to spend time on the highest priority work.
|
||||
|
||||
## Documentation
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ Invoke runs on Windows 10+, macOS 14+ and Linux (Ubuntu 20.04+ is well-tested).
|
||||
|
||||
Hardware requirements vary significantly depending on model and image output size.
|
||||
|
||||
The requirements below are rough guidelines for best performance. GPUs with less VRAM typically still work, if a bit slower. Follow the [Low-VRAM mode guide](./features/low-vram.md) to optimize performance.
|
||||
The requirements below are rough guidelines for best performance. GPUs
|
||||
with less VRAM typically still work, if a bit slower. Follow the
|
||||
[Low-VRAM mode guide](../features/low-vram.md) to optimize performance.
|
||||
|
||||
- All Apple Silicon (M1, M2, etc) Macs work, but 16GB+ memory is recommended.
|
||||
- AMD GPUs are supported on Linux only. The VRAM requirements are the same as Nvidia GPUs.
|
||||
|
||||
876
docs/multiuser/admin_guide.md
Normal file
876
docs/multiuser/admin_guide.md
Normal file
@@ -0,0 +1,876 @@
|
||||
# InvokeAI Multi-User Administrator Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide is for administrators managing a multi-user InvokeAI installation. It covers initial setup, user management, security best practices, and troubleshooting.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before enabling multi-user support, ensure you have:
|
||||
|
||||
- InvokeAI installed and running
|
||||
- Access to the server filesystem (for initial setup)
|
||||
- Understanding of your deployment environment
|
||||
- Backup of your existing data (recommended)
|
||||
|
||||
## Initial Setup
|
||||
|
||||
### Activating Multiuser Mode
|
||||
|
||||
To put InvokeAI into multiuser mode, you will need to add the option
|
||||
`multiuser: true` to its configuration file. This file is located at
|
||||
`INVOKEAI_ROOT/invokeai.yaml` With the InvokeAI backend halted, add
|
||||
the new configuration option to the end of the file with a text editor
|
||||
so that it looks like this:
|
||||
|
||||
```yaml
|
||||
# Internal metadata - do not edit:
|
||||
schema_version: 4.0.2
|
||||
|
||||
# Enable/disable multi-user mode
|
||||
multiuser: true
|
||||
```
|
||||
|
||||
Then restart the InvokeAI server backend from the command line or
|
||||
using the launcher.
|
||||
|
||||
!!! note "Reverting to single-user mode"
|
||||
If at any time you wish to revert to single-user mode, simply comment
|
||||
out the `multiuser` line, or change "true" to "false". Then
|
||||
restart the server. Because of the way that browsers cache pages,
|
||||
users with open InvokeAI sessions may need to force-refresh their
|
||||
browsers.
|
||||
|
||||
|
||||
### First Administrator Account
|
||||
|
||||
When InvokeAI starts for the first time in multi-user mode, you'll see the **Administrator Setup** dialog.
|
||||
|
||||
**Setup Steps:**
|
||||
|
||||
1. **Email Address**: Enter a valid email address (this becomes your username)
|
||||
|
||||
* Example: `admin@example.com` or `admin@localhost` for testing
|
||||
* Must be a valid email format
|
||||
* Cannot be changed later without database access
|
||||
|
||||
2. **Display Name**: Enter a friendly name
|
||||
|
||||
* Example: "System Administrator" or your real name
|
||||
* Can be changed later in your profile
|
||||
* Visible to other users in shared contexts
|
||||
|
||||
3. **Password**: Create a strong administrator password
|
||||
|
||||
* **Minimum requirements:**
|
||||
|
||||
* At least 8 characters long
|
||||
* Contains uppercase letters (A-Z)
|
||||
* Contains lowercase letters (a-z)
|
||||
* Contains numbers (0-9)
|
||||
|
||||
* **Recommended:**
|
||||
|
||||
* Use 12+ characters
|
||||
* Include special characters (!@#$%^&*)
|
||||
* Use a password manager to generate and store
|
||||
* Don't reuse passwords from other services
|
||||
|
||||
4. **Confirm Password**: Re-enter the password
|
||||
|
||||
5. Click **Create Administrator Account**
|
||||
|
||||
!!! warning "Important"
|
||||
Store these credentials securely! The
|
||||
first administrator account can reset
|
||||
the password to something new, but cannot
|
||||
retrieve a lost one.
|
||||
|
||||
### Configuration
|
||||
|
||||
InvokeAI can run in single-user or multi-user mode, controlled by the `multiuser` configuration option in `invokeai.yaml`:
|
||||
|
||||
```yaml
|
||||
# Enable/disable multi-user mode
|
||||
multiuser: true # Enable multi-user mode (requires authentication)
|
||||
# multiuser: false # Single-user mode (no authentication required)
|
||||
# If the multiuser option is absent, single-user mode is used
|
||||
|
||||
# Database configuration
|
||||
use_memory_db: false # Use persistent database
|
||||
db_path: databases/invokeai.db # Database location
|
||||
|
||||
# Session configuration (multi-user mode only)
|
||||
jwt_secret_key: "your-secret-key-here" # Auto-generated if not specified
|
||||
jwt_token_expiry_hours: 24 # Default session timeout
|
||||
jwt_remember_me_days: 7 # "Remember me" duration
|
||||
```
|
||||
|
||||
**Single-User Mode** (`multiuser: false` or option absent):
|
||||
- No authentication required
|
||||
- All functionality enabled by default
|
||||
- All boards and images visible in unified view
|
||||
- Ideal for personal use or trusted environments
|
||||
|
||||
**Multi-User Mode** (`multiuser: true`):
|
||||
- Authentication required for access
|
||||
- User isolation for boards, images, and workflows
|
||||
- Role-based permissions enforced
|
||||
- Ideal for shared servers or team environments
|
||||
|
||||
!!! warning "Mode Switching Behavior"
|
||||
**Switching to Single-User Mode:** If boards or images were created in multi-user mode, they will all be combined into a single unified view when switching to single-user mode.
|
||||
|
||||
**Switching to Multi-User Mode:** Legacy boards and images created under single-user mode will be owned by an internal user named "system." Only the Administrator will have access to these legacy assets. A utility to migrate these legacy assets to another user will be part of a future release.
|
||||
|
||||
### Migration from Single-User
|
||||
|
||||
When upgrading from a single-user installation or switching modes:
|
||||
|
||||
1. **Automatic Migration**: The database will automatically migrate to multi-user schema when multi-user mode is first enabled
|
||||
2. **Legacy Data Ownership**: Existing data (boards, images, workflows) created in single-user mode is assigned to an internal user named "system"
|
||||
3. **Administrator Access**: Only administrators will have access to legacy "system"-owned assets when in multi-user mode
|
||||
4. **No Data Loss**: All existing content is preserved
|
||||
|
||||
**Migration Process:**
|
||||
|
||||
```bash
|
||||
# Backup your database first
|
||||
cp databases/invokeai.db databases/invokeai.db.backup
|
||||
|
||||
# Enable multi-user mode in invokeai.yaml
|
||||
# multiuser: true
|
||||
|
||||
# Start InvokeAI (migration happens automatically)
|
||||
invokeai-web
|
||||
|
||||
# Complete the administrator setup dialog
|
||||
# Legacy data will be owned by "system" user
|
||||
```
|
||||
|
||||
!!! note "Legacy Asset Migration"
|
||||
A utility to migrate legacy "system"-owned assets to specific user accounts will be available in a future release. Until then, administrators can access and manage all legacy content.
|
||||
|
||||
## User Management
|
||||
|
||||
### Creating Users
|
||||
|
||||
**Via Web Interface (Coming Soon):**
|
||||
|
||||
!!! info "Web UI for User Management"
|
||||
A web-based user interface that allows administrators to manage users is coming in a future release. Until then, use the command-line scripts described below.
|
||||
|
||||
**Via Command Line Scripts:**
|
||||
|
||||
InvokeAI provides several command-line scripts in the `scripts/` directory for user management:
|
||||
|
||||
**useradd.py** - Add a new user:
|
||||
|
||||
```bash
|
||||
# Interactive mode (prompts for details)
|
||||
python scripts/useradd.py
|
||||
|
||||
# Create a regular user
|
||||
python scripts/useradd.py \
|
||||
--email user@example.com \
|
||||
--password TempPass123 \
|
||||
--name "User Name"
|
||||
|
||||
# Create an administrator
|
||||
python scripts/useradd.py \
|
||||
--email admin@example.com \
|
||||
--password AdminPass123 \
|
||||
--name "Admin Name" \
|
||||
--admin
|
||||
```
|
||||
|
||||
**userlist.py** - List all users:
|
||||
|
||||
```bash
|
||||
# List all users
|
||||
python scripts/userlist.py
|
||||
|
||||
# Show detailed information
|
||||
python scripts/userlist.py --verbose
|
||||
```
|
||||
|
||||
**usermod.py** - Modify an existing user:
|
||||
|
||||
```bash
|
||||
# Change display name
|
||||
python scripts/usermod.py --email user@example.com --name "New Name"
|
||||
|
||||
# Promote to administrator
|
||||
python scripts/usermod.py --email user@example.com --admin
|
||||
|
||||
# Demote from administrator
|
||||
python scripts/usermod.py --email user@example.com --no-admin
|
||||
|
||||
# Deactivate account
|
||||
python scripts/usermod.py --email user@example.com --deactivate
|
||||
|
||||
# Reactivate account
|
||||
python scripts/usermod.py --email user@example.com --activate
|
||||
|
||||
# Change password
|
||||
python scripts/usermod.py --email user@example.com --password NewPassword123
|
||||
```
|
||||
|
||||
**userdel.py** - Delete a user:
|
||||
|
||||
```bash
|
||||
# Delete a user (prompts for confirmation)
|
||||
python scripts/userdel.py --email user@example.com
|
||||
|
||||
# Delete without confirmation
|
||||
python scripts/userdel.py --email user@example.com --force
|
||||
```
|
||||
|
||||
!!! tip "Script Usage"
|
||||
Run any script with `--help` to see all available options:
|
||||
```bash
|
||||
python scripts/useradd.py --help
|
||||
```
|
||||
|
||||
!!! warning "Command Line Management"
|
||||
- These scripts directly modify the database
|
||||
- Always backup your database before making changes
|
||||
- Changes take effect immediately (users may need to log in again)
|
||||
- Deleting a user permanently removes all their content
|
||||
|
||||
### Editing Users
|
||||
|
||||
**Via Command Line:**
|
||||
|
||||
Use `usermod.py` as described above to modify user properties.
|
||||
|
||||
!!! warning "Last Administrator"
|
||||
You cannot remove admin privileges from the last remaining administrator account.
|
||||
|
||||
### Resetting User Passwords
|
||||
|
||||
**Via Web Interface (Coming Soon):**
|
||||
|
||||
Web-based password reset functionality for administrators is coming in a future release.
|
||||
|
||||
**Via Command Line:**
|
||||
|
||||
```bash
|
||||
# Reset a user's password
|
||||
python scripts/usermod.py --email user@example.com --password NewTempPassword123
|
||||
```
|
||||
|
||||
**Security Note:** Never send passwords via email or unsecured channels. Use secure communication methods.
|
||||
|
||||
### Deactivating Users
|
||||
|
||||
**Via Command Line:**
|
||||
|
||||
```bash
|
||||
# Deactivate a user account
|
||||
python scripts/usermod.py --email user@example.com --deactivate
|
||||
|
||||
# Reactivate a user account
|
||||
python scripts/usermod.py --email user@example.com --activate
|
||||
```
|
||||
|
||||
**Effects:**
|
||||
|
||||
- User cannot log in when deactivated
|
||||
- Existing sessions are immediately invalidated
|
||||
- User's data is preserved
|
||||
- Can be reactivated at any time
|
||||
|
||||
### Deleting Users
|
||||
|
||||
**Via Command Line:**
|
||||
|
||||
```bash
|
||||
# Delete a user (prompts for confirmation)
|
||||
python scripts/userdel.py --email user@example.com
|
||||
|
||||
# Delete without confirmation prompt
|
||||
python scripts/userdel.py --email user@example.com --force
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- ⚠️ This action is **permanent**
|
||||
- User's boards, images, and workflows are deleted
|
||||
- Cannot be undone
|
||||
- Consider deactivating instead of deleting
|
||||
|
||||
!!! warning "Data Loss"
|
||||
Deleting a user permanently removes all their content. Back up the database first if recovery might be needed.
|
||||
|
||||
### Viewing User Activity
|
||||
|
||||
**Queue Management:**
|
||||
|
||||
1. Navigate to **Admin** → **Queue Overview**
|
||||
2. View all users' active and pending generations
|
||||
3. Filter by user
|
||||
4. Cancel stuck or problematic tasks
|
||||
|
||||
**User Statistics:**
|
||||
|
||||
- Number of boards created
|
||||
- Number of images generated
|
||||
- Storage usage (if enabled)
|
||||
- Last login time
|
||||
|
||||
## Model Management
|
||||
|
||||
As an administrator, you have full access to model management.
|
||||
|
||||
### Adding Models
|
||||
|
||||
**Via Model Manager UI:**
|
||||
|
||||
1. Go to **Models** tab
|
||||
2. Click **Add Model**
|
||||
3. Choose installation method:
|
||||
- **From URL**: Provide HuggingFace repo or download URL
|
||||
- **From Local Path**: Scan local directories
|
||||
- **Import**: Import model from filesystem
|
||||
|
||||
**Supported Model Types:**
|
||||
|
||||
- Main models (Stable Diffusion, SDXL, FLUX)
|
||||
- LoRA models
|
||||
- ControlNet models
|
||||
- VAE models
|
||||
- Textual Inversions
|
||||
- IP-Adapters
|
||||
|
||||
### Configuring Models
|
||||
|
||||
**Model Settings:**
|
||||
|
||||
- Display name
|
||||
- Description
|
||||
- Default generation settings (CFG, steps, scheduler)
|
||||
- Variant selection (fp16/fp32)
|
||||
- Model thumbnail image
|
||||
|
||||
**Default Settings:**
|
||||
|
||||
Set default parameters that users will start with:
|
||||
|
||||
1. Select a model
|
||||
2. Go to **Default Settings** tab
|
||||
3. Configure:
|
||||
- CFG Scale
|
||||
- Steps
|
||||
- Scheduler
|
||||
- VAE selection
|
||||
4. Save settings
|
||||
|
||||
### Removing Models
|
||||
|
||||
1. Go to **Models** tab
|
||||
2. Select model(s) to remove
|
||||
3. Click **Delete**
|
||||
4. Confirm deletion
|
||||
|
||||
!!! warning "Impact"
|
||||
Removing a model affects all users who may be using it in workflows or saved settings.
|
||||
|
||||
## Shared Boards
|
||||
|
||||
Shared boards enable collaboration between users while maintaining control.
|
||||
|
||||
!!! note "Future Feature"
|
||||
Board sharing will be implemented in a future release.
|
||||
|
||||
### Creating Shared Boards
|
||||
|
||||
1. Log in as administrator
|
||||
2. Create a new board (or use existing board)
|
||||
3. Right-click the board → **Share Board**
|
||||
4. Add users and set permissions
|
||||
5. Click **Save Sharing Settings**
|
||||
|
||||
### Permission Levels
|
||||
|
||||
| Level | View | Add Images | Edit/Delete | Manage Sharing |
|
||||
|-------|------|------------|-------------|----------------|
|
||||
| **Read** | ✅ | ❌ | ❌ | ❌ |
|
||||
| **Write** | ✅ | ✅ | ✅ | ❌ |
|
||||
| **Admin** | ✅ | ✅ | ✅ | ✅ |
|
||||
|
||||
**Permission Recommendations:**
|
||||
|
||||
- **Read**: For viewers who should see but not modify content
|
||||
- **Write**: For active collaborators who add and organize images
|
||||
- **Admin**: For trusted users who help manage the shared board
|
||||
|
||||
### Managing Shared Boards
|
||||
|
||||
**Add Users to Shared Board:**
|
||||
|
||||
1. Right-click shared board → **Manage Sharing**
|
||||
2. Click **Add User**
|
||||
3. Select user from dropdown
|
||||
4. Choose permission level
|
||||
5. Save changes
|
||||
|
||||
**Remove Users from Shared Board:**
|
||||
|
||||
1. Right-click shared board → **Manage Sharing**
|
||||
2. Find user in list
|
||||
3. Click **Remove**
|
||||
4. Confirm removal
|
||||
|
||||
**Change User Permissions:**
|
||||
|
||||
1. Right-click shared board → **Manage Sharing**
|
||||
2. Find user in list
|
||||
3. Change permission dropdown
|
||||
4. Save changes
|
||||
|
||||
### Shared Board Best Practices
|
||||
|
||||
- Give meaningful names to shared boards
|
||||
- Document the board's purpose in the description
|
||||
- Assign minimum necessary permissions
|
||||
- Regularly audit access lists
|
||||
- Remove users who no longer need access
|
||||
|
||||
## Security
|
||||
|
||||
### Password Policies
|
||||
|
||||
**Enforced Requirements:**
|
||||
|
||||
- Minimum 8 characters
|
||||
- Must contain uppercase letters
|
||||
- Must contain lowercase letters
|
||||
- Must contain numbers
|
||||
|
||||
**Recommended Policies:**
|
||||
|
||||
- Require 12+ character passwords
|
||||
- Include special characters
|
||||
- Implement password rotation every 90 days
|
||||
- Prevent password reuse
|
||||
- Use multi-factor authentication (when available)
|
||||
|
||||
### Session Management
|
||||
|
||||
**Session Security and Token Management:**
|
||||
|
||||
This system uses stateless JWT tokens with HMAC signatures to
|
||||
identify users after they provide their initial credentials. The
|
||||
tokens will persist for 24 hours by default, or for 7 days if the user
|
||||
clicks the "Remember me" checkbox at login. Expired tokens are
|
||||
automatically rejected and the user will have to log in again.
|
||||
|
||||
At the client side, tokens are stored in browser localStorage. Logging
|
||||
out clears them. No server-side session storage is required.
|
||||
|
||||
The tokens include the user's ID, email, and admin status, along with
|
||||
an HMAC signature.
|
||||
|
||||
### Secret Key Management
|
||||
|
||||
**Important:** The JWT secret key must be kept confidential.
|
||||
|
||||
To generate tokens, each InvokeAI instance has a distinct secret JWT key that must be
|
||||
kept confidential. The key is stored in the `app_settings` table of
|
||||
the InvokeAI database with in a field value named `jwt_secret`.
|
||||
|
||||
The secret key is automatically generated during database creation or
|
||||
migration. If you wish to change the key, you may generate a
|
||||
replacement using either of these commands:
|
||||
|
||||
|
||||
```bash
|
||||
# Python
|
||||
python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
|
||||
# OpenSSL
|
||||
openssl rand -base64 32
|
||||
```
|
||||
|
||||
Then cut and paste the printed secret into this Sqlite3 command:
|
||||
|
||||
```bash
|
||||
sqlite3 INVOKE_ROOT/databases/invokeai.db 'update app_settings set value="THE_SECRET" where key="jwt_secret"'
|
||||
```
|
||||
|
||||
(replace INVOKE_ROOT with your InvokeAI root directory and THE_SECRET
|
||||
with the new secret).
|
||||
|
||||
After this, restart the server. All logged in users will be logged out
|
||||
and will need to provide their usernames and passwords again.
|
||||
|
||||
### Hosting a Shared InvokeAI Instance
|
||||
|
||||
The multiuser feature allows you to run an InvokeAI backend that can
|
||||
be accessed by your friends and family across your home network. It is
|
||||
also possible to host a backend that is accessible over the Internet.
|
||||
|
||||
By default, InvokeAI runs on `localhost`, IP address `127.0.0.1`,
|
||||
which is only accessible to browsers running on the same machine as
|
||||
the backend. To make the backend accessible to any machine on your
|
||||
home or work LAN, add the line `host: 0.0.0.0` to the InvokeAI
|
||||
configuration file, usually stored at `INVOKE_ROOT/invokeai.yaml`.
|
||||
|
||||
Here is a minimal example.
|
||||
|
||||
```yaml
|
||||
# Internal metadata - do not edit:
|
||||
schema_version: 4.0.2
|
||||
|
||||
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/configuration/:
|
||||
multiuser: true
|
||||
host: 0.0.0.0
|
||||
```
|
||||
|
||||
After relaunching the backend you will be able to reach the server
|
||||
from other machines on the LAN using the server machine's IP address
|
||||
or hostname and port 9090.
|
||||
|
||||
#### Connecting to the Internet
|
||||
|
||||
!!! warning "Use at your own risk"
|
||||
The InvokeAI team has done its best to make the software free of
|
||||
exploitable bugs, but the software has not undergone a rigorous security
|
||||
audit or intrusion testing. Use at your own risk
|
||||
|
||||
It is also possible to create a (semi) public server accessible from
|
||||
the Internet. The details of how to do this depend very much on your
|
||||
home or corporate router/firewall system and are beyond the scope of
|
||||
this document.
|
||||
|
||||
If you expose InvokeAI to the Internet, there are a number of
|
||||
precautions to take. Here is a brief list of recommended network
|
||||
security practices.
|
||||
|
||||
**HTTPS Configuration:**
|
||||
|
||||
For internet deployments, always use HTTPS:
|
||||
|
||||
```yaml
|
||||
# Use a reverse proxy like nginx or Traefik
|
||||
# Example nginx configuration:
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name invoke.example.com;
|
||||
|
||||
ssl_certificate /path/to/cert.pem;
|
||||
ssl_certificate_key /path/to/key.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass http://localhost:9090;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# WebSocket support
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Firewall Rules:**
|
||||
|
||||
It is best to restrict access to trusted networks and remote IP
|
||||
addresses, or use a VPN to connect to your home network. Rate limit
|
||||
connections to InvokeAI's authentication endpoint
|
||||
`http://your.host:9090/login`.
|
||||
|
||||
**Backup and Recovery:**
|
||||
|
||||
It is a good idea to periodically backup your InvokeAI database,
|
||||
images, and possibly models in the event of unauthorized use of a
|
||||
publicly-accessible server.
|
||||
|
||||
**Manual Backup:**
|
||||
|
||||
```bash
|
||||
# Stop InvokeAI
|
||||
# Copy database file
|
||||
cd INVOKE_ROOT
|
||||
cp databases/invokeai.db databases/invokeai.db.$(date +%Y%m%d)
|
||||
|
||||
# Or create compressed backup
|
||||
tar -czf invokeai_backup_$(date +%Y%m%d).tar.gz databases/
|
||||
```
|
||||
|
||||
**Automated Backup Script:**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# backup_invokeai.sh
|
||||
|
||||
INVOKE_ROOT="/path/to/invoke_root"
|
||||
BACKUP_DIR="/path/to/backups"
|
||||
DB_PATH="$INVOKE_ROOT/databases/invokeai.db"
|
||||
DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
# Create backup directory
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
|
||||
# Copy database
|
||||
cp "$DB_PATH" "$BACKUP_DIR/invokeai_$DATE.db"
|
||||
|
||||
# Keep only last 30 days
|
||||
find "$BACKUP_DIR" -name "invokeai_*.db" -mtime +30 -delete
|
||||
|
||||
echo "Backup completed: invokeai_$DATE.db"
|
||||
```
|
||||
|
||||
**Schedule with cron:**
|
||||
|
||||
```bash
|
||||
# Edit crontab
|
||||
crontab -e
|
||||
|
||||
# Add daily backup at 2 AM
|
||||
0 2 * * * /path/to/backup_invokeai.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
```bash
|
||||
# Stop InvokeAI
|
||||
# Replace current database with backup
|
||||
cd INVOKE_ROOT
|
||||
cp databases/invokeai.db databases/invokeai.db.old # Save current
|
||||
cp databases/invokeai_backup.db databases/invokeai.db
|
||||
|
||||
# Restart InvokeAI
|
||||
invokeai-web
|
||||
```
|
||||
|
||||
**Disaster Recover - Complete System Backup:**
|
||||
|
||||
Include these directories/files:
|
||||
|
||||
- `databases/` - All database files
|
||||
- `models/` - Installed models (if locally stored)
|
||||
- `outputs/` - Generated images
|
||||
- `invokeai.yaml` - Configuration file
|
||||
- Any custom scripts or modifications
|
||||
|
||||
**Recovery Process:**
|
||||
|
||||
1. Install InvokeAI on new system
|
||||
2. Restore configuration file
|
||||
3. Restore database directory
|
||||
4. Restore models and outputs
|
||||
5. Verify file permissions
|
||||
6. Start InvokeAI and test
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### User Cannot Login
|
||||
|
||||
**Symptom:** User reports unable to log in
|
||||
|
||||
**Diagnosis:**
|
||||
|
||||
1. Verify account exists and is active
|
||||
```bash
|
||||
sqlite3 databases/invokeai.db "SELECT * FROM users WHERE email = 'user@example.com';"
|
||||
```
|
||||
|
||||
2. Check password (have user try resetting)
|
||||
3. Verify account is active (`is_active = 1`)
|
||||
4. Check for account lockout (if implemented)
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Reset user password
|
||||
- Reactivate disabled account
|
||||
- Verify email address is correct
|
||||
- Check system logs for auth errors
|
||||
|
||||
### Database Locked Errors
|
||||
|
||||
**Symptom:** "Database is locked" errors
|
||||
|
||||
**Causes:**
|
||||
|
||||
- Concurrent write operations
|
||||
- Long-running transactions
|
||||
- Backup process accessing database
|
||||
- File system issues
|
||||
|
||||
**Solutions:**
|
||||
|
||||
```bash
|
||||
# Check for locks
|
||||
fuser databases/invokeai.db
|
||||
|
||||
# Increase timeout (in config)
|
||||
# Or switch to WAL mode:
|
||||
sqlite3 databases/invokeai.db "PRAGMA journal_mode=WAL;"
|
||||
```
|
||||
|
||||
### Forgotten Admin Password
|
||||
|
||||
**Recovery Process:**
|
||||
|
||||
1. Stop InvokeAI
|
||||
2. Direct database access:
|
||||
```bash
|
||||
sqlite3 databases/invokeai.db
|
||||
```
|
||||
|
||||
3. Reset admin password (requires password hash):
|
||||
```sql
|
||||
-- Generate hash first using Python:
|
||||
-- from passlib.context import CryptContext
|
||||
-- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
-- print(pwd_context.hash("NewPassword123"))
|
||||
|
||||
UPDATE users
|
||||
SET password_hash = '$2b$12$...'
|
||||
WHERE email = 'admin@example.com';
|
||||
```
|
||||
|
||||
4. Restart InvokeAI
|
||||
|
||||
**Alternative:** Remove `jwt_secret_key` from config to trigger setup wizard (will create new admin).
|
||||
|
||||
### Performance Issues
|
||||
|
||||
**Symptom:** Slow generation or UI
|
||||
|
||||
**Diagnosis:**
|
||||
|
||||
1. Check active generation count
|
||||
2. Review resource usage (CPU/GPU/RAM)
|
||||
3. Check database size and performance
|
||||
4. Review network latency
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Limit concurrent generations
|
||||
- Increase hardware resources
|
||||
- Optimize database (`VACUUM`, `ANALYZE`)
|
||||
- Add indexes for slow queries
|
||||
- Consider load balancing
|
||||
|
||||
### Migration Failures
|
||||
|
||||
**Symptom:** Database migration fails on upgrade
|
||||
|
||||
**Prevention:**
|
||||
|
||||
- Always backup before upgrading
|
||||
- Test migration on copy of database
|
||||
- Review migration logs
|
||||
|
||||
**Recovery:**
|
||||
|
||||
```bash
|
||||
# Restore backup
|
||||
cp databases/invokeai.db.backup databases/invokeai.db
|
||||
|
||||
# Try migration again with verbose logging
|
||||
invokeai-web --log-level DEBUG
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Complete Configuration Example for a Public Site
|
||||
|
||||
```yaml
|
||||
# invokeai.yaml - Multi-user configuration
|
||||
|
||||
# Internal metadata - do not edit:
|
||||
schema_version: 4.0.2
|
||||
|
||||
# Put user settings here
|
||||
multiuser: true
|
||||
|
||||
# Server
|
||||
host: "0.0.0.0"
|
||||
port: 9090
|
||||
|
||||
# Performance
|
||||
enable_partial_loading: true
|
||||
precision: float16
|
||||
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
|
||||
hashing_algorithm: blake3_multi
|
||||
```
|
||||
## Frequently Asked Questions
|
||||
|
||||
### How many users can InvokeAI support?
|
||||
|
||||
The backend will support dozens of concurrent users. However, because
|
||||
the image generation queue is single-threaded, image generation tasks
|
||||
are processed on a first-come, first-serve basis. This means that a
|
||||
user may have to wait for all the other users' image generation jobs
|
||||
to complete before their generation job starts to execute.
|
||||
|
||||
A future version of InvokeAI may support concurrent execution on
|
||||
systems with multiple GPUs/graphics cards.
|
||||
|
||||
### Can I integrate with existing authentication systems?
|
||||
|
||||
OAuth2/OpenID Connect support is planned for a future release. Currently, InvokeAI uses its own authentication system.
|
||||
|
||||
### How do I audit user actions?
|
||||
|
||||
Full audit logging is planned for a future release. Currently, you can:
|
||||
|
||||
- Monitor the generation queue
|
||||
- Review database changes
|
||||
- Check application logs
|
||||
|
||||
### Can users have different model access?
|
||||
|
||||
Not in the current release. All users can view and use all installed models. Per-user model access is a possible enhancement.
|
||||
|
||||
### How do I handle user data when they leave?
|
||||
|
||||
Best practice:
|
||||
|
||||
1. Deactivate the account first
|
||||
2. Transfer ownership of shared boards
|
||||
3. After transition period, delete the account
|
||||
4. Or keep the account deactivated for audit purposes
|
||||
|
||||
### What's the licensing impact of multi-user mode?
|
||||
|
||||
InvokeAI remains under its existing license. Multi-user mode does not change licensing terms.
|
||||
|
||||
## Getting Help
|
||||
|
||||
### Support Resources
|
||||
|
||||
- **Documentation**: [InvokeAI Docs](https://invoke-ai.github.io/InvokeAI/)
|
||||
- **Discord**: [Join Community](https://discord.gg/ZmtBAhwWhy)
|
||||
- **GitHub Issues**: [Report Problems](https://github.com/invoke-ai/InvokeAI/issues)
|
||||
- **User Guide**: [For Users](user_guide.md)
|
||||
- **API Guide**: [For Developers](api_guide.md)
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
When reporting administrator issues, include:
|
||||
|
||||
- InvokeAI version
|
||||
- Operating system and version
|
||||
- Database size and user count
|
||||
- Relevant log excerpts
|
||||
- Steps to reproduce
|
||||
- Expected vs actual behavior
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [User Guide](user_guide.md) - For end users
|
||||
- [API Guide](api_guide.md) - For API consumers
|
||||
- [Multiuser Specification](specification.md) - Technical details
|
||||
|
||||
---
|
||||
|
||||
**Need additional assistance?** Visit the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy) or file an issue on [GitHub](https://github.com/invoke-ai/InvokeAI/issues).
|
||||
1224
docs/multiuser/api_guide.md
Normal file
1224
docs/multiuser/api_guide.md
Normal file
File diff suppressed because it is too large
Load Diff
870
docs/multiuser/specification.md
Normal file
870
docs/multiuser/specification.md
Normal file
@@ -0,0 +1,870 @@
|
||||
# InvokeAI Multi-User Support - Detailed Specification
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document provides a comprehensive specification for adding multi-user support to InvokeAI. The feature will enable a single InvokeAI instance to support multiple isolated users, each with their own generation settings, image boards, and workflows, while maintaining administrative controls for model management and system configuration.
|
||||
|
||||
## 2. Overview
|
||||
|
||||
### 2.1 Goals
|
||||
- Enable multiple users to share a single InvokeAI instance
|
||||
- Provide user isolation for personal content (boards, images, workflows, settings)
|
||||
- Maintain centralized model management by administrators
|
||||
- Support shared boards for collaboration
|
||||
- Provide secure authentication and authorization
|
||||
- Minimize impact on existing single-user installations
|
||||
|
||||
### 2.2 Non-Goals
|
||||
- Real-time collaboration features (multiple users editing same workflow simultaneously)
|
||||
- Advanced team management features (in initial release)
|
||||
- Migration of existing multi-user enterprise edition data
|
||||
- Support for external identity providers (in initial release, can be added later)
|
||||
|
||||
## 3. User Roles and Permissions
|
||||
|
||||
### 3.1 Administrator Role
|
||||
**Capabilities:**
|
||||
|
||||
- Full access to all InvokeAI features
|
||||
- Model management (add, delete, configure models)
|
||||
- User management (create, edit, delete users)
|
||||
- View and manage all users' queue sessions
|
||||
- Access system configuration
|
||||
- Create and manage shared boards
|
||||
- Grant/revoke administrative privileges to other users
|
||||
|
||||
**Restrictions:**
|
||||
|
||||
- Cannot delete their own account if they are the last administrator
|
||||
- Cannot revoke their own admin privileges if they are the last administrator
|
||||
|
||||
### 3.2 Regular User Role
|
||||
**Capabilities:**
|
||||
|
||||
- Create, edit, and delete their own image boards
|
||||
- Upload and manage their own assets
|
||||
- Use all image generation tools (linear, canvas, upscale, workflow tabs)
|
||||
- Create, edit, save, and load workflows
|
||||
- Access public/shared workflows
|
||||
- View and manage their own queue sessions
|
||||
- Adjust personal UI preferences (theme, hotkeys, etc.)
|
||||
- Access shared boards (read/write based on permissions)
|
||||
- **View model configurations** (read-only access to model manager)
|
||||
- **View model details, default settings, and metadata**
|
||||
|
||||
**Restrictions:**
|
||||
|
||||
- Cannot add, delete, or edit models
|
||||
- **Can view but cannot modify model manager settings** (read-only access)
|
||||
- Cannot reidentify, convert, or update model paths
|
||||
- Cannot upload or change model thumbnail images
|
||||
- Cannot save changes to model default settings
|
||||
- Cannot perform bulk delete operations on models
|
||||
- Cannot view or modify other users' boards, images, or workflows
|
||||
- Cannot cancel or modify other users' queue sessions
|
||||
- Cannot access system configuration
|
||||
- Cannot manage users or permissions
|
||||
|
||||
### 3.3 Future Role Considerations
|
||||
- **Viewer Role**: Read-only access (future enhancement)
|
||||
- **Team/Group-based Permissions**: Organizational hierarchy (future enhancement)
|
||||
|
||||
## 4. Authentication System
|
||||
|
||||
### 4.1 Authentication Method
|
||||
- **Primary Method**: Username and password authentication with secure password hashing
|
||||
- **Password Hashing**: Use bcrypt or Argon2 for password storage
|
||||
- **Session Management**: JWT tokens or secure session cookies
|
||||
- **Token Expiration**: Configurable session timeout (default: 7 days for "remember me", 24 hours otherwise)
|
||||
|
||||
### 4.2 Initial Administrator Setup
|
||||
**First-time Launch Flow:**
|
||||
|
||||
1. Application detects no administrator account exists
|
||||
2. Displays mandatory setup dialog (cannot be skipped)
|
||||
3. Prompts for:
|
||||
- Administrator username (email format recommended)
|
||||
- Administrator display name
|
||||
- Strong password (minimum requirements enforced)
|
||||
- Password confirmation
|
||||
4. Stores hashed credentials in configuration
|
||||
5. Creates administrator account in database
|
||||
6. Proceeds to normal login screen
|
||||
|
||||
**Reset Capability:**
|
||||
|
||||
- Administrators can be reset by manually editing the config file
|
||||
- Requires access to server filesystem (intentional security measure)
|
||||
- Database maintains user records; config file contains root admin credentials
|
||||
|
||||
### 4.3 Password Requirements
|
||||
- Minimum 8 characters
|
||||
- At least one uppercase letter
|
||||
- At least one lowercase letter
|
||||
- At least one number
|
||||
- At least one special character (optional but recommended)
|
||||
- Not in common password list
|
||||
|
||||
### 4.4 Login Flow
|
||||
|
||||
1. User navigates to InvokeAI URL
|
||||
2. If not authenticated, redirect to login page
|
||||
3. User enters username/email and password
|
||||
4. Optional "Remember me" checkbox for extended session
|
||||
5. Backend validates credentials
|
||||
6. On success: Generate session token, redirect to application
|
||||
7. On failure: Display error, allow retry with rate limiting (prevent brute force)
|
||||
|
||||
### 4.5 Logout Flow
|
||||
- User clicks logout button
|
||||
- Frontend clears session token
|
||||
- Backend invalidates session (if using server-side sessions)
|
||||
- Redirect to login page
|
||||
|
||||
### 4.6 Future Authentication Enhancements
|
||||
- OAuth2/OpenID Connect support
|
||||
- Two-factor authentication (2FA)
|
||||
- SSO integration
|
||||
- API key authentication for programmatic access
|
||||
|
||||
## 5. User Management
|
||||
|
||||
### 5.1 User Creation (Administrator)
|
||||
**Flow:**
|
||||
|
||||
1. Administrator navigates to user management interface
|
||||
2. Clicks "Add User" button
|
||||
3. Enters user information:
|
||||
- Email address (required, used as username)
|
||||
- Display name (optional, defaults to email)
|
||||
- Role (User or Administrator)
|
||||
- Initial password or "Send invitation email"
|
||||
4. System validates email uniqueness
|
||||
5. System creates user account
|
||||
6. If invitation mode:
|
||||
- Generate one-time secure token
|
||||
- Send email with setup link
|
||||
- Link expires after 7 days
|
||||
7. If direct password mode:
|
||||
- Administrator provides initial password
|
||||
- User must change on first login
|
||||
|
||||
**Invitation Email Flow:**
|
||||
|
||||
1. User receives email with unique link
|
||||
2. Link contains secure token
|
||||
3. User clicks link, redirected to setup page
|
||||
4. User enters desired password
|
||||
5. Token validated and consumed (single-use)
|
||||
6. Account activated
|
||||
7. User redirected to login page
|
||||
|
||||
### 5.2 User Profile Management
|
||||
**User Self-Service:**
|
||||
|
||||
- Update display name
|
||||
- Change password (requires current password)
|
||||
- Update email address (requires verification)
|
||||
- Manage UI preferences
|
||||
- View account creation date and last login
|
||||
|
||||
**Administrator Actions:**
|
||||
|
||||
- Edit user information (name, email)
|
||||
- Reset user password (generates reset link)
|
||||
- Toggle administrator privileges
|
||||
- Assign to groups (future feature)
|
||||
- Suspend/unsuspend account
|
||||
- Delete account (with data retention options)
|
||||
|
||||
### 5.3 Password Reset Flow
|
||||
**User-Initiated (Future Enhancement):**
|
||||
|
||||
1. User clicks "Forgot Password" on login page
|
||||
2. Enters email address
|
||||
3. System sends password reset link (if email exists)
|
||||
4. User clicks link, enters new password
|
||||
5. Password updated, user can login
|
||||
|
||||
**Administrator-Initiated:**
|
||||
|
||||
1. Administrator selects user
|
||||
2. Clicks "Send Password Reset"
|
||||
3. System generates reset token and link
|
||||
4. Email sent to user
|
||||
5. User follows same flow as user-initiated reset
|
||||
|
||||
## 6. Data Model and Database Schema
|
||||
|
||||
### 6.1 New Tables
|
||||
|
||||
#### 6.1.1 users
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
user_id TEXT NOT NULL PRIMARY KEY,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT,
|
||||
password_hash TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
last_login_at DATETIME
|
||||
);
|
||||
CREATE INDEX idx_users_email ON users(email);
|
||||
CREATE INDEX idx_users_is_admin ON users(is_admin);
|
||||
CREATE INDEX idx_users_is_active ON users(is_active);
|
||||
```
|
||||
|
||||
#### 6.1.2 user_sessions
|
||||
```sql
|
||||
CREATE TABLE user_sessions (
|
||||
session_id TEXT NOT NULL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
last_activity_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
user_agent TEXT,
|
||||
ip_address TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_user_sessions_user_id ON user_sessions(user_id);
|
||||
CREATE INDEX idx_user_sessions_expires_at ON user_sessions(expires_at);
|
||||
CREATE INDEX idx_user_sessions_token_hash ON user_sessions(token_hash);
|
||||
```
|
||||
|
||||
#### 6.1.3 user_invitations
|
||||
```sql
|
||||
CREATE TABLE user_invitations (
|
||||
invitation_id TEXT NOT NULL PRIMARY KEY,
|
||||
email TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
invited_by_user_id TEXT NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
used_at DATETIME,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
FOREIGN KEY (invited_by_user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_user_invitations_email ON user_invitations(email);
|
||||
CREATE INDEX idx_user_invitations_token_hash ON user_invitations(token_hash);
|
||||
CREATE INDEX idx_user_invitations_expires_at ON user_invitations(expires_at);
|
||||
```
|
||||
|
||||
#### 6.1.4 shared_boards
|
||||
```sql
|
||||
CREATE TABLE shared_boards (
|
||||
board_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
permission TEXT NOT NULL CHECK(permission IN ('read', 'write', 'admin')),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
PRIMARY KEY (board_id, user_id),
|
||||
FOREIGN KEY (board_id) REFERENCES boards(board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_shared_boards_user_id ON shared_boards(user_id);
|
||||
CREATE INDEX idx_shared_boards_board_id ON shared_boards(board_id);
|
||||
```
|
||||
|
||||
### 6.2 Modified Tables
|
||||
|
||||
#### 6.2.1 boards
|
||||
```sql
|
||||
-- Add columns:
|
||||
ALTER TABLE boards ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
ALTER TABLE boards ADD COLUMN is_shared BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
ALTER TABLE boards ADD COLUMN created_by_user_id TEXT;
|
||||
|
||||
-- Add foreign key (requires recreation in SQLite):
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
FOREIGN KEY (created_by_user_id) REFERENCES users(user_id) ON DELETE SET NULL
|
||||
|
||||
-- Add indices:
|
||||
CREATE INDEX idx_boards_user_id ON boards(user_id);
|
||||
CREATE INDEX idx_boards_is_shared ON boards(is_shared);
|
||||
```
|
||||
|
||||
#### 6.2.2 images
|
||||
```sql
|
||||
-- Add column:
|
||||
ALTER TABLE images ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
|
||||
-- Add foreign key:
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
|
||||
-- Add index:
|
||||
CREATE INDEX idx_images_user_id ON images(user_id);
|
||||
```
|
||||
|
||||
#### 6.2.3 workflows
|
||||
```sql
|
||||
-- Add columns:
|
||||
ALTER TABLE workflows ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
ALTER TABLE workflows ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add foreign key:
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
|
||||
-- Add indices:
|
||||
CREATE INDEX idx_workflows_user_id ON workflows(user_id);
|
||||
CREATE INDEX idx_workflows_is_public ON workflows(is_public);
|
||||
```
|
||||
|
||||
#### 6.2.4 session_queue
|
||||
```sql
|
||||
-- Add column:
|
||||
ALTER TABLE session_queue ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
|
||||
-- Add foreign key:
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
|
||||
-- Add index:
|
||||
CREATE INDEX idx_session_queue_user_id ON session_queue(user_id);
|
||||
```
|
||||
|
||||
#### 6.2.5 style_presets
|
||||
```sql
|
||||
-- Add columns:
|
||||
ALTER TABLE style_presets ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
ALTER TABLE style_presets ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add foreign key:
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
|
||||
-- Add indices:
|
||||
CREATE INDEX idx_style_presets_user_id ON style_presets(user_id);
|
||||
CREATE INDEX idx_style_presets_is_public ON style_presets(is_public);
|
||||
```
|
||||
|
||||
### 6.3 Migration Strategy
|
||||
|
||||
1. Create new user tables (users, user_sessions, user_invitations, shared_boards)
|
||||
2. Create default 'system' user for backward compatibility
|
||||
3. Update existing data to reference 'system' user
|
||||
4. Add foreign key constraints
|
||||
5. Version as database migration (e.g., migration_25.py)
|
||||
|
||||
### 6.4 Migration for Existing Installations
|
||||
- Single-user installations: Prompt to create admin account on first launch after update
|
||||
- Existing data migration: Administrator can specify an arbitrary user account to hold legacy data (can be the admin account or a separate user)
|
||||
- System provides UI during migration to choose destination user for existing data
|
||||
|
||||
## 7. API Endpoints
|
||||
|
||||
### 7.1 Authentication Endpoints
|
||||
|
||||
#### POST /api/v1/auth/setup
|
||||
- Initialize first administrator account
|
||||
- Only works if no admin exists
|
||||
- Body: `{ email, display_name, password }`
|
||||
- Response: `{ success, user }`
|
||||
|
||||
#### POST /api/v1/auth/login
|
||||
- Authenticate user
|
||||
- Body: `{ email, password, remember_me? }`
|
||||
- Response: `{ token, user, expires_at }`
|
||||
|
||||
#### POST /api/v1/auth/logout
|
||||
- Invalidate current session
|
||||
- Headers: `Authorization: Bearer <token>`
|
||||
- Response: `{ success }`
|
||||
|
||||
#### GET /api/v1/auth/me
|
||||
- Get current user information
|
||||
- Headers: `Authorization: Bearer <token>`
|
||||
- Response: `{ user }`
|
||||
|
||||
#### POST /api/v1/auth/change-password
|
||||
- Change current user's password
|
||||
- Body: `{ current_password, new_password }`
|
||||
- Headers: `Authorization: Bearer <token>`
|
||||
- Response: `{ success }`
|
||||
|
||||
### 7.2 User Management Endpoints (Admin Only)
|
||||
|
||||
#### GET /api/v1/users
|
||||
- List all users (paginated)
|
||||
- Query params: `offset`, `limit`, `search`, `role_filter`
|
||||
- Response: `{ users[], total, offset, limit }`
|
||||
|
||||
#### POST /api/v1/users
|
||||
- Create new user
|
||||
- Body: `{ email, display_name, is_admin, send_invitation?, initial_password? }`
|
||||
- Response: `{ user, invitation_link? }`
|
||||
|
||||
#### GET /api/v1/users/{user_id}
|
||||
- Get user details
|
||||
- Response: `{ user }`
|
||||
|
||||
#### PATCH /api/v1/users/{user_id}
|
||||
- Update user
|
||||
- Body: `{ display_name?, is_admin?, is_active? }`
|
||||
- Response: `{ user }`
|
||||
|
||||
#### DELETE /api/v1/users/{user_id}
|
||||
- Delete user
|
||||
- Query params: `delete_data` (true/false)
|
||||
- Response: `{ success }`
|
||||
|
||||
#### POST /api/v1/users/{user_id}/reset-password
|
||||
- Send password reset email
|
||||
- Response: `{ success, reset_link }`
|
||||
|
||||
### 7.3 Shared Boards Endpoints
|
||||
|
||||
#### POST /api/v1/boards/{board_id}/share
|
||||
- Share board with users
|
||||
- Body: `{ user_ids[], permission: 'read' | 'write' | 'admin' }`
|
||||
- Response: `{ success, shared_with[] }`
|
||||
|
||||
#### GET /api/v1/boards/{board_id}/shares
|
||||
- Get board sharing information
|
||||
- Response: `{ shares[] }`
|
||||
|
||||
#### DELETE /api/v1/boards/{board_id}/share/{user_id}
|
||||
- Remove board sharing
|
||||
- Response: `{ success }`
|
||||
|
||||
### 7.4 Modified Endpoints
|
||||
|
||||
All existing endpoints will be modified to:
|
||||
|
||||
1. Require authentication (except setup/login)
|
||||
2. Filter data by current user (unless admin viewing all)
|
||||
3. Enforce permissions (e.g., model management requires admin)
|
||||
4. Include user context in operations
|
||||
|
||||
Example modifications:
|
||||
- `GET /api/v1/boards` → Returns only user's boards + shared boards
|
||||
- `POST /api/v1/session/queue` → Associates queue item with current user
|
||||
- `GET /api/v1/queue` → Returns all items for admin, only user's items for regular users
|
||||
|
||||
## 8. Frontend Changes
|
||||
|
||||
### 8.1 New Components
|
||||
|
||||
#### LoginPage
|
||||
- Email/password form
|
||||
- "Remember me" checkbox
|
||||
- Login button
|
||||
- Forgot password link (future)
|
||||
- Branding and welcome message
|
||||
|
||||
#### AdministratorSetup
|
||||
- Modal dialog (cannot be dismissed)
|
||||
- Administrator account creation form
|
||||
- Password strength indicator
|
||||
- Terms/welcome message
|
||||
|
||||
#### UserManagementPage (Admin only)
|
||||
- User list table
|
||||
- Add user button
|
||||
- User actions (edit, delete, reset password)
|
||||
- Search and filter
|
||||
- Role toggle
|
||||
|
||||
#### UserProfilePage
|
||||
- Display user information
|
||||
- Change password form
|
||||
- UI preferences
|
||||
- Account details
|
||||
|
||||
#### BoardSharingDialog
|
||||
- User picker/search
|
||||
- Permission selector
|
||||
- Share button
|
||||
- Current shares list
|
||||
|
||||
### 8.2 Modified Components
|
||||
|
||||
#### App Root
|
||||
- Add authentication check
|
||||
- Redirect to login if not authenticated
|
||||
- Handle session expiration
|
||||
- Add global error boundary for auth errors
|
||||
|
||||
#### Navigation/Header
|
||||
- Add user menu with logout
|
||||
- Display current user name
|
||||
- Admin indicator badge
|
||||
|
||||
#### ModelManagerTab
|
||||
- Hide/disable for non-admin users
|
||||
- Show "Admin only" message
|
||||
|
||||
#### QueuePanel
|
||||
- Filter by current user (for non-admin)
|
||||
- Show all with user indicators (for admin)
|
||||
- Disable actions on other users' items (for non-admin)
|
||||
|
||||
#### BoardsPanel
|
||||
- Show personal boards section
|
||||
- Show shared boards section
|
||||
- Add sharing controls to board actions
|
||||
|
||||
### 8.3 State Management
|
||||
|
||||
New Redux slices/zustand stores:
|
||||
- `authSlice`: Current user, authentication status, token
|
||||
- `usersSlice`: User list for admin interface
|
||||
- `sharingSlice`: Board sharing state
|
||||
|
||||
Updated slices:
|
||||
- `boardsSlice`: Include shared boards, ownership info
|
||||
- `queueSlice`: Include user filtering
|
||||
- `workflowsSlice`: Include public/private status
|
||||
|
||||
## 9. Configuration
|
||||
|
||||
### 9.1 New Config Options
|
||||
|
||||
Add to `InvokeAIAppConfig`:
|
||||
|
||||
```python
|
||||
# Authentication
|
||||
auth_enabled: bool = True # Enable/disable multi-user auth
|
||||
session_expiry_hours: int = 24 # Default session expiration
|
||||
session_expiry_hours_remember: int = 168 # "Remember me" expiration (7 days)
|
||||
password_min_length: int = 8 # Minimum password length
|
||||
require_strong_passwords: bool = True # Enforce password complexity
|
||||
|
||||
# Session tracking
|
||||
enable_server_side_sessions: bool = False # Optional server-side session tracking
|
||||
|
||||
# Audit logging
|
||||
audit_log_auth_events: bool = True # Log authentication events
|
||||
audit_log_admin_actions: bool = True # Log administrative actions
|
||||
|
||||
# Email (optional - for invitations and password reset)
|
||||
email_enabled: bool = False
|
||||
smtp_host: str = ""
|
||||
smtp_port: int = 587
|
||||
smtp_username: str = ""
|
||||
smtp_password: str = ""
|
||||
smtp_from_address: str = ""
|
||||
smtp_from_name: str = "InvokeAI"
|
||||
|
||||
# Initial admin (stored as hash)
|
||||
admin_email: Optional[str] = None
|
||||
admin_password_hash: Optional[str] = None
|
||||
```
|
||||
|
||||
### 9.2 Backward Compatibility
|
||||
|
||||
- If `auth_enabled = False`, system runs in legacy single-user mode
|
||||
- All data belongs to implicit "system" user
|
||||
- No authentication required
|
||||
- Smooth upgrade path for existing installations
|
||||
|
||||
## 10. Security Considerations
|
||||
|
||||
### 10.1 Password Security
|
||||
- Never store passwords in plain text
|
||||
- Use bcrypt or Argon2id for password hashing
|
||||
- Implement proper salt generation
|
||||
- Enforce password complexity requirements
|
||||
- Implement rate limiting on login attempts
|
||||
- Consider password breach checking (Have I Been Pwned API)
|
||||
|
||||
### 10.2 Session Security
|
||||
- Use cryptographically secure random tokens
|
||||
- Implement token rotation
|
||||
- Set appropriate cookie flags (HttpOnly, Secure, SameSite)
|
||||
- Implement session timeout and renewal
|
||||
- Invalidate sessions on logout
|
||||
- Clean up expired sessions periodically
|
||||
|
||||
### 10.3 Authorization
|
||||
- Always verify user identity from session token (never trust client)
|
||||
- Check permissions on every API call
|
||||
- Implement principle of least privilege
|
||||
- Validate user ownership of resources before operations
|
||||
- Implement proper error messages (avoid information leakage)
|
||||
|
||||
### 10.4 Data Isolation
|
||||
- Strict separation of user data in database queries
|
||||
- Prevent SQL injection via parameterized queries
|
||||
- Validate all user inputs
|
||||
- Implement proper access control checks
|
||||
- Audit trail for sensitive operations
|
||||
|
||||
### 10.5 API Security
|
||||
- Implement rate limiting on sensitive endpoints
|
||||
- Use HTTPS in production (enforce via config)
|
||||
- Implement CSRF protection
|
||||
- Validate and sanitize all inputs
|
||||
- Implement proper CORS configuration
|
||||
- Add security headers (CSP, X-Frame-Options, etc.)
|
||||
|
||||
### 10.6 Deployment Security
|
||||
- Document secure deployment practices
|
||||
- Recommend reverse proxy configuration (nginx, Apache)
|
||||
- Provide example configurations for HTTPS
|
||||
- Document firewall requirements
|
||||
- Recommend network isolation strategies
|
||||
|
||||
## 11. Email Integration (Optional)
|
||||
|
||||
**Note**: Email/SMTP configuration is optional. Many administrators will not have ready access to an outgoing SMTP server. When email is not configured, the system provides fallback mechanisms by displaying setup links directly in the admin UI.
|
||||
|
||||
### 11.1 Email Templates
|
||||
|
||||
#### User Invitation
|
||||
```
|
||||
Subject: You've been invited to InvokeAI
|
||||
|
||||
Hello,
|
||||
|
||||
You've been invited to join InvokeAI by [Administrator Name].
|
||||
|
||||
Click the link below to set up your account:
|
||||
[Setup Link]
|
||||
|
||||
This link expires in 7 days.
|
||||
|
||||
---
|
||||
InvokeAI
|
||||
```
|
||||
|
||||
#### Password Reset
|
||||
```
|
||||
Subject: Reset your InvokeAI password
|
||||
|
||||
Hello [User Name],
|
||||
|
||||
A password reset was requested for your account.
|
||||
|
||||
Click the link below to reset your password:
|
||||
[Reset Link]
|
||||
|
||||
This link expires in 24 hours.
|
||||
|
||||
If you didn't request this, please ignore this email.
|
||||
|
||||
---
|
||||
InvokeAI
|
||||
```
|
||||
|
||||
### 11.2 Email Service
|
||||
- Support SMTP configuration
|
||||
- Use secure connection (TLS)
|
||||
- Handle email failures gracefully
|
||||
- Implement email queue for reliability
|
||||
- Log email activities (without sensitive data)
|
||||
- Provide fallback for no-email deployments (show links in admin UI)
|
||||
|
||||
## 12. Testing Requirements
|
||||
|
||||
### 12.1 Unit Tests
|
||||
- Authentication service (password hashing, validation)
|
||||
- Authorization checks
|
||||
- Token generation and validation
|
||||
- User management operations
|
||||
- Shared board permissions
|
||||
- Data isolation queries
|
||||
|
||||
### 12.2 Integration Tests
|
||||
- Complete authentication flows
|
||||
- User creation and invitation
|
||||
- Password reset flow
|
||||
- Multi-user data isolation
|
||||
- Shared board access
|
||||
- Session management
|
||||
- Admin operations
|
||||
|
||||
### 12.3 Security Tests
|
||||
- SQL injection prevention
|
||||
- XSS prevention
|
||||
- CSRF protection
|
||||
- Session hijacking prevention
|
||||
- Brute force protection
|
||||
- Authorization bypass attempts
|
||||
|
||||
### 12.4 Performance Tests
|
||||
- Authentication overhead
|
||||
- Query performance with user filters
|
||||
- Concurrent user sessions
|
||||
- Database scalability with many users
|
||||
|
||||
## 13. Documentation Requirements
|
||||
|
||||
### 13.1 User Documentation
|
||||
- Getting started with multi-user InvokeAI
|
||||
- Login and account management
|
||||
- Using shared boards
|
||||
- Understanding permissions
|
||||
- Troubleshooting authentication issues
|
||||
|
||||
### 13.2 Administrator Documentation
|
||||
- Setting up multi-user InvokeAI
|
||||
- User management guide
|
||||
- Creating and managing shared boards
|
||||
- Email configuration
|
||||
- Security best practices
|
||||
- Backup and restore with user data
|
||||
|
||||
### 13.3 Developer Documentation
|
||||
- Authentication architecture
|
||||
- API authentication requirements
|
||||
- Adding new multi-user features
|
||||
- Database schema changes
|
||||
- Testing multi-user features
|
||||
|
||||
### 13.4 Migration Documentation
|
||||
- Upgrading from single-user to multi-user
|
||||
- Data migration strategies
|
||||
- Rollback procedures
|
||||
- Common issues and solutions
|
||||
|
||||
## 14. Future Enhancements
|
||||
|
||||
### 14.1 Phase 2 Features
|
||||
- **OAuth2/OpenID Connect integration** (deferred from initial release to keep scope manageable)
|
||||
- Two-factor authentication
|
||||
- API keys for programmatic access
|
||||
- Enhanced team/group management
|
||||
- Advanced permission system (roles and capabilities)
|
||||
|
||||
### 14.2 Phase 3 Features
|
||||
- SSO integration (SAML, LDAP)
|
||||
- User quotas and limits
|
||||
- Resource usage tracking
|
||||
- Advanced collaboration features
|
||||
- Workflow template library with permissions
|
||||
- Model access controls per user/group
|
||||
|
||||
## 15. Success Metrics
|
||||
|
||||
### 15.1 Functionality Metrics
|
||||
- Successful user authentication rate
|
||||
- Zero unauthorized data access incidents
|
||||
- All tests passing (unit, integration, security)
|
||||
- API response time within acceptable limits
|
||||
|
||||
### 15.2 Usability Metrics
|
||||
- User setup completion time < 2 minutes
|
||||
- Login time < 2 seconds
|
||||
- Clear error messages for all auth failures
|
||||
- Positive user feedback on multi-user features
|
||||
|
||||
### 15.3 Security Metrics
|
||||
- No critical security vulnerabilities identified
|
||||
- CodeQL scan passes
|
||||
- Penetration testing completed
|
||||
- Security best practices followed
|
||||
|
||||
## 16. Risks and Mitigations
|
||||
|
||||
### 16.1 Technical Risks
|
||||
| Risk | Impact | Probability | Mitigation |
|
||||
|------|--------|-------------|------------|
|
||||
| Performance degradation with user filtering | Medium | Low | Index optimization, query caching |
|
||||
| Database migration failures | High | Low | Thorough testing, rollback procedures |
|
||||
| Session management complexity | Medium | Medium | Use proven libraries (PyJWT), extensive testing |
|
||||
| Auth bypass vulnerabilities | High | Low | Security review, penetration testing |
|
||||
|
||||
### 16.2 UX Risks
|
||||
| Risk | Impact | Probability | Mitigation |
|
||||
|------|--------|-------------|------------|
|
||||
| Confusion in migration for existing users | Medium | High | Clear documentation, migration wizard |
|
||||
| Friction from additional login step | Low | High | Remember me option, long session timeout |
|
||||
| Complexity of admin interface | Medium | Medium | Intuitive UI design, user testing |
|
||||
|
||||
### 16.3 Operational Risks
|
||||
| Risk | Impact | Probability | Mitigation |
|
||||
|------|--------|-------------|------------|
|
||||
| Email delivery failures | Low | Medium | Show links in UI, document manual methods |
|
||||
| Lost admin password | High | Low | Document recovery procedure, config reset |
|
||||
| User data conflicts in migration | Medium | Low | Data validation, backup requirements |
|
||||
|
||||
## 17. Implementation Phases
|
||||
|
||||
### Phase 1: Foundation (Weeks 1-2)
|
||||
- Database schema design and migration
|
||||
- Basic authentication service
|
||||
- Password hashing and validation
|
||||
- Session management
|
||||
|
||||
### Phase 2: Backend API (Weeks 3-4)
|
||||
- Authentication endpoints
|
||||
- User management endpoints
|
||||
- Authorization middleware
|
||||
- Update existing endpoints with auth
|
||||
|
||||
### Phase 3: Frontend Auth (Weeks 5-6)
|
||||
- Login page and flow
|
||||
- Administrator setup
|
||||
- Session management
|
||||
- Auth state management
|
||||
|
||||
### Phase 4: Multi-tenancy (Weeks 7-9)
|
||||
- User isolation in all services
|
||||
- Shared boards implementation
|
||||
- Queue permission filtering
|
||||
- Workflow public/private
|
||||
|
||||
### Phase 5: Admin Interface (Weeks 10-11)
|
||||
- User management UI
|
||||
- Board sharing UI
|
||||
- Admin-specific features
|
||||
- User profile page
|
||||
|
||||
### Phase 6: Testing & Polish (Weeks 12-13)
|
||||
- Comprehensive testing
|
||||
- Security audit
|
||||
- Performance optimization
|
||||
- Documentation
|
||||
- Bug fixes
|
||||
|
||||
### Phase 7: Beta & Release (Week 14+)
|
||||
- Beta testing with selected users
|
||||
- Feedback incorporation
|
||||
- Final testing
|
||||
- Release preparation
|
||||
- Documentation finalization
|
||||
|
||||
## 18. Acceptance Criteria
|
||||
|
||||
- [ ] Administrator can set up initial account on first launch
|
||||
- [ ] Users can log in with email and password
|
||||
- [ ] Users can change their password
|
||||
- [ ] Administrators can create, edit, and delete users
|
||||
- [ ] User data is properly isolated (boards, images, workflows)
|
||||
- [ ] Shared boards work correctly with permissions
|
||||
- [ ] Non-admin users cannot access model management
|
||||
- [ ] Queue filtering works correctly for users and admins
|
||||
- [ ] Session management works correctly (expiry, renewal, logout)
|
||||
- [ ] All security tests pass
|
||||
- [ ] API documentation is updated
|
||||
- [ ] User and admin documentation is complete
|
||||
- [ ] Migration from single-user works smoothly
|
||||
- [ ] Performance is acceptable with multiple concurrent users
|
||||
- [ ] Backward compatibility mode works (auth disabled)
|
||||
|
||||
## 19. Design Decisions
|
||||
|
||||
The following design decisions have been approved for implementation:
|
||||
|
||||
1. **OAuth2 Priority**: OAuth2/OpenID Connect integration will be a **future enhancement**. The initial release will focus on username/password authentication to keep scope manageable.
|
||||
|
||||
2. **Email Requirement**: Email/SMTP configuration is **optional**. Many administrators will not have ready access to an outgoing SMTP server. The system will provide fallback mechanisms (showing setup links directly in the admin UI) when email is not configured.
|
||||
|
||||
3. **Data Migration**: During migration from single-user to multi-user mode, the administrator will be given the **option to specify an arbitrary user account** to hold legacy data. The admin account can be used for this purpose if the administrator wishes.
|
||||
|
||||
4. **API Compatibility**: Authentication will be **required on all APIs**, but authentication will not be required if multi-user support is disabled (backward compatibility mode with `auth_enabled: false`).
|
||||
|
||||
5. **Session Storage**: The system will use **JWT tokens with optional server-side session tracking**. This provides scalability while allowing administrators to enable server-side tracking if needed.
|
||||
|
||||
6. **Audit Logging**: The system will **log authentication events and admin actions**. This provides accountability and security monitoring for critical operations.
|
||||
|
||||
## 20. Conclusion
|
||||
|
||||
This specification provides a comprehensive blueprint for implementing multi-user support in InvokeAI. The design prioritizes:
|
||||
|
||||
- **Security**: Proper authentication, authorization, and data isolation
|
||||
- **Usability**: Intuitive UI, smooth migration, minimal friction
|
||||
- **Scalability**: Efficient database design, performant queries
|
||||
- **Maintainability**: Clean architecture, comprehensive testing
|
||||
- **Flexibility**: Future enhancement paths, optional features
|
||||
|
||||
The phased implementation approach allows for iterative development and testing, while the detailed specifications ensure all stakeholders have clear expectations of the final system.
|
||||
399
docs/multiuser/user_guide.md
Normal file
399
docs/multiuser/user_guide.md
Normal file
@@ -0,0 +1,399 @@
|
||||
# InvokeAI Multi-User Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Multi-User mode is a recent feature (introduced in version 6.12), which allows multiple individuals to share a single InvokeAI server while keeping their work separate and organized. Each user has their own username and login password, images, assets, image boards, customization settings and workflows.
|
||||
|
||||
Two types of users are recognized:
|
||||
|
||||
* A user with **Administrator** status can add, remove and modify other users, and can install models. They also have the ability to view the full session queue and pause or kill other users' jobs.
|
||||
* **Non-administrator** users can modify their own profile but not others. They also do not have the ability to install or configure models, but must ask an Administrator to do this task.
|
||||
|
||||
Multiple users can be granted Administrator status.
|
||||
|
||||
***
|
||||
|
||||
## Getting Started
|
||||
|
||||
To activate Multi-User mode, open the `INVOKEAI_ROOT/invokeai.yaml` configuration file in a text editor. Add this line anywhere in the file:
|
||||
```yaml
|
||||
multiuser: true
|
||||
```
|
||||
|
||||
You may also wish to make InvokeAI available to other machines on your local LAN. Add an additional line to `invokeai.yaml`:
|
||||
|
||||
```yaml
|
||||
host: 0.0.0.0
|
||||
```
|
||||
|
||||
Restart the server. It will now be in multi-user mode. If you enabled
|
||||
the `host` option, other users on your home or office LAN will be able
|
||||
to reach it by browsing to the IP address of the machine the backend
|
||||
is running on (`http://host-ip-address:9090`).
|
||||
|
||||
!!! tip "Do not expose InvokeAI to the internet"
|
||||
It is not recommended to expose the InvokeAI host to the internet
|
||||
due to security concerns.
|
||||
|
||||
### Initial Setup (First Time in Multi-User Mode)
|
||||
|
||||
If you're the first person to access a fresh InvokeAI installation in multi-user mode, you'll see the **Administrator Setup** dialog:
|
||||
|
||||

|
||||
|
||||
Now
|
||||
|
||||
1. Enter your email address (this will be your login name)
|
||||
2. Create a display name (this will be the name other users see)
|
||||
3. Choose a strong password that meets the requirements:
|
||||
- At least 8 characters long
|
||||
- Contains uppercase letters
|
||||
- Contains lowercase letters
|
||||
- Contains numbers
|
||||
4. Confirm your password
|
||||
5. Click **Create Administrator Account**
|
||||
|
||||
You'll now be taken to a login screen and can enter the credentials
|
||||
you just created.
|
||||
|
||||
### Adding and Modifying Users
|
||||
|
||||
If you are logged in as Administrator, you can add additional users. Click on the small "person silhouette" icon at the bottom left of the main Invoke screen and select "User Management:"
|
||||
|
||||

|
||||
|
||||
This will take you to the User Management screen...
|
||||
|
||||

|
||||
|
||||
...where you can click "Create User" to add a new user.
|
||||
|
||||

|
||||
|
||||
The User Management screen also allows you to:
|
||||
|
||||
1. Temporarily change a user's status to Inactive, preventing them from logging in to Invoke.
|
||||
2. Edit a user (by clicking on the pencil icon) to change the user's display name or password.
|
||||
3. Permanently delete a user.
|
||||
4. Grant a user Administrator privileges.
|
||||
|
||||
### Command-line User Management Scripts
|
||||
|
||||
Administrators can also use a series of command-line scripts to add, modify, or delete users. If you use the launcher, click the ">" icon to enter the command-line interface. Otherwise, if you are a native command-line user, activate the InvokeAI environment from your terminal.
|
||||
|
||||
The commands are named:
|
||||
|
||||
* **invoke-useradd** -- add a user
|
||||
* **invoke-usermod** -- modify a user
|
||||
* **invoke-userdel** -- delete a user
|
||||
* **invoke-userlist** -- list all users
|
||||
|
||||
Pass the `--help` argument to get the usage of each script. For example:
|
||||
|
||||
```bash
|
||||
> invoke-useradd --help
|
||||
usage: invoke-useradd [-h] [--root ROOT] [--email EMAIL] [--password PASSWORD] [--name NAME] [--admin]
|
||||
|
||||
Add a user to the InvokeAI database
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--root ROOT, -r ROOT Path to the InvokeAI root directory. If omitted, the root is resolved in this order: the $INVOKEAI_ROOT environment
|
||||
variable, the active virtual environment's parent directory, or $HOME/invokeai.
|
||||
--email EMAIL, -e EMAIL
|
||||
User email address
|
||||
--password PASSWORD, -p PASSWORD
|
||||
User password
|
||||
--name NAME, -n NAME User display name (optional)
|
||||
--admin, -a Make user an administrator
|
||||
|
||||
If no arguments are provided, the script will run in interactive mode.
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## Logging in as a Non-Administrative User
|
||||
|
||||
If you are a registered user on the system, enter your email address and password to log in. The Administrator will be able to provide you with the values to use:
|
||||
|
||||

|
||||
|
||||
As an unprivileged user you can do pretty much anything that's allowed under single-user mode -- generating images, using LoRAs, creating and running workflows, creating image boards -- but you are restricted against installing new models, changing low-level server settings, or interfering with other users. More information on user roles is given below.
|
||||
|
||||
### Changing your Profile
|
||||
|
||||
To change your display name or profile, click on the person silhouette icon at the bottom left of the screen and choose "My Profile". This will take you to a screen that lets you change these values. At this time you can change your display name but not your login ID (ordinarily your contact email address).
|
||||
|
||||
***
|
||||
|
||||
## Understanding User Roles
|
||||
|
||||
In single-user mode, you have access to all features without restrictions. In multi-user mode, InvokeAI has two user roles:
|
||||
|
||||
### Regular User
|
||||
|
||||
As a regular user, you can:
|
||||
|
||||
- ✅ Create and manage your own image boards
|
||||
- ✅ Generate images using all AI tools (Linear, Canvas, Upscale, Workflows)
|
||||
- ✅ Create, save, and load your own workflows
|
||||
- ✅ View your own generation queue
|
||||
- ✅ Customize your UI preferences (theme, hotkeys, etc.)
|
||||
- ✅ View available models (read-only access to Model Manager)
|
||||
- ✅ Access shared boards (based on permissions granted to you) (FUTURE FEATURE)
|
||||
- ✅ Access workflows marked as public (FUTURE FEATURE)
|
||||
|
||||
You cannot:
|
||||
|
||||
- ❌ Add, delete, or modify models
|
||||
- ❌ View or modify other users' boards, images, or workflows
|
||||
- ❌ Manage user accounts
|
||||
- ❌ Access system configuration
|
||||
- ❌ View or cancel other users' generation tasks
|
||||
|
||||
!!! tip "The generation queue"
|
||||
When two or more users are accessing InvokeAI at the same time,
|
||||
their image generation jobs will be placed on the session queue on
|
||||
a first-come, first-serve basis. This means that you will have to
|
||||
wait for other users' image rendering jobs to complete before
|
||||
yours will start.
|
||||
|
||||
When another user's job is running, you will see the image
|
||||
generation progress bar and a queue badge that reads `X/Y`, where
|
||||
"X" is the number of jobs you have queued and "Y" is the total
|
||||
number of jobs queued, including your own and others.
|
||||
|
||||
You can also pull up the Queue tab in order to see where your job
|
||||
is in relationship to other queued tasks.
|
||||
|
||||
### Administrator
|
||||
|
||||
Administrators have all regular user capabilities, plus:
|
||||
|
||||
- ✅ Full model management (add, delete, configure models)
|
||||
- ✅ Create and manage user accounts
|
||||
- ✅ View and manage all users' generation queues
|
||||
- ✅ Create and manage shared boards (FUTURE FEATURE)
|
||||
- ✅ Access system configuration
|
||||
- ✅ Grant or revoke admin privileges
|
||||
|
||||
***
|
||||
|
||||
## Working with Your Content in Multi-User Mode
|
||||
|
||||
### Image Boards
|
||||
|
||||
In multi-user model, Image Boards work as before. Each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards are private: you cannot see a board owned by a different user.
|
||||
|
||||
!!! tip "Shared Boards"
|
||||
InvokeAI 6.13 will add support for creating public boards that are accessible to all users.
|
||||
|
||||
The Administrator can see all users Image Boards and their contents.
|
||||
|
||||
### Going From Multi-User to Single-User mode
|
||||
|
||||
If an InvokeAI instance was in multiuser mode and then restarted in single user mode (by setting `multiuser: false` in the configuration file), all users' boards will be consolidated in one place. Any images that were in "Uncategorized" will be merged together into a single Uncategorized board. If, at a later date, the server is restarted in multi-user mode, the boards and images will be separated and restored to their owners.
|
||||
|
||||
### Workflows
|
||||
|
||||
In the current released version (6.12) workflows are always shared among users. Any workflow that you create will be visible to other users and vice-versa, and there is no protection against one user modifying another user's workflow.
|
||||
|
||||
!!! tip "Private and Shared Workflows"
|
||||
InvokeAI 6.13 will provide the ability to create private and shared workflows. A private workflow can only be viewed by the user who created it. At any time, however, the user can designate the workflow *shared*, in which case it can be opened on a read-only basis by all logged-in users.
|
||||
|
||||
|
||||
### The Generation Queue
|
||||
|
||||
The queue shows your pending and running generation tasks.
|
||||
|
||||
**Queue Features:**
|
||||
|
||||
- View your current and completed generations
|
||||
- Cancel pending tasks
|
||||
- Re-run previous generations
|
||||
- Monitor progress in real-time
|
||||
|
||||
**Queue Isolation:**
|
||||
|
||||
- You will see your own queue items, as well as the items generated by
|
||||
either users, but the generation parameters (e.g. prompts) for other
|
||||
users' are hidden for privacy reasons.
|
||||
- Administrators can view all queues for troubleshooting
|
||||
- Your generations won't interfere with other users' tasks
|
||||
|
||||
***
|
||||
|
||||
## Customizing Your Experience
|
||||
|
||||
### Personal Preferences
|
||||
|
||||
Your UI preferences are saved to your account and are restored when you log in:
|
||||
|
||||
- **Theme**: Choose between light and dark modes
|
||||
- **Hotkeys**: Customize keyboard shortcuts
|
||||
- **Canvas Settings**: Default zoom, grid visibility, etc.
|
||||
- **Generation Defaults**: Default values for width, height, steps, etc.
|
||||
|
||||
These settings are stored per-user and won't affect other users.
|
||||
|
||||
***
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Cannot Log In
|
||||
|
||||
**Issue:** Login fails with "Incorrect email or password"
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Verify you're entering the correct email address
|
||||
- Check that Caps Lock is off
|
||||
- Try typing the password slowly to avoid mistakes
|
||||
- Contact your administrator if you've forgotten your password
|
||||
|
||||
**Issue:** Login fails with "Account is disabled"
|
||||
|
||||
**Solution:** Contact your administrator to reactivate your account
|
||||
|
||||
### Session Expired
|
||||
|
||||
**Issue:** You're suddenly logged out and see "Session expired"
|
||||
|
||||
**Explanation:** Sessions expire after 24 hours (or 7 days with "remember me")
|
||||
|
||||
**Solution:** Simply log in again with your credentials
|
||||
|
||||
### Cannot Access Features
|
||||
|
||||
**Issue:** Features like Model Manager show "Admin privileges required"
|
||||
|
||||
**Explanation:** Some features are restricted to administrators
|
||||
|
||||
**Solution:**
|
||||
|
||||
- For model viewing: You can view but not modify models
|
||||
- For user management: Contact an administrator
|
||||
- For system configuration: Contact an administrator
|
||||
|
||||
### Missing Boards or Images
|
||||
|
||||
**Issue:** Boards or images you created are not visible
|
||||
|
||||
**Possible Causes:**
|
||||
|
||||
1. **Filter Applied:** Check if a filter is hiding content
|
||||
2. **Wrong User:** Ensure you're logged in with the correct account
|
||||
3. **Archived Board:** Check the "Show Archived" option
|
||||
|
||||
**Solution:**
|
||||
|
||||
- Clear any active filters
|
||||
- Verify you're logged in as the right user
|
||||
- Check archived items
|
||||
|
||||
### Slow Performance
|
||||
|
||||
**Issue:** Generation or UI feels slower than expected
|
||||
|
||||
**Possible Causes:**
|
||||
|
||||
- Other users generating images simultaneously
|
||||
- Server resource limits
|
||||
- Network latency
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Check the queue to see if others are generating
|
||||
- Wait for current generations to complete
|
||||
- Contact administrator if persistent
|
||||
|
||||
### Generation Stuck in Queue
|
||||
|
||||
**Issue:** Your generation is queued but not starting
|
||||
|
||||
**Possible Causes:**
|
||||
|
||||
- Server is processing other users' generations
|
||||
- Server resources are fully utilized
|
||||
- Technical issue with the server
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Wait for your turn in the queue
|
||||
- Check if your generation is paused
|
||||
- Contact administrator if stuck for extended period
|
||||
|
||||
|
||||
***
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
||||
### Can other users see my images?
|
||||
|
||||
No, unless you add them to a shared board (FUTURE FEATURE). All your personal boards and images are private.
|
||||
|
||||
### Can I share my workflows with others?
|
||||
|
||||
Not directly. Ask your administrator to mark workflows as public if you want to share them.
|
||||
|
||||
### How long do sessions last?
|
||||
|
||||
- 24 hours by default
|
||||
- 7 days if you check "Remember me" during login
|
||||
|
||||
### Can I use the API with multi-user mode?
|
||||
|
||||
Yes, but you'll need to authenticate with a JWT token. See the [API Guide](api_guide.md) for details.
|
||||
|
||||
### What happens if I forget my password?
|
||||
|
||||
Contact your administrator. They can reset your password for you.
|
||||
|
||||
### Can I have multiple sessions?
|
||||
|
||||
Yes, you can log in from multiple devices or browsers simultaneously. All sessions will use the same account and see the same content.
|
||||
|
||||
### Why can't I see the Model Manager "Add Models" tab?
|
||||
|
||||
Regular users can see the Models tab but with read-only access. Check that you're logged in and try refreshing the page.
|
||||
|
||||
### How do I know if I'm an administrator?
|
||||
|
||||
Administrators see an "Admin" badge next to their name in the top-right corner and have access to additional features like User Management.
|
||||
|
||||
### Can I request admin privileges?
|
||||
|
||||
Yes, ask your current administrator to grant you admin
|
||||
privileges. Admin privileges will give you the ability to see all
|
||||
other user's boards and images, as well as to add models and change
|
||||
various server-wide settings.
|
||||
|
||||
## Getting Help
|
||||
|
||||
### Support Channels
|
||||
|
||||
- **Administrator:** Contact your system administrator for account issues
|
||||
- **Documentation:** Check the [FAQ](../faq.md) for common issues
|
||||
- **Community:** Join the [Discord](https://discord.gg/ZmtBAhwWhy) for help
|
||||
- **Bug Reports:** File issues on [GitHub](https://github.com/invoke-ai/InvokeAI/issues)
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
When reporting an issue, include:
|
||||
|
||||
- Your role (regular user or administrator)
|
||||
- What you were trying to do
|
||||
- What happened instead
|
||||
- Any error messages you saw
|
||||
- Your browser and operating system
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Administrator Guide](admin_guide.md) - For administrators managing users and the system
|
||||
- [API Guide](api_guide.md) - For developers using the InvokeAI API
|
||||
- [Multiuser Specification](specification.md) - Technical details about the feature
|
||||
- [InvokeAI Documentation](../index.md) - Main documentation hub
|
||||
|
||||
---
|
||||
|
||||
**Need more help?** Contact your administrator or visit the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).
|
||||
166
invokeai/app/api/auth_dependencies.py
Normal file
166
invokeai/app/api/auth_dependencies.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""FastAPI dependencies for authentication."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.auth.token_service import TokenData, verify_token
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HTTP Bearer token security scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||||
) -> TokenData:
|
||||
"""Get current authenticated user from Bearer token.
|
||||
|
||||
Note: This function accesses ApiDependencies.invoker.services.users directly,
|
||||
which is the established pattern in this codebase. The ApiDependencies.invoker
|
||||
is initialized in the FastAPI lifespan context before any requests are handled.
|
||||
|
||||
Args:
|
||||
credentials: The HTTP authorization credentials containing the Bearer token
|
||||
|
||||
Returns:
|
||||
TokenData containing user information from the token
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is missing, invalid, or expired (401 Unauthorized)
|
||||
"""
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
token_data = verify_token(token)
|
||||
|
||||
if token_data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired authentication token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Verify user still exists and is active
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(token_data.user_id)
|
||||
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User account is inactive or does not exist",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return token_data
|
||||
|
||||
|
||||
async def get_current_user_or_default(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||||
) -> TokenData:
|
||||
"""Get current authenticated user from Bearer token, or return a default system user if not authenticated.
|
||||
|
||||
This dependency is useful for endpoints that should work in both single-user and multiuser modes.
|
||||
|
||||
When multiuser mode is disabled (default), this always returns a system user with admin privileges,
|
||||
allowing unrestricted access to all operations.
|
||||
|
||||
When multiuser mode is enabled, authentication is required and this function validates the token,
|
||||
returning authenticated user data or raising 401 Unauthorized if no valid credentials are provided.
|
||||
|
||||
Args:
|
||||
credentials: The HTTP authorization credentials containing the Bearer token
|
||||
|
||||
Returns:
|
||||
TokenData containing user information from the token, or system user in single-user mode
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 Unauthorized if in multiuser mode and credentials are missing, invalid, or user is inactive
|
||||
"""
|
||||
# Get configuration to check if multiuser is enabled
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# In single-user mode (multiuser=False), always return system user with admin privileges
|
||||
if not config.multiuser:
|
||||
return TokenData(user_id="system", email="system@system.invokeai", is_admin=True)
|
||||
|
||||
# Multiuser mode is enabled - validate credentials
|
||||
if credentials is None:
|
||||
# In multiuser mode, authentication is required
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
token = credentials.credentials
|
||||
token_data = verify_token(token)
|
||||
|
||||
if token_data is None:
|
||||
# Invalid token in multiuser mode - reject
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token")
|
||||
|
||||
# Verify user still exists and is active
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(token_data.user_id)
|
||||
|
||||
if user is None or not user.is_active:
|
||||
# User doesn't exist or is inactive in multiuser mode - reject
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
|
||||
|
||||
return token_data
|
||||
|
||||
|
||||
async def require_admin(
|
||||
current_user: Annotated[TokenData, Depends(get_current_user)],
|
||||
) -> TokenData:
|
||||
"""Require admin role for the current user.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user's token data
|
||||
|
||||
Returns:
|
||||
The token data if user is an admin
|
||||
|
||||
Raises:
|
||||
HTTPException: If user does not have admin privileges (403 Forbidden)
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_admin_or_default(
|
||||
current_user: Annotated[TokenData, Depends(get_current_user_or_default)],
|
||||
) -> TokenData:
|
||||
"""Require admin role for the current user, or return default system admin in single-user mode.
|
||||
|
||||
This dependency is useful for admin-only endpoints that should work in both single-user and multiuser modes.
|
||||
|
||||
When multiuser mode is disabled (default), this always returns a system user with admin privileges.
|
||||
When multiuser mode is enabled, this validates that the authenticated user has admin privileges.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user's token data (or default system user)
|
||||
|
||||
Returns:
|
||||
The token data if user is an admin (or system user in single-user mode)
|
||||
|
||||
Raises:
|
||||
HTTPException: If user does not have admin privileges (403 Forbidden) in multiuser mode
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
|
||||
return current_user
|
||||
|
||||
|
||||
# Type aliases for convenient use in route dependencies
|
||||
CurrentUser = Annotated[TokenData, Depends(get_current_user)]
|
||||
CurrentUserOrDefault = Annotated[TokenData, Depends(get_current_user_or_default)]
|
||||
AdminUser = Annotated[TokenData, Depends(require_admin)]
|
||||
AdminUserOrDefault = Annotated[TokenData, Depends(require_admin_or_default)]
|
||||
@@ -5,6 +5,8 @@ from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.app_settings import AppSettingsService
|
||||
from invokeai.app.services.auth.token_service import set_jwt_secret
|
||||
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
||||
from invokeai.app.services.board_images.board_images_default import BoardImagesService
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
@@ -40,6 +42,7 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
@@ -101,6 +104,12 @@ class ApiDependencies:
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
# Initialize JWT secret from database
|
||||
app_settings = AppSettingsService(db=db)
|
||||
jwt_secret = app_settings.get_jwt_secret()
|
||||
set_jwt_secret(jwt_secret)
|
||||
logger.info("JWT secret loaded from database")
|
||||
|
||||
configuration = config
|
||||
logger = logger
|
||||
|
||||
@@ -155,6 +164,7 @@ class ApiDependencies:
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
|
||||
client_state_persistence = ClientStatePersistenceSqlite(db=db)
|
||||
users = UserService(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -186,6 +196,7 @@ class ApiDependencies:
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
workflow_thumbnails=workflow_thumbnails,
|
||||
client_state_persistence=client_state_persistence,
|
||||
users=users,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import Response
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from starlette.types import Scope
|
||||
|
||||
|
||||
class NoCacheStaticFiles(StaticFiles):
|
||||
@@ -12,6 +14,10 @@ class NoCacheStaticFiles(StaticFiles):
|
||||
|
||||
Static files include the javascript bundles, fonts, locales, and some images. Generated
|
||||
images are not included, as they are served by a router.
|
||||
|
||||
This class also implements proper SPA (Single Page Application) routing by serving index.html
|
||||
for any routes that don't match static files, enabling client-side routing to work correctly
|
||||
in production builds.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
@@ -26,3 +32,19 @@ class NoCacheStaticFiles(StaticFiles):
|
||||
resp.headers.setdefault("Pragma", self.pragma)
|
||||
resp.headers.setdefault("Expires", self.expires)
|
||||
return resp
|
||||
|
||||
async def get_response(self, path: str, scope: Scope) -> Response:
|
||||
"""
|
||||
Override get_response to implement SPA routing.
|
||||
|
||||
When a file is not found and html mode is enabled, serve index.html instead of raising a 404.
|
||||
This allows client-side routing to work correctly in SPAs.
|
||||
"""
|
||||
try:
|
||||
return await super().get_response(path, scope)
|
||||
except HTTPException as exc:
|
||||
# If the file is not found (404) and html mode is enabled, serve index.html
|
||||
# This allows client-side routing to handle the path
|
||||
if exc.status_code == 404 and self.html:
|
||||
return await super().get_response("index.html", scope)
|
||||
raise
|
||||
|
||||
514
invokeai/app/api/routers/auth.py
Normal file
514
invokeai/app/api/routers/auth.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import timedelta
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, status
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUser, CurrentUser
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.auth.token_service import TokenData, create_access_token
|
||||
from invokeai.app.services.users.users_common import (
|
||||
UserCreateRequest,
|
||||
UserDTO,
|
||||
UserUpdateRequest,
|
||||
validate_email_with_special_domains,
|
||||
)
|
||||
|
||||
auth_router = APIRouter(prefix="/v1/auth", tags=["authentication"])
|
||||
|
||||
# Token expiration constants (in days)
|
||||
TOKEN_EXPIRATION_NORMAL = 1 # 1 day for normal login
|
||||
TOKEN_EXPIRATION_REMEMBER_ME = 7 # 7 days for "remember me" login
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Request body for user login."""
|
||||
|
||||
email: str = Field(description="User email address")
|
||||
password: str = Field(description="User password")
|
||||
remember_me: bool = Field(default=False, description="Whether to extend session duration")
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
"""Validate email address, allowing special-use domains."""
|
||||
return validate_email_with_special_domains(v)
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Response from successful login."""
|
||||
|
||||
token: str = Field(description="JWT access token")
|
||||
user: UserDTO = Field(description="User information")
|
||||
expires_in: int = Field(description="Token expiration time in seconds")
|
||||
|
||||
|
||||
class SetupRequest(BaseModel):
|
||||
"""Request body for initial admin setup."""
|
||||
|
||||
email: str = Field(description="Admin email address")
|
||||
display_name: str | None = Field(default=None, description="Admin display name")
|
||||
password: str = Field(description="Admin password")
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
"""Validate email address, allowing special-use domains."""
|
||||
return validate_email_with_special_domains(v)
|
||||
|
||||
|
||||
class SetupResponse(BaseModel):
|
||||
"""Response from successful admin setup."""
|
||||
|
||||
success: bool = Field(description="Whether setup was successful")
|
||||
user: UserDTO = Field(description="Created admin user information")
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
"""Response from logout."""
|
||||
|
||||
success: bool = Field(description="Whether logout was successful")
|
||||
|
||||
|
||||
class SetupStatusResponse(BaseModel):
|
||||
"""Response for setup status check."""
|
||||
|
||||
setup_required: bool = Field(description="Whether initial setup is required")
|
||||
multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled")
|
||||
|
||||
|
||||
@auth_router.get("/status", response_model=SetupStatusResponse)
|
||||
async def get_setup_status() -> SetupStatusResponse:
|
||||
"""Check if initial administrator setup is required.
|
||||
|
||||
Returns:
|
||||
SetupStatusResponse indicating whether setup is needed and multiuser mode status
|
||||
"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# If multiuser is disabled, setup is never required
|
||||
if not config.multiuser:
|
||||
return SetupStatusResponse(setup_required=False, multiuser_enabled=False)
|
||||
|
||||
# In multiuser mode, check if an admin exists
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
setup_required = not user_service.has_admin()
|
||||
|
||||
return SetupStatusResponse(setup_required=setup_required, multiuser_enabled=True)
|
||||
|
||||
|
||||
@auth_router.post("/login", response_model=LoginResponse)
|
||||
async def login(
|
||||
request: Annotated[LoginRequest, Body(description="Login credentials")],
|
||||
) -> LoginResponse:
|
||||
"""Authenticate user and return access token.
|
||||
|
||||
Args:
|
||||
request: Login credentials (email and password)
|
||||
|
||||
Returns:
|
||||
LoginResponse containing JWT token and user information
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if credentials are invalid or user is inactive
|
||||
HTTPException: 403 if multiuser mode is disabled
|
||||
"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# Check if multiuser is enabled
|
||||
if not config.multiuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Multiuser mode is disabled. Authentication is not required in single-user mode.",
|
||||
)
|
||||
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.authenticate(request.email, request.password)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User account is disabled")
|
||||
|
||||
# Create token with appropriate expiration
|
||||
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME if request.remember_me else TOKEN_EXPIRATION_NORMAL)
|
||||
token_data = TokenData(
|
||||
user_id=user.user_id,
|
||||
email=user.email,
|
||||
is_admin=user.is_admin,
|
||||
)
|
||||
token = create_access_token(token_data, expires_delta)
|
||||
|
||||
return LoginResponse(
|
||||
token=token,
|
||||
user=user,
|
||||
expires_in=int(expires_delta.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
@auth_router.post("/logout", response_model=LogoutResponse)
|
||||
async def logout(
|
||||
current_user: CurrentUser,
|
||||
) -> LogoutResponse:
|
||||
"""Logout current user.
|
||||
|
||||
Currently a no-op since we use stateless JWT tokens. For token invalidation in
|
||||
future implementations, consider:
|
||||
- Token blacklist: Store invalidated tokens in Redis/database with expiration
|
||||
- Token versioning: Add version field to user record, increment on logout
|
||||
- Short-lived tokens: Use refresh token pattern with token rotation
|
||||
- Session storage: Track active sessions server-side for revocation
|
||||
|
||||
Args:
|
||||
current_user: The authenticated user (validates token)
|
||||
|
||||
Returns:
|
||||
LogoutResponse indicating success
|
||||
"""
|
||||
# TODO: Implement token invalidation when server-side session management is added
|
||||
# For now, this is a no-op since we use stateless JWT tokens
|
||||
return LogoutResponse(success=True)
|
||||
|
||||
|
||||
@auth_router.get("/me", response_model=UserDTO)
|
||||
async def get_current_user_info(
|
||||
current_user: CurrentUser,
|
||||
) -> UserDTO:
|
||||
"""Get current authenticated user's information.
|
||||
|
||||
Args:
|
||||
current_user: The authenticated user's token data
|
||||
|
||||
Returns:
|
||||
UserDTO containing user information
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if user is not found (should not happen normally)
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(current_user.user_id)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@auth_router.post("/setup", response_model=SetupResponse)
|
||||
async def setup_admin(
|
||||
request: Annotated[SetupRequest, Body(description="Admin account details")],
|
||||
) -> SetupResponse:
|
||||
"""Set up initial administrator account.
|
||||
|
||||
This endpoint can only be called once, when no admin user exists. It creates
|
||||
the first admin user for the system.
|
||||
|
||||
Args:
|
||||
request: Admin account details (email, display_name, password)
|
||||
|
||||
Returns:
|
||||
SetupResponse containing the created admin user
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if admin already exists or password is weak
|
||||
HTTPException: 403 if multiuser mode is disabled
|
||||
"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# Check if multiuser is enabled
|
||||
if not config.multiuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Multiuser mode is disabled. Admin setup is not required in single-user mode.",
|
||||
)
|
||||
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
|
||||
# Check if any admin exists
|
||||
if user_service.has_admin():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Administrator account already configured",
|
||||
)
|
||||
|
||||
# Create admin user - this will validate password strength
|
||||
try:
|
||||
user_data = UserCreateRequest(
|
||||
email=request.email,
|
||||
display_name=request.display_name,
|
||||
password=request.password,
|
||||
is_admin=True,
|
||||
)
|
||||
user = user_service.create_admin(user_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
return SetupResponse(success=True, user=user)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User management models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PASSWORD_ALPHABET = string.ascii_letters + string.digits + string.punctuation
|
||||
|
||||
|
||||
class AdminUserCreateRequest(BaseModel):
|
||||
"""Request body for admin to create a new user."""
|
||||
|
||||
email: str = Field(description="User email address")
|
||||
display_name: str | None = Field(default=None, description="Display name")
|
||||
password: str = Field(description="User password")
|
||||
is_admin: bool = Field(default=False, description="Whether user should have admin privileges")
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
"""Validate email address, allowing special-use domains."""
|
||||
return validate_email_with_special_domains(v)
|
||||
|
||||
|
||||
class AdminUserUpdateRequest(BaseModel):
|
||||
"""Request body for admin to update any user."""
|
||||
|
||||
display_name: str | None = Field(default=None, description="Display name")
|
||||
password: str | None = Field(default=None, description="New password")
|
||||
is_admin: bool | None = Field(default=None, description="Whether user should have admin privileges")
|
||||
is_active: bool | None = Field(default=None, description="Whether user account should be active")
|
||||
|
||||
|
||||
class UserProfileUpdateRequest(BaseModel):
|
||||
"""Request body for a user to update their own profile."""
|
||||
|
||||
display_name: str | None = Field(default=None, description="New display name")
|
||||
current_password: str | None = Field(default=None, description="Current password (required when changing password)")
|
||||
new_password: str | None = Field(default=None, description="New password")
|
||||
|
||||
|
||||
class GeneratePasswordResponse(BaseModel):
|
||||
"""Response containing a generated password."""
|
||||
|
||||
password: str = Field(description="Generated strong password")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@auth_router.get("/generate-password", response_model=GeneratePasswordResponse)
|
||||
async def generate_password(
|
||||
current_user: CurrentUser,
|
||||
) -> GeneratePasswordResponse:
|
||||
"""Generate a strong random password.
|
||||
|
||||
Returns a cryptographically secure random password of 16 characters
|
||||
containing uppercase, lowercase, digits, and punctuation.
|
||||
"""
|
||||
# Ensure the generated password always meets strength requirements:
|
||||
# at least one uppercase, one lowercase, one digit, one special char.
|
||||
while True:
|
||||
password = "".join(secrets.choice(_PASSWORD_ALPHABET) for _ in range(16))
|
||||
if (
|
||||
any(c.isupper() for c in password)
|
||||
and any(c.islower() for c in password)
|
||||
and any(c.isdigit() for c in password)
|
||||
):
|
||||
return GeneratePasswordResponse(password=password)
|
||||
|
||||
|
||||
@auth_router.get("/users", response_model=list[UserDTO])
|
||||
async def list_users(
|
||||
current_user: AdminUser,
|
||||
) -> list[UserDTO]:
|
||||
"""List all users. Requires admin privileges.
|
||||
|
||||
The internal 'system' user (created for backward compatibility) is excluded
|
||||
from the results since it cannot be managed through this interface.
|
||||
|
||||
Returns:
|
||||
List of all real users (system user excluded)
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
return [u for u in user_service.list_users() if u.user_id != "system"]
|
||||
|
||||
|
||||
@auth_router.post("/users", response_model=UserDTO, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
request: Annotated[AdminUserCreateRequest, Body(description="New user details")],
|
||||
current_user: AdminUser,
|
||||
) -> UserDTO:
|
||||
"""Create a new user. Requires admin privileges.
|
||||
|
||||
Args:
|
||||
request: New user details
|
||||
|
||||
Returns:
|
||||
The created user
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if email already exists or password is weak
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
try:
|
||||
user_data = UserCreateRequest(
|
||||
email=request.email,
|
||||
display_name=request.display_name,
|
||||
password=request.password,
|
||||
is_admin=request.is_admin,
|
||||
)
|
||||
return user_service.create(user_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
|
||||
@auth_router.get("/users/{user_id}", response_model=UserDTO)
|
||||
async def get_user(
|
||||
user_id: Annotated[str, Path(description="User ID")],
|
||||
current_user: AdminUser,
|
||||
) -> UserDTO:
|
||||
"""Get a user by ID. Requires admin privileges.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
The user
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
@auth_router.patch("/users/{user_id}", response_model=UserDTO)
|
||||
async def update_user(
|
||||
user_id: Annotated[str, Path(description="User ID")],
|
||||
request: Annotated[AdminUserUpdateRequest, Body(description="User fields to update")],
|
||||
current_user: AdminUser,
|
||||
) -> UserDTO:
|
||||
"""Update a user. Requires admin privileges.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
request: Fields to update
|
||||
|
||||
Returns:
|
||||
The updated user
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if password is weak
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
try:
|
||||
changes = UserUpdateRequest(
|
||||
display_name=request.display_name,
|
||||
password=request.password,
|
||||
is_admin=request.is_admin,
|
||||
is_active=request.is_active,
|
||||
)
|
||||
return user_service.update(user_id, changes)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
|
||||
@auth_router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(
|
||||
user_id: Annotated[str, Path(description="User ID")],
|
||||
current_user: AdminUser,
|
||||
) -> None:
|
||||
"""Delete a user. Requires admin privileges.
|
||||
|
||||
Admins can delete any user including other admins, but cannot delete the last
|
||||
remaining admin.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if attempting to delete the last admin
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
# Prevent deleting the last active admin
|
||||
if user.is_admin and user.is_active and user_service.count_admins() <= 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete the last administrator",
|
||||
)
|
||||
|
||||
try:
|
||||
user_service.delete(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
|
||||
@auth_router.patch("/me", response_model=UserDTO)
|
||||
async def update_current_user(
|
||||
request: Annotated[UserProfileUpdateRequest, Body(description="Profile fields to update")],
|
||||
current_user: CurrentUser,
|
||||
) -> UserDTO:
|
||||
"""Update the current user's own profile.
|
||||
|
||||
To change the password, both ``current_password`` and ``new_password`` must
|
||||
be provided. The current password is verified before the change is applied.
|
||||
|
||||
Args:
|
||||
request: Profile fields to update
|
||||
current_user: The authenticated user
|
||||
|
||||
Returns:
|
||||
The updated user
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if current password is incorrect or new password is weak
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
|
||||
# Verify current password when attempting a password change
|
||||
if request.new_password is not None:
|
||||
if not request.current_password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is required to set a new password",
|
||||
)
|
||||
|
||||
# Re-authenticate to verify the current password
|
||||
user = user_service.get(current_user.user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
authenticated = user_service.authenticate(user.email, request.current_password)
|
||||
if authenticated is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect",
|
||||
)
|
||||
|
||||
try:
|
||||
changes = UserUpdateRequest(
|
||||
display_name=request.display_name,
|
||||
password=request.new_password,
|
||||
)
|
||||
return user_service.update(current_user.user_id, changes)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
@@ -4,6 +4,7 @@ from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
@@ -32,11 +33,12 @@ class DeleteBoardResult(BaseModel):
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def create_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_name: str = Query(description="The name of the board to create", max_length=300),
|
||||
) -> BoardDTO:
|
||||
"""Creates a board"""
|
||||
"""Creates a board for the current user"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name, user_id=current_user.user_id)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
@@ -44,16 +46,21 @@ async def create_board(
|
||||
|
||||
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
|
||||
async def get_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Path(description="The id of board to get"),
|
||||
) -> BoardDTO:
|
||||
"""Gets a board"""
|
||||
"""Gets a board (user must have access to it)"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if not current_user.is_admin and result.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this board")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@boards_router.patch(
|
||||
"/{board_id}",
|
||||
@@ -67,10 +74,19 @@ async def get_board(
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def update_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Path(description="The id of board to update"),
|
||||
changes: BoardChanges = Body(description="The changes to apply to the board"),
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
"""Updates a board (user must have access to it)"""
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if not current_user.is_admin and board.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this board")
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||
return result
|
||||
@@ -80,10 +96,19 @@ async def update_board(
|
||||
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
|
||||
async def delete_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
|
||||
) -> DeleteBoardResult:
|
||||
"""Deletes a board"""
|
||||
"""Deletes a board (user must have access to it)"""
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if not current_user.is_admin and board.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to delete this board")
|
||||
|
||||
try:
|
||||
if include_images is True:
|
||||
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
@@ -120,6 +145,7 @@ async def delete_board(
|
||||
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||
)
|
||||
async def list_boards(
|
||||
current_user: CurrentUserOrDefault,
|
||||
order_by: BoardRecordOrderBy = Query(default=BoardRecordOrderBy.CreatedAt, description="The attribute to order by"),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The direction to order by"),
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
@@ -127,11 +153,15 @@ async def list_boards(
|
||||
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
||||
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
"""Gets a list of boards for the current user, including shared boards. Admin users see all boards."""
|
||||
if all:
|
||||
return ApiDependencies.invoker.services.boards.get_all(order_by, direction, include_archived)
|
||||
return ApiDependencies.invoker.services.boards.get_all(
|
||||
current_user.user_id, current_user.is_admin, order_by, direction, include_archived
|
||||
)
|
||||
elif offset is not None and limit is not None:
|
||||
return ApiDependencies.invoker.services.boards.get_many(order_by, direction, offset, limit, include_archived)
|
||||
return ApiDependencies.invoker.services.boards.get_many(
|
||||
current_user.user_id, current_user.is_admin, order_by, direction, offset, limit, include_archived
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -145,12 +175,22 @@ async def list_boards(
|
||||
response_model=list[str],
|
||||
)
|
||||
async def list_all_board_image_names(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Path(description="The id of the board or 'none' for uncategorized images"),
|
||||
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
|
||||
) -> list[str]:
|
||||
"""Gets a list of images for a board"""
|
||||
|
||||
if board_id != "none":
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if not current_user.is_admin and board.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this board")
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id,
|
||||
categories,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
@@ -13,15 +14,16 @@ client_state_router = APIRouter(prefix="/v1/client_state", tags=["client_state"]
|
||||
response_model=str | None,
|
||||
)
|
||||
async def get_client_state_by_key(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
|
||||
key: str = Query(..., description="Key to get"),
|
||||
) -> str | None:
|
||||
"""Gets the client state"""
|
||||
"""Gets the client state for the current user (or system user if not authenticated)"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
|
||||
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(current_user.user_id, key)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
raise HTTPException(status_code=500, detail="Error getting client state")
|
||||
|
||||
|
||||
@client_state_router.post(
|
||||
@@ -30,13 +32,14 @@ async def get_client_state_by_key(
|
||||
response_model=str,
|
||||
)
|
||||
async def set_client_state(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
|
||||
key: str = Query(..., description="Key to set"),
|
||||
value: str = Body(..., description="Stringified value to set"),
|
||||
) -> str:
|
||||
"""Sets the client state"""
|
||||
"""Sets the client state for the current user (or system user if not authenticated)"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
|
||||
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(current_user.user_id, key, value)
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
@@ -48,11 +51,12 @@ async def set_client_state(
|
||||
responses={204: {"description": "Client state deleted"}},
|
||||
)
|
||||
async def delete_client_state(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
|
||||
) -> None:
|
||||
"""Deletes the client state"""
|
||||
"""Deletes the client state for the current user (or system user if not authenticated)"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
|
||||
ApiDependencies.invoker.services.client_state_persistence.delete(current_user.user_id)
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error deleting client state")
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
@@ -61,6 +62,7 @@ class ResizeToDimensions(BaseModel):
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def upload_image(
|
||||
current_user: CurrentUserOrDefault,
|
||||
file: UploadFile,
|
||||
request: Request,
|
||||
response: Response,
|
||||
@@ -80,7 +82,7 @@ async def upload_image(
|
||||
embed=True,
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
"""Uploads an image for the current user"""
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
@@ -133,6 +135,7 @@ async def upload_image(
|
||||
workflow=extracted_metadata.invokeai_workflow,
|
||||
graph=extracted_metadata.invokeai_graph,
|
||||
is_intermediate=is_intermediate,
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
@@ -373,6 +376,7 @@ async def get_image_urls(
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_image_dtos(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
@@ -386,10 +390,19 @@ async def list_image_dtos(
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image DTOs"""
|
||||
"""Gets a list of image DTOs for the current user"""
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
|
||||
offset,
|
||||
limit,
|
||||
starred_first,
|
||||
order_dir,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term,
|
||||
current_user.user_id,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
@@ -567,6 +580,7 @@ async def get_bulk_download_item(
|
||||
|
||||
@images_router.get("/names", operation_id="get_image_names")
|
||||
async def get_image_names(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
@@ -589,6 +603,8 @@ async def get_image_names(
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
user_id=current_user.user_id,
|
||||
is_admin=current_user.is_admin,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
|
||||
@@ -19,6 +19,7 @@ from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
@@ -229,6 +230,7 @@ async def get_model_record(
|
||||
)
|
||||
async def reidentify_model(
|
||||
key: Annotated[str, Path(description="Key of the model to reidentify.")],
|
||||
current_admin: AdminUserOrDefault,
|
||||
) -> AnyModelConfig:
|
||||
"""Attempt to reidentify a model by re-probing its weights file."""
|
||||
try:
|
||||
@@ -244,11 +246,13 @@ async def reidentify_model(
|
||||
raise InvalidModelException("Unable to identify model format")
|
||||
|
||||
# Retain user-editable fields from the original config
|
||||
result.config.path = config.path
|
||||
result.config.key = config.key
|
||||
result.config.name = config.name
|
||||
result.config.description = config.description
|
||||
result.config.cover_image = config.cover_image
|
||||
result.config.trigger_phrases = config.trigger_phrases
|
||||
if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"):
|
||||
result.config.trigger_phrases = config.trigger_phrases
|
||||
result.config.source = config.source
|
||||
result.config.source_type = config.source_type
|
||||
|
||||
@@ -364,6 +368,7 @@ async def get_hugging_face_models(
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
|
||||
current_admin: AdminUserOrDefault,
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@@ -426,6 +431,7 @@ async def get_model_image(
|
||||
async def update_model_image(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
image: UploadFile,
|
||||
current_admin: AdminUserOrDefault,
|
||||
) -> None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
@@ -459,6 +465,7 @@ async def update_model_image(
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model(
|
||||
current_admin: AdminUserOrDefault,
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@@ -501,6 +508,7 @@ class BulkDeleteModelsResponse(BaseModel):
|
||||
status_code=200,
|
||||
)
|
||||
async def bulk_delete_models(
|
||||
current_admin: AdminUserOrDefault,
|
||||
request: BulkDeleteModelsRequest = Body(description="List of model keys to delete"),
|
||||
) -> BulkDeleteModelsResponse:
|
||||
"""
|
||||
@@ -542,6 +550,7 @@ async def bulk_delete_models(
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model_image(
|
||||
current_admin: AdminUserOrDefault,
|
||||
key: str = Path(description="Unique key of model image to remove from model_images directory."),
|
||||
) -> None:
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@@ -567,6 +576,7 @@ async def delete_model_image(
|
||||
status_code=201,
|
||||
)
|
||||
async def install_model(
|
||||
current_admin: AdminUserOrDefault,
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
@@ -637,6 +647,7 @@ async def install_model(
|
||||
response_class=HTMLResponse,
|
||||
)
|
||||
async def install_hugging_face_model(
|
||||
current_admin: AdminUserOrDefault,
|
||||
source: str = Query(description="HuggingFace repo_id to install"),
|
||||
) -> HTMLResponse:
|
||||
"""Install a Hugging Face model using a string identifier."""
|
||||
@@ -809,7 +820,10 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
async def cancel_model_install_job(
|
||||
current_admin: AdminUserOrDefault,
|
||||
id: int = Path(description="Model install job ID"),
|
||||
) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
@@ -910,7 +924,7 @@ async def restart_model_install_file(
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
async def prune_model_install_jobs(current_admin: AdminUserOrDefault) -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
@@ -930,6 +944,7 @@ async def prune_model_install_jobs() -> Response:
|
||||
},
|
||||
)
|
||||
async def convert_model(
|
||||
current_admin: AdminUserOrDefault,
|
||||
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
@@ -1111,7 +1126,7 @@ async def get_stats() -> Optional[CacheStats]:
|
||||
operation_id="empty_model_cache",
|
||||
status_code=200,
|
||||
)
|
||||
async def empty_model_cache() -> None:
|
||||
async def empty_model_cache(current_admin: AdminUserOrDefault) -> None:
|
||||
"""Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped."""
|
||||
# Request 1000GB of room in order to force the cache to drop all models.
|
||||
ApiDependencies.invoker.services.logger.info("Emptying model cache.")
|
||||
@@ -1128,11 +1143,11 @@ class HFTokenHelper:
|
||||
@classmethod
|
||||
def get_status(cls) -> HFTokenStatus:
|
||||
try:
|
||||
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
|
||||
# Valid token!
|
||||
return HFTokenStatus.VALID
|
||||
# No token set
|
||||
return HFTokenStatus.INVALID
|
||||
token = huggingface_hub.get_token()
|
||||
if not token:
|
||||
return HFTokenStatus.INVALID
|
||||
huggingface_hub.whoami(token=token)
|
||||
return HFTokenStatus.VALID
|
||||
except Exception:
|
||||
return HFTokenStatus.UNKNOWN
|
||||
|
||||
@@ -1161,6 +1176,7 @@ async def get_hf_login_status() -> HFTokenStatus:
|
||||
|
||||
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
|
||||
async def do_hf_login(
|
||||
current_admin: AdminUserOrDefault,
|
||||
token: str = Body(description="Hugging Face token to use for login", embed=True),
|
||||
) -> HFTokenStatus:
|
||||
HFTokenHelper.set_token(token)
|
||||
@@ -1173,7 +1189,7 @@ async def do_hf_login(
|
||||
|
||||
|
||||
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
|
||||
async def reset_hf_token() -> HFTokenStatus:
|
||||
async def reset_hf_token(current_admin: AdminUserOrDefault) -> HFTokenStatus:
|
||||
return HFTokenHelper.reset_token()
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUserOrDefault, CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@@ -24,6 +25,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
@@ -36,6 +38,40 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
def sanitize_queue_item_for_user(
|
||||
queue_item: SessionQueueItem, current_user_id: str, is_admin: bool
|
||||
) -> SessionQueueItem:
|
||||
"""Sanitize queue item for non-admin users viewing other users' items.
|
||||
|
||||
For non-admin users viewing queue items belonging to other users,
|
||||
the field_values, session graph, and workflow should be hidden/cleared to protect privacy.
|
||||
|
||||
Args:
|
||||
queue_item: The queue item to sanitize
|
||||
current_user_id: The ID of the current user viewing the item
|
||||
is_admin: Whether the current user is an admin
|
||||
|
||||
Returns:
|
||||
The sanitized queue item (sensitive fields cleared if necessary)
|
||||
"""
|
||||
# Admins and item owners can see everything
|
||||
if is_admin or queue_item.user_id == current_user_id:
|
||||
return queue_item
|
||||
|
||||
# For non-admins viewing other users' items, clear sensitive fields
|
||||
# Create a shallow copy to avoid mutating the original
|
||||
sanitized_item = queue_item.model_copy(deep=False)
|
||||
sanitized_item.field_values = None
|
||||
sanitized_item.workflow = None
|
||||
# Clear the session graph by replacing it with an empty graph execution state
|
||||
# This prevents information leakage through the generation graph
|
||||
sanitized_item.session = GraphExecutionState(
|
||||
id=queue_item.session.id,
|
||||
graph=Graph(),
|
||||
)
|
||||
return sanitized_item
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
@@ -44,14 +80,15 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
},
|
||||
)
|
||||
async def enqueue_batch(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
"""Processes a batch and enqueues the output graphs for execution for the current user."""
|
||||
try:
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
queue_id=queue_id, batch=batch, prepend=prepend, user_id=current_user.user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
|
||||
@@ -65,15 +102,18 @@ async def enqueue_batch(
|
||||
},
|
||||
)
|
||||
async def list_all_queue_items(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
|
||||
items = ApiDependencies.invoker.services.session_queue.list_all_queue_items(
|
||||
queue_id=queue_id,
|
||||
destination=destination,
|
||||
)
|
||||
# Sanitize items for non-admin users
|
||||
return [sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) for item in items]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
|
||||
|
||||
@@ -102,6 +142,7 @@ async def get_queue_item_ids(
|
||||
responses={200: {"model": list[SessionQueueItem]}},
|
||||
)
|
||||
async def get_queue_items_by_item_ids(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_ids: list[int] = Body(
|
||||
embed=True, description="Object containing list of queue item ids to fetch queue items for"
|
||||
@@ -118,7 +159,9 @@ async def get_queue_items_by_item_ids(
|
||||
queue_item = session_queue_service.get_queue_item(item_id=item_id)
|
||||
if queue_item.queue_id != queue_id: # Auth protection for items from other queues
|
||||
continue
|
||||
queue_items.append(queue_item)
|
||||
# Sanitize item for non-admin users
|
||||
sanitized_item = sanitize_queue_item_for_user(queue_item, current_user.user_id, current_user.is_admin)
|
||||
queue_items.append(sanitized_item)
|
||||
except Exception:
|
||||
# Skip missing queue items - they may have been deleted between item id fetch and queue item fetch
|
||||
continue
|
||||
@@ -134,9 +177,10 @@ async def get_queue_items_by_item_ids(
|
||||
responses={200: {"model": SessionProcessorStatus}},
|
||||
)
|
||||
async def resume(
|
||||
current_user: AdminUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Resumes session processor"""
|
||||
"""Resumes session processor. Admin only."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
except Exception as e:
|
||||
@@ -148,10 +192,11 @@ async def resume(
|
||||
operation_id="pause",
|
||||
responses={200: {"model": SessionProcessorStatus}},
|
||||
)
|
||||
async def Pause(
|
||||
async def pause(
|
||||
current_user: AdminUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Pauses session processor"""
|
||||
"""Pauses session processor. Admin only."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
except Exception as e:
|
||||
@@ -164,11 +209,16 @@ async def Pause(
|
||||
responses={200: {"model": CancelAllExceptCurrentResult}},
|
||||
)
|
||||
async def cancel_all_except_current(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> CancelAllExceptCurrentResult:
|
||||
"""Immediately cancels all queue items except in-processing items"""
|
||||
"""Immediately cancels all queue items except in-processing items. Non-admin users can only cancel their own items."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
|
||||
# Admin users can cancel all items, non-admin users can only cancel their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(
|
||||
queue_id=queue_id, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
|
||||
|
||||
@@ -179,11 +229,16 @@ async def cancel_all_except_current(
|
||||
responses={200: {"model": DeleteAllExceptCurrentResult}},
|
||||
)
|
||||
async def delete_all_except_current(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> DeleteAllExceptCurrentResult:
|
||||
"""Immediately deletes all queue items except in-processing items"""
|
||||
"""Immediately deletes all queue items except in-processing items. Non-admin users can only delete their own items."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
|
||||
# Admin users can delete all items, non-admin users can only delete their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(
|
||||
queue_id=queue_id, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
|
||||
|
||||
@@ -194,13 +249,16 @@ async def delete_all_except_current(
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_batch_ids(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Immediately cancels all queue items from the given batch ids"""
|
||||
"""Immediately cancels all queue items from the given batch ids. Non-admin users can only cancel their own items."""
|
||||
try:
|
||||
# Admin users can cancel all items, non-admin users can only cancel their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
|
||||
queue_id=queue_id, batch_ids=batch_ids
|
||||
queue_id=queue_id, batch_ids=batch_ids, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
|
||||
@@ -212,13 +270,16 @@ async def cancel_by_batch_ids(
|
||||
responses={200: {"model": CancelByDestinationResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: str = Query(description="The destination to cancel all queue items for"),
|
||||
) -> CancelByDestinationResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
"""Immediately cancels all queue items with the given destination. Non-admin users can only cancel their own items."""
|
||||
try:
|
||||
# Admin users can cancel all items, non-admin users can only cancel their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
queue_id=queue_id, destination=destination, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
|
||||
@@ -230,12 +291,28 @@ async def cancel_by_destination(
|
||||
responses={200: {"model": RetryItemsResult}},
|
||||
)
|
||||
async def retry_items_by_id(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_ids: list[int] = Body(description="The queue item ids to retry"),
|
||||
) -> RetryItemsResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
"""Retries the given queue items. Users can only retry their own items unless they are an admin."""
|
||||
try:
|
||||
# Check authorization: user must own all items or be an admin
|
||||
if not current_user.is_admin:
|
||||
for item_id in item_ids:
|
||||
try:
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
if queue_item.user_id != current_user.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"You do not have permission to retry queue item {item_id}"
|
||||
)
|
||||
except SessionQueueItemNotFoundError:
|
||||
# Skip items that don't exist - they will be handled by retry_items_by_id
|
||||
continue
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
|
||||
|
||||
@@ -248,15 +325,25 @@ async def retry_items_by_id(
|
||||
},
|
||||
)
|
||||
async def clear(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> ClearResult:
|
||||
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||
"""Clears the queue entirely. Admin users clear all items; non-admin users only clear their own items. If there's a currently-executing item, users can only cancel it if they own it or are an admin."""
|
||||
try:
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
# Check authorization for canceling the current item
|
||||
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You do not have permission to cancel the currently executing queue item"
|
||||
)
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
# Admin users can clear all items, non-admin users can only clear their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id, user_id=user_id)
|
||||
return clear_result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
|
||||
|
||||
@@ -269,11 +356,14 @@ async def clear(
|
||||
},
|
||||
)
|
||||
async def prune(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> PruneResult:
|
||||
"""Prunes all completed or errored queue items"""
|
||||
"""Prunes all completed or errored queue items. Non-admin users can only prune their own items."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
# Admin users can prune all items, non-admin users can only prune their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id, user_id=user_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
|
||||
|
||||
@@ -320,11 +410,12 @@ async def get_next_queue_item(
|
||||
},
|
||||
)
|
||||
async def get_queue_status(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndProcessorStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
except Exception as e:
|
||||
@@ -358,6 +449,7 @@ async def get_batch_status(
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def get_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to get"),
|
||||
) -> SessionQueueItem:
|
||||
@@ -366,7 +458,8 @@ async def get_queue_item(
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id=item_id)
|
||||
if queue_item.queue_id != queue_id:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
return queue_item
|
||||
# Sanitize item for non-admin users
|
||||
return sanitize_queue_item_for_user(queue_item, current_user.user_id, current_user.is_admin)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except Exception as e:
|
||||
@@ -378,12 +471,24 @@ async def get_queue_item(
|
||||
operation_id="delete_queue_item",
|
||||
)
|
||||
async def delete_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to delete"),
|
||||
) -> None:
|
||||
"""Deletes a queue item"""
|
||||
"""Deletes a queue item. Users can only delete their own items unless they are an admin."""
|
||||
try:
|
||||
# Get the queue item to check ownership
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
|
||||
# Check authorization: user must own the item or be an admin
|
||||
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="You do not have permission to delete this queue item")
|
||||
|
||||
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
|
||||
|
||||
@@ -396,14 +501,24 @@ async def delete_queue_item(
|
||||
},
|
||||
)
|
||||
async def cancel_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to cancel"),
|
||||
) -> SessionQueueItem:
|
||||
"""Deletes a queue item"""
|
||||
"""Cancels a queue item. Users can only cancel their own items unless they are an admin."""
|
||||
try:
|
||||
# Get the queue item to check ownership
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
|
||||
# Check authorization: user must own the item or be an admin
|
||||
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="You do not have permission to cancel this queue item")
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
|
||||
|
||||
@@ -432,13 +547,16 @@ async def counts_by_destination(
|
||||
responses={200: {"model": DeleteByDestinationResult}},
|
||||
)
|
||||
async def delete_by_destination(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to query"),
|
||||
destination: str = Path(description="The destination to query"),
|
||||
) -> DeleteByDestinationResult:
|
||||
"""Deletes all items with the given destination"""
|
||||
"""Deletes all items with the given destination. Non-admin users can only delete their own items."""
|
||||
try:
|
||||
# Admin users can delete all items, non-admin users can only delete their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
queue_id=queue_id, destination=destination, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from invokeai.app.services.auth.token_service import verify_token
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
@@ -38,6 +39,9 @@ from invokeai.app.services.events.events_common import (
|
||||
RecallParametersUpdatedEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class QueueSubscriptionEvent(BaseModel):
|
||||
@@ -96,6 +100,13 @@ class SocketIO:
|
||||
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self._app)
|
||||
|
||||
# Track user information for each socket connection
|
||||
self._socket_users: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# Set up authentication middleware
|
||||
self._sio.on("connect", handler=self._handle_connect)
|
||||
self._sio.on("disconnect", handler=self._handle_disconnect)
|
||||
|
||||
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
|
||||
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
|
||||
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
@@ -105,8 +116,83 @@ class SocketIO:
|
||||
register_events(MODEL_EVENTS, self._handle_model_event)
|
||||
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
|
||||
|
||||
async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> bool:
|
||||
"""Handle socket connection and authenticate the user.
|
||||
|
||||
Returns True to accept the connection, False to reject it.
|
||||
Stores user_id in the internal socket users dict for later use.
|
||||
"""
|
||||
# Extract token from auth data or headers
|
||||
token = None
|
||||
if auth and isinstance(auth, dict):
|
||||
token = auth.get("token")
|
||||
|
||||
if not token and environ:
|
||||
# Try to get token from headers
|
||||
headers = environ.get("HTTP_AUTHORIZATION", "")
|
||||
if headers.startswith("Bearer "):
|
||||
token = headers[7:]
|
||||
|
||||
# Verify the token
|
||||
if token:
|
||||
token_data = verify_token(token)
|
||||
if token_data:
|
||||
# Store user_id and is_admin in socket users dict
|
||||
self._socket_users[sid] = {
|
||||
"user_id": token_data.user_id,
|
||||
"is_admin": token_data.is_admin,
|
||||
}
|
||||
logger.info(
|
||||
f"Socket {sid} connected with user_id: {token_data.user_id}, is_admin: {token_data.is_admin}"
|
||||
)
|
||||
return True
|
||||
|
||||
# If no valid token, store system user for backward compatibility
|
||||
self._socket_users[sid] = {
|
||||
"user_id": "system",
|
||||
"is_admin": False,
|
||||
}
|
||||
logger.debug(f"Socket {sid} connected as system user (no valid token)")
|
||||
return True
|
||||
|
||||
async def _handle_disconnect(self, sid: str) -> None:
|
||||
"""Handle socket disconnection and cleanup user info."""
|
||||
if sid in self._socket_users:
|
||||
del self._socket_users[sid]
|
||||
logger.debug(f"Socket {sid} disconnected and cleaned up")
|
||||
|
||||
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
"""Handle queue subscription and add socket to both queue and user-specific rooms."""
|
||||
queue_id = QueueSubscriptionEvent(**data).queue_id
|
||||
|
||||
# Check if we have user info for this socket
|
||||
if sid not in self._socket_users:
|
||||
logger.warning(
|
||||
f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event"
|
||||
)
|
||||
# Store as system user temporarily - real auth should happen in connect
|
||||
self._socket_users[sid] = {
|
||||
"user_id": "system",
|
||||
"is_admin": False,
|
||||
}
|
||||
|
||||
user_id = self._socket_users[sid]["user_id"]
|
||||
is_admin = self._socket_users[sid]["is_admin"]
|
||||
|
||||
# Add socket to the queue room
|
||||
await self._sio.enter_room(sid, queue_id)
|
||||
|
||||
# Also add socket to a user-specific room for event filtering
|
||||
user_room = f"user:{user_id}"
|
||||
await self._sio.enter_room(sid, user_room)
|
||||
|
||||
# If admin, also add to admin room to receive all events
|
||||
if is_admin:
|
||||
await self._sio.enter_room(sid, "admin")
|
||||
|
||||
logger.debug(
|
||||
f"Socket {sid} (user_id: {user_id}, is_admin: {is_admin}) subscribed to queue {queue_id} and user room {user_room}"
|
||||
)
|
||||
|
||||
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
@@ -118,7 +204,62 @@ class SocketIO:
|
||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
||||
"""Handle queue events with user isolation.
|
||||
|
||||
Invocation events (progress, started, complete) are private - only emit to owner and admins.
|
||||
Queue item status events are public - emit to all users (field values hidden via API).
|
||||
Other queue events emit to all subscribers.
|
||||
|
||||
IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase
|
||||
inherits from QueueItemEventBase. The order of isinstance checks matters!
|
||||
"""
|
||||
try:
|
||||
event_name, event_data = event
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from invokeai.app.services.events.events_common import InvocationEventBase, QueueItemEventBase
|
||||
|
||||
# Check InvocationEventBase FIRST (before QueueItemEventBase) since it's a subclass
|
||||
# Invocation events (progress, started, complete, error) are private to owner + admins
|
||||
if isinstance(event_data, InvocationEventBase) and hasattr(event_data, "user_id"):
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
|
||||
# Emit to the user's room
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
|
||||
# Also emit to admin room so admins can see all events, but strip image preview data
|
||||
# from InvocationProgressEvent to prevent admins from seeing other users' image content
|
||||
if isinstance(event_data, InvocationProgressEvent):
|
||||
admin_event_data = event_data.model_copy(update={"image": None})
|
||||
await self._sio.emit(event=event_name, data=admin_event_data.model_dump(mode="json"), room="admin")
|
||||
else:
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
|
||||
logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room")
|
||||
|
||||
# Queue item status events are visible to all users (field values masked via API)
|
||||
# This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above)
|
||||
elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"):
|
||||
# Emit to all subscribers in the queue
|
||||
await self._sio.emit(
|
||||
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}"
|
||||
)
|
||||
|
||||
else:
|
||||
# For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers
|
||||
await self._sio.emit(
|
||||
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
|
||||
)
|
||||
logger.info(
|
||||
f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log any unhandled exceptions in event handling to prevent silent failures
|
||||
logger.error(f"Error handling queue event {event[0]}: {e}", exc_info=True)
|
||||
|
||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
|
||||
|
||||
@@ -17,6 +17,7 @@ from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.api.routers import (
|
||||
app_info,
|
||||
auth,
|
||||
board_images,
|
||||
boards,
|
||||
client_state,
|
||||
@@ -122,6 +123,8 @@ app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
# Include all routers
|
||||
# Authentication router should be first so it's registered before protected routes
|
||||
app.include_router(auth.auth_router, prefix="/api")
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
|
||||
@@ -45,7 +45,7 @@ KLEIN_MAX_SEQ_LEN = 512
|
||||
title="Prompt - Flux2 Klein",
|
||||
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2KleinTextEncoderInvocation(BaseInvocation):
|
||||
@@ -73,140 +73,111 @@ class Flux2KleinTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
qwen3_embeds, pooled_embeds = self._encode_prompt(context)
|
||||
# Open the exitstack here to lock models for the duration of the node
|
||||
with ExitStack() as exit_stack:
|
||||
# Pass the locked stack down to the helper function
|
||||
qwen3_embeds, pooled_embeds = self._encode_prompt(context, exit_stack)
|
||||
|
||||
# Use FLUXConditioningInfo for compatibility with existing Flux denoiser
|
||||
# t5_embeds -> qwen3 stacked embeddings
|
||||
# clip_embeds -> pooled qwen3 embedding
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
|
||||
)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
# The models are still locked while we save the data
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode prompt using Qwen3 text encoder with Klein-style layer extraction.
|
||||
|
||||
This matches the diffusers Flux2KleinPipeline._get_qwen3_prompt_embeds() exactly.
|
||||
|
||||
Returns:
|
||||
Tuple of (stacked_embeddings, pooled_embedding):
|
||||
- stacked_embeddings: Hidden states from layers (9, 18, 27) stacked together.
|
||||
Shape: (1, seq_len, hidden_size * 3)
|
||||
- pooled_embedding: Pooled representation for global conditioning.
|
||||
Shape: (1, hidden_size)
|
||||
"""
|
||||
def _encode_prompt(self, context: InvocationContext, exit_stack: ExitStack) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
prompt = self.prompt
|
||||
|
||||
# Reordered loading to prevent the annoying cache drop issue
|
||||
# This prevents it from being evicted while we look up the tokenizer
|
||||
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
|
||||
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
|
||||
# Now it is safe to load and lock the tokenizer
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
device = text_encoder.device
|
||||
|
||||
# you can now define the device, as the text_encoder exists here
|
||||
device = text_encoder.device
|
||||
# Apply LoRA models
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_T5_PREFIX,
|
||||
dtype=lora_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply LoRA models to the text encoder
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_T5_PREFIX, # Reuse T5 prefix for Qwen3 LoRAs
|
||||
dtype=lora_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
context.util.signal_progress("Running Qwen3 text encoder (Klein)")
|
||||
|
||||
if not isinstance(text_encoder, PreTrainedModel):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
|
||||
"The Qwen3 encoder model may be corrupted or incompatible."
|
||||
)
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
|
||||
"The Qwen3 tokenizer may be corrupted or incompatible."
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running Qwen3 text encoder (Klein)")
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if not isinstance(text_encoder, PreTrainedModel):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
|
||||
"The Qwen3 encoder model may be corrupted or incompatible."
|
||||
)
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
|
||||
"The Qwen3 tokenizer may be corrupted or incompatible."
|
||||
)
|
||||
text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
# Format messages exactly like diffusers Flux2KleinPipeline:
|
||||
# - Only user message, NO system message
|
||||
# - add_generation_prompt=True (adds assistant prefix)
|
||||
# - enable_thinking=False
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_seq_len,
|
||||
)
|
||||
|
||||
# Step 1: Apply chat template to get formatted text (tokenize=False)
|
||||
text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True, # Adds assistant prefix like diffusers
|
||||
enable_thinking=False, # Disable thinking mode
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
attention_mask = inputs["attention_mask"].to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
outputs = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
|
||||
raise RuntimeError(
|
||||
"Text encoder did not return hidden_states. "
|
||||
"Ensure output_hidden_states=True is supported by this model."
|
||||
)
|
||||
num_hidden_layers = len(outputs.hidden_states)
|
||||
|
||||
# Step 2: Tokenize the formatted text
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_seq_len,
|
||||
)
|
||||
hidden_states_list = []
|
||||
for layer_idx in KLEIN_EXTRACTION_LAYERS:
|
||||
if layer_idx >= num_hidden_layers:
|
||||
layer_idx = num_hidden_layers - 1
|
||||
hidden_states_list.append(outputs.hidden_states[layer_idx])
|
||||
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
attention_mask = inputs["attention_mask"].to(device)
|
||||
out = torch.stack(hidden_states_list, dim=1)
|
||||
out = out.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
# Move to device
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
# Forward pass through the model - matching diffusers exactly
|
||||
# Explicitly move inputs to the same device as the text_encoder
|
||||
outputs = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
# Validate hidden_states output
|
||||
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
|
||||
raise RuntimeError(
|
||||
"Text encoder did not return hidden_states. "
|
||||
"Ensure output_hidden_states=True is supported by this model."
|
||||
)
|
||||
num_hidden_layers = len(outputs.hidden_states)
|
||||
|
||||
# Extract and stack hidden states - EXACTLY like diffusers:
|
||||
# out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
# prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
hidden_states_list = []
|
||||
for layer_idx in KLEIN_EXTRACTION_LAYERS:
|
||||
if layer_idx >= num_hidden_layers:
|
||||
layer_idx = num_hidden_layers - 1
|
||||
hidden_states_list.append(outputs.hidden_states[layer_idx])
|
||||
|
||||
# Stack along dim=1, then permute and reshape - exactly like diffusers
|
||||
out = torch.stack(hidden_states_list, dim=1)
|
||||
out = out.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
# Create pooled embedding for global conditioning
|
||||
# Use mean pooling over the sequence (excluding padding)
|
||||
# This serves a similar role to CLIP's pooled output in standard FLUX
|
||||
last_hidden_state = outputs.hidden_states[-1] # Use last layer for pooling
|
||||
# Expand mask to match hidden state dimensions
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
|
||||
sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
|
||||
num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
|
||||
pooled_embeds = sum_embeds / num_tokens
|
||||
last_hidden_state = outputs.hidden_states[-1]
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
|
||||
sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
|
||||
num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
|
||||
pooled_embeds = sum_embeds / num_tokens
|
||||
|
||||
return prompt_embeds, pooled_embeds
|
||||
|
||||
|
||||
@@ -9,6 +9,11 @@ def get_app():
|
||||
|
||||
def run_app() -> None:
|
||||
"""The main entrypoint for the app."""
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
# Parse the CLI arguments before doing anything else, which ensures CLI args correctly override settings from other
|
||||
@@ -100,4 +105,41 @@ def run_app() -> None:
|
||||
for hdlr in logger.handlers:
|
||||
uvicorn_logger.addHandler(hdlr)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
try:
|
||||
loop.run_until_complete(server.serve())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("InvokeAI shutting down...")
|
||||
# Gracefully shut down services (e.g. model download and install managers) so that any
|
||||
# active work is completed or cleanly cancelled before the process exits.
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
# Cancel any pending asyncio tasks (e.g. socket.io ping tasks) so that loop.close() does
|
||||
# not emit "Task was destroyed but it is pending!" warnings for each one.
|
||||
pending = [t for t in asyncio.all_tasks(loop) if not t.done()]
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
|
||||
# Shut down the asyncio default thread executor. asyncio.to_thread() (used e.g. in the
|
||||
# session queue for SQLite operations during generation) creates non-daemon threads via the
|
||||
# event loop's default ThreadPoolExecutor. Without this call those threads remain alive and
|
||||
# cause threading._shutdown() to hang indefinitely after the process's main code finishes.
|
||||
loop.run_until_complete(loop.shutdown_default_executor())
|
||||
loop.close()
|
||||
|
||||
# After graceful shutdown, log any non-daemon threads that are still alive. These are the
|
||||
# threads that will cause Python's threading._shutdown() to block, preventing the process
|
||||
# from exiting cleanly. This helps identify threads that need to be fixed or joined.
|
||||
frames = sys._current_frames()
|
||||
for thread in threading.enumerate():
|
||||
if thread.daemon or thread is threading.main_thread():
|
||||
continue
|
||||
frame = frames.get(thread.ident)
|
||||
stack = "".join(traceback.format_stack(frame)) if frame else "(no frame available)"
|
||||
logger.warning(
|
||||
f"Non-daemon thread still alive after shutdown: {thread.name!r} "
|
||||
f"(ident={thread.ident})\nStack trace:\n{stack}"
|
||||
)
|
||||
|
||||
5
invokeai/app/services/app_settings/__init__.py
Normal file
5
invokeai/app/services/app_settings/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""App settings service exports."""
|
||||
|
||||
from invokeai.app.services.app_settings.app_settings_service import AppSettingsService
|
||||
|
||||
__all__ = ["AppSettingsService"]
|
||||
74
invokeai/app/services/app_settings/app_settings_service.py
Normal file
74
invokeai/app/services/app_settings/app_settings_service.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Service for managing application-level settings stored in the database."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class AppSettingsService:
|
||||
"""Service for accessing application-level settings from the database.
|
||||
|
||||
This service provides a simple key-value store for application-level configuration
|
||||
that needs to be persisted across restarts, such as JWT secrets.
|
||||
"""
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
"""Initialize the app settings service.
|
||||
|
||||
Args:
|
||||
db: The SQLite database instance
|
||||
"""
|
||||
self._db = db
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
"""Get a setting value by key.
|
||||
|
||||
Args:
|
||||
key: The setting key
|
||||
|
||||
Returns:
|
||||
The setting value if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute("SELECT value FROM app_settings WHERE key = ?;", (key,))
|
||||
row = cursor.fetchone()
|
||||
return row[0] if row else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
"""Set a setting value.
|
||||
|
||||
Args:
|
||||
key: The setting key
|
||||
value: The setting value
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO app_settings (key, value)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
value = excluded.value,
|
||||
updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW');
|
||||
""",
|
||||
(key, value),
|
||||
)
|
||||
|
||||
def get_jwt_secret(self) -> str:
|
||||
"""Get the JWT secret key from the database.
|
||||
|
||||
Returns:
|
||||
The JWT secret key
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the JWT secret is not found in the database
|
||||
"""
|
||||
secret = self.get("jwt_secret")
|
||||
if secret is None:
|
||||
raise RuntimeError(
|
||||
"JWT secret not found in database. This should have been created during database migration. "
|
||||
"Please ensure database migrations have been run successfully."
|
||||
)
|
||||
return secret
|
||||
1
invokeai/app/services/auth/__init__.py
Normal file
1
invokeai/app/services/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Authentication service module."""
|
||||
86
invokeai/app/services/auth/password_utils.py
Normal file
86
invokeai/app/services/auth/password_utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Password hashing and validation utilities."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from passlib.context import CryptContext
|
||||
|
||||
# Configure bcrypt context - set truncate_error=False to allow passwords >72 bytes
|
||||
# without raising an error. They will be automatically truncated by bcrypt to 72 bytes.
|
||||
pwd_context = CryptContext(
|
||||
schemes=["bcrypt"],
|
||||
deprecated="auto",
|
||||
bcrypt__truncate_error=False,
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt.
|
||||
|
||||
bcrypt has a maximum password length of 72 bytes. Longer passwords
|
||||
are automatically truncated to comply with this limit.
|
||||
|
||||
Args:
|
||||
password: The plain text password to hash
|
||||
|
||||
Returns:
|
||||
The hashed password
|
||||
"""
|
||||
# bcrypt has a 72 byte limit - encode and truncate if necessary
|
||||
password_bytes = password.encode("utf-8")
|
||||
if len(password_bytes) > 72:
|
||||
# Truncate to 72 bytes and decode back, dropping incomplete UTF-8 sequences
|
||||
password = password_bytes[:72].decode("utf-8", errors="ignore")
|
||||
return cast(str, pwd_context.hash(password))
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash.
|
||||
|
||||
bcrypt has a maximum password length of 72 bytes. Longer passwords
|
||||
are automatically truncated to match hash_password behavior.
|
||||
|
||||
Args:
|
||||
plain_password: The plain text password to verify
|
||||
hashed_password: The hashed password to verify against
|
||||
|
||||
Returns:
|
||||
True if the password matches the hash, False otherwise
|
||||
"""
|
||||
try:
|
||||
# bcrypt has a 72 byte limit - encode and truncate if necessary to match hash_password
|
||||
password_bytes = plain_password.encode("utf-8")
|
||||
if len(password_bytes) > 72:
|
||||
# Truncate to 72 bytes and decode back, dropping incomplete UTF-8 sequences
|
||||
plain_password = password_bytes[:72].decode("utf-8", errors="ignore")
|
||||
return cast(bool, pwd_context.verify(plain_password, hashed_password))
|
||||
except Exception:
|
||||
# Invalid hash format or other error - return False
|
||||
return False
|
||||
|
||||
|
||||
def validate_password_strength(password: str) -> tuple[bool, str]:
|
||||
"""Validate password meets minimum security requirements.
|
||||
|
||||
Password requirements:
|
||||
- At least 8 characters long
|
||||
- Contains at least one uppercase letter
|
||||
- Contains at least one lowercase letter
|
||||
- Contains at least one digit
|
||||
|
||||
Args:
|
||||
password: The password to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (is_valid, error_message). If valid, error_message is empty.
|
||||
"""
|
||||
if len(password) < 8:
|
||||
return False, "Password must be at least 8 characters long"
|
||||
|
||||
has_upper = any(c.isupper() for c in password)
|
||||
has_lower = any(c.islower() for c in password)
|
||||
has_digit = any(c.isdigit() for c in password)
|
||||
|
||||
if not (has_upper and has_lower and has_digit):
|
||||
return False, "Password must contain uppercase, lowercase, and numbers"
|
||||
|
||||
return True, ""
|
||||
105
invokeai/app/services/auth/token_service.py
Normal file
105
invokeai/app/services/auth/token_service.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""JWT token generation and validation."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
DEFAULT_EXPIRATION_HOURS = 24
|
||||
|
||||
# Module-level variable to store the JWT secret. This is set during application initialization
|
||||
# by calling set_jwt_secret(). The secret is loaded from the database where it is stored
|
||||
# securely after being generated during database migration.
|
||||
_jwt_secret: str | None = None
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""Data stored in JWT token."""
|
||||
|
||||
user_id: str
|
||||
email: str
|
||||
is_admin: bool
|
||||
|
||||
|
||||
def set_jwt_secret(secret: str) -> None:
|
||||
"""Set the JWT secret key for token signing and verification.
|
||||
|
||||
This should be called once during application initialization with the secret
|
||||
loaded from the database.
|
||||
|
||||
Args:
|
||||
secret: The JWT secret key
|
||||
"""
|
||||
global _jwt_secret
|
||||
_jwt_secret = secret
|
||||
|
||||
|
||||
def get_jwt_secret() -> str:
|
||||
"""Get the JWT secret key.
|
||||
|
||||
Returns:
|
||||
The JWT secret key
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the secret has not been initialized
|
||||
"""
|
||||
if _jwt_secret is None:
|
||||
raise RuntimeError("JWT secret has not been initialized. Call set_jwt_secret() during application startup.")
|
||||
return _jwt_secret
|
||||
|
||||
|
||||
def create_access_token(data: TokenData, expires_delta: timedelta | None = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
data: The token data to encode
|
||||
expires_delta: Optional expiration time delta. Defaults to 24 hours.
|
||||
|
||||
Returns:
|
||||
The encoded JWT token
|
||||
"""
|
||||
to_encode = data.model_dump()
|
||||
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(hours=DEFAULT_EXPIRATION_HOURS))
|
||||
to_encode.update({"exp": expire})
|
||||
return cast(str, jwt.encode(to_encode, get_jwt_secret(), algorithm=ALGORITHM))
|
||||
|
||||
|
||||
def verify_token(token: str) -> TokenData | None:
|
||||
"""Verify and decode a JWT token.
|
||||
|
||||
Args:
|
||||
token: The JWT token to verify
|
||||
|
||||
Returns:
|
||||
TokenData if valid, None if invalid or expired
|
||||
"""
|
||||
try:
|
||||
# python-jose 3.5.0 has a bug where exp verification doesn't work properly
|
||||
# We need to manually check expiration, but MUST verify signature first
|
||||
# to prevent accepting tokens with valid payloads but invalid signatures
|
||||
|
||||
# First, verify the signature - this will raise JWTError if signature is invalid
|
||||
# Note: python-jose won't reject expired tokens here due to the bug
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
get_jwt_secret(),
|
||||
algorithms=[ALGORITHM],
|
||||
)
|
||||
|
||||
# Now manually check expiration (because python-jose 3.5.0 doesn't do this properly)
|
||||
if "exp" in payload:
|
||||
exp_timestamp = payload["exp"]
|
||||
current_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
if current_timestamp >= exp_timestamp:
|
||||
# Token is expired
|
||||
return None
|
||||
|
||||
return TokenData(**payload)
|
||||
except JWTError:
|
||||
# Token is invalid (bad signature, malformed, etc.)
|
||||
return None
|
||||
except Exception:
|
||||
# Catch any other exceptions (e.g., Pydantic validation errors)
|
||||
return None
|
||||
@@ -17,8 +17,9 @@ class BoardRecordStorageBase(ABC):
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
user_id: str,
|
||||
) -> BoardRecord:
|
||||
"""Saves a board record."""
|
||||
"""Saves a board record for a specific user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -41,18 +42,25 @@ class BoardRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
"""Gets many board records for a specific user, including shared boards. Admin users see all boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
include_archived: bool = False,
|
||||
) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
"""Gets all board records for a specific user, including shared boards. Admin users see all boards."""
|
||||
pass
|
||||
|
||||
@@ -16,6 +16,8 @@ class BoardRecord(BaseModelExcludeNull):
|
||||
"""The unique ID of the board."""
|
||||
board_name: str = Field(description="The name of the board.")
|
||||
"""The name of the board."""
|
||||
user_id: str = Field(description="The user ID of the board owner.")
|
||||
"""The user ID of the board owner."""
|
||||
created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
||||
@@ -35,6 +37,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
|
||||
board_id = board_dict.get("board_id", "unknown")
|
||||
board_name = board_dict.get("board_name", "unknown")
|
||||
# Default to 'system' for backwards compatibility with boards created before multiuser support
|
||||
user_id = board_dict.get("user_id", "system")
|
||||
cover_image_name = board_dict.get("cover_image_name", "unknown")
|
||||
created_at = board_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||
@@ -44,6 +48,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
return BoardRecord(
|
||||
board_id=board_id,
|
||||
board_name=board_name,
|
||||
user_id=user_id,
|
||||
cover_image_name=cover_image_name,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
|
||||
@@ -38,16 +38,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
user_id: str,
|
||||
) -> BoardRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name, user_id)
|
||||
VALUES (?, ?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
(board_id, board_name, user_id),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
@@ -121,6 +122,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
@@ -128,74 +131,147 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
# Build base query - admins see all boards, regular users see owned, shared, or public boards
|
||||
if is_admin:
|
||||
base_query = """
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
# Determine archived filter condition
|
||||
archived_filter = "WHERE 1=1" if include_archived else "WHERE boards.archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
else:
|
||||
base_query = """
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "AND boards.archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (user_id, user_id, limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
# Determine count query - admins count all boards, regular users count accessible boards
|
||||
if is_admin:
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(DISTINCT boards.board_id)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(DISTINCT boards.board_id)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
WHERE boards.archived = 0;
|
||||
"""
|
||||
cursor.execute(count_query)
|
||||
else:
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(DISTINCT boards.board_id)
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1);
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(DISTINCT boards.board_id)
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
|
||||
AND boards.archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
# Execute count query
|
||||
cursor.execute(count_query, (user_id, user_id))
|
||||
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
include_archived: bool = False,
|
||||
) -> list[BoardRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
# Build query - admins see all boards, regular users see owned, shared, or public boards
|
||||
if is_admin:
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
ORDER BY LOWER(boards.board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
else:
|
||||
base_query = """
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
archived_filter = "WHERE 1=1" if include_archived else "WHERE boards.archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
cursor.execute(final_query)
|
||||
cursor.execute(final_query)
|
||||
else:
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(boards.board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
archived_filter = "" if include_archived else "AND boards.archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
cursor.execute(final_query, (user_id, user_id))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
@@ -13,8 +13,9 @@ class BoardServiceABC(ABC):
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
user_id: str,
|
||||
) -> BoardDTO:
|
||||
"""Creates a board."""
|
||||
"""Creates a board for a specific user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -45,18 +46,25 @@ class BoardServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
"""Gets many boards for a specific user, including shared boards. Admin users see all boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
include_archived: bool = False,
|
||||
) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
"""Gets all boards for a specific user, including shared boards. Admin users see all boards."""
|
||||
pass
|
||||
|
||||
@@ -14,10 +14,16 @@ class BoardDTO(BoardRecord):
|
||||
"""The number of images in the board."""
|
||||
asset_count: int = Field(description="The number of assets in the board.")
|
||||
"""The number of assets in the board."""
|
||||
owner_username: Optional[str] = Field(default=None, description="The username of the board owner (for admin view).")
|
||||
"""The username of the board owner (for admin view)."""
|
||||
|
||||
|
||||
def board_record_to_dto(
|
||||
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int, asset_count: int
|
||||
board_record: BoardRecord,
|
||||
cover_image_name: Optional[str],
|
||||
image_count: int,
|
||||
asset_count: int,
|
||||
owner_username: Optional[str] = None,
|
||||
) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
@@ -25,4 +31,5 @@ def board_record_to_dto(
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
asset_count=asset_count,
|
||||
owner_username=owner_username,
|
||||
)
|
||||
|
||||
@@ -15,8 +15,9 @@ class BoardService(BoardServiceABC):
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
user_id: str,
|
||||
) -> BoardDTO:
|
||||
board_record = self.__invoker.services.board_records.save(board_name)
|
||||
board_record = self.__invoker.services.board_records.save(board_name, user_id)
|
||||
return board_record_to_dto(board_record, None, 0, 0)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
@@ -51,6 +52,8 @@ class BoardService(BoardServiceABC):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
@@ -58,7 +61,7 @@ class BoardService(BoardServiceABC):
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_many(
|
||||
order_by, direction, offset, limit, include_archived
|
||||
user_id, is_admin, order_by, direction, offset, limit, include_archived
|
||||
)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
@@ -70,14 +73,29 @@ class BoardService(BoardServiceABC):
|
||||
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
asset_count = self.__invoker.services.board_image_records.get_asset_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count))
|
||||
|
||||
# For admin users, include owner username
|
||||
owner_username = None
|
||||
if is_admin:
|
||||
owner = self.__invoker.services.users.get(r.user_id)
|
||||
if owner:
|
||||
owner_username = owner.display_name or owner.email
|
||||
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count, owner_username))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
include_archived: bool = False,
|
||||
) -> list[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_all(order_by, direction, include_archived)
|
||||
board_records = self.__invoker.services.board_records.get_all(
|
||||
user_id, is_admin, order_by, direction, include_archived
|
||||
)
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
@@ -88,6 +106,14 @@ class BoardService(BoardServiceABC):
|
||||
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
asset_count = self.__invoker.services.board_image_records.get_asset_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count))
|
||||
|
||||
# For admin users, include owner username
|
||||
owner_username = None
|
||||
if is_admin:
|
||||
owner = self.__invoker.services.users.get(r.user_id)
|
||||
if owner:
|
||||
owner_username = owner.display_name or owner.email
|
||||
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count, owner_username))
|
||||
|
||||
return board_dtos
|
||||
|
||||
@@ -4,15 +4,16 @@ from abc import ABC, abstractmethod
|
||||
class ClientStatePersistenceABC(ABC):
|
||||
"""
|
||||
Base class for client persistence implementations.
|
||||
This class defines the interface for persisting client data.
|
||||
This class defines the interface for persisting client data per user.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
|
||||
def set_by_key(self, user_id: str, key: str, value: str) -> str:
|
||||
"""
|
||||
Set a key-value pair for the client.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to set state for.
|
||||
key (str): The key to set.
|
||||
value (str): The value to set for the key.
|
||||
|
||||
@@ -22,11 +23,12 @@ class ClientStatePersistenceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_by_key(self, queue_id: str, key: str) -> str | None:
|
||||
def get_by_key(self, user_id: str, key: str) -> str | None:
|
||||
"""
|
||||
Get the value for a specific key of the client.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to get state for.
|
||||
key (str): The key to retrieve the value for.
|
||||
|
||||
Returns:
|
||||
@@ -35,8 +37,11 @@ class ClientStatePersistenceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, queue_id: str) -> None:
|
||||
def delete(self, user_id: str) -> None:
|
||||
"""
|
||||
Delete all client state.
|
||||
Delete all client state for a user.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to delete state for.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import json
|
||||
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
@@ -7,59 +5,51 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
|
||||
"""
|
||||
Base class for client persistence implementations.
|
||||
This class defines the interface for persisting client data.
|
||||
SQLite implementation for client state persistence.
|
||||
This class stores client state data per user to prevent data leakage between users.
|
||||
"""
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._default_row_id = 1
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def _get(self) -> dict[str, str] | None:
|
||||
def set_by_key(self, user_id: str, key: str, value: str) -> str:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT data FROM client_state
|
||||
WHERE id = {self._default_row_id}
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return json.loads(row[0])
|
||||
|
||||
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
|
||||
state = self._get() or {}
|
||||
state.update({key: value})
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT INTO client_state (id, data)
|
||||
VALUES ({self._default_row_id}, ?)
|
||||
ON CONFLICT(id) DO UPDATE
|
||||
SET data = excluded.data;
|
||||
INSERT INTO client_state (user_id, key, value)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id, key) DO UPDATE
|
||||
SET value = excluded.value;
|
||||
""",
|
||||
(json.dumps(state),),
|
||||
(user_id, key, value),
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def get_by_key(self, queue_id: str, key: str) -> str | None:
|
||||
state = self._get()
|
||||
if state is None:
|
||||
return None
|
||||
return state.get(key, None)
|
||||
|
||||
def delete(self, queue_id: str) -> None:
|
||||
def get_by_key(self, user_id: str, key: str) -> str | None:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
DELETE FROM client_state
|
||||
WHERE id = {self._default_row_id}
|
||||
"""
|
||||
SELECT value FROM client_state
|
||||
WHERE user_id = ? AND key = ?
|
||||
""",
|
||||
(user_id, key),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return row[0]
|
||||
|
||||
def delete(self, user_id: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM client_state
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
@@ -110,6 +110,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.
|
||||
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
|
||||
allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
|
||||
multiuser: Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.
|
||||
"""
|
||||
|
||||
_root: Optional[Path] = PrivateAttr(default=None)
|
||||
@@ -203,6 +204,9 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.")
|
||||
allow_unknown_models: bool = Field(default=True, description="Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.")
|
||||
|
||||
# MULTIUSER
|
||||
multiuser: bool = Field(default=False, description="Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.")
|
||||
|
||||
# fmt: on
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2023, Lincoln D. Stein
|
||||
# Copyright (c) 2023,2026 Lincoln D. Stein
|
||||
"""Implementation of multithreaded download queue for invokeai."""
|
||||
|
||||
import os
|
||||
@@ -60,7 +60,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""
|
||||
self._app_config = app_config or get_config()
|
||||
self._jobs: Dict[int, DownloadJob] = {}
|
||||
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
|
||||
self._download_part2parent: Dict[int, MultiFileDownloadJob] = {}
|
||||
self._mfd_pending: Dict[int, list[DownloadJob]] = {}
|
||||
self._mfd_active: Dict[int, DownloadJob] = {}
|
||||
self._next_job_id = 0
|
||||
@@ -88,7 +88,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Stop the download worker threads."""
|
||||
with self._lock:
|
||||
if not self._worker_pool:
|
||||
raise Exception("Attempt to stop the download service before it was started")
|
||||
return
|
||||
self._accept_download_requests = False # reject attempts to add new jobs to queue
|
||||
queued_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.WAITING]
|
||||
active_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.RUNNING]
|
||||
@@ -118,7 +118,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
raise ServiceInactiveException(
|
||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||
)
|
||||
job.id = self._next_id()
|
||||
if job.id == -1:
|
||||
job.id = self._next_id()
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
@@ -197,12 +198,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
dest=path,
|
||||
access_token=access_token or self._lookup_access_token(url),
|
||||
)
|
||||
job.id = self._next_id() # pre-assign ID so _download_part2parent can be keyed by ID
|
||||
if part.size and part.size > 0:
|
||||
job.total_bytes = part.size
|
||||
job.expected_total_bytes = part.size
|
||||
job.canonical_url = str(url)
|
||||
mfdj.download_parts.add(job)
|
||||
self._download_part2parent[job.source] = mfdj
|
||||
self._download_part2parent[job.id] = mfdj
|
||||
if submit_job:
|
||||
self.submit_multifile_download(mfdj)
|
||||
return mfdj
|
||||
@@ -327,7 +329,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
self._job_terminated_event.set() # signal a change to terminal state
|
||||
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
|
||||
self._download_part2parent.pop(job.id, None) # if this is a subpart of a multipart job, remove it
|
||||
self._queue.task_done()
|
||||
|
||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||
@@ -386,18 +388,23 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if len(candidates) == 1:
|
||||
inferred = candidates[0].with_name(candidates[0].name.removesuffix(".downloading"))
|
||||
job.download_path = inferred
|
||||
resume_from = candidates[0].stat().st_size
|
||||
job.bytes = resume_from
|
||||
self._logger.debug(
|
||||
f"Resume check (dir): inferred in-progress file path={candidates[0]} size={resume_from} bytes"
|
||||
)
|
||||
if resume_from > 0:
|
||||
if job.etag:
|
||||
header["If-Range"] = job.etag
|
||||
elif job.last_modified:
|
||||
header["If-Range"] = job.last_modified
|
||||
header["Range"] = f"bytes={resume_from}-"
|
||||
open_mode = "ab"
|
||||
try:
|
||||
resume_from = candidates[0].stat().st_size
|
||||
except FileNotFoundError:
|
||||
# The .downloading file was renamed/deleted between glob and stat (race condition); skip resume.
|
||||
job.download_path = None
|
||||
else:
|
||||
job.bytes = resume_from
|
||||
self._logger.debug(
|
||||
f"Resume check (dir): inferred in-progress file path={candidates[0]} size={resume_from} bytes"
|
||||
)
|
||||
if resume_from > 0:
|
||||
if job.etag:
|
||||
header["If-Range"] = job.etag
|
||||
elif job.last_modified:
|
||||
header["If-Range"] = job.last_modified
|
||||
header["Range"] = f"bytes={resume_from}-"
|
||||
open_mode = "ab"
|
||||
else:
|
||||
self._logger.debug(
|
||||
"Resume check (dir): no prior download_path available; cannot resume from disk "
|
||||
@@ -622,7 +629,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
self._event_bus.emit_download_cancelled(job)
|
||||
|
||||
# if multifile download, then signal the parent
|
||||
if parent_job := self._download_part2parent.get(job.source, None):
|
||||
if parent_job := self._download_part2parent.get(job.id, None):
|
||||
if not parent_job.in_terminal_state:
|
||||
parent_job.status = DownloadJobStatus.CANCELLED
|
||||
self._execute_cb(parent_job, "on_cancelled")
|
||||
@@ -639,7 +646,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_paused(job)
|
||||
|
||||
if parent_job := self._download_part2parent.get(job.source, None):
|
||||
if parent_job := self._download_part2parent.get(job.id, None):
|
||||
if not parent_job.in_terminal_state:
|
||||
parent_job.status = DownloadJobStatus.PAUSED
|
||||
self._execute_cb(parent_job, "on_cancelled")
|
||||
@@ -669,7 +676,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def _mfd_started(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"File download started: {download_job.source}")
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
mf_job = self._download_part2parent[download_job.id]
|
||||
if mf_job.waiting:
|
||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
mf_job.status = DownloadJobStatus.RUNNING
|
||||
@@ -682,7 +689,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
mf_job = self._download_part2parent[download_job.id]
|
||||
if mf_job.cancelled:
|
||||
for part in mf_job.download_parts:
|
||||
self.cancel_job(part)
|
||||
@@ -696,7 +703,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
submit_next = False
|
||||
mf_job: Optional[MultiFileDownloadJob] = None
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
mf_job = self._download_part2parent[download_job.id]
|
||||
self._mfd_active.pop(mf_job.id, None)
|
||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
mf_job.bytes = sum(x.bytes for x in mf_job.download_parts)
|
||||
@@ -715,7 +722,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
mf_job = self._download_part2parent[download_job.id]
|
||||
assert mf_job is not None
|
||||
self._mfd_active.pop(mf_job.id, None)
|
||||
|
||||
@@ -735,7 +742,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
mf_job = self._download_part2parent[download_job.id]
|
||||
assert mf_job is not None
|
||||
self._mfd_active.pop(mf_job.id, None)
|
||||
if not mf_job.in_terminal_state:
|
||||
@@ -748,7 +755,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
)
|
||||
for s in [x for x in mf_job.download_parts if x.running]:
|
||||
self.cancel_job(s)
|
||||
self._download_part2parent.pop(download_job.source)
|
||||
self._mfd_pending.pop(mf_job.id, None)
|
||||
self._job_terminated_event.set()
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class QueueItemEventBase(QueueEventBase):
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the queue item")
|
||||
destination: str | None = Field(default=None, description="The destination of the queue item")
|
||||
user_id: str = Field(default="system", description="The ID of the user who created the queue item")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
@@ -117,6 +118,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
user_id=queue_item.user_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -152,6 +154,7 @@ class InvocationProgressEvent(InvocationEventBase):
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
user_id=queue_item.user_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -179,6 +182,7 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
user_id=queue_item.user_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -211,6 +215,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
user_id=queue_item.user_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -248,6 +253,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
user_id=queue_item.user_id,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
|
||||
@@ -28,6 +28,10 @@ class FastAPIEventService(EventServiceBase):
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
if self._loop.is_closed():
|
||||
# The event loop was closed during shutdown. Events can no longer be dispatched;
|
||||
# silently drop this one so the generation thread can wind down cleanly.
|
||||
return
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
|
||||
@@ -50,8 +50,10 @@ class ImageRecordStorageBase(ABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
"""Gets a page of image records. When board_id is 'none', filters by user_id for per-user uncategorized images unless is_admin is True."""
|
||||
pass
|
||||
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
@@ -90,6 +92,7 @@ class ImageRecordStorageBase(ABC):
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
@@ -109,6 +112,8 @@ class ImageRecordStorageBase(ABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates."""
|
||||
pass
|
||||
|
||||
@@ -134,6 +134,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
@@ -186,6 +188,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
# For uncategorized images, filter by user_id to ensure per-user isolation
|
||||
# Admin users can see all uncategorized images from all users
|
||||
if user_id is not None and not is_admin:
|
||||
query_conditions += """--sql
|
||||
AND images.user_id = ?
|
||||
"""
|
||||
query_params.append(user_id)
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
@@ -305,6 +314,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> datetime:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
@@ -321,9 +331,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
has_workflow,
|
||||
user_id
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
@@ -337,6 +348,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
user_id or "system",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -386,6 +398,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> ImageNamesResult:
|
||||
with self._db.transaction() as cursor:
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
@@ -417,6 +431,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
# For uncategorized images, filter by user_id to ensure per-user isolation
|
||||
# Admin users can see all uncategorized images from all users
|
||||
if user_id is not None and not is_admin:
|
||||
query_conditions += """--sql
|
||||
AND images.user_id = ?
|
||||
"""
|
||||
query_params.append(user_id)
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
|
||||
@@ -55,6 +55,7 @@ class ImageServiceABC(ABC):
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@@ -125,6 +126,8 @@ class ImageServiceABC(ABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs with starred images first when starred_first=True."""
|
||||
pass
|
||||
@@ -159,6 +162,8 @@ class ImageServiceABC(ABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates."""
|
||||
pass
|
||||
|
||||
@@ -45,6 +45,7 @@ class ImageService(ImageServiceABC):
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
@@ -72,6 +73,7 @@ class ImageService(ImageServiceABC):
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if board_id is not None:
|
||||
try:
|
||||
@@ -215,6 +217,8 @@ class ImageService(ImageServiceABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self.__invoker.services.image_records.get_many(
|
||||
@@ -227,6 +231,8 @@ class ImageService(ImageServiceABC):
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term,
|
||||
user_id,
|
||||
is_admin,
|
||||
)
|
||||
|
||||
image_dtos = [
|
||||
@@ -320,6 +326,8 @@ class ImageService(ImageServiceABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> ImageNamesResult:
|
||||
try:
|
||||
return self.__invoker.services.image_records.get_image_names(
|
||||
@@ -330,6 +338,8 @@ class ImageService(ImageServiceABC):
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
user_id=user_id,
|
||||
is_admin=is_admin,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image names")
|
||||
|
||||
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.urls.urls_base import UrlServiceBase
|
||||
from invokeai.app.services.users.users_base import UserServiceBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_base import WorkflowThumbnailServiceBase
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@@ -75,6 +76,7 @@ class InvocationServices:
|
||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
||||
workflow_thumbnails: "WorkflowThumbnailServiceBase",
|
||||
client_state_persistence: "ClientStatePersistenceABC",
|
||||
users: "UserServiceBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@@ -105,3 +107,4 @@ class InvocationServices:
|
||||
self.style_preset_image_files = style_preset_image_files
|
||||
self.workflow_thumbnails = workflow_thumbnails
|
||||
self.client_state_persistence = client_state_persistence
|
||||
self.users = users
|
||||
|
||||
@@ -132,6 +132,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
marker = {
|
||||
"version": INSTALL_MARKER_VERSION,
|
||||
"source": str(job.source),
|
||||
"access_token": (
|
||||
job.source.access_token if isinstance(job.source, (HFModelSource, URLModelSource)) else None
|
||||
),
|
||||
"config_in": job.config_in.model_dump(),
|
||||
"status": (status or job.status).value,
|
||||
"updated_at": get_iso_timestamp(),
|
||||
@@ -186,6 +189,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _restore_incomplete_installs(self) -> None:
|
||||
path = self._app_config.models_path
|
||||
seen_sources: set[str] = set()
|
||||
# Collect sources already tracked by active jobs (including those being downloaded right now).
|
||||
# We must not re-queue these or delete their tmpdirs.
|
||||
with self._lock:
|
||||
active_sources = {str(j.source) for j in self._install_jobs if not j.in_terminal_state}
|
||||
active_sources.update(str(j.source) for j in self._download_cache.values() if not j.in_terminal_state)
|
||||
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
|
||||
marker = self._read_install_marker(tmpdir)
|
||||
if not marker:
|
||||
@@ -195,13 +203,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
continue
|
||||
|
||||
try:
|
||||
source_str = marker["source"]
|
||||
source_str = marker.get("source")
|
||||
if not isinstance(source_str, str):
|
||||
raise ValueError("Missing source in install marker")
|
||||
source = self._guess_source(source_str)
|
||||
access_token = marker.get("access_token")
|
||||
if isinstance(source, (HFModelSource, URLModelSource)) and isinstance(access_token, str):
|
||||
source.access_token = access_token
|
||||
if source_str in active_sources:
|
||||
# This tmpdir belongs to an install already in progress; leave it alone.
|
||||
self._logger.debug(f"Skipping restore for {source_str} - already being tracked")
|
||||
continue
|
||||
if source_str in seen_sources:
|
||||
self._logger.info(f"Removing duplicate temporary directory {tmpdir}")
|
||||
self._safe_rmtree(tmpdir, self._logger)
|
||||
continue
|
||||
seen_sources.add(source_str)
|
||||
source = self._guess_source(source_str)
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Skipping install marker in {tmpdir}: {e}")
|
||||
continue
|
||||
@@ -313,7 +330,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
||||
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
|
||||
if not self._running:
|
||||
raise Exception("Attempt to stop the install service before it was started")
|
||||
return
|
||||
self._logger.debug("calling stop_event.set()")
|
||||
self._stop_event.set()
|
||||
self._clear_pending_jobs()
|
||||
|
||||
@@ -355,6 +355,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
daemon=True,
|
||||
kwargs={
|
||||
"stop_event": self._stop_event,
|
||||
"poll_now_event": self._poll_now_event,
|
||||
@@ -366,6 +367,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
# Cancel any in-progress generation so that long-running nodes (e.g. denoising) stop at
|
||||
# the next step boundary instead of running to completion. Without this, the generation
|
||||
# thread may still be executing CUDA operations when Python teardown begins, which can
|
||||
# cause a C++ std::terminate() crash ("terminate called without an active exception").
|
||||
self._cancel_event.set()
|
||||
# Wake the thread if it is sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused).
|
||||
self._poll_now_event.set()
|
||||
self._resume_event.set()
|
||||
|
||||
def _poll_now(self) -> None:
|
||||
self._poll_now_event.set()
|
||||
|
||||
@@ -36,8 +36,10 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
|
||||
"""Enqueues all permutations of a batch for execution."""
|
||||
def enqueue_batch(
|
||||
self, queue_id: str, batch: Batch, prepend: bool, user_id: str = "system"
|
||||
) -> Coroutine[Any, Any, EnqueueBatchResult]:
|
||||
"""Enqueues all permutations of a batch for execution for a specific user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -51,13 +53,13 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
"""Deletes all session queue items"""
|
||||
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
|
||||
"""Deletes all session queue items. If user_id is provided, only clears items owned by that user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
"""Deletes all completed and errored session queue items"""
|
||||
def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult:
|
||||
"""Deletes all completed and errored session queue items. If user_id is provided, only prunes items owned by that user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -71,8 +73,8 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
"""Gets the status of the queue"""
|
||||
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
|
||||
"""Gets the status of the queue. If user_id is provided, also includes user-specific counts."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -108,18 +110,24 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
"""Cancels all queue items with matching batch IDs"""
|
||||
def cancel_by_batch_ids(
|
||||
self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Cancels all queue items with matching batch IDs. If user_id is provided, only cancels items owned by that user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
"""Cancels all queue items with the given batch destination"""
|
||||
def cancel_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> CancelByDestinationResult:
|
||||
"""Cancels all queue items with the given batch destination. If user_id is provided, only cancels items owned by that user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
"""Deletes all queue items with the given batch destination"""
|
||||
def delete_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> DeleteByDestinationResult:
|
||||
"""Deletes all queue items with the given batch destination. If user_id is provided, only deletes items owned by that user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -128,13 +136,13 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
"""Cancels all queue items except in-progress items"""
|
||||
def cancel_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> CancelAllExceptCurrentResult:
|
||||
"""Cancels all queue items except in-progress items. If user_id is provided, only cancels items owned by that user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
"""Deletes all queue items except in-progress items"""
|
||||
def delete_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> DeleteAllExceptCurrentResult:
|
||||
"""Deletes all queue items except in-progress items. If user_id is provided, only deletes items owned by that user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -170,6 +170,7 @@ class Batch(BaseModel):
|
||||
# region Queue Items
|
||||
|
||||
DEFAULT_QUEUE_ID = "default"
|
||||
SYSTEM_USER_ID = "system" # Default user_id for system-generated queue items
|
||||
|
||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||
|
||||
@@ -243,6 +244,13 @@ class SessionQueueItem(BaseModel):
|
||||
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
||||
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
||||
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
||||
user_id: str = Field(default="system", description="The id of the user who created this queue item")
|
||||
user_display_name: Optional[str] = Field(
|
||||
default=None, description="The display name of the user who created this queue item, if available"
|
||||
)
|
||||
user_email: Optional[str] = Field(
|
||||
default=None, description="The email of the user who created this queue item, if available"
|
||||
)
|
||||
field_values: Optional[list[NodeFieldValue]] = Field(
|
||||
default=None, description="The field values that were used for this queue item"
|
||||
)
|
||||
@@ -296,6 +304,12 @@ class SessionQueueStatus(BaseModel):
|
||||
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||
total: int = Field(..., description="Total number of queue items")
|
||||
user_pending: Optional[int] = Field(
|
||||
default=None, description="Number of queue items with status 'pending' for the current user"
|
||||
)
|
||||
user_in_progress: Optional[int] = Field(
|
||||
default=None, description="Number of queue items with status 'in_progress' for the current user"
|
||||
)
|
||||
|
||||
|
||||
class SessionQueueCountsByDestination(BaseModel):
|
||||
@@ -565,6 +579,7 @@ ValueToInsertTuple: TypeAlias = tuple[
|
||||
str | None, # origin (optional)
|
||||
str | None, # destination (optional)
|
||||
int | None, # retried_from_item_id (optional, this is always None for new items)
|
||||
str, # user_id
|
||||
]
|
||||
"""A type alias for the tuple of values to insert into the session queue table.
|
||||
|
||||
@@ -573,7 +588,7 @@ ValueToInsertTuple: TypeAlias = tuple[
|
||||
|
||||
|
||||
def prepare_values_to_insert(
|
||||
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
|
||||
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int, user_id: str = "system"
|
||||
) -> list[ValueToInsertTuple]:
|
||||
"""
|
||||
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
|
||||
@@ -584,6 +599,7 @@ def prepare_values_to_insert(
|
||||
batch: The batch to prepare the values for
|
||||
priority: The priority of the queue items
|
||||
max_new_queue_items: The maximum number of queue items to insert
|
||||
user_id: The user ID who is creating these queue items
|
||||
|
||||
Returns:
|
||||
A list of tuples to insert into the session queue table. Each tuple contains the following values:
|
||||
@@ -597,6 +613,7 @@ def prepare_values_to_insert(
|
||||
- origin (optional)
|
||||
- destination (optional)
|
||||
- retried_from_item_id (optional, this is always None for new items)
|
||||
- user_id
|
||||
"""
|
||||
|
||||
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
|
||||
@@ -626,6 +643,7 @@ def prepare_values_to_insert(
|
||||
batch.origin,
|
||||
batch.destination,
|
||||
None,
|
||||
user_id,
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -100,7 +100,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
return priority
|
||||
|
||||
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
async def enqueue_batch(
|
||||
self, queue_id: str, batch: Batch, prepend: bool, user_id: str = "system"
|
||||
) -> EnqueueBatchResult:
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
@@ -119,14 +121,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
user_id=user_id,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
@@ -155,12 +158,16 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
SELECT
|
||||
sq.*,
|
||||
u.display_name as user_display_name,
|
||||
u.email as user_email
|
||||
FROM session_queue sq
|
||||
LEFT JOIN users u ON sq.user_id = u.user_id
|
||||
WHERE sq.status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
sq.priority DESC,
|
||||
sq.item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
@@ -175,14 +182,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
SELECT
|
||||
sq.*,
|
||||
u.display_name as user_display_name,
|
||||
u.email as user_email
|
||||
FROM session_queue sq
|
||||
LEFT JOIN users u ON sq.user_id = u.user_id
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
sq.queue_id = ?
|
||||
AND sq.status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
sq.priority DESC,
|
||||
sq.created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
@@ -196,11 +207,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
SELECT
|
||||
sq.*,
|
||||
u.display_name as user_display_name,
|
||||
u.email as user_email
|
||||
FROM session_queue sq
|
||||
LEFT JOIN users u ON sq.user_id = u.user_id
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
sq.queue_id = ?
|
||||
AND sq.status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
@@ -277,31 +292,41 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
|
||||
with self._db.transaction() as cursor:
|
||||
user_filter = "AND user_id = ?" if user_id is not None else ""
|
||||
where = f"""--sql
|
||||
WHERE queue_id = ?
|
||||
{user_filter}
|
||||
"""
|
||||
params: list[str] = [queue_id]
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
{where}
|
||||
""",
|
||||
(queue_id,),
|
||||
tuple(params),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
f"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
{where}
|
||||
""",
|
||||
(queue_id,),
|
||||
tuple(params),
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult:
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
# Build WHERE clause with optional user_id filter
|
||||
user_filter = "AND user_id = ?" if user_id is not None else ""
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND (
|
||||
@@ -309,14 +334,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
OR status = 'failed'
|
||||
OR status = 'canceled'
|
||||
)
|
||||
{user_filter}
|
||||
"""
|
||||
params = [queue_id]
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
tuple(params),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
@@ -325,7 +355,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
tuple(params),
|
||||
)
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
@@ -369,10 +399,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
def cancel_by_batch_ids(
|
||||
self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None
|
||||
) -> CancelByBatchIDsResult:
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
|
||||
# Build WHERE clause with optional user_id filter
|
||||
user_filter = "AND user_id = ?" if user_id is not None else ""
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -382,8 +417,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
{user_filter}
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
@@ -402,15 +441,22 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
tuple(params),
|
||||
)
|
||||
|
||||
# Handle current item separately - check ownership if user_id is provided
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
def cancel_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> CancelByDestinationResult:
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
|
||||
# Build WHERE clause with optional user_id filter
|
||||
user_filter = "AND user_id = ?" if user_id is not None else ""
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND destination == ?
|
||||
@@ -419,15 +465,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
{user_filter}
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
params = [queue_id, destination]
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
tuple(params),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
@@ -436,55 +486,78 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
tuple(params),
|
||||
)
|
||||
|
||||
# Handle current item separately - check ownership if user_id is provided
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
def delete_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> DeleteByDestinationResult:
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
|
||||
# Handle current item separately - check ownership if user_id is provided
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
params = (queue_id, destination)
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
|
||||
# Build WHERE clause with optional user_id filter
|
||||
user_filter = "AND user_id = ?" if user_id is not None else ""
|
||||
params = [queue_id, destination]
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND destination = ?;
|
||||
queue_id == ?
|
||||
AND destination == ?
|
||||
{user_filter}
|
||||
""",
|
||||
params,
|
||||
tuple(params),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
f"""--sql
|
||||
DELETE FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND destination = ?;
|
||||
queue_id == ?
|
||||
AND destination == ?
|
||||
{user_filter}
|
||||
""",
|
||||
params,
|
||||
tuple(params),
|
||||
)
|
||||
return DeleteByDestinationResult(deleted=count)
|
||||
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
def delete_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> DeleteAllExceptCurrentResult:
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
# Build WHERE clause with optional user_id filter
|
||||
user_filter = "AND user_id = ?" if user_id is not None else ""
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND status == 'pending'
|
||||
{user_filter}
|
||||
"""
|
||||
params = [queue_id]
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
tuple(params),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
@@ -493,7 +566,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
tuple(params),
|
||||
)
|
||||
return DeleteAllExceptCurrentResult(deleted=count)
|
||||
|
||||
@@ -532,20 +605,27 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
def cancel_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> CancelAllExceptCurrentResult:
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
# Build WHERE clause with optional user_id filter
|
||||
user_filter = "AND user_id = ?" if user_id is not None else ""
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND status == 'pending'
|
||||
{user_filter}
|
||||
"""
|
||||
params = [queue_id]
|
||||
if user_id is not None:
|
||||
params.append(user_id)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
tuple(params),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
@@ -554,7 +634,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
tuple(params),
|
||||
)
|
||||
return CancelAllExceptCurrentResult(canceled=count)
|
||||
|
||||
@@ -562,9 +642,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
SELECT
|
||||
sq.*,
|
||||
u.display_name as user_display_name,
|
||||
u.email as user_email
|
||||
FROM session_queue sq
|
||||
LEFT JOIN users u ON sq.user_id = u.user_id
|
||||
WHERE sq.item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
@@ -650,22 +734,26 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
with self._db.transaction() as cursor:
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
SELECT
|
||||
sq.*,
|
||||
u.display_name as user_display_name,
|
||||
u.email as user_email
|
||||
FROM session_queue sq
|
||||
LEFT JOIN users u ON sq.user_id = u.user_id
|
||||
WHERE sq.queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
AND sq.destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
sq.priority DESC,
|
||||
sq.item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor.execute(query, params)
|
||||
@@ -693,8 +781,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids))
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
|
||||
with self._db.transaction() as cursor:
|
||||
# Get total counts
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
@@ -706,9 +795,32 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
# Get user-specific counts if user_id is provided (using a single query with CASE)
|
||||
user_counts_result = []
|
||||
if user_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ? AND user_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, user_id),
|
||||
)
|
||||
user_counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
|
||||
# Process user-specific counts if available
|
||||
user_pending = None
|
||||
user_in_progress = None
|
||||
if user_id is not None:
|
||||
user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result}
|
||||
user_pending = user_counts.get("pending", 0)
|
||||
user_in_progress = user_counts.get("in_progress", 0)
|
||||
|
||||
return SessionQueueStatus(
|
||||
queue_id=queue_id,
|
||||
item_id=current_item.item_id if current_item else None,
|
||||
@@ -720,6 +832,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
user_pending=user_pending,
|
||||
user_in_progress=user_in_progress,
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
@@ -822,6 +936,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item.origin,
|
||||
queue_item.destination,
|
||||
retried_from_item_id,
|
||||
queue_item.user_id,
|
||||
)
|
||||
values_to_insert.append(value_to_insert)
|
||||
|
||||
@@ -829,8 +944,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
|
||||
@@ -72,7 +72,7 @@ class InvocationContextInterface:
|
||||
|
||||
class BoardsInterface(InvocationContextInterface):
|
||||
def create(self, board_name: str) -> BoardDTO:
|
||||
"""Creates a board.
|
||||
"""Creates a board for the current user.
|
||||
|
||||
Args:
|
||||
board_name: The name of the board to create.
|
||||
@@ -80,7 +80,8 @@ class BoardsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
The created board DTO.
|
||||
"""
|
||||
return self._services.boards.create(board_name)
|
||||
user_id = self._data.queue_item.user_id
|
||||
return self._services.boards.create(board_name, user_id)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
"""Gets a board DTO.
|
||||
@@ -94,13 +95,14 @@ class BoardsInterface(InvocationContextInterface):
|
||||
return self._services.boards.get_dto(board_id)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
"""Gets all boards.
|
||||
"""Gets all boards accessible to the current user.
|
||||
|
||||
Returns:
|
||||
A list of all boards.
|
||||
A list of all boards accessible to the current user.
|
||||
"""
|
||||
user_id = self._data.queue_item.user_id
|
||||
return self._services.boards.get_all(
|
||||
order_by=BoardRecordOrderBy.CreatedAt, direction=SQLiteDirection.Descending
|
||||
user_id, order_by=BoardRecordOrderBy.CreatedAt, direction=SQLiteDirection.Descending
|
||||
)
|
||||
|
||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||
@@ -228,6 +230,7 @@ class ImagesInterface(InvocationContextInterface):
|
||||
graph=graph_,
|
||||
session_id=self._data.queue_item.session_id,
|
||||
node_id=self._data.invocation.id,
|
||||
user_id=self._data.queue_item.user_id,
|
||||
)
|
||||
|
||||
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
||||
|
||||
@@ -29,6 +29,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -75,6 +76,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_24(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_25(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_26(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_27())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
"""Migration 27: Add multi-user support, per-user client state, and app settings.
|
||||
|
||||
This migration adds the database schema for multi-user support, including:
|
||||
- users table for user accounts
|
||||
- user_sessions table for session management
|
||||
- user_invitations table for invitation system
|
||||
- shared_boards table for board sharing
|
||||
- Adding user_id columns to existing tables for data ownership
|
||||
- Restructuring client_state table to support per-user storage
|
||||
- app_settings table for storing JWT secret and other app-level settings
|
||||
"""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration27Callback:
|
||||
"""Migration to add multi-user support, per-user client state, and app settings."""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_users_table(cursor)
|
||||
self._create_user_sessions_table(cursor)
|
||||
self._create_user_invitations_table(cursor)
|
||||
self._create_shared_boards_table(cursor)
|
||||
self._update_boards_table(cursor)
|
||||
self._update_images_table(cursor)
|
||||
self._update_workflows_table(cursor)
|
||||
self._update_session_queue_table(cursor)
|
||||
self._update_style_presets_table(cursor)
|
||||
self._create_system_user(cursor)
|
||||
self._update_client_state_table(cursor)
|
||||
self._create_app_settings_table(cursor)
|
||||
self._generate_jwt_secret(cursor)
|
||||
|
||||
def _create_users_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create users table."""
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id TEXT NOT NULL PRIMARY KEY,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT,
|
||||
password_hash TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
last_login_at DATETIME
|
||||
);
|
||||
""")
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_is_admin ON users(is_admin);")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_is_active ON users(is_active);")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS tg_users_updated_at
|
||||
AFTER UPDATE ON users FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE users SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE user_id = old.user_id;
|
||||
END;
|
||||
""")
|
||||
|
||||
def _create_user_sessions_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create user_sessions table for session management."""
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
session_id TEXT NOT NULL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
last_activity_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id);")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_sessions_token_hash ON user_sessions(token_hash);")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_sessions_expires_at ON user_sessions(expires_at);")
|
||||
|
||||
def _create_user_invitations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create user_invitations table for invitation system."""
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS user_invitations (
|
||||
invitation_id TEXT NOT NULL PRIMARY KEY,
|
||||
email TEXT NOT NULL,
|
||||
invited_by TEXT NOT NULL,
|
||||
invitation_code TEXT NOT NULL UNIQUE,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
expires_at DATETIME NOT NULL,
|
||||
used_at DATETIME,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
FOREIGN KEY (invited_by) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_invitations_email ON user_invitations(email);")
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_user_invitations_invitation_code ON user_invitations(invitation_code);"
|
||||
)
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_invitations_expires_at ON user_invitations(expires_at);")
|
||||
|
||||
def _create_shared_boards_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create shared_boards table for board sharing."""
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS shared_boards (
|
||||
board_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
can_edit BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
shared_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
PRIMARY KEY (board_id, user_id),
|
||||
FOREIGN KEY (board_id) REFERENCES boards(board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_shared_boards_user_id ON shared_boards(user_id);")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_shared_boards_board_id ON shared_boards(board_id);")
|
||||
|
||||
def _update_boards_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add user_id and is_public columns to boards table."""
|
||||
# Check if boards table exists
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='boards';")
|
||||
if cursor.fetchone() is None:
|
||||
return
|
||||
|
||||
# Check if user_id column exists
|
||||
cursor.execute("PRAGMA table_info(boards);")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "user_id" not in columns:
|
||||
cursor.execute("ALTER TABLE boards ADD COLUMN user_id TEXT DEFAULT 'system';")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_user_id ON boards(user_id);")
|
||||
|
||||
if "is_public" not in columns:
|
||||
cursor.execute("ALTER TABLE boards ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_is_public ON boards(is_public);")
|
||||
|
||||
def _update_images_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add user_id column to images table."""
|
||||
# Check if images table exists
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='images';")
|
||||
if cursor.fetchone() is None:
|
||||
return
|
||||
|
||||
cursor.execute("PRAGMA table_info(images);")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "user_id" not in columns:
|
||||
cursor.execute("ALTER TABLE images ADD COLUMN user_id TEXT DEFAULT 'system';")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_images_user_id ON images(user_id);")
|
||||
|
||||
def _update_workflows_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add user_id and is_public columns to workflows table."""
|
||||
# Check if workflows table exists
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='workflows';")
|
||||
if cursor.fetchone() is None:
|
||||
return
|
||||
|
||||
cursor.execute("PRAGMA table_info(workflows);")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "user_id" not in columns:
|
||||
cursor.execute("ALTER TABLE workflows ADD COLUMN user_id TEXT DEFAULT 'system';")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflows_user_id ON workflows(user_id);")
|
||||
|
||||
if "is_public" not in columns:
|
||||
cursor.execute("ALTER TABLE workflows ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflows_is_public ON workflows(is_public);")
|
||||
|
||||
def _update_session_queue_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add user_id column to session_queue table."""
|
||||
# Check if session_queue table exists
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='session_queue';")
|
||||
if cursor.fetchone() is None:
|
||||
return
|
||||
|
||||
cursor.execute("PRAGMA table_info(session_queue);")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "user_id" not in columns:
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN user_id TEXT DEFAULT 'system';")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_session_queue_user_id ON session_queue(user_id);")
|
||||
|
||||
def _update_style_presets_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add user_id and is_public columns to style_presets table."""
|
||||
# Check if style_presets table exists
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='style_presets';")
|
||||
if cursor.fetchone() is None:
|
||||
return
|
||||
|
||||
cursor.execute("PRAGMA table_info(style_presets);")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "user_id" not in columns:
|
||||
cursor.execute("ALTER TABLE style_presets ADD COLUMN user_id TEXT DEFAULT 'system';")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_style_presets_user_id ON style_presets(user_id);")
|
||||
|
||||
if "is_public" not in columns:
|
||||
cursor.execute("ALTER TABLE style_presets ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_style_presets_is_public ON style_presets(is_public);")
|
||||
|
||||
def _create_system_user(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create system user for backward compatibility.
|
||||
|
||||
The system user is NOT an admin - it's just used to own existing data
|
||||
from before multi-user support was added. Real admin users should be
|
||||
created through the /auth/setup endpoint.
|
||||
"""
|
||||
cursor.execute("""
|
||||
INSERT OR IGNORE INTO users (user_id, email, display_name, password_hash, is_admin, is_active)
|
||||
VALUES ('system', 'system@system.invokeai', 'System', '', FALSE, TRUE);
|
||||
""")
|
||||
|
||||
def _update_client_state_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Restructure client_state table to support per-user storage."""
|
||||
# Check if client_state table exists
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='client_state';")
|
||||
if cursor.fetchone() is None:
|
||||
# Table doesn't exist, create it with the new schema
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE client_state (
|
||||
user_id TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP),
|
||||
PRIMARY KEY (user_id, key),
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER tg_client_state_updated_at
|
||||
AFTER UPDATE ON client_state
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE client_state
|
||||
SET updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = OLD.user_id AND key = OLD.key;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
# Table exists with old schema - migrate it
|
||||
# Get existing data if the data column is present (it may be absent if an older
|
||||
# version of migration 21 was deployed without the column)
|
||||
cursor.execute("PRAGMA table_info(client_state);")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
existing_data = {}
|
||||
if "data" in columns:
|
||||
cursor.execute("SELECT data FROM client_state WHERE id = 1;")
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
try:
|
||||
existing_data = json.loads(row[0])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If data is corrupt, just start fresh
|
||||
pass
|
||||
|
||||
# Drop the old table
|
||||
cursor.execute("DROP TABLE IF EXISTS client_state;")
|
||||
|
||||
# Create new table with per-user schema
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE client_state (
|
||||
user_id TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP),
|
||||
PRIMARY KEY (user_id, key),
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);")
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER tg_client_state_updated_at
|
||||
AFTER UPDATE ON client_state
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE client_state
|
||||
SET updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = OLD.user_id AND key = OLD.key;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Migrate existing data to 'system' user
|
||||
for key, value in existing_data.items():
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO client_state (user_id, key, value)
|
||||
VALUES ('system', ?, ?);
|
||||
""",
|
||||
(key, value),
|
||||
)
|
||||
|
||||
def _create_app_settings_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create app_settings table for storing application-level configuration."""
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS app_settings (
|
||||
key TEXT NOT NULL PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS tg_app_settings_updated_at
|
||||
AFTER UPDATE ON app_settings
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE app_settings SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE key = OLD.key;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def _generate_jwt_secret(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Generate and store a cryptographically secure JWT secret key.
|
||||
|
||||
The secret is a 64-character hexadecimal string (256 bits of entropy),
|
||||
which is suitable for HS256 JWT signing.
|
||||
"""
|
||||
# Check if JWT secret already exists
|
||||
cursor.execute("SELECT value FROM app_settings WHERE key = 'jwt_secret';")
|
||||
existing_secret = cursor.fetchone()
|
||||
|
||||
if existing_secret is None:
|
||||
# Generate a new cryptographically secure secret (256 bits)
|
||||
jwt_secret = secrets.token_hex(32) # 32 bytes = 256 bits = 64 hex characters
|
||||
|
||||
# Store in database
|
||||
cursor.execute(
|
||||
"INSERT INTO app_settings (key, value) VALUES ('jwt_secret', ?);",
|
||||
(jwt_secret,),
|
||||
)
|
||||
|
||||
|
||||
def build_migration_27() -> Migration:
|
||||
"""Builds the migration object for migrating from version 26 to version 27.
|
||||
|
||||
This migration adds multi-user support, per-user client state, and app settings
|
||||
(including a JWT secret) to the database schema.
|
||||
"""
|
||||
return Migration(
|
||||
from_version=26,
|
||||
to_version=27,
|
||||
callback=Migration27Callback(),
|
||||
)
|
||||
1
invokeai/app/services/users/__init__.py
Normal file
1
invokeai/app/services/users/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""User service module."""
|
||||
135
invokeai/app/services/users/users_base.py
Normal file
135
invokeai/app/services/users/users_base.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Abstract base class for user service."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.users.users_common import UserCreateRequest, UserDTO, UserUpdateRequest
|
||||
|
||||
|
||||
class UserServiceBase(ABC):
|
||||
"""High-level service for user management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, user_data: UserCreateRequest) -> UserDTO:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
user_data: User creation data
|
||||
|
||||
Returns:
|
||||
The created user
|
||||
|
||||
Raises:
|
||||
ValueError: If email already exists or password is weak
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, user_id: str) -> UserDTO | None:
|
||||
"""Get user by ID.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
UserDTO if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_by_email(self, email: str) -> UserDTO | None:
|
||||
"""Get user by email.
|
||||
|
||||
Args:
|
||||
email: The email address
|
||||
|
||||
Returns:
|
||||
UserDTO if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, user_id: str, changes: UserUpdateRequest) -> UserDTO:
|
||||
"""Update user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
changes: Fields to update
|
||||
|
||||
Returns:
|
||||
The updated user
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found or password is weak
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, user_id: str) -> None:
|
||||
"""Delete user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def authenticate(self, email: str, password: str) -> UserDTO | None:
|
||||
"""Authenticate user credentials.
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
password: User password
|
||||
|
||||
Returns:
|
||||
UserDTO if authentication successful, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def has_admin(self) -> bool:
|
||||
"""Check if any admin user exists.
|
||||
|
||||
Returns:
|
||||
True if at least one admin user exists, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_admin(self, user_data: UserCreateRequest) -> UserDTO:
|
||||
"""Create an admin user (for initial setup).
|
||||
|
||||
Args:
|
||||
user_data: User creation data
|
||||
|
||||
Returns:
|
||||
The created admin user
|
||||
|
||||
Raises:
|
||||
ValueError: If admin already exists or password is weak
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
|
||||
"""List all users.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of users to return
|
||||
offset: Number of users to skip
|
||||
|
||||
Returns:
|
||||
List of users
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_admins(self) -> int:
|
||||
"""Count active admin users.
|
||||
|
||||
Returns:
|
||||
The number of active admin users
|
||||
"""
|
||||
pass
|
||||
114
invokeai/app/services/users/users_common.py
Normal file
114
invokeai/app/services/users/users_common.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Common types and data models for user service."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
|
||||
def validate_email_with_special_domains(email: str) -> str:
|
||||
"""Validate email address, allowing special-use domains like .local for testing.
|
||||
|
||||
This validator first tries standard email validation using email-validator library.
|
||||
If it fails due to special-use domains (like .local, .test, .localhost), it performs
|
||||
a basic syntax check instead. This allows development/testing with non-routable domains
|
||||
while still catching actual typos and malformed emails.
|
||||
|
||||
Args:
|
||||
email: The email address to validate
|
||||
|
||||
Returns:
|
||||
The validated email address (lowercased)
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If the email format is invalid
|
||||
"""
|
||||
try:
|
||||
# Try standard email validation using email-validator
|
||||
from email_validator import EmailNotValidError, validate_email
|
||||
|
||||
result = validate_email(email, check_deliverability=False)
|
||||
return result.normalized
|
||||
except EmailNotValidError as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Check if the error is specifically about special-use/reserved domains or localhost
|
||||
if (
|
||||
"special-use" in error_msg.lower()
|
||||
or "reserved" in error_msg.lower()
|
||||
or "should have a period" in error_msg.lower()
|
||||
):
|
||||
# Perform basic email syntax validation
|
||||
email = email.strip().lower()
|
||||
|
||||
if "@" not in email:
|
||||
raise PydanticCustomError(
|
||||
"value_error",
|
||||
"Email address must contain an @ symbol",
|
||||
)
|
||||
|
||||
local_part, domain = email.rsplit("@", 1)
|
||||
|
||||
if not local_part or not domain:
|
||||
raise PydanticCustomError(
|
||||
"value_error",
|
||||
"Email address must have both local and domain parts",
|
||||
)
|
||||
|
||||
# Allow localhost and domains with dots
|
||||
if domain == "localhost" or "." in domain:
|
||||
return email
|
||||
|
||||
raise PydanticCustomError(
|
||||
"value_error",
|
||||
"Email domain must contain a dot or be 'localhost'",
|
||||
)
|
||||
else:
|
||||
# Re-raise other validation errors
|
||||
raise PydanticCustomError(
|
||||
"value_error",
|
||||
f"Invalid email address: {error_msg}",
|
||||
)
|
||||
|
||||
|
||||
class UserDTO(BaseModel):
|
||||
"""User data transfer object."""
|
||||
|
||||
user_id: str = Field(description="Unique user identifier")
|
||||
email: str = Field(description="User email address")
|
||||
display_name: str | None = Field(default=None, description="Display name")
|
||||
is_admin: bool = Field(default=False, description="Whether user has admin privileges")
|
||||
is_active: bool = Field(default=True, description="Whether user account is active")
|
||||
created_at: datetime = Field(description="When the user was created")
|
||||
updated_at: datetime = Field(description="When the user was last updated")
|
||||
last_login_at: datetime | None = Field(default=None, description="When user last logged in")
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
"""Validate email address, allowing special-use domains."""
|
||||
return validate_email_with_special_domains(v)
|
||||
|
||||
|
||||
class UserCreateRequest(BaseModel):
|
||||
"""Request to create a new user."""
|
||||
|
||||
email: str = Field(description="User email address")
|
||||
display_name: str | None = Field(default=None, description="Display name")
|
||||
password: str = Field(description="User password")
|
||||
is_admin: bool = Field(default=False, description="Whether user should have admin privileges")
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
"""Validate email address, allowing special-use domains."""
|
||||
return validate_email_with_special_domains(v)
|
||||
|
||||
|
||||
class UserUpdateRequest(BaseModel):
|
||||
"""Request to update a user."""
|
||||
|
||||
display_name: str | None = Field(default=None, description="Display name")
|
||||
password: str | None = Field(default=None, description="New password")
|
||||
is_admin: bool | None = Field(default=None, description="Whether user should have admin privileges")
|
||||
is_active: bool | None = Field(default=None, description="Whether user account should be active")
|
||||
258
invokeai/app/services/users/users_default.py
Normal file
258
invokeai/app/services/users/users_default.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Default SQLite implementation of user service."""
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from invokeai.app.services.auth.password_utils import hash_password, validate_password_strength, verify_password
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_base import UserServiceBase
|
||||
from invokeai.app.services.users.users_common import UserCreateRequest, UserDTO, UserUpdateRequest
|
||||
|
||||
|
||||
class UserService(UserServiceBase):
|
||||
"""SQLite-based user service."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""Initialize user service.
|
||||
|
||||
Args:
|
||||
db: SQLite database instance
|
||||
"""
|
||||
self._db = db
|
||||
|
||||
def create(self, user_data: UserCreateRequest) -> UserDTO:
|
||||
"""Create a new user."""
|
||||
# Validate password strength
|
||||
is_valid, error_msg = validate_password_strength(user_data.password)
|
||||
if not is_valid:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Check if email already exists
|
||||
if self.get_by_email(user_data.email) is not None:
|
||||
raise ValueError(f"User with email {user_data.email} already exists")
|
||||
|
||||
user_id = str(uuid4())
|
||||
password_hash = hash_password(user_data.password)
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (user_id, email, display_name, password_hash, is_admin)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, user_data.email, user_data.display_name, password_hash, user_data.is_admin),
|
||||
)
|
||||
except sqlite3.IntegrityError as e:
|
||||
raise ValueError(f"Failed to create user: {e}") from e
|
||||
|
||||
user = self.get(user_id)
|
||||
if user is None:
|
||||
raise RuntimeError("Failed to retrieve created user")
|
||||
return user
|
||||
|
||||
def get(self, user_id: str) -> UserDTO | None:
|
||||
"""Get user by ID."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT user_id, email, display_name, is_admin, is_active, created_at, updated_at, last_login_at
|
||||
FROM users
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return UserDTO(
|
||||
user_id=row[0],
|
||||
email=row[1],
|
||||
display_name=row[2],
|
||||
is_admin=bool(row[3]),
|
||||
is_active=bool(row[4]),
|
||||
created_at=datetime.fromisoformat(row[5]),
|
||||
updated_at=datetime.fromisoformat(row[6]),
|
||||
last_login_at=datetime.fromisoformat(row[7]) if row[7] else None,
|
||||
)
|
||||
|
||||
def get_by_email(self, email: str) -> UserDTO | None:
|
||||
"""Get user by email."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT user_id, email, display_name, is_admin, is_active, created_at, updated_at, last_login_at
|
||||
FROM users
|
||||
WHERE email = ?
|
||||
""",
|
||||
(email,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return UserDTO(
|
||||
user_id=row[0],
|
||||
email=row[1],
|
||||
display_name=row[2],
|
||||
is_admin=bool(row[3]),
|
||||
is_active=bool(row[4]),
|
||||
created_at=datetime.fromisoformat(row[5]),
|
||||
updated_at=datetime.fromisoformat(row[6]),
|
||||
last_login_at=datetime.fromisoformat(row[7]) if row[7] else None,
|
||||
)
|
||||
|
||||
def update(self, user_id: str, changes: UserUpdateRequest) -> UserDTO:
|
||||
"""Update user."""
|
||||
# Check if user exists
|
||||
user = self.get(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
# Validate password if provided
|
||||
if changes.password is not None:
|
||||
is_valid, error_msg = validate_password_strength(changes.password)
|
||||
if not is_valid:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Build update query dynamically based on provided fields
|
||||
updates: list[str] = []
|
||||
params: list[str | bool | int] = []
|
||||
|
||||
if changes.display_name is not None:
|
||||
updates.append("display_name = ?")
|
||||
params.append(changes.display_name)
|
||||
|
||||
if changes.password is not None:
|
||||
updates.append("password_hash = ?")
|
||||
params.append(hash_password(changes.password))
|
||||
|
||||
if changes.is_admin is not None:
|
||||
updates.append("is_admin = ?")
|
||||
params.append(changes.is_admin)
|
||||
|
||||
if changes.is_active is not None:
|
||||
updates.append("is_active = ?")
|
||||
params.append(changes.is_active)
|
||||
|
||||
if not updates:
|
||||
return user
|
||||
|
||||
params.append(user_id)
|
||||
query = f"UPDATE users SET {', '.join(updates)} WHERE user_id = ?"
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(query, params)
|
||||
|
||||
updated_user = self.get(user_id)
|
||||
if updated_user is None:
|
||||
raise RuntimeError("Failed to retrieve updated user")
|
||||
return updated_user
|
||||
|
||||
def delete(self, user_id: str) -> None:
|
||||
"""Delete user."""
|
||||
user = self.get(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute("DELETE FROM users WHERE user_id = ?", (user_id,))
|
||||
|
||||
def authenticate(self, email: str, password: str) -> UserDTO | None:
|
||||
"""Authenticate user credentials."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT user_id, email, display_name, password_hash, is_admin, is_active, created_at, updated_at, last_login_at
|
||||
FROM users
|
||||
WHERE email = ?
|
||||
""",
|
||||
(email,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
password_hash = row[3]
|
||||
if not verify_password(password, password_hash):
|
||||
return None
|
||||
|
||||
# Update last login time
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"UPDATE users SET last_login_at = ? WHERE user_id = ?",
|
||||
(datetime.now(timezone.utc).isoformat(), row[0]),
|
||||
)
|
||||
|
||||
return UserDTO(
|
||||
user_id=row[0],
|
||||
email=row[1],
|
||||
display_name=row[2],
|
||||
is_admin=bool(row[4]),
|
||||
is_active=bool(row[5]),
|
||||
created_at=datetime.fromisoformat(row[6]),
|
||||
updated_at=datetime.fromisoformat(row[7]),
|
||||
last_login_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def has_admin(self) -> bool:
|
||||
"""Check if any admin user exists."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users WHERE is_admin = TRUE AND is_active = TRUE")
|
||||
row = cursor.fetchone()
|
||||
count = row[0] if row else 0
|
||||
return bool(count > 0)
|
||||
|
||||
def create_admin(self, user_data: UserCreateRequest) -> UserDTO:
|
||||
"""Create an admin user (for initial setup)."""
|
||||
if self.has_admin():
|
||||
raise ValueError("Admin user already exists")
|
||||
|
||||
# Force is_admin to True
|
||||
admin_data = UserCreateRequest(
|
||||
email=user_data.email,
|
||||
display_name=user_data.display_name,
|
||||
password=user_data.password,
|
||||
is_admin=True,
|
||||
)
|
||||
return self.create(admin_data)
|
||||
|
||||
def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
|
||||
"""List all users."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT user_id, email, display_name, is_admin, is_active, created_at, updated_at, last_login_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
return [
|
||||
UserDTO(
|
||||
user_id=row[0],
|
||||
email=row[1],
|
||||
display_name=row[2],
|
||||
is_admin=bool(row[3]),
|
||||
is_active=bool(row[4]),
|
||||
created_at=datetime.fromisoformat(row[5]),
|
||||
updated_at=datetime.fromisoformat(row[6]),
|
||||
last_login_at=datetime.fromisoformat(row[7]) if row[7] else None,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def count_admins(self) -> int:
|
||||
"""Count active admin users."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM users WHERE is_admin = TRUE AND is_active = TRUE")
|
||||
row = cursor.fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
579
invokeai/app/util/user_management.py
Normal file
579
invokeai/app/util/user_management.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""User management command entry points for InvokeAI.
|
||||
|
||||
These functions are registered as console scripts in pyproject.toml and can be
|
||||
called from the command line after installing the package:
|
||||
|
||||
invoke-useradd -- add a user
|
||||
invoke-userdel -- delete a user
|
||||
invoke-userlist -- list users
|
||||
invoke-usermod -- modify a user
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
_root_help = (
|
||||
"Path to the InvokeAI root directory. If omitted, the root is resolved in this order: "
|
||||
"the $INVOKEAI_ROOT environment variable, the active virtual environment's parent directory, "
|
||||
"or $HOME/invokeai."
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# useradd
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _add_user_interactive() -> bool:
|
||||
"""Add a user interactively by prompting for details."""
|
||||
from invokeai.app.services.auth.password_utils import validate_password_strength
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_common import UserCreateRequest
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
print("=== Add InvokeAI User ===\n")
|
||||
|
||||
email = input("Email address: ").strip()
|
||||
if not email:
|
||||
print("Error: Email is required")
|
||||
return False
|
||||
|
||||
display_name = input("Display name (optional): ").strip() or None
|
||||
|
||||
while True:
|
||||
password = getpass.getpass("Password: ")
|
||||
password_confirm = getpass.getpass("Confirm password: ")
|
||||
|
||||
if password != password_confirm:
|
||||
print("Error: Passwords do not match. Please try again.\n")
|
||||
continue
|
||||
|
||||
is_valid, error_msg = validate_password_strength(password)
|
||||
if not is_valid:
|
||||
print(f"Error: {error_msg}\n")
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
is_admin_input = input("Make this user an administrator? (y/N): ").strip().lower()
|
||||
is_admin = is_admin_input in ("y", "yes")
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
|
||||
user_service = UserService(db)
|
||||
|
||||
user_data = UserCreateRequest(email=email, display_name=display_name, password=password, is_admin=is_admin)
|
||||
user = user_service.create(user_data)
|
||||
|
||||
print("\n✅ User created successfully!")
|
||||
print(f" User ID: {user.user_id}")
|
||||
print(f" Email: {user.email}")
|
||||
print(f" Display Name: {user.display_name or '(not set)'}")
|
||||
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
|
||||
print(f" Active: {'Yes' if user.is_active else 'No'}")
|
||||
return True
|
||||
|
||||
except ValueError as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def _add_user_cli(email: str, password: str, display_name: str | None = None, is_admin: bool = False) -> bool:
|
||||
"""Add a user via CLI arguments."""
|
||||
from invokeai.app.services.auth.password_utils import validate_password_strength
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_common import UserCreateRequest
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
is_valid, error_msg = validate_password_strength(password)
|
||||
if not is_valid:
|
||||
print(f"❌ Password validation failed: {error_msg}")
|
||||
return False
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
|
||||
user_service = UserService(db)
|
||||
|
||||
user_data = UserCreateRequest(email=email, display_name=display_name, password=password, is_admin=is_admin)
|
||||
user = user_service.create(user_data)
|
||||
|
||||
print("✅ User created successfully!")
|
||||
print(f" User ID: {user.user_id}")
|
||||
print(f" Email: {user.email}")
|
||||
print(f" Display Name: {user.display_name or '(not set)'}")
|
||||
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
|
||||
print(f" Active: {'Yes' if user.is_active else 'No'}")
|
||||
return True
|
||||
|
||||
except ValueError as e:
|
||||
print(f"❌ Error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def useradd() -> None:
|
||||
"""Entry point for ``invoke-useradd``."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Add a user to the InvokeAI database",
|
||||
epilog="If no arguments are provided, the script will run in interactive mode.",
|
||||
)
|
||||
parser.add_argument("--root", "-r", help=_root_help)
|
||||
parser.add_argument("--email", "-e", help="User email address")
|
||||
parser.add_argument("--password", "-p", help="User password")
|
||||
parser.add_argument("--name", "-n", help="User display name (optional)")
|
||||
parser.add_argument("--admin", "-a", action="store_true", help="Make user an administrator")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.root:
|
||||
os.environ["INVOKEAI_ROOT"] = args.root
|
||||
|
||||
if args.email or args.password:
|
||||
if not args.email or not args.password:
|
||||
print("❌ Error: Both --email and --password are required when using CLI mode")
|
||||
print(" Run without arguments for interactive mode")
|
||||
sys.exit(1)
|
||||
success = _add_user_cli(args.email, args.password, args.name, args.admin)
|
||||
else:
|
||||
success = _add_user_interactive()
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# userdel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _delete_user_interactive() -> bool:
|
||||
"""Delete a user interactively by prompting for email."""
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
print("=== Delete InvokeAI User ===\n")
|
||||
|
||||
email = input("Email address of user to delete: ").strip()
|
||||
if not email:
|
||||
print("Error: Email is required")
|
||||
return False
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
|
||||
user_service = UserService(db)
|
||||
|
||||
user = user_service.get_by_email(email)
|
||||
if not user:
|
||||
print(f"\n❌ Error: No user found with email '{email}'")
|
||||
return False
|
||||
|
||||
print("\nUser to delete:")
|
||||
print(f" User ID: {user.user_id}")
|
||||
print(f" Email: {user.email}")
|
||||
print(f" Display Name: {user.display_name or '(not set)'}")
|
||||
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
|
||||
print(f" Active: {'Yes' if user.is_active else 'No'}")
|
||||
|
||||
confirm = input("\n⚠️ Are you sure you want to delete this user? (yes/no): ").strip().lower()
|
||||
if confirm not in ("yes", "y"):
|
||||
print("Deletion cancelled.")
|
||||
return False
|
||||
|
||||
user_service.delete(user.user_id)
|
||||
print("\n✅ User deleted successfully!")
|
||||
return True
|
||||
|
||||
except ValueError as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def _delete_user_cli(email: str, force: bool = False) -> bool:
|
||||
"""Delete a user via CLI arguments."""
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
|
||||
user_service = UserService(db)
|
||||
|
||||
user = user_service.get_by_email(email)
|
||||
if not user:
|
||||
print(f"❌ Error: No user found with email '{email}'")
|
||||
return False
|
||||
|
||||
if not force:
|
||||
print("User to delete:")
|
||||
print(f" User ID: {user.user_id}")
|
||||
print(f" Email: {user.email}")
|
||||
print(f" Display Name: {user.display_name or '(not set)'}")
|
||||
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
|
||||
print(f" Active: {'Yes' if user.is_active else 'No'}")
|
||||
|
||||
confirm = input("\n⚠️ Are you sure you want to delete this user? (yes/no): ").strip().lower()
|
||||
if confirm not in ("yes", "y"):
|
||||
print("Deletion cancelled.")
|
||||
return False
|
||||
|
||||
user_service.delete(user.user_id)
|
||||
print("✅ User deleted successfully!")
|
||||
return True
|
||||
|
||||
except ValueError as e:
|
||||
print(f"❌ Error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def userdel() -> None:
|
||||
"""Entry point for ``invoke-userdel``."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Delete a user from the InvokeAI database",
|
||||
epilog="If no arguments are provided, the script will run in interactive mode.",
|
||||
)
|
||||
parser.add_argument("--root", "-r", help=_root_help)
|
||||
parser.add_argument("--email", "-e", help="User email address")
|
||||
parser.add_argument("--force", "-f", action="store_true", help="Delete without confirmation prompt")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.root:
|
||||
os.environ["INVOKEAI_ROOT"] = args.root
|
||||
|
||||
if args.email:
|
||||
success = _delete_user_cli(args.email, args.force)
|
||||
else:
|
||||
success = _delete_user_interactive()
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# userlist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _list_users_table() -> bool:
|
||||
"""List all users in a formatted table."""
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config.db_path, logger)
|
||||
user_service = UserService(db)
|
||||
|
||||
try:
|
||||
users = user_service.list_users()
|
||||
|
||||
if not users:
|
||||
print("No users found in database.")
|
||||
return True
|
||||
|
||||
print("\n=== InvokeAI Users ===\n")
|
||||
print(f"{'User ID':<36} {'Email':<30} {'Display Name':<20} {'Admin':<8} {'Active':<8}")
|
||||
print("-" * 108)
|
||||
|
||||
for user in users:
|
||||
user_id = user.user_id
|
||||
email = user.email[:29] if len(user.email) > 29 else user.email
|
||||
raw_name = user.display_name or ""
|
||||
name = raw_name[:19] if len(raw_name) > 19 else raw_name
|
||||
is_admin = "Yes" if user.is_admin else "No"
|
||||
is_active = "Yes" if user.is_active else "No"
|
||||
print(f"{user_id:<36} {email:<30} {name:<20} {is_admin:<8} {is_active:<8}")
|
||||
|
||||
print(f"\nTotal users: {len(users)}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error listing users: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _list_users_json() -> bool:
|
||||
"""List all users in JSON format."""
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config.db_path, logger)
|
||||
user_service = UserService(db)
|
||||
|
||||
try:
|
||||
users = user_service.list_users()
|
||||
|
||||
users_data = [
|
||||
{
|
||||
"id": user.user_id,
|
||||
"email": user.email,
|
||||
"name": user.display_name,
|
||||
"is_admin": user.is_admin,
|
||||
"is_active": user.is_active,
|
||||
}
|
||||
for user in users
|
||||
]
|
||||
|
||||
print(json.dumps(users_data, indent=2))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f'{{"error": "{e}"}}', file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
def userlist() -> None:
|
||||
"""Entry point for ``invoke-userlist``."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="List users from the InvokeAI database",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
invoke-userlist
|
||||
invoke-userlist --json
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--root", "-r", help=_root_help)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Output users in JSON format instead of table",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.root:
|
||||
os.environ["INVOKEAI_ROOT"] = args.root
|
||||
|
||||
success = _list_users_json() if args.json else _list_users_table()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# usermod
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _modify_user_interactive() -> bool:
|
||||
"""Modify a user interactively by prompting for details."""
|
||||
from invokeai.app.services.auth.password_utils import validate_password_strength
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_common import UserUpdateRequest
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
print("=== Modify InvokeAI User ===\n")
|
||||
|
||||
email = input("Email address of user to modify: ").strip()
|
||||
if not email:
|
||||
print("Error: Email is required")
|
||||
return False
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
|
||||
user_service = UserService(db)
|
||||
|
||||
user = user_service.get_by_email(email)
|
||||
if not user:
|
||||
print(f"\n❌ Error: No user found with email '{email}'")
|
||||
return False
|
||||
|
||||
print("\nCurrent user details:")
|
||||
print(f" User ID: {user.user_id}")
|
||||
print(f" Email: {user.email}")
|
||||
print(f" Display Name: {user.display_name or '(not set)'}")
|
||||
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
|
||||
print(f" Active: {'Yes' if user.is_active else 'No'}")
|
||||
|
||||
print("\n--- What would you like to change? (leave blank to keep current value) ---\n")
|
||||
|
||||
new_name = input(f"New display name [{user.display_name or '(not set)'}]: ").strip()
|
||||
display_name = new_name if new_name else None
|
||||
|
||||
change_password = input("Change password? (y/N): ").strip().lower()
|
||||
password = None
|
||||
if change_password in ("y", "yes"):
|
||||
while True:
|
||||
password = getpass.getpass("New password: ")
|
||||
if not password:
|
||||
print("Keeping existing password.")
|
||||
password = None
|
||||
break
|
||||
|
||||
password_confirm = getpass.getpass("Confirm new password: ")
|
||||
|
||||
if password != password_confirm:
|
||||
print("Error: Passwords do not match. Please try again.\n")
|
||||
continue
|
||||
|
||||
is_valid, error_msg = validate_password_strength(password)
|
||||
if not is_valid:
|
||||
print(f"Error: {error_msg}\n")
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
change_admin = input("Change admin status? (y/N): ").strip().lower()
|
||||
is_admin = None
|
||||
if change_admin in ("y", "yes"):
|
||||
is_admin_input = (
|
||||
input(f"Make administrator? [current: {'Yes' if user.is_admin else 'No'}] (y/N): ").strip().lower()
|
||||
)
|
||||
is_admin = is_admin_input in ("y", "yes")
|
||||
|
||||
if display_name is None and password is None and is_admin is None:
|
||||
print("\nNo changes requested. User not modified.")
|
||||
return True
|
||||
|
||||
changes = UserUpdateRequest(display_name=display_name, password=password, is_admin=is_admin)
|
||||
updated_user = user_service.update(user.user_id, changes)
|
||||
|
||||
print("\n✅ User updated successfully!")
|
||||
print(f" User ID: {updated_user.user_id}")
|
||||
print(f" Email: {updated_user.email}")
|
||||
print(f" Display Name: {updated_user.display_name or '(not set)'}")
|
||||
print(f" Admin: {'Yes' if updated_user.is_admin else 'No'}")
|
||||
print(f" Active: {'Yes' if updated_user.is_active else 'No'}")
|
||||
return True
|
||||
|
||||
except ValueError as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def _modify_user_cli(
|
||||
email: str,
|
||||
display_name: str | None = None,
|
||||
password: str | None = None,
|
||||
is_admin: bool | None = None,
|
||||
) -> bool:
|
||||
"""Modify a user via CLI arguments."""
|
||||
from invokeai.app.services.auth.password_utils import validate_password_strength
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_common import UserUpdateRequest
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
if password is not None:
|
||||
is_valid, error_msg = validate_password_strength(password)
|
||||
if not is_valid:
|
||||
print(f"❌ Password validation failed: {error_msg}")
|
||||
return False
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
|
||||
user_service = UserService(db)
|
||||
|
||||
user = user_service.get_by_email(email)
|
||||
if not user:
|
||||
print(f"❌ Error: No user found with email '{email}'")
|
||||
return False
|
||||
|
||||
if display_name is None and password is None and is_admin is None:
|
||||
print("❌ Error: No changes specified. Use --name, --password, --admin, or --no-admin")
|
||||
return False
|
||||
|
||||
changes = UserUpdateRequest(display_name=display_name, password=password, is_admin=is_admin)
|
||||
updated_user = user_service.update(user.user_id, changes)
|
||||
|
||||
print("✅ User updated successfully!")
|
||||
print(f" User ID: {updated_user.user_id}")
|
||||
print(f" Email: {updated_user.email}")
|
||||
print(f" Display Name: {updated_user.display_name or '(not set)'}")
|
||||
print(f" Admin: {'Yes' if updated_user.is_admin else 'No'}")
|
||||
print(f" Active: {'Yes' if updated_user.is_active else 'No'}")
|
||||
return True
|
||||
|
||||
except ValueError as e:
|
||||
print(f"❌ Error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def usermod() -> None:
|
||||
"""Entry point for ``invoke-usermod``."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Modify a user in the InvokeAI database",
|
||||
epilog="If no arguments are provided, the script will run in interactive mode.",
|
||||
)
|
||||
parser.add_argument("--root", "-r", help=_root_help)
|
||||
parser.add_argument("--email", "-e", help="User email address")
|
||||
parser.add_argument("--name", "-n", help="New display name")
|
||||
parser.add_argument("--password", "-p", help="New password")
|
||||
|
||||
admin_group = parser.add_mutually_exclusive_group()
|
||||
admin_group.add_argument("--admin", "-a", action="store_true", help="Grant administrator privileges")
|
||||
admin_group.add_argument("--no-admin", dest="no_admin", action="store_true", help="Remove administrator privileges")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.root:
|
||||
os.environ["INVOKEAI_ROOT"] = args.root
|
||||
|
||||
is_admin = None
|
||||
if args.admin:
|
||||
is_admin = True
|
||||
elif args.no_admin:
|
||||
is_admin = False
|
||||
|
||||
if args.email:
|
||||
success = _modify_user_cli(args.email, args.name, args.password, is_admin)
|
||||
else:
|
||||
success = _modify_user_interactive()
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -239,6 +239,52 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
|
||||
if in_dim is not None:
|
||||
return in_dim in _FLUX2_VEC_IN_DIMS
|
||||
|
||||
# Kohya format: check transformer block dimensions (hidden_size and MLP ratio).
|
||||
# This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder).
|
||||
# Klein 9B has hidden_size=4096 (vs 3072 for FLUX.1 and Klein 4B).
|
||||
# Klein 4B has same hidden_size as FLUX.1 (3072) but different mlp_ratio (6 vs 4).
|
||||
kohya_hidden_size: int | None = None
|
||||
for key in state_dict:
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if not key.startswith("lora_unet_"):
|
||||
continue
|
||||
|
||||
# Check img_attn_proj hidden_size
|
||||
if "_img_attn_proj." in key and key.endswith("lora_down.weight"):
|
||||
kohya_hidden_size = state_dict[key].shape[1]
|
||||
if kohya_hidden_size != _FLUX1_HIDDEN_SIZE:
|
||||
return True
|
||||
break
|
||||
# LoKR variant
|
||||
elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
|
||||
layer_prefix = key.rsplit(".", 1)[0]
|
||||
in_dim = _lokr_in_dim(state_dict, layer_prefix)
|
||||
if in_dim is not None:
|
||||
if in_dim != _FLUX1_HIDDEN_SIZE:
|
||||
return True
|
||||
kohya_hidden_size = in_dim
|
||||
break
|
||||
|
||||
# Kohya format: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
|
||||
# Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288).
|
||||
if kohya_hidden_size == _FLUX1_HIDDEN_SIZE:
|
||||
for key in state_dict:
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith("lora_up.weight"):
|
||||
ffn_dim = state_dict[key].shape[0]
|
||||
if ffn_dim != kohya_hidden_size * _FLUX1_MLP_RATIO:
|
||||
return True
|
||||
break
|
||||
# LoKR variant
|
||||
if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith((".lokr_w1", ".lokr_w1_a")):
|
||||
layer_prefix = key.rsplit(".", 1)[0]
|
||||
out_dim = _lokr_out_dim(state_dict, layer_prefix)
|
||||
if out_dim is not None and out_dim != kohya_hidden_size * _FLUX1_MLP_RATIO:
|
||||
return True
|
||||
break
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -421,6 +467,33 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
|
||||
return Flux2VariantType.Klein9B
|
||||
return None
|
||||
|
||||
# Kohya format: check transformer block dimensions (hidden_size from img_attn_proj).
|
||||
# This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder).
|
||||
for key in state_dict:
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if not key.startswith("lora_unet_"):
|
||||
continue
|
||||
|
||||
# Check img_attn_proj hidden_size
|
||||
if "_img_attn_proj." in key and key.endswith("lora_down.weight"):
|
||||
dim = state_dict[key].shape[1]
|
||||
if dim == KLEIN_4B_HIDDEN_SIZE:
|
||||
return Flux2VariantType.Klein4B
|
||||
if dim == KLEIN_9B_HIDDEN_SIZE:
|
||||
return Flux2VariantType.Klein9B
|
||||
return None
|
||||
# LoKR variant
|
||||
elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
|
||||
layer_prefix = key.rsplit(".", 1)[0]
|
||||
in_dim = _lokr_in_dim(state_dict, layer_prefix)
|
||||
if in_dim is not None:
|
||||
if in_dim == KLEIN_4B_HIDDEN_SIZE:
|
||||
return Flux2VariantType.Klein4B
|
||||
if in_dim == KLEIN_9B_HIDDEN_SIZE:
|
||||
return Flux2VariantType.Klein9B
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -253,16 +253,35 @@ class ZImageCheckpointModel(ModelLoader):
|
||||
target_device = TorchDevice.choose_torch_device()
|
||||
model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
|
||||
|
||||
# Filter out keys that don't belong to the ZImageTransformer2DModel.
|
||||
# Merged checkpoints (e.g. LoRA-baked models) may bundle text encoder weights
|
||||
# (text_encoders.*) or other non-transformer keys alongside the transformer weights.
|
||||
# Also filter FP8 quantization metadata (scale_weight, scaled_fp8).
|
||||
valid_prefixes = (
|
||||
"all_x_embedder.",
|
||||
"all_final_layer.",
|
||||
"layers.",
|
||||
"noise_refiner.",
|
||||
"context_refiner.",
|
||||
"t_embedder.",
|
||||
"cap_embedder.",
|
||||
"rope_embedder.",
|
||||
)
|
||||
valid_exact = {"x_pad_token", "cap_pad_token"}
|
||||
keys_to_remove = [
|
||||
k
|
||||
for k in sd.keys()
|
||||
if not (k.startswith(valid_prefixes) or k in valid_exact)
|
||||
or k.endswith(".scale_weight")
|
||||
or k == "scaled_fp8"
|
||||
]
|
||||
for k in keys_to_remove:
|
||||
del sd[k]
|
||||
|
||||
# Handle memory management and dtype conversion
|
||||
new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
|
||||
self._ram_cache.make_room(new_sd_size)
|
||||
|
||||
# Filter out FP8 scale_weight and scaled_fp8 metadata keys
|
||||
# These are quantization metadata that shouldn't be loaded into the model
|
||||
keys_to_remove = [k for k in sd.keys() if k.endswith(".scale_weight") or k == "scaled_fp8"]
|
||||
for k in keys_to_remove:
|
||||
del sd[k]
|
||||
|
||||
# Convert to target dtype
|
||||
for k in sd.keys():
|
||||
sd[k] = sd[k].to(model_dtype)
|
||||
|
||||
@@ -821,7 +821,7 @@ z_image_turbo = StarterModel(
|
||||
name="Z-Image Turbo",
|
||||
base=BaseModelType.ZImage,
|
||||
source="Tongyi-MAI/Z-Image-Turbo",
|
||||
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~13GB",
|
||||
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~30.6GB",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,9 @@ const config: KnipConfig = {
|
||||
// Will be using this
|
||||
'src/common/hooks/useAsyncState.ts',
|
||||
'src/app/store/use-debounced-app-selector.ts',
|
||||
// Auth features - exports will be used in follow-up phases
|
||||
'src/features/auth/**',
|
||||
'src/services/api/endpoints/auth.ts',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
ignoreDependencies: ['magic-string'],
|
||||
|
||||
@@ -53481,6 +53481,36 @@
|
||||
}
|
||||
],
|
||||
"description": "The workflow associated with this queue item"
|
||||
},
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"title": "User Id",
|
||||
"description": "The id of the user who created this queue item",
|
||||
"default": "system"
|
||||
},
|
||||
"user_display_name": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "User Display Name",
|
||||
"description": "The display name of the user who created this queue item, if available"
|
||||
},
|
||||
"user_email": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "User Email",
|
||||
"description": "The email of the user who created this queue item, if available"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -53571,6 +53601,30 @@
|
||||
"type": "integer",
|
||||
"title": "Total",
|
||||
"description": "Total number of queue items"
|
||||
},
|
||||
"user_pending": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "User Pending",
|
||||
"description": "Number of queue items with status 'pending' for the current user"
|
||||
},
|
||||
"user_in_progress": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "User In Progress",
|
||||
"description": "Number of queue items with status 'in_progress' for the current user"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -60983,6 +61037,45 @@
|
||||
"output": {
|
||||
"$ref": "#/components/schemas/ZImageConditioningOutput"
|
||||
}
|
||||
},
|
||||
"UserDTO": {
|
||||
"type": "object",
|
||||
"required": ["user_id", "email", "is_admin", "is_active"],
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"title": "User Id",
|
||||
"description": "The user ID"
|
||||
},
|
||||
"email": {
|
||||
"type": "string",
|
||||
"title": "Email",
|
||||
"description": "The user email"
|
||||
},
|
||||
"display_name": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Display Name",
|
||||
"description": "The user display name"
|
||||
},
|
||||
"is_admin": {
|
||||
"type": "boolean",
|
||||
"title": "Is Admin",
|
||||
"description": "Whether the user is an admin"
|
||||
},
|
||||
"is_active": {
|
||||
"type": "boolean",
|
||||
"title": "Is Active",
|
||||
"description": "Whether the user is active"
|
||||
}
|
||||
},
|
||||
"title": "UserDTO"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,6 +89,7 @@
|
||||
"react-icons": "^5.5.0",
|
||||
"react-redux": "9.2.0",
|
||||
"react-resizable-panels": "^3.0.3",
|
||||
"react-router-dom": "^7.12.0",
|
||||
"react-textarea-autosize": "^8.5.9",
|
||||
"react-use": "^17.6.0",
|
||||
"react-virtuoso": "^4.13.0",
|
||||
|
||||
45
invokeai/frontend/web/pnpm-lock.yaml
generated
45
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -158,6 +158,9 @@ importers:
|
||||
react-resizable-panels:
|
||||
specifier: ^3.0.3
|
||||
version: 3.0.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
react-router-dom:
|
||||
specifier: ^7.12.0
|
||||
version: 7.12.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
react-textarea-autosize:
|
||||
specifier: ^8.5.9
|
||||
version: 8.5.9(@types/react@18.3.23)(react@18.3.1)
|
||||
@@ -1993,6 +1996,10 @@ packages:
|
||||
convert-source-map@2.0.0:
|
||||
resolution: {integrity: sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==}
|
||||
|
||||
cookie@1.1.1:
|
||||
resolution: {integrity: sha512-ei8Aos7ja0weRpFzJnEA9UHJ/7XQmqglbRwnf2ATjcB9Wq874VKH9kfjjirM6UhU2/E5fFYadylyhFldcqSidQ==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
copy-to-clipboard@3.3.3:
|
||||
resolution: {integrity: sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==}
|
||||
|
||||
@@ -3459,6 +3466,23 @@ packages:
|
||||
react: ^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc
|
||||
react-dom: ^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc
|
||||
|
||||
react-router-dom@7.12.0:
|
||||
resolution: {integrity: sha512-pfO9fiBcpEfX4Tx+iTYKDtPbrSLLCbwJ5EqP+SPYQu1VYCXdy79GSj0wttR0U4cikVdlImZuEZ/9ZNCgoaxwBA==}
|
||||
engines: {node: '>=20.0.0'}
|
||||
peerDependencies:
|
||||
react: '>=18'
|
||||
react-dom: '>=18'
|
||||
|
||||
react-router@7.12.0:
|
||||
resolution: {integrity: sha512-kTPDYPFzDVGIIGNLS5VJykK0HfHLY5MF3b+xj0/tTyNYL1gF1qs7u67Z9jEhQk2sQ98SUaHxlG31g1JtF7IfVw==}
|
||||
engines: {node: '>=20.0.0'}
|
||||
peerDependencies:
|
||||
react: '>=18'
|
||||
react-dom: '>=18'
|
||||
peerDependenciesMeta:
|
||||
react-dom:
|
||||
optional: true
|
||||
|
||||
react-select@5.10.2:
|
||||
resolution: {integrity: sha512-Z33nHdEFWq9tfnfVXaiM12rbJmk+QjFEztWLtmXqQhz6Al4UZZ9xc0wiatmGtUOCCnHN0WizL3tCMYRENX4rVQ==}
|
||||
peerDependencies:
|
||||
@@ -3675,6 +3699,9 @@ packages:
|
||||
resolution: {integrity: sha512-ZYkZLAvKTKQXWuh5XpBw7CdbSzagarX39WyZ2H07CDLC5/KfsRGlIXV8d4+tfqX1M7916mRqR1QfNHSij+c9Pw==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
set-cookie-parser@2.7.2:
|
||||
resolution: {integrity: sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==}
|
||||
|
||||
set-function-length@1.2.2:
|
||||
resolution: {integrity: sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==}
|
||||
engines: {node: '>= 0.4'}
|
||||
@@ -6120,6 +6147,8 @@ snapshots:
|
||||
|
||||
convert-source-map@2.0.0: {}
|
||||
|
||||
cookie@1.1.1: {}
|
||||
|
||||
copy-to-clipboard@3.3.3:
|
||||
dependencies:
|
||||
toggle-selection: 1.0.6
|
||||
@@ -7707,6 +7736,20 @@ snapshots:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
react-router-dom@7.12.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
react-router: 7.12.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
|
||||
react-router@7.12.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
cookie: 1.1.1
|
||||
react: 18.3.1
|
||||
set-cookie-parser: 2.7.2
|
||||
optionalDependencies:
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
react-select@5.10.2(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
'@babel/runtime': 7.28.3
|
||||
@@ -7982,6 +8025,8 @@ snapshots:
|
||||
dependencies:
|
||||
type-fest: 4.41.0
|
||||
|
||||
set-cookie-parser@2.7.2: {}
|
||||
|
||||
set-function-length@1.2.2:
|
||||
dependencies:
|
||||
define-data-property: 1.1.4
|
||||
|
||||
@@ -15,6 +15,95 @@
|
||||
"uploadImage": "Upload Image",
|
||||
"uploadImages": "Upload Image(s)"
|
||||
},
|
||||
"auth": {
|
||||
"login": {
|
||||
"title": "Sign In to InvokeAI",
|
||||
"email": "Email",
|
||||
"emailPlaceholder": "Email",
|
||||
"password": "Password",
|
||||
"passwordPlaceholder": "Password",
|
||||
"rememberMe": "Remember me for 7 days",
|
||||
"signIn": "Sign In",
|
||||
"signingIn": "Signing in...",
|
||||
"loginFailed": "Login failed. Please check your credentials."
|
||||
},
|
||||
"setup": {
|
||||
"title": "Welcome to InvokeAI",
|
||||
"subtitle": "Set up your administrator account to get started",
|
||||
"email": "Email",
|
||||
"emailPlaceholder": "admin@example.com",
|
||||
"emailHelper": "This will be your username for signing in",
|
||||
"displayName": "Display Name",
|
||||
"displayNamePlaceholder": "Administrator",
|
||||
"displayNameHelper": "Your name as it will appear in the application",
|
||||
"password": "Password",
|
||||
"passwordPlaceholder": "Password",
|
||||
"passwordHelper": "Must be at least 8 characters with uppercase, lowercase, and numbers",
|
||||
"passwordTooShort": "Password must be at least 8 characters long",
|
||||
"passwordMissingRequirements": "Password must contain uppercase, lowercase, and numbers",
|
||||
"confirmPassword": "Confirm Password",
|
||||
"confirmPasswordPlaceholder": "Confirm Password",
|
||||
"passwordsDoNotMatch": "Passwords do not match",
|
||||
"createAccount": "Create Administrator Account",
|
||||
"creatingAccount": "Setting up...",
|
||||
"setupFailed": "Setup failed. Please try again."
|
||||
},
|
||||
"userMenu": "User Menu",
|
||||
"admin": "Admin",
|
||||
"logout": "Logout",
|
||||
"adminOnlyFeature": "This feature is only available to administrators.",
|
||||
"profile": {
|
||||
"menuItem": "My Profile",
|
||||
"title": "My Profile",
|
||||
"email": "Email",
|
||||
"emailReadOnly": "Email address cannot be changed",
|
||||
"displayName": "Display Name",
|
||||
"displayNamePlaceholder": "Your name",
|
||||
"changePassword": "Change Password",
|
||||
"currentPassword": "Current Password",
|
||||
"currentPasswordPlaceholder": "Current password",
|
||||
"newPassword": "New Password",
|
||||
"newPasswordPlaceholder": "New password",
|
||||
"confirmPassword": "Confirm New Password",
|
||||
"confirmPasswordPlaceholder": "Confirm new password",
|
||||
"passwordsDoNotMatch": "Passwords do not match",
|
||||
"saveSuccess": "Profile updated successfully",
|
||||
"saveFailed": "Failed to save profile. Please try again."
|
||||
},
|
||||
"userManagement": {
|
||||
"menuItem": "User Management",
|
||||
"title": "User Management",
|
||||
"email": "Email",
|
||||
"emailPlaceholder": "user@example.com",
|
||||
"displayName": "Display Name",
|
||||
"displayNamePlaceholder": "Display name",
|
||||
"password": "Password",
|
||||
"passwordPlaceholder": "Password",
|
||||
"newPassword": "New Password",
|
||||
"newPasswordPlaceholder": "Leave blank to keep current password",
|
||||
"role": "Role",
|
||||
"status": "Status",
|
||||
"actions": "Actions",
|
||||
"isAdmin": "Administrator",
|
||||
"user": "User",
|
||||
"you": "You",
|
||||
"createUser": "Create User",
|
||||
"editUser": "Edit User",
|
||||
"deleteUser": "Delete User",
|
||||
"deleteConfirm": "Are you sure you want to delete \"{{name}}\"? This action cannot be undone.",
|
||||
"generatePassword": "Generate Strong Password",
|
||||
"showPassword": "Show password",
|
||||
"hidePassword": "Hide password",
|
||||
"activate": "Activate",
|
||||
"deactivate": "Deactivate",
|
||||
"saveFailed": "Failed to save user. Please try again.",
|
||||
"deleteFailed": "Failed to delete user. Please try again.",
|
||||
"loadFailed": "Failed to load users.",
|
||||
"back": "Back",
|
||||
"cannotDeleteSelf": "You cannot delete your own account",
|
||||
"cannotDeactivateSelf": "You cannot deactivate your own account"
|
||||
}
|
||||
},
|
||||
"boards": {
|
||||
"addBoard": "Add Board",
|
||||
"addPrivateBoard": "Add Private Board",
|
||||
@@ -272,6 +361,7 @@
|
||||
"cancelTooltip": "Cancel Current Item",
|
||||
"cancelSucceeded": "Item Canceled",
|
||||
"cancelFailed": "Problem Canceling Item",
|
||||
"cancelFailedAccessDenied": "Problem Canceling Item: Access Denied",
|
||||
"retrySucceeded": "Item Retried",
|
||||
"retryFailed": "Problem Retrying Item",
|
||||
"confirm": "Confirm",
|
||||
@@ -283,6 +373,7 @@
|
||||
"clearTooltip": "Cancel and Clear All Items",
|
||||
"clearSucceeded": "Queue Cleared",
|
||||
"clearFailed": "Problem Clearing Queue",
|
||||
"clearFailedAccessDenied": "Problem Clearing Queue: Access Denied",
|
||||
"cancelBatch": "Cancel Batch",
|
||||
"cancelItem": "Cancel Item",
|
||||
"retryItem": "Retry Item",
|
||||
@@ -304,6 +395,7 @@
|
||||
"canceled": "Canceled",
|
||||
"completedIn": "Completed in",
|
||||
"batch": "Batch",
|
||||
"user": "User",
|
||||
"origin": "Origin",
|
||||
"destination": "Dest",
|
||||
"upscaling": "Upscaling",
|
||||
@@ -313,6 +405,8 @@
|
||||
"other": "Other",
|
||||
"gallery": "Gallery",
|
||||
"batchFieldValues": "Batch Field Values",
|
||||
"fieldValuesHidden": "<Hidden>",
|
||||
"cannotViewDetails": "You do not have permission to view the details of this queue item",
|
||||
"item": "Item",
|
||||
"session": "Session",
|
||||
"notReady": "Unable to Queue",
|
||||
@@ -1094,6 +1188,7 @@
|
||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
||||
"queueEmpty": "The install queue is empty.",
|
||||
"selectAll": "Select All",
|
||||
"selectModelToView": "Select a model to view its details",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"t5Encoder": "T5 Encoder",
|
||||
"qwen3Encoder": "Qwen3 Encoder",
|
||||
@@ -1121,7 +1216,15 @@
|
||||
"installingXModels_other": "Installing {{count}} models",
|
||||
"skippingXDuplicates_one": ", skipping {{count}} duplicate",
|
||||
"skippingXDuplicates_other": ", skipping {{count}} duplicates",
|
||||
"manageModels": "Manage Models"
|
||||
"manageModels": "Manage Models",
|
||||
"exportSettings": "Export Settings",
|
||||
"importSettings": "Import Settings",
|
||||
"settingsExported": "Model settings exported",
|
||||
"settingsImported": "Model settings imported",
|
||||
"settingsImportedPartial": "Model settings partially imported. Incompatible settings were skipped: {{fields}}",
|
||||
"settingsImportFailed": "Failed to import model settings",
|
||||
"settingsImportIncompatible": "The settings file contains no compatible settings for this model type",
|
||||
"settingsImportInvalidFile": "Invalid settings file"
|
||||
},
|
||||
"models": {
|
||||
"addLora": "Add LoRA",
|
||||
@@ -1515,6 +1618,8 @@
|
||||
"general": "General",
|
||||
"generation": "Generation",
|
||||
"models": "Models",
|
||||
"preferAttentionStyleNumeric": "Prefer Numeric Attention Style",
|
||||
"prompt": "Prompt",
|
||||
"resetComplete": "Web UI has been reset.",
|
||||
"resetWebUI": "Reset Web UI",
|
||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||
@@ -2370,10 +2475,14 @@
|
||||
"text": {
|
||||
"font": "Font",
|
||||
"size": "Size",
|
||||
"lineHeight": "Spacing",
|
||||
"lineHeightDense": "Dense",
|
||||
"lineHeightNormal": "Normal",
|
||||
"lineHeightSpacious": "Spacious"
|
||||
"bold": "Bold",
|
||||
"italic": "Italic",
|
||||
"underline": "Underline",
|
||||
"strikethrough": "Strikethrough",
|
||||
"alignLeft": "Align Left",
|
||||
"alignCenter": "Align Center",
|
||||
"alignRight": "Align Right",
|
||||
"px": "px"
|
||||
},
|
||||
"newCanvasFromImage": "New Canvas from Image",
|
||||
"newImg2ImgCanvasFromImage": "New Img2Img from Image",
|
||||
@@ -2538,18 +2647,6 @@
|
||||
"colorPicker": "Color Picker",
|
||||
"text": "Text"
|
||||
},
|
||||
"text": {
|
||||
"font": "Font",
|
||||
"size": "Size",
|
||||
"bold": "Bold",
|
||||
"italic": "Italic",
|
||||
"underline": "Underline",
|
||||
"strikethrough": "Strikethrough",
|
||||
"alignLeft": "Align Left",
|
||||
"alignCenter": "Align Center",
|
||||
"alignRight": "Align Right",
|
||||
"px": "px"
|
||||
},
|
||||
"filter": {
|
||||
"filter": "Filter",
|
||||
"filters": "Filters",
|
||||
|
||||
@@ -822,7 +822,29 @@
|
||||
"orphanedModelsDeleted": "Eliminato con successo {{count}} modello orfano",
|
||||
"orphanedModelsDeleteErrors": "Alcuni modelli non possono essere eliminati",
|
||||
"orphanedModelsDeleteFailed": "Impossibile eliminare i modelli orfani",
|
||||
"errorLoadingOrphanedModels": "Errore durante il caricamento dei modelli orfani. Riprova."
|
||||
"errorLoadingOrphanedModels": "Errore durante il caricamento dei modelli orfani. Riprova.",
|
||||
"pause": "Pausa",
|
||||
"pauseAll": "Metti in pausa tutto",
|
||||
"pauseAllTooltip": "Metti in pausa tutti i download attivi",
|
||||
"resume": "Riprendi",
|
||||
"resumeAll": "Riprendi tutto",
|
||||
"resumeAllTooltip": "Riprendi tutti i download in pausa",
|
||||
"restartFailed": "Riavvio non riuscito",
|
||||
"restartFile": "Riavvia il file",
|
||||
"restartRequired": "Riavvio richiesto",
|
||||
"resumeRefused": "Ripristino rifiutato dal server. Riavvio richiesto.",
|
||||
"backendDisconnected": "Backend disconnesso",
|
||||
"cancelAll": "Annulla tutto",
|
||||
"cancelAllTooltip": "Annulla tutti i download attivi",
|
||||
"selectModelToView": "Seleziona un modello per visualizzarne i dettagli",
|
||||
"exportSettings": "Impostazioni di esportazione",
|
||||
"importSettings": "Impostazioni di importazione",
|
||||
"settingsExported": "Impostazioni del modello esportate",
|
||||
"settingsImported": "Impostazioni del modello importate",
|
||||
"settingsImportedPartial": "Impostazioni del modello parzialmente importate. Le impostazioni incompatibili sono state ignorate: {{fields}}",
|
||||
"settingsImportFailed": "Impossibile importare le impostazioni del modello",
|
||||
"settingsImportIncompatible": "Il file delle impostazioni non contiene impostazioni compatibili per questo tipo di modello",
|
||||
"settingsImportInvalidFile": "File di impostazioni non valido"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -993,7 +1015,8 @@
|
||||
"showDetailedInvocationProgress": "Mostra dettagli avanzamento",
|
||||
"enableHighlightFocusedRegions": "Evidenzia le regioni interessate",
|
||||
"modelDescriptionsDisabled": "Descrizioni dei modelli nei menu a discesa disabilitate",
|
||||
"modelDescriptionsDisabledDesc": "Le descrizioni dei modelli nei menu a discesa sono state disattivate. Abilitale nelle Impostazioni."
|
||||
"modelDescriptionsDisabledDesc": "Le descrizioni dei modelli nei menu a discesa sono state disattivate. Abilitale nelle Impostazioni.",
|
||||
"preferAttentionStyleNumeric": "Preferisci lo stile di attenzione numerico"
|
||||
},
|
||||
"toast": {
|
||||
"uploadFailed": "Caricamento fallito",
|
||||
@@ -1098,7 +1121,12 @@
|
||||
"kleinEncoderClearedDescription": "Selezionare un encoder Qwen3 compatibile per la nuova variante del modello Klein",
|
||||
"kleinEncoderCleared": "Encoder Qwen3 cancellato",
|
||||
"schedulerReset": "Ripristino campionatore",
|
||||
"schedulerResetZImageBase": "Il campionatore LCM non è compatibile con i modelli Z-Image Base. Reimpostare su Euler."
|
||||
"schedulerResetZImageBase": "Il campionatore LCM non è compatibile con i modelli Z-Image Base. Reimpostare su Euler.",
|
||||
"modelDownloadPaused": "Download del modello in pausa",
|
||||
"modelDownloadResumed": "Ripresa del download",
|
||||
"modelDownloadRestartFailed": "Riavvia i download non riusciti",
|
||||
"modelDownloadRestartFile": "Riavvio del download del file",
|
||||
"modelDownloadRestartedFromScratch": "Manca una parte del file. Riavviato il download dall'inizio."
|
||||
},
|
||||
"accessibility": {
|
||||
"invokeProgressBar": "Barra di avanzamento generazione",
|
||||
@@ -1357,7 +1385,13 @@
|
||||
"locateInGalery": "Trova nella Galleria",
|
||||
"deletedImagesCannotBeRestored": "Le immagini eliminate non possono essere ripristinate.",
|
||||
"hideBoards": "Nascondi bacheche",
|
||||
"viewBoards": "Visualizza le bacheche"
|
||||
"viewBoards": "Visualizza le bacheche",
|
||||
"pause": "Pausa",
|
||||
"resume": "Riprendi",
|
||||
"restartFailed": "Riavvio non riuscito",
|
||||
"restartFile": "Riavvia il file",
|
||||
"restartRequired": "Riavvio richiesto",
|
||||
"resumeRefused": "Ripristino rifiutato dal server. Riavvio richiesto."
|
||||
},
|
||||
"queue": {
|
||||
"queueFront": "Aggiungi all'inizio della coda",
|
||||
@@ -1449,7 +1483,13 @@
|
||||
"sortOrderDescending": "Discendente",
|
||||
"createdAt": "Creato",
|
||||
"completedAt": "Completato",
|
||||
"batchFieldValues": "Valori del campo Lotto"
|
||||
"batchFieldValues": "Valori del campo Lotto",
|
||||
"paused": "In pausa",
|
||||
"cancelFailedAccessDenied": "Problema durante l'annullamento dell'articolo: accesso negato",
|
||||
"clearFailedAccessDenied": "Problema durante la cancellazione della coda: accesso negato",
|
||||
"user": "Utente",
|
||||
"cannotViewDetails": "Non hai l'autorizzazione per visualizzare i dettagli di questo elemento della coda",
|
||||
"fieldValuesHidden": "<Nascosto>"
|
||||
},
|
||||
"models": {
|
||||
"noMatchingModels": "Nessun modello corrispondente",
|
||||
@@ -2557,7 +2597,8 @@
|
||||
"isEmpty": "{{title}} è vuoto",
|
||||
"isDisabled": "{{title}} è disabilitato"
|
||||
},
|
||||
"scaledBbox": "Riquadro scalato"
|
||||
"scaledBbox": "Riquadro scalato",
|
||||
"textSessionActive": "L'inserimento del testo è attivo"
|
||||
},
|
||||
"canvasContextMenu": {
|
||||
"newControlLayer": "Nuovo Livello di Controllo",
|
||||
@@ -3024,5 +3065,80 @@
|
||||
},
|
||||
"lora": {
|
||||
"weight": "Peso"
|
||||
},
|
||||
"auth": {
|
||||
"login": {
|
||||
"title": "Accedi a InvokeAI",
|
||||
"rememberMe": "Ricordami per 7 giorni",
|
||||
"signIn": "Accedi",
|
||||
"signingIn": "Accesso in corso...",
|
||||
"loginFailed": "Accesso non riuscito. Controlla le tue credenziali."
|
||||
},
|
||||
"setup": {
|
||||
"title": "Benvenuti a InvokeAI",
|
||||
"subtitle": "Configura il tuo account amministratore per iniziare",
|
||||
"emailHelper": "Questo sarà il tuo nome utente per accedere",
|
||||
"displayName": "Nome da visualizzare",
|
||||
"displayNamePlaceholder": "Amministratore",
|
||||
"displayNameHelper": "Il tuo nome come apparirà nell'applicazione",
|
||||
"passwordHelper": "Deve contenere almeno 8 caratteri, tra maiuscole, minuscole e numeri",
|
||||
"passwordTooShort": "La password deve essere lunga almeno 8 caratteri",
|
||||
"passwordMissingRequirements": "La password deve contenere maiuscole, minuscole e numeri",
|
||||
"confirmPassword": "Conferma password",
|
||||
"confirmPasswordPlaceholder": "Conferma password",
|
||||
"passwordsDoNotMatch": "Le password non corrispondono",
|
||||
"createAccount": "Crea un account amministratore",
|
||||
"creatingAccount": "Impostazione in corso...",
|
||||
"setupFailed": "Installazione non riuscita. Riprova."
|
||||
},
|
||||
"userMenu": "Menu utente",
|
||||
"logout": "Esci",
|
||||
"adminOnlyFeature": "Questa funzionalità è disponibile solo per gli amministratori.",
|
||||
"profile": {
|
||||
"menuItem": "Il mio profilo",
|
||||
"title": "Il mio profilo",
|
||||
"emailReadOnly": "L'indirizzo email non può essere modificato",
|
||||
"displayName": "Nome da visualizzare",
|
||||
"displayNamePlaceholder": "Il tuo nome",
|
||||
"changePassword": "Cambiare la password",
|
||||
"currentPassword": "Password attuale",
|
||||
"currentPasswordPlaceholder": "Password attuale",
|
||||
"newPassword": "Nuova password",
|
||||
"newPasswordPlaceholder": "Nuova password",
|
||||
"confirmPassword": "Conferma nuova password",
|
||||
"confirmPasswordPlaceholder": "Conferma nuova password",
|
||||
"passwordsDoNotMatch": "Le password non corrispondono",
|
||||
"saveSuccess": "Profilo aggiornato con successo",
|
||||
"saveFailed": "Impossibile salvare il profilo. Riprova."
|
||||
},
|
||||
"userManagement": {
|
||||
"menuItem": "Gestione utenti",
|
||||
"title": "Gestione utenti",
|
||||
"displayName": "Nome da visualizzare",
|
||||
"displayNamePlaceholder": "Nome da visualizzare",
|
||||
"newPassword": "Nuova password",
|
||||
"newPasswordPlaceholder": "Lasciare vuoto per mantenere la password corrente",
|
||||
"role": "Ruolo",
|
||||
"status": "Stato",
|
||||
"actions": "Azioni",
|
||||
"isAdmin": "Amministratore",
|
||||
"user": "Utente",
|
||||
"you": "Tu",
|
||||
"createUser": "Crea utente",
|
||||
"editUser": "Modifica utente",
|
||||
"deleteUser": "Elimina utente",
|
||||
"deleteConfirm": "Vuoi davvero eliminare \"{{name}}\"? Questa azione non può essere annullata.",
|
||||
"generatePassword": "Genera una password complessa",
|
||||
"showPassword": "Mostra password",
|
||||
"hidePassword": "Nascondi password",
|
||||
"activate": "Attiva",
|
||||
"deactivate": "Disattiva",
|
||||
"saveFailed": "Impossibile salvare l'utente. Riprova.",
|
||||
"deleteFailed": "Impossibile eliminare l'utente. Riprova.",
|
||||
"loadFailed": "Impossibile caricare gli utenti.",
|
||||
"back": "Indietro",
|
||||
"cannotDeleteSelf": "Non puoi eliminare il tuo account",
|
||||
"cannotDeactivateSelf": "Non puoi disattivare il tuo account"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +89,16 @@
|
||||
"none": "Ничего",
|
||||
"new": "Новый",
|
||||
"ok": "Ok",
|
||||
"close": "Закрыть"
|
||||
"close": "Закрыть",
|
||||
"error_withCount_one": "{{count}} Ошибка",
|
||||
"error_withCount_few": "{{count}} Ошибки",
|
||||
"error_withCount_many": "{{count}} Ошибок",
|
||||
"model_withCount_one": "{{count}} Модель",
|
||||
"model_withCount_few": "{{count}} Модели",
|
||||
"model_withCount_many": "{{count}} Моделей",
|
||||
"options_withCount_one": "{{count}} Опция",
|
||||
"options_withCount_few": "{{count}} Опции",
|
||||
"options_withCount_many": "{{count}} Опций"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Размер изображений",
|
||||
@@ -108,7 +117,7 @@
|
||||
"downloadSelection": "Скачать выделенное",
|
||||
"currentlyInUse": "В настоящее время это изображение используется в следующих функциях:",
|
||||
"unstarImage": "Удалить из избранного",
|
||||
"dropOrUpload": "$t(gallery.drop) или загрузить",
|
||||
"dropOrUpload": "Перетащите или загрузите",
|
||||
"copy": "Копировать",
|
||||
"download": "Скачать",
|
||||
"noImageSelected": "Изображение не выбрано",
|
||||
@@ -239,7 +248,7 @@
|
||||
},
|
||||
"filterSelected": {
|
||||
"title": "Filter",
|
||||
"desc": "Filter the selected layer. Only applies to Raster and Control layers."
|
||||
"desc": "Применяет фильтр к выбранному слою. Применимо только к растровым слоям и слоям управления."
|
||||
},
|
||||
"undo": {
|
||||
"desc": "Отменяет последнее действие на холсте.",
|
||||
@@ -483,7 +492,7 @@
|
||||
"deleteMsg1": "Вы точно хотите удалить модель из InvokeAI?",
|
||||
"deleteMsg2": "Это приведет К УДАЛЕНИЮ модели С ДИСКА, если она находится в корневой папке Invoke. Если вы используете пользовательское расположение, то модель НЕ будет удалена с диска.",
|
||||
"convertToDiffusersHelpText5": "Пожалуйста, убедитесь, что у вас достаточно места на диске. Модели обычно занимают 2–7 Гб.",
|
||||
"convertToDiffusersHelpText3": "Ваш файл контрольной точки НА ДИСКЕ будет УДАЛЕН, если он находится в корневой папке InvokeAI. Если он находится в пользовательском расположении, то он НЕ будет удален.",
|
||||
"convertToDiffusersHelpText3": "Файл чекпоинта будет удалён с диска, если он находится в корневой папке InvokeAI. Если файл расположен в пользовательской папке, он удалён не будет.",
|
||||
"allModels": "Все модели",
|
||||
"repo_id": "ID репозитория",
|
||||
"convert": "Преобразовать",
|
||||
@@ -541,7 +550,7 @@
|
||||
"pathToConfig": "Путь к конфигурации",
|
||||
"loraTriggerPhrases": "Триггерные фразы LoRA",
|
||||
"mainModelTriggerPhrases": "Триггерные фразы основной модели",
|
||||
"inplaceInstallDesc": "Устанавливайте модели без копирования файлов. При использовании модели она будет загружаться из этого места. Если этот параметр отключен, файлы модели будут скопированы в каталог моделей, управляемых Invoke, во время установки.",
|
||||
"inplaceInstallDesc": "Устанавливать модели без перемещения файлов. В этом случае модель будет загружаться из исходной папки. Если опция отключена, файлы модели при установке будут перемещены в каталог моделей Invoke.",
|
||||
"huggingFaceRepoID": "ID репозитория HuggingFace",
|
||||
"installQueue": "Очередь установки",
|
||||
"installAll": "Установить все",
|
||||
@@ -575,8 +584,8 @@
|
||||
"skippingXDuplicates_one": ", пропуская {{count}} дубликат",
|
||||
"skippingXDuplicates_few": ", пропуская {{count}} дубликата",
|
||||
"skippingXDuplicates_many": ", пропуская {{count}} дубликатов",
|
||||
"includesNModels": "Включает в себя {{n}} моделей и их зависимостей",
|
||||
"starterBundleHelpText": "Легко установите все модели, необходимые для начала работы с базовой моделью, включая основную модель, сети управления, IP-адаптеры и многое другое. При выборе комплекта все уже установленные модели будут пропущены."
|
||||
"includesNModels": "Включает в себя {{n}} моделей и их зависимостей.",
|
||||
"starterBundleHelpText": "Легко установите все модели, необходимые для начала работы с базовой моделью, включая основную модель, ControlNet, IP-адаптеры и другие. При выборе набора уже установленные модели будут пропущены."
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Изображения",
|
||||
@@ -632,8 +641,8 @@
|
||||
"canvasIsFiltering": "Холст фильтруется",
|
||||
"canvasIsTransforming": "Холст трансформируется",
|
||||
"noCLIPEmbedModelSelected": "Для генерации FLUX не выбрана модель CLIP Embed",
|
||||
"canvasIsRasterizing": "Холст растрируется",
|
||||
"canvasIsCompositing": "Холст составляется"
|
||||
"canvasIsRasterizing": "Холст занят (идёт растеризация)",
|
||||
"canvasIsCompositing": "Холст занят (идёт компоновка)"
|
||||
},
|
||||
"cfgRescaleMultiplier": "Множитель масштабирования CFG",
|
||||
"patchmatchDownScaleSize": "уменьшить",
|
||||
@@ -660,7 +669,10 @@
|
||||
"optimizedImageToImage": "Оптимизированное img2img",
|
||||
"sendToCanvas": "Отправить на холст",
|
||||
"guidance": "Точность",
|
||||
"boxBlur": "Box Blur"
|
||||
"boxBlur": "Box Blur",
|
||||
"images_withCount_one": "Изображение",
|
||||
"images_withCount_few": "Изображения",
|
||||
"images_withCount_many": "Изображений"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Модели",
|
||||
@@ -690,7 +702,7 @@
|
||||
"intermediatesCleared_one": "Очищено {{count}} промежуточное",
|
||||
"intermediatesCleared_few": "Очищено {{count}} промежуточных",
|
||||
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
|
||||
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
|
||||
"clearIntermediatesDesc1": "Очистка промежуточных данных приведёт к сбросу состояния холста и ControlNet.",
|
||||
"intermediatesClearedFailed": "Проблема очистки промежуточных",
|
||||
"reloadingIn": "Перезагрузка через",
|
||||
"informationalPopoversDisabled": "Информационные всплывающие окна отключены",
|
||||
@@ -704,7 +716,7 @@
|
||||
"serverError": "Ошибка сервера",
|
||||
"connected": "Подключено к серверу",
|
||||
"canceled": "Обработка отменена",
|
||||
"uploadFailedInvalidUploadDesc": "Это должны быть изображения PNG или JPEG.",
|
||||
"uploadFailedInvalidUploadDesc": "Допускаются только изображения в формате PNG, JPEG или WEBP.",
|
||||
"parameterNotSet": "Параметр не задан",
|
||||
"parameterSet": "Параметр задан",
|
||||
"problemCopyingImage": "Не удается скопировать изображение",
|
||||
@@ -747,7 +759,12 @@
|
||||
"sentToUpscale": "Отправить на увеличение",
|
||||
"linkCopied": "Ссылка скопирована",
|
||||
"addedToUncategorized": "Добавлено в активы доски $t(boards.uncategorized)",
|
||||
"imagesWillBeAddedTo": "Загруженные изображения будут добавлены в активы доски {{boardName}}."
|
||||
"imagesWillBeAddedTo": "Загруженные изображения будут добавлены в активы доски {{boardName}}.",
|
||||
"schedulerResetZImageBase": "Планировщик LCM несовместим с моделями Z-Image Base. Переключено на Euler.",
|
||||
"schedulerReset": "Планировщик сброшен",
|
||||
"uploadFailedInvalidUploadDesc_withCount_one": "Допускается не более 1 изображения в формате PNG, JPEG или WEBP.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_few": "Допускается не более {{count}} изображения в формате PNG, JPEG или WEBP.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_many": "Допускается не более {{count}} изображений в формате PNG, JPEG или WEBP."
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Загрузить изображение",
|
||||
@@ -892,7 +909,13 @@
|
||||
"saveToGallery": "Сохранить в галерею",
|
||||
"noWorkflows": "Нет рабочих процессов",
|
||||
"noMatchingWorkflows": "Нет совпадающих рабочих процессов",
|
||||
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>."
|
||||
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>.",
|
||||
"generatorImages_one": "{{count}} изображение",
|
||||
"generatorImages_few": "{{count}} изображения",
|
||||
"generatorImages_many": "{{count}} изображений",
|
||||
"generatorNRandomValues_one": "{{count}} случайное значение",
|
||||
"generatorNRandomValues_few": "{{count}} случайных значения",
|
||||
"generatorNRandomValues_many": "{{count}} случайных значений"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Коллекция для автодобавления",
|
||||
@@ -935,7 +958,19 @@
|
||||
"shared": "Общие коллекции",
|
||||
"noBoards": "Нет коллекций {{boardType}}",
|
||||
"deletedPrivateBoardsCannotbeRestored": "Удалённые коллекции и изображения нельзя восстановить. При выборе «Удалить только коллекцию» изображения будут перемещены в личный раздел «Без категории» автора изображения.",
|
||||
"updateBoardError": "Ошибка обновления коллекции"
|
||||
"updateBoardError": "Ошибка обновления коллекции",
|
||||
"pause": "Пауза",
|
||||
"resume": "Возобновить",
|
||||
"restartFailed": "Ошибка перезапуска",
|
||||
"restartFile": "Перезапустить загрузку",
|
||||
"restartRequired": "Требуется перезапуск",
|
||||
"resumeRefused": "Сервер отклонил попытку возобновления. Требуется перезапуск.",
|
||||
"uncategorizedImages": "Без категории",
|
||||
"deleteAllUncategorizedImages": "Удалить все изображения без категории",
|
||||
"deletedImagesCannotBeRestored": "Удалённые изображения нельзя восстановить.",
|
||||
"hideBoards": "Скрыть коллекции",
|
||||
"locateInGalery": "Показать в галерее",
|
||||
"viewBoards": "Просмотреть коллекции"
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"seedBehaviour": {
|
||||
@@ -1031,19 +1066,19 @@
|
||||
"controlNetResizeMode": {
|
||||
"heading": "Режим изменения размера",
|
||||
"paragraphs": [
|
||||
"Метод подгонки размера входного изображения Control Adaptor к размеру выходного изображения."
|
||||
"Метод подгонки размера входного изображения Control Adapter под размер выходного изображения."
|
||||
]
|
||||
},
|
||||
"controlNetBeginEnd": {
|
||||
"paragraphs": [
|
||||
"Часть процесса шумоподавления, к которой будет применен адаптер контроля.",
|
||||
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
|
||||
"Эта настройка определяет, на каком этапе денойзинга (генерации) используется влияние данного слоя.",
|
||||
"• Начальный шаг (%): Определяет, с какого момента в процессе генерации начинает учитываться влияние данного слоя."
|
||||
],
|
||||
"heading": "Процент начала/конца шага"
|
||||
},
|
||||
"dynamicPromptsSeedBehaviour": {
|
||||
"paragraphs": [
|
||||
"Управляет использованием сида при создании запросов.",
|
||||
"Определяет, как используется сид при генерации промптов.",
|
||||
"Для каждой итерации будет использоваться уникальный сид. Используйте это, чтобы изучить варианты запросов для одного сида.",
|
||||
"Например, если у вас 5 запросов, каждое изображение будет использовать один и то же сид.",
|
||||
"для каждого изображения будет использоваться уникальный сид. Это обеспечивает большую вариативность."
|
||||
@@ -1071,8 +1106,8 @@
|
||||
},
|
||||
"paramDenoisingStrength": {
|
||||
"paragraphs": [
|
||||
"Количество шума, добавляемого к входному изображению.",
|
||||
"0 приведет к идентичному изображению, а 1 - к совершенно новому."
|
||||
"Определяет, насколько сгенерированное изображение отличается от растрового слоя (слоёв).",
|
||||
"Меньшее значение сохраняет больше сходства с объединёнными видимыми растровыми слоями. Большее значение усиливает влияние глобального промпта."
|
||||
],
|
||||
"heading": "Шумоподавление"
|
||||
},
|
||||
@@ -1111,7 +1146,7 @@
|
||||
"controlNetWeight": {
|
||||
"heading": "Вес",
|
||||
"paragraphs": [
|
||||
"Вес адаптера управления. Более высокий вес приведет к большему воздействию на окончательное изображение."
|
||||
"Определяет, насколько сильно слой влияет на процесс генерации."
|
||||
]
|
||||
},
|
||||
"controlNet": {
|
||||
@@ -1123,13 +1158,13 @@
|
||||
"paramCFGScale": {
|
||||
"heading": "Шкала точности (CFG)",
|
||||
"paragraphs": [
|
||||
"Контролирует, насколько запрос влияет на процесс генерации.",
|
||||
"Определяет, насколько сильно промпт влияет на процесс генерации.",
|
||||
"Высокие значения шкалы CFG могут привести к перенасыщению и искажению результатов генерации. "
|
||||
]
|
||||
},
|
||||
"controlNetControlMode": {
|
||||
"paragraphs": [
|
||||
"Придает больший вес либо запросу, либо ControlNet."
|
||||
"Смещает приоритет в сторону промпта или ControlNet."
|
||||
],
|
||||
"heading": "Режим управления"
|
||||
},
|
||||
@@ -1181,7 +1216,7 @@
|
||||
"refinerCfgScale": {
|
||||
"heading": "Шкала CFG",
|
||||
"paragraphs": [
|
||||
"Контролирует, насколько сильно запрос влияет на процесс генерации.",
|
||||
"Определяет, насколько сильно промпт влияет на процесс генерации.",
|
||||
"Аналогично CFG шкале генерации."
|
||||
]
|
||||
},
|
||||
@@ -1290,24 +1325,24 @@
|
||||
"ipAdapterMethod": {
|
||||
"heading": "Метод",
|
||||
"paragraphs": [
|
||||
"Метод, с помощью которого применяется текущий IP-адаптер."
|
||||
"Метод определяет, как референсное изображение будет влиять на процесс генерации."
|
||||
]
|
||||
},
|
||||
"structure": {
|
||||
"paragraphs": [
|
||||
"Структура контролирует, насколько точно выходное изображение будет соответствовать макету оригинала. Низкая структура допускает значительные изменения, в то время как высокая структура строго сохраняет исходную композицию и макет."
|
||||
"Структура определяет, насколько точно выходное изображение сохраняет компоновку исходного. Низкое значение допускает значительные изменения, а высокое строго сохраняет исходную композицию и расположение элементов."
|
||||
],
|
||||
"heading": "Структура"
|
||||
},
|
||||
"scale": {
|
||||
"paragraphs": [
|
||||
"Масштаб управляет размером выходного изображения и основывается на кратном разрешении входного изображения. Например, при увеличении в 2 раза изображения 1024x1024 на выходе получится 2048 x 2048."
|
||||
"Масштаб определяет размер выходного изображения и рассчитывается как кратное разрешению исходного изображения. Например, увеличение в 2 раза для изображения 1024×1024 даст результат 2048×2048."
|
||||
],
|
||||
"heading": "Масштаб"
|
||||
},
|
||||
"creativity": {
|
||||
"paragraphs": [
|
||||
"Креативность контролирует степень свободы, предоставляемой модели при добавлении деталей. При низкой креативности модель остается близкой к оригинальному изображению, в то время как высокая креативность позволяет вносить больше изменений. При использовании подсказки высокая креативность увеличивает влияние подсказки."
|
||||
"Креативность определяет степень свободы модели при добавлении деталей. Низкое значение сохраняет больше сходства с исходным изображением, а высокое допускает более значительные изменения. При использовании промпта высокое значение усиливает его влияние."
|
||||
],
|
||||
"heading": "Креативность"
|
||||
},
|
||||
@@ -1320,18 +1355,18 @@
|
||||
"fluxDevLicense": {
|
||||
"heading": "Некоммерческая лицензия",
|
||||
"paragraphs": [
|
||||
"Модели FLUX.1 [dev] лицензируются по некоммерческой лицензии FLUX [dev]. Чтобы использовать этот тип модели в коммерческих целях в Invoke, посетите наш веб-сайт, чтобы узнать больше."
|
||||
"Модели FLUX.1 [dev] распространяются по некоммерческой лицензии FLUX [dev]. Для их коммерческого использования требуется отдельная лицензия."
|
||||
]
|
||||
},
|
||||
"optimizedDenoising": {
|
||||
"heading": "Оптимизированный img2img",
|
||||
"paragraphs": [
|
||||
"Включите опцию «Оптимизированный img2img», чтобы получить более плавную шкалу Denoise Strength для img2img и перерисовки с моделями Flux. Эта настройка улучшает возможность контролировать степень изменения изображения, но может быть отключена, если вы предпочитаете использовать стандартную шкалу Denoise Strength. Эта настройка все еще находится в стадии настройки и в настоящее время имеет статус бета-версии."
|
||||
"Включите «Optimized Image-to-Image», чтобы использовать более плавную шкалу Denoise Strength для преобразований image-to-image и инпейнтинга с моделями Flux. Эта настройка улучшает контроль над степенью изменений изображения, однако её можно отключить, если вы предпочитаете стандартную шкалу Denoise Strength. Функция находится в стадии настройки и имеет статус бета-версии."
|
||||
]
|
||||
},
|
||||
"paramGuidance": {
|
||||
"paragraphs": [
|
||||
"Контролирует, насколько сильно запрос влияет на процесс генерации.",
|
||||
"Определяет, насколько сильно промпт влияет на процесс генерации.",
|
||||
"Высокие значения точности могут привести к перенасыщению, а высокие или низкие значения точности могут привести к искажению результатов генерации. Точность применима только к моделям FLUX DEV."
|
||||
],
|
||||
"heading": "Точность"
|
||||
@@ -1363,7 +1398,7 @@
|
||||
"parameterSet": "Параметр {{parameter}} установлен",
|
||||
"allPrompts": "Все запросы",
|
||||
"imageDimensions": "Размеры изображения",
|
||||
"canvasV2Metadata": "Холст",
|
||||
"canvasV2Metadata": "Слои холста",
|
||||
"guidance": "Точность"
|
||||
},
|
||||
"queue": {
|
||||
@@ -1393,7 +1428,7 @@
|
||||
"graphQueued": "График поставлен в очередь",
|
||||
"queue": "Очередь",
|
||||
"batch": "Пакет",
|
||||
"clearQueueAlertDialog": "Очистка очереди немедленно отменяет все элементы обработки и полностью очищает очередь. Ожидающие фильтры будут отменены.",
|
||||
"clearQueueAlertDialog": "Очистка очереди немедленно отменит все текущие задачи и очистит очередь. Ожидающие фильтры будут отменены, а область предпросмотра на холсте сброшена.",
|
||||
"pending": "В ожидании",
|
||||
"completedIn": "Завершено за",
|
||||
"resumeFailed": "Проблема с возобновлением рендеринга",
|
||||
@@ -1477,7 +1512,7 @@
|
||||
"workflowEditorMenu": "Меню редактора рабочего процесса",
|
||||
"workflowName": "Имя рабочего процесса",
|
||||
"saveWorkflow": "Сохранить рабочий процесс",
|
||||
"workflowLibrary": "Библиотека",
|
||||
"workflowLibrary": "Библиотека схем генерации",
|
||||
"downloadWorkflow": "Сохранить в файл",
|
||||
"workflowSaved": "Рабочий процесс сохранен",
|
||||
"unnamedWorkflow": "Безымянный рабочий процесс",
|
||||
@@ -1560,7 +1595,7 @@
|
||||
"autoNegative": "Авто негатив",
|
||||
"rectangle": "Прямоугольник",
|
||||
"addNegativePrompt": "Добавить $t(controlLayers.negativePrompt)",
|
||||
"regionalGuidance": "Региональная точность",
|
||||
"regionalGuidance": "Региональное влияние",
|
||||
"opacity": "Непрозрачность",
|
||||
"addLayer": "Добавить слой",
|
||||
"moveToFront": "На передний план",
|
||||
@@ -1568,33 +1603,33 @@
|
||||
"regional": "Региональный",
|
||||
"bookmark": "Закладка для быстрого переключения",
|
||||
"fitBboxToLayers": "Подогнать рамку к слоям",
|
||||
"mergeVisibleOk": "Объединенные видимые слои",
|
||||
"mergeVisibleError": "Ошибка объединения видимых слоев",
|
||||
"mergeVisibleOk": "Объединенные слои",
|
||||
"mergeVisibleError": "Ошибка объединения слоев",
|
||||
"clearHistory": "Очистить историю",
|
||||
"mergeVisible": "Объединить видимые",
|
||||
"removeBookmark": "Удалить закладку",
|
||||
"saveLayerToAssets": "Сохранить слой в активы",
|
||||
"saveLayerToAssets": "Сохранить слой в ресурсы",
|
||||
"clearCaches": "Очистить кэши",
|
||||
"recalculateRects": "Пересчитать прямоугольники",
|
||||
"saveBboxToGallery": "Сохранить рамку в галерею",
|
||||
"saveBboxToGallery": "Сохранить область в галерею",
|
||||
"canvas": "Холст",
|
||||
"global": "Глобальный",
|
||||
"newGlobalReferenceImageError": "Проблема с созданием глобального эталонного изображения",
|
||||
"newRegionalReferenceImageOk": "Создано региональное эталонное изображение",
|
||||
"newRegionalReferenceImageError": "Проблема создания регионального эталонного изображения",
|
||||
"newGlobalReferenceImageError": "Проблема с созданием глобального референсного изображения",
|
||||
"newRegionalReferenceImageOk": "Создано региональное референсное изображение",
|
||||
"newRegionalReferenceImageError": "Проблема создания регионального референсного изображения",
|
||||
"newControlLayerOk": "Создан слой управления",
|
||||
"newControlLayerError": "Ошибка создания слоя управления",
|
||||
"newRasterLayerOk": "Создан растровый слой",
|
||||
"newRasterLayerError": "Ошибка создания растрового слоя",
|
||||
"newGlobalReferenceImageOk": "Создано глобальное эталонное изображение",
|
||||
"bboxOverlay": "Показать наложение ограничительной рамки",
|
||||
"newGlobalReferenceImageOk": "Создано глобальное референсное изображение",
|
||||
"bboxOverlay": "Показать наложение рамки",
|
||||
"saveCanvasToGallery": "Сохранить холст в галерею",
|
||||
"pullBboxIntoReferenceImageOk": "рамка перенесена в эталонное изображение",
|
||||
"pullBboxIntoReferenceImageError": "Ошибка переноса рамки в эталонное изображение",
|
||||
"pullBboxIntoReferenceImageOk": "Рамка перенесена в референсное изображение",
|
||||
"pullBboxIntoReferenceImageError": "Ошибка переноса рамки в референсное изображение",
|
||||
"regionIsEmpty": "Выбранный регион пуст",
|
||||
"savedToGalleryOk": "Сохранено в галерею",
|
||||
"savedToGalleryError": "Ошибка сохранения в галерею",
|
||||
"pullBboxIntoLayerOk": "Рамка перенесена в слой",
|
||||
"pullBboxIntoLayerOk": "Содержимое рамки перенесено в слой",
|
||||
"pullBboxIntoLayerError": "Проблема с переносом рамки в слой",
|
||||
"newLayerFromImage": "Новый слой из изображения",
|
||||
"filter": {
|
||||
@@ -1693,11 +1728,12 @@
|
||||
"isTransforming": "{{title}} трансформируется"
|
||||
},
|
||||
"scaledBbox": "Масштабированная рамка",
|
||||
"bbox": "Ограничительная рамка"
|
||||
"bbox": "Ограничительная рамка",
|
||||
"textSessionActive": "Активен режим ввода"
|
||||
},
|
||||
"canvasContextMenu": {
|
||||
"saveBboxToGallery": "Сохранить рамку в галерею",
|
||||
"newGlobalReferenceImage": "Новое глобальное эталонное изображение",
|
||||
"newGlobalReferenceImage": "Новое глобальное референсное изображение",
|
||||
"bboxGroup": "Сохдать из рамки",
|
||||
"canvasGroup": "Холст",
|
||||
"newControlLayer": "Новый контрольный слой",
|
||||
@@ -1709,8 +1745,8 @@
|
||||
},
|
||||
"fill": {
|
||||
"solid": "Сплошной",
|
||||
"fillStyle": "Стиль заполнения",
|
||||
"fillColor": "Цвет заполнения",
|
||||
"fillStyle": "Стиль заливки",
|
||||
"fillColor": "Цвет заливкии",
|
||||
"grid": "Сетка",
|
||||
"horizontal": "Горизонтальная",
|
||||
"diagonal": "Диагональная",
|
||||
@@ -1729,8 +1765,8 @@
|
||||
"inpaintMask": "Маска перерисовки",
|
||||
"sendToCanvas": "Отправить на холст",
|
||||
"regionalGuidance_withCount_one": "$t(controlLayers.regionalGuidance)",
|
||||
"regionalGuidance_withCount_few": "Региональных точности",
|
||||
"regionalGuidance_withCount_many": "Региональных точностей",
|
||||
"regionalGuidance_withCount_few": "Региональных влияния",
|
||||
"regionalGuidance_withCount_many": "Региональных влияний",
|
||||
"controlLayer_withCount_one": "$t(controlLayers.controlLayer)",
|
||||
"controlLayer_withCount_few": "Контрольных слоя",
|
||||
"controlLayer_withCount_many": "Контрольных слоев",
|
||||
@@ -1739,9 +1775,9 @@
|
||||
"inpaintMask_withCount_few": "Маски перерисовки",
|
||||
"inpaintMask_withCount_many": "Масок перерисовки",
|
||||
"controlMode": {
|
||||
"prompt": "Запрос",
|
||||
"prompt": "Промпт",
|
||||
"controlMode": "Режим контроля",
|
||||
"megaControl": "Мега контроль",
|
||||
"megaControl": "Максимальный контроль",
|
||||
"balanced": "Сбалансированный",
|
||||
"control": "Контроль"
|
||||
},
|
||||
@@ -1770,24 +1806,25 @@
|
||||
"showResultsOn": "Показать результаты",
|
||||
"showResultsOff": "Скрыть результаты"
|
||||
},
|
||||
"pullBboxIntoReferenceImage": "Поместить рамку в эталонное изображение",
|
||||
"pullBboxIntoReferenceImage": "Преобразовать рамку в референсное изображение",
|
||||
"enableAutoNegative": "Включить авто негатив",
|
||||
"maskFill": "Заполнение маски",
|
||||
"maskFill": "Заливка маски",
|
||||
"tool": {
|
||||
"move": "Двигать",
|
||||
"move": "Перемещение",
|
||||
"bbox": "Ограничительная рамка",
|
||||
"view": "Смотреть",
|
||||
"view": "Перемещение холста",
|
||||
"brush": "Кисть",
|
||||
"eraser": "Ластик",
|
||||
"rectangle": "Прямоугольник",
|
||||
"colorPicker": "Подборщик цветов"
|
||||
"colorPicker": "Пипетка",
|
||||
"text": "Текст"
|
||||
},
|
||||
"rasterLayer": "Растровый слой",
|
||||
"enableTransparencyEffect": "Включить эффект прозрачности",
|
||||
"hidingType": "Скрыть {{type}}",
|
||||
"addRegionalGuidance": "Добавить $t(controlLayers.regionalGuidance)",
|
||||
"deleteSelected": "Удалить выбранное",
|
||||
"pullBboxIntoLayer": "Поместить рамку в слой",
|
||||
"pullBboxIntoLayer": "Преобразовать рамку в слой",
|
||||
"locked": "Заблокировано",
|
||||
"replaceLayer": "Заменить слой",
|
||||
"width": "Ширина",
|
||||
@@ -1795,15 +1832,15 @@
|
||||
"addRasterLayer": "Добавить $t(controlLayers.rasterLayer)",
|
||||
"addControlLayer": "Добавить $t(controlLayers.controlLayer)",
|
||||
"addInpaintMask": "Добавить $t(controlLayers.inpaintMask)",
|
||||
"cropLayerToBbox": "Обрезать слой по ограничительной рамке",
|
||||
"clipToBbox": "Обрезка штрихов в рамке",
|
||||
"outputOnlyMaskedRegions": "Вывод только маскированных областей",
|
||||
"cropLayerToBbox": "Обрезать слой по рамке",
|
||||
"clipToBbox": "Ограничить мазки рамкой",
|
||||
"outputOnlyMaskedRegions": "Выводить только сгенерированные области",
|
||||
"duplicate": "Дублировать",
|
||||
"layer_one": "Слой",
|
||||
"layer_few": "Слоя",
|
||||
"layer_many": "Слоев",
|
||||
"prompt": "Запрос",
|
||||
"negativePrompt": "Исключающий запрос",
|
||||
"prompt": "Промпт",
|
||||
"negativePrompt": "Негативный промпт",
|
||||
"beginEndStepPercentShort": "Начало/конец %",
|
||||
"transform": {
|
||||
"transform": "Трансформировать",
|
||||
@@ -1816,7 +1853,7 @@
|
||||
"fitModeFill": "Заполнить"
|
||||
},
|
||||
"disableAutoNegative": "Отключить авто негатив",
|
||||
"deleteReferenceImage": "Удалить эталонное изображение",
|
||||
"deleteReferenceImage": "Удалить референсное изображение",
|
||||
"rasterLayer_withCount_one": "$t(controlLayers.rasterLayer)",
|
||||
"rasterLayer_withCount_few": "Растровых слоя",
|
||||
"rasterLayer_withCount_many": "Растровых слоев",
|
||||
@@ -1828,9 +1865,42 @@
|
||||
"logDebugInfo": "Писать отладочную информацию",
|
||||
"unlocked": "Разблокировано",
|
||||
"showProgressOnCanvas": "Показать прогресс на холсте",
|
||||
"regionalReferenceImage": "Региональное эталонное изображение",
|
||||
"globalReferenceImage": "Глобальное эталонное изображение",
|
||||
"referenceImage": "Эталонное изображение"
|
||||
"regionalReferenceImage": "Региональное референсное изображение",
|
||||
"globalReferenceImage": "Глобальное референсное изображение",
|
||||
"referenceImage": "Референсное изображение",
|
||||
"text": {
|
||||
"px": "px",
|
||||
"alignRight": "По правому краю",
|
||||
"alignCenter": "По центру",
|
||||
"alignLeft": "По левому краю",
|
||||
"strikethrough": "Зачёркнутый",
|
||||
"italic": "Курсив",
|
||||
"bold": "Полужирный",
|
||||
"size": "Размер",
|
||||
"font": "Шрифт"
|
||||
},
|
||||
"newImg2ImgCanvasFromImage": "Новое изображение из Img2Img",
|
||||
"sendToCanvasDesc": "При нажатии Invoke результат появляется на холсте в режиме предпросмотра.",
|
||||
"compositeOperation": {
|
||||
"blendModes": {
|
||||
"darken": "Затемнение",
|
||||
"multiply": "Умножение",
|
||||
"color-dodge": "Осветление основы",
|
||||
"color-burn": "Затемнение основы",
|
||||
"screen": "Экран",
|
||||
"hard-light": "Жёсткий свет",
|
||||
"soft-light": "Мягкий свет",
|
||||
"overlay": "Перекрытие",
|
||||
"hue": "Тон",
|
||||
"color": "Цвет",
|
||||
"source-over": "Обычный"
|
||||
}
|
||||
},
|
||||
"globalReferenceImage_withCount_one": "$t(controlLayers.globalReferenceImage)",
|
||||
"globalReferenceImage_withCount_few": "Глобальных референсных изображения",
|
||||
"globalReferenceImage_withCount_many": "Глобальных референсных изображений",
|
||||
"regionalGuidance_withCount_hidden": "Региональное влияние (скрыто: {{count}})",
|
||||
"controlLayers_withCount_hidden": "Слои управления (скрыто: {{count}})"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { Box, Center, Spinner } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
|
||||
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
|
||||
import { clearStorage } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { AdministratorSetup } from 'features/auth/components/AdministratorSetup';
|
||||
import { LoginPage } from 'features/auth/components/LoginPage';
|
||||
import { ProtectedRoute } from 'features/auth/components/ProtectedRoute';
|
||||
import { UserManagement } from 'features/auth/components/UserManagement';
|
||||
import { UserProfile } from 'features/auth/components/UserProfile';
|
||||
import { AppContent } from 'features/ui/components/AppContent';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { memo } from 'react';
|
||||
import type { ReactNode } from 'react';
|
||||
import { memo, useEffect } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
import { Route, Routes, useNavigate } from 'react-router-dom';
|
||||
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
|
||||
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import ThemeLocaleProvider from './ThemeLocaleProvider';
|
||||
@@ -18,14 +26,94 @@ const errorBoundaryOnReset = () => {
|
||||
return false;
|
||||
};
|
||||
|
||||
const App = () => {
|
||||
const MainApp = () => {
|
||||
const isNavigationAPIConnected = useStore(navigationApi.$isConnected);
|
||||
return (
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
{isNavigationAPIConnected ? <AppContent /> : <Loading />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
const SetupChecker = () => {
|
||||
const { data, isLoading } = useGetSetupStatusQuery();
|
||||
const navigate = useNavigate();
|
||||
|
||||
// Check if user is already authenticated
|
||||
const token = localStorage.getItem('auth_token');
|
||||
const isAuthenticated = !!token;
|
||||
|
||||
useEffect(() => {
|
||||
if (!isLoading && data) {
|
||||
// If multiuser mode is disabled, go directly to the app
|
||||
if (!data.multiuser_enabled) {
|
||||
navigate('/app', { replace: true });
|
||||
} else if (isAuthenticated) {
|
||||
// In multiuser mode, check authentication
|
||||
navigate('/app', { replace: true });
|
||||
} else if (data.setup_required) {
|
||||
navigate('/setup', { replace: true });
|
||||
} else {
|
||||
navigate('/login', { replace: true });
|
||||
}
|
||||
}
|
||||
}, [data, isLoading, navigate, isAuthenticated]);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Center w="100dvw" h="100dvh">
|
||||
<Spinner size="xl" />
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
/** Full-page wrapper for user management / profile pages rendered inside the protected area */
|
||||
const FullPageWrapper = ({ children }: { children: ReactNode }) => (
|
||||
<Box w="100dvw" h="100dvh" overflowY="auto" bg="base.900">
|
||||
{children}
|
||||
</Box>
|
||||
);
|
||||
|
||||
const App = () => {
|
||||
return (
|
||||
<ThemeLocaleProvider>
|
||||
<ErrorBoundary onReset={errorBoundaryOnReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
{isNavigationAPIConnected ? <AppContent /> : <Loading />}
|
||||
</Box>
|
||||
<Routes>
|
||||
<Route path="/" element={<SetupChecker />} />
|
||||
<Route path="/login" element={<LoginPage />} />
|
||||
<Route path="/setup" element={<AdministratorSetup />} />
|
||||
<Route
|
||||
path="/profile"
|
||||
element={
|
||||
<ProtectedRoute>
|
||||
<FullPageWrapper>
|
||||
<UserProfile />
|
||||
</FullPageWrapper>
|
||||
</ProtectedRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/admin/users"
|
||||
element={
|
||||
<ProtectedRoute requireAdmin>
|
||||
<FullPageWrapper>
|
||||
<UserManagement />
|
||||
</FullPageWrapper>
|
||||
</ProtectedRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/*"
|
||||
element={
|
||||
<ProtectedRoute>
|
||||
<MainApp />
|
||||
</ProtectedRoute>
|
||||
}
|
||||
/>
|
||||
</Routes>
|
||||
<GlobalHookIsolator />
|
||||
<GlobalModalIsolator />
|
||||
</ErrorBoundary>
|
||||
|
||||
@@ -7,6 +7,7 @@ import { createStore } from 'app/store/store';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import React, { lazy, memo, useEffect, useState } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { BrowserRouter } from 'react-router-dom';
|
||||
|
||||
/*
|
||||
* We need to configure logging before anything else happens - useLayoutEffect ensures we set this at the first
|
||||
@@ -51,9 +52,11 @@ const InvokeAIUI = () => {
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<App />
|
||||
</React.Suspense>
|
||||
<BrowserRouter>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<App />
|
||||
</React.Suspense>
|
||||
</BrowserRouter>
|
||||
</Provider>
|
||||
</React.StrictMode>
|
||||
);
|
||||
|
||||
@@ -68,10 +68,26 @@ const getIdbKey = (key: string) => {
|
||||
return `${IDB_STORAGE_PREFIX}${key}`;
|
||||
};
|
||||
|
||||
// Helper to get auth headers for client_state requests
|
||||
const getAuthHeaders = (): Record<string, string> => {
|
||||
const headers: Record<string, string> = {};
|
||||
// Safe access to localStorage (not available in Node.js test environment)
|
||||
if (typeof window !== 'undefined' && window.localStorage) {
|
||||
const token = localStorage.getItem('auth_token');
|
||||
if (token) {
|
||||
headers['Authorization'] = `Bearer ${token}`;
|
||||
}
|
||||
}
|
||||
return headers;
|
||||
};
|
||||
|
||||
const getItem = async (key: string) => {
|
||||
try {
|
||||
const url = getUrl('get_by_key', key);
|
||||
const res = await fetch(url, { method: 'GET' });
|
||||
const res = await fetch(url, {
|
||||
method: 'GET',
|
||||
headers: getAuthHeaders(),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
@@ -130,7 +146,11 @@ const setItem = async (key: string, value: string) => {
|
||||
}
|
||||
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
|
||||
const url = getUrl('set_by_key', key);
|
||||
const res = await fetch(url, { method: 'POST', body: value });
|
||||
const res = await fetch(url, {
|
||||
method: 'POST',
|
||||
body: value,
|
||||
headers: getAuthHeaders(),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
@@ -158,7 +178,10 @@ export const clearStorage = async () => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
const url = getUrl('delete');
|
||||
const res = await fetch(url, { method: 'POST' });
|
||||
const res = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: getAuthHeaders(),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
|
||||
@@ -12,27 +12,12 @@ export const appStarted = createAction('app/appStarted');
|
||||
export const addAppStartedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: appStarted,
|
||||
effect: (action, { unsubscribe, cancelActiveListeners, take, getState, dispatch }) => {
|
||||
effect: async (action, { unsubscribe, cancelActiveListeners, take, getState, dispatch }) => {
|
||||
// this should only run once
|
||||
cancelActiveListeners();
|
||||
unsubscribe();
|
||||
|
||||
// ensure an image is selected when we load the first board
|
||||
take(imagesApi.endpoints.getImageNames.matchFulfilled).then((firstImageLoad) => {
|
||||
if (firstImageLoad === null) {
|
||||
// timeout or cancelled
|
||||
return;
|
||||
}
|
||||
const [{ payload }] = firstImageLoad;
|
||||
const selectedImage = selectLastSelectedItem(getState());
|
||||
if (selectedImage) {
|
||||
return;
|
||||
}
|
||||
if (payload.image_names[0]) {
|
||||
dispatch(imageSelected(payload.image_names[0]));
|
||||
}
|
||||
});
|
||||
|
||||
// Fire patchmatch check without blocking the image-selection logic below
|
||||
dispatch(appInfoApi.endpoints.getPatchmatchStatus.initiate())
|
||||
.unwrap()
|
||||
.then((isPatchmatchAvailable) => {
|
||||
@@ -43,6 +28,24 @@ export const addAppStartedListener = (startAppListening: AppStartListening) => {
|
||||
}
|
||||
})
|
||||
.catch(noop);
|
||||
|
||||
// ensure an image is selected when we load the first board.
|
||||
// The effect must be async and await take() so that RTK keeps the listener's AbortController
|
||||
// alive until the query resolves; a synchronous effect causes the controller to be aborted
|
||||
// immediately after the effect returns, before any network response arrives.
|
||||
const firstImageLoad = await take(imagesApi.endpoints.getImageNames.matchFulfilled, 5000);
|
||||
if (firstImageLoad === null) {
|
||||
// timeout or cancelled
|
||||
return;
|
||||
}
|
||||
const [{ payload }] = firstImageLoad;
|
||||
const selectedImage = selectLastSelectedItem(getState());
|
||||
if (selectedImage) {
|
||||
return;
|
||||
}
|
||||
if (payload.image_names[0]) {
|
||||
dispatch(imageSelected(payload.image_names[0]));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -19,6 +19,7 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { merge } from 'es-toolkit';
|
||||
import { omit, pick } from 'es-toolkit/compat';
|
||||
import { authSliceConfig } from 'features/auth/store/authSlice';
|
||||
import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
|
||||
@@ -61,6 +62,7 @@ const log = logger('system');
|
||||
|
||||
// When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
|
||||
const SLICE_CONFIGS = {
|
||||
[authSliceConfig.slice.reducerPath]: authSliceConfig,
|
||||
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
|
||||
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
|
||||
[canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig,
|
||||
@@ -87,6 +89,7 @@ const SLICE_CONFIGS = {
|
||||
// Remember to wrap undoable reducers in `undoable()`!
|
||||
const ALL_REDUCERS = {
|
||||
[api.reducerPath]: api.reducer,
|
||||
[authSliceConfig.slice.reducerPath]: authSliceConfig.slice.reducer,
|
||||
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
|
||||
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig.slice.reducer,
|
||||
[canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig.slice.reducer,
|
||||
|
||||
@@ -78,6 +78,44 @@ describe('promptAST', () => {
|
||||
{ type: 'rembed', start: 15, end: 16 },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should tokenize prompt function syntax', () => {
|
||||
const tokens = tokenize("('a', 'b').and()");
|
||||
expect(tokens).toEqual([
|
||||
{ type: 'lparen', start: 0, end: 1 },
|
||||
{ type: 'punct', value: "'", start: 1, end: 2 },
|
||||
{ type: 'word', value: 'a', start: 2, end: 3 },
|
||||
{ type: 'punct', value: "'", start: 3, end: 4 },
|
||||
{ type: 'punct', value: ',', start: 4, end: 5 },
|
||||
{ type: 'whitespace', value: ' ', start: 5, end: 6 },
|
||||
{ type: 'punct', value: "'", start: 6, end: 7 },
|
||||
{ type: 'word', value: 'b', start: 7, end: 8 },
|
||||
{ type: 'punct', value: "'", start: 8, end: 9 },
|
||||
{ type: 'rparen', start: 9, end: 10 },
|
||||
{ type: 'punct', value: '.', start: 10, end: 11 },
|
||||
{ type: 'word', value: 'and', start: 11, end: 14 },
|
||||
{ type: 'lparen', start: 14, end: 15 },
|
||||
{ type: 'rparen', start: 15, end: 16 },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should tokenize curly/smart quotes as punctuation', () => {
|
||||
const tokens = tokenize('\u201chello\u201d');
|
||||
expect(tokens).toEqual([
|
||||
{ type: 'punct', value: '\u201c', start: 0, end: 1 },
|
||||
{ type: 'word', value: 'hello', start: 1, end: 6 },
|
||||
{ type: 'punct', value: '\u201d', start: 6, end: 7 },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should tokenize curly single quotes as punctuation', () => {
|
||||
const tokens = tokenize('\u2018hello\u2019');
|
||||
expect(tokens).toEqual([
|
||||
{ type: 'punct', value: '\u2018', start: 0, end: 1 },
|
||||
{ type: 'word', value: 'hello', start: 1, end: 6 },
|
||||
{ type: 'punct', value: '\u2019', start: 6, end: 7 },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseTokens', () => {
|
||||
@@ -167,6 +205,312 @@ describe('promptAST', () => {
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name', range: { start: 0, end: 16 } }]);
|
||||
});
|
||||
|
||||
describe('prompt functions', () => {
|
||||
it('should parse .and() prompt function with single-quoted args', () => {
|
||||
const tokens = tokenize("('one two', 'three four').and()");
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('and');
|
||||
expect(pf.functionParams).toBe('');
|
||||
expect(pf.promptArgs).toHaveLength(2);
|
||||
|
||||
// First arg: 'one two'
|
||||
expect(pf.promptArgs[0]!.quote).toBe("'");
|
||||
expect(pf.promptArgs[0]!.nodes).toHaveLength(3); // word, ws, word
|
||||
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
|
||||
expect(pf.promptArgs[0]!.nodes[2]).toMatchObject({ type: 'word', text: 'two' });
|
||||
|
||||
// Second arg: 'three four'
|
||||
expect(pf.promptArgs[1]!.quote).toBe("'");
|
||||
expect(pf.promptArgs[1]!.nodes).toHaveLength(3);
|
||||
expect(pf.promptArgs[1]!.nodes[0]).toMatchObject({ type: 'word', text: 'three' });
|
||||
expect(pf.promptArgs[1]!.nodes[2]).toMatchObject({ type: 'word', text: 'four' });
|
||||
});
|
||||
|
||||
it('should parse .or() prompt function', () => {
|
||||
const tokens = tokenize("('one', 'two three. four.').or()");
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('or');
|
||||
expect(pf.promptArgs).toHaveLength(2);
|
||||
|
||||
// First arg: 'one'
|
||||
expect(pf.promptArgs[0]!.nodes).toHaveLength(1);
|
||||
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
|
||||
|
||||
// Second arg: 'two three. four.'
|
||||
expect(pf.promptArgs[1]!.nodes.length).toBeGreaterThanOrEqual(5);
|
||||
});
|
||||
|
||||
it('should parse .blend() prompt function with params', () => {
|
||||
const tokens = tokenize("('one', 'two').blend(0.7, 0.3)");
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('blend');
|
||||
expect(pf.functionParams).toBe('0.7, 0.3');
|
||||
expect(pf.promptArgs).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should parse prompt function with double-quoted args', () => {
|
||||
const tokens = tokenize('("one", "two").and()');
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('and');
|
||||
expect(pf.promptArgs[0]!.quote).toBe('"');
|
||||
});
|
||||
|
||||
it('should parse prompt function with curly double quotes', () => {
|
||||
const tokens = tokenize('(\u201cone\u201d, \u201ctwo\u201d).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('and');
|
||||
expect(pf.promptArgs).toHaveLength(2);
|
||||
expect(pf.promptArgs[0]!.quote).toBe('\u201c');
|
||||
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
|
||||
expect(pf.promptArgs[1]!.nodes[0]).toMatchObject({ type: 'word', text: 'two' });
|
||||
});
|
||||
|
||||
it('should parse prompt function with curly single quotes', () => {
|
||||
const tokens = tokenize('(\u2018one\u2019, \u2018two\u2019).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('and');
|
||||
expect(pf.promptArgs[0]!.quote).toBe('\u2018');
|
||||
});
|
||||
|
||||
it('should parse prompt function with curly quotes containing commas in args', () => {
|
||||
const prompt = '(\u201chigh detail, cinematic\u201d, \u201csoft light, portrait\u201d).and()';
|
||||
const ast = parseTokens(tokenize(prompt));
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.promptArgs).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should parse prompt function with newline before .method()', () => {
|
||||
const prompt = '(\u201cone\u201d, \u201ctwo\u201d)\n.and()';
|
||||
const ast = parseTokens(tokenize(prompt));
|
||||
expect(ast).toHaveLength(1);
|
||||
expect(ast[0]!.type).toBe('prompt_function');
|
||||
});
|
||||
|
||||
it('should parse quoted prompt function with newline before .method()', () => {
|
||||
const prompt = "('one', 'two')\n.and()";
|
||||
const ast = parseTokens(tokenize(prompt));
|
||||
expect(ast).toHaveLength(1);
|
||||
expect(ast[0]!.type).toBe('prompt_function');
|
||||
});
|
||||
|
||||
it('should parse prompt function with attention inside args', () => {
|
||||
const tokens = tokenize("('hello+', '(world)-').and()");
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
|
||||
// First arg: hello+
|
||||
const arg0Word = pf.promptArgs[0]!.nodes[0]!;
|
||||
expect(arg0Word).toMatchObject({ type: 'word', text: 'hello', attention: '+' });
|
||||
|
||||
// Second arg: (world)-
|
||||
const arg1Group = pf.promptArgs[1]!.nodes[0]!;
|
||||
expect(arg1Group.type).toBe('group');
|
||||
if (arg1Group.type === 'group') {
|
||||
expect(arg1Group.attention).toBe('-');
|
||||
}
|
||||
});
|
||||
|
||||
it('should preserve content range for each arg', () => {
|
||||
const tokens = tokenize("('one two', 'three four').and()");
|
||||
const ast = parseTokens(tokens);
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
|
||||
// 'one two' content is between quotes at positions 1 and 9
|
||||
expect(pf.promptArgs[0]!.contentRange.start).toBe(2);
|
||||
expect(pf.promptArgs[0]!.contentRange.end).toBe(9);
|
||||
|
||||
// 'three four' content is between quotes at positions 12 and 23
|
||||
expect(pf.promptArgs[1]!.contentRange.start).toBe(13);
|
||||
expect(pf.promptArgs[1]!.contentRange.end).toBe(23);
|
||||
});
|
||||
|
||||
it('should parse prompt function embedded in larger prompt', () => {
|
||||
const tokens = tokenize("some text, ('a', 'b').and(), more text");
|
||||
const ast = parseTokens(tokens);
|
||||
|
||||
// Should have: word, ws, word, punct, ws, prompt_function, punct, ws, word, ws, word
|
||||
const pfNodes = ast.filter((n) => n.type === 'prompt_function');
|
||||
expect(pfNodes).toHaveLength(1);
|
||||
expect(pfNodes[0]!.type).toBe('prompt_function');
|
||||
});
|
||||
|
||||
it('should fall back to regular group when no method call follows', () => {
|
||||
const tokens = tokenize("('a', 'b')");
|
||||
const ast = parseTokens(tokens);
|
||||
|
||||
// Without .method(), this should be parsed as a regular group
|
||||
expect(ast[0]!.type).toBe('group');
|
||||
});
|
||||
|
||||
it('should parse three-arg prompt function', () => {
|
||||
const tokens = tokenize("('a', 'b', 'c').blend(0.5, 0.3, 0.2)");
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.promptArgs).toHaveLength(3);
|
||||
expect(pf.functionParams).toBe('0.5, 0.3, 0.2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('unquoted prompt functions', () => {
|
||||
it('should parse unquoted .and() prompt function', () => {
|
||||
const tokens = tokenize('(one,two).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('and');
|
||||
expect(pf.functionParams).toBe('');
|
||||
expect(pf.promptArgs).toHaveLength(2);
|
||||
expect(pf.promptArgs[0]!.quote).toBe('');
|
||||
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
|
||||
expect(pf.promptArgs[1]!.quote).toBe('');
|
||||
expect(pf.promptArgs[1]!.nodes[0]).toMatchObject({ type: 'word', text: 'two' });
|
||||
});
|
||||
|
||||
it('should parse unquoted .and() with spaces', () => {
|
||||
const tokens = tokenize('(one two, three four).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('and');
|
||||
expect(pf.promptArgs).toHaveLength(2);
|
||||
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
|
||||
expect(pf.promptArgs[0]!.nodes[2]).toMatchObject({ type: 'word', text: 'two' });
|
||||
expect(pf.promptArgs[1]!.nodes[0]).toMatchObject({ type: 'word', text: 'three' });
|
||||
expect(pf.promptArgs[1]!.nodes[2]).toMatchObject({ type: 'word', text: 'four' });
|
||||
});
|
||||
|
||||
it('should parse unquoted .blend() with params', () => {
|
||||
const tokens = tokenize('(one two, three four).blend(0.7, 0.3)');
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.name).toBe('blend');
|
||||
expect(pf.functionParams).toBe('0.7, 0.3');
|
||||
expect(pf.promptArgs).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should parse unquoted three-arg prompt function', () => {
|
||||
const tokens = tokenize('(a, b, c).blend(0.5, 0.3, 0.2)');
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
expect(pf.promptArgs).toHaveLength(3);
|
||||
expect(pf.functionParams).toBe('0.5, 0.3, 0.2');
|
||||
});
|
||||
|
||||
it('should parse unquoted prompt function with attention inside args', () => {
|
||||
const tokens = tokenize('(hello+, world).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
expect(ast).toHaveLength(1);
|
||||
|
||||
const pf = ast[0]!;
|
||||
expect(pf.type).toBe('prompt_function');
|
||||
if (pf.type !== 'prompt_function') {
|
||||
return;
|
||||
}
|
||||
const arg0Word = pf.promptArgs[0]!.nodes[0]!;
|
||||
expect(arg0Word).toMatchObject({ type: 'word', text: 'hello', attention: '+' });
|
||||
});
|
||||
|
||||
it('should fall back to regular group for single-arg unquoted function', () => {
|
||||
const tokens = tokenize('(hello world).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
// Without a comma, this is not detected as a prompt function
|
||||
expect(ast[0]!.type).toBe('group');
|
||||
});
|
||||
|
||||
it('should parse unquoted prompt function embedded in larger prompt', () => {
|
||||
const tokens = tokenize('some text, (a, b).and(), more text');
|
||||
const ast = parseTokens(tokens);
|
||||
const pfNodes = ast.filter((n) => n.type === 'prompt_function');
|
||||
expect(pfNodes).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('serialize', () => {
|
||||
@@ -218,6 +562,163 @@ describe('promptAST', () => {
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe('<embedding_name>');
|
||||
});
|
||||
|
||||
describe('prompt functions', () => {
|
||||
it('should serialize .and() prompt function', () => {
|
||||
const tokens = tokenize("('one two', 'three four').and()");
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe("('one two', 'three four').and()");
|
||||
});
|
||||
|
||||
it('should serialize .or() prompt function', () => {
|
||||
const tokens = tokenize("('one', 'two three. four.').or()");
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe("('one', 'two three. four.').or()");
|
||||
});
|
||||
|
||||
it('should serialize .blend() with params', () => {
|
||||
const tokens = tokenize("('one', 'two').blend(0.7, 0.3)");
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe("('one', 'two').blend(0.7, 0.3)");
|
||||
});
|
||||
|
||||
it('should serialize prompt function with attention inside args', () => {
|
||||
const tokens = tokenize("('hello+', '(world)-').and()");
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe("('hello+', '(world)-').and()");
|
||||
});
|
||||
|
||||
it('should serialize prompt function embedded in larger prompt', () => {
|
||||
const prompt = "some text, ('a', 'b').and(), more text";
|
||||
const tokens = tokenize(prompt);
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe(prompt);
|
||||
});
|
||||
|
||||
it('should serialize three-arg blend', () => {
|
||||
const tokens = tokenize("('a', 'b', 'c').blend(0.5, 0.3, 0.2)");
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe("('a', 'b', 'c').blend(0.5, 0.3, 0.2)");
|
||||
});
|
||||
|
||||
it('should serialize double-quoted prompt function', () => {
|
||||
const tokens = tokenize('("one", "two").and()');
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe('("one", "two").and()');
|
||||
});
|
||||
|
||||
it('should serialize curly double-quoted prompt function', () => {
|
||||
const tokens = tokenize('(\u201cone\u201d, \u201ctwo\u201d).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe('(\u201cone\u201d, \u201ctwo\u201d).and()');
|
||||
});
|
||||
|
||||
it('should serialize curly single-quoted prompt function', () => {
|
||||
const tokens = tokenize('(\u2018one\u2019, \u2018two\u2019).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe('(\u2018one\u2019, \u2018two\u2019).and()');
|
||||
});
|
||||
});
|
||||
|
||||
describe('unquoted prompt functions', () => {
|
||||
it('should serialize unquoted .and()', () => {
|
||||
const tokens = tokenize('(one two, three four).and()');
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe('(one two, three four).and()');
|
||||
});
|
||||
|
||||
it('should serialize unquoted .blend() with params', () => {
|
||||
const tokens = tokenize('(one two, three four).blend(0.7, 0.3)');
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe('(one two, three four).blend(0.7, 0.3)');
|
||||
});
|
||||
|
||||
it('should serialize unquoted prompt function embedded in larger prompt', () => {
|
||||
const prompt = 'some text, (a, b).and(), more text';
|
||||
const tokens = tokenize(prompt);
|
||||
const ast = parseTokens(tokens);
|
||||
const result = serialize(ast);
|
||||
expect(result).toBe(prompt);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('round-trip (tokenize → parse → serialize)', () => {
|
||||
const roundTrip = (prompt: string) => {
|
||||
const tokens = tokenize(prompt);
|
||||
const ast = parseTokens(tokens);
|
||||
return serialize(ast);
|
||||
};
|
||||
|
||||
it.each([
|
||||
'a cat',
|
||||
'(a cat)',
|
||||
'(a cat)1.2',
|
||||
'cat+',
|
||||
'cat++',
|
||||
'cat-',
|
||||
'(hello world)+',
|
||||
'(hello world)++',
|
||||
'(hello world)-',
|
||||
'\\(medium\\)',
|
||||
'colored pencil \\(medium\\) (enhanced)',
|
||||
'<embedding_name>',
|
||||
'portrait \\(realistic\\) (high quality)1.2',
|
||||
'(masterpiece)1.3, best quality, (high detail)1.2',
|
||||
"('one two', 'three four').and()",
|
||||
"('one', 'two three. four.').or()",
|
||||
"('one', 'two').blend(0.7, 0.3)",
|
||||
"('hello+', '(world)-').and()",
|
||||
"some text, ('a', 'b').and(), more text",
|
||||
"('a', 'b', 'c').blend(0.5, 0.3, 0.2)",
|
||||
'("one", "two").and()',
|
||||
// Curly double-quoted prompt functions
|
||||
'(\u201cone\u201d, \u201ctwo\u201d).and()',
|
||||
'(\u201chigh detail, cinematic\u201d, \u201csoft light, portrait\u201d).and()',
|
||||
'(\u201cone\u201d, \u201ctwo\u201d).blend(0.7, 0.3)',
|
||||
// Curly single-quoted prompt functions
|
||||
'(\u2018one\u2019, \u2018two\u2019).and()',
|
||||
'(\u2018one\u2019, \u2018two\u2019).or()',
|
||||
// Unquoted prompt functions
|
||||
'(one two, three four).and()',
|
||||
'(one two, three four).blend(0.7, 0.3)',
|
||||
'(a, b, c).blend(0.5, 0.3, 0.2)',
|
||||
'some text, (a, b).and(), more text',
|
||||
"('one',\n 'two',\n 'three').and()",
|
||||
])('should round-trip: %s', (prompt) => {
|
||||
expect(roundTrip(prompt)).toBe(prompt);
|
||||
});
|
||||
});
|
||||
|
||||
describe('newline normalization', () => {
|
||||
const roundTrip = (prompt: string) => {
|
||||
const tokens = tokenize(prompt);
|
||||
const ast = parseTokens(tokens);
|
||||
return serialize(ast);
|
||||
};
|
||||
|
||||
it('should normalize newline before .method() in quoted prompt function', () => {
|
||||
expect(roundTrip("('one', 'two')\n.and()")).toBe("('one', 'two').and()");
|
||||
});
|
||||
|
||||
it('should normalize newline before .method() in curly-quoted prompt function', () => {
|
||||
expect(roundTrip('(\u201cone\u201d, \u201ctwo\u201d)\n.and()')).toBe('(\u201cone\u201d, \u201ctwo\u201d).and()');
|
||||
});
|
||||
|
||||
it('should normalize newline before .method() in unquoted prompt function', () => {
|
||||
expect(roundTrip('(one, two)\n.and()')).toBe('(one, two).and()');
|
||||
});
|
||||
});
|
||||
|
||||
describe('compel compatibility examples', () => {
|
||||
|
||||
@@ -3,18 +3,10 @@
|
||||
*/
|
||||
export type Attention = string | number;
|
||||
|
||||
type Word = string;
|
||||
|
||||
type Punct = string;
|
||||
|
||||
type Whitespace = string;
|
||||
|
||||
type Embedding = string;
|
||||
|
||||
type Token =
|
||||
| { type: 'word'; value: Word; start: number; end: number }
|
||||
| { type: 'whitespace'; value: Whitespace; start: number; end: number }
|
||||
| { type: 'punct'; value: Punct; start: number; end: number }
|
||||
| { type: 'word'; value: string; start: number; end: number }
|
||||
| { type: 'whitespace'; value: string; start: number; end: number }
|
||||
| { type: 'punct'; value: string; start: number; end: number }
|
||||
| { type: 'lparen'; start: number; end: number }
|
||||
| { type: 'rparen'; start: number; end: number }
|
||||
| { type: 'weight'; value: Attention; start: number; end: number }
|
||||
@@ -22,8 +14,21 @@ type Token =
|
||||
| { type: 'rembed'; start: number; end: number }
|
||||
| { type: 'escaped_paren'; value: '(' | ')'; start: number; end: number };
|
||||
|
||||
/**
|
||||
* A single argument in a prompt function like .and(), .or(), or .blend().
|
||||
* Contains the parsed AST nodes of the argument content and metadata about quoting/range.
|
||||
*/
|
||||
export type PromptFunctionArg = {
|
||||
nodes: ASTNode[];
|
||||
quote: string;
|
||||
/** Range of the content between the quotes (exclusive of quotes themselves) in original prompt coordinates. */
|
||||
contentRange: { start: number; end: number };
|
||||
/** Raw separator whitespace after the comma before this arg (args[1+] only). */
|
||||
separator?: string;
|
||||
};
|
||||
|
||||
export type ASTNode =
|
||||
| { type: 'word'; text: Word; attention?: Attention; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| { type: 'word'; text: string; attention?: Attention; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| {
|
||||
type: 'group';
|
||||
children: ASTNode[];
|
||||
@@ -31,20 +36,60 @@ export type ASTNode =
|
||||
range: { start: number; end: number };
|
||||
isSelection?: boolean;
|
||||
}
|
||||
| { type: 'embedding'; value: Embedding; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| { type: 'whitespace'; value: Whitespace; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| { type: 'punct'; value: Punct; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| { type: 'escaped_paren'; value: '(' | ')'; range: { start: number; end: number }; isSelection?: boolean };
|
||||
| { type: 'embedding'; value: string; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| { type: 'whitespace'; value: string; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| { type: 'punct'; value: string; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| { type: 'escaped_paren'; value: '(' | ')'; range: { start: number; end: number }; isSelection?: boolean }
|
||||
| {
|
||||
type: 'prompt_function';
|
||||
name: string;
|
||||
promptArgs: PromptFunctionArg[];
|
||||
functionParams: string;
|
||||
range: { start: number; end: number };
|
||||
isSelection?: boolean;
|
||||
};
|
||||
|
||||
const WEIGHT_PATTERN = /^[+-]?(\d+(\.\d+)?|[+-]+)/;
|
||||
const WHITESPACE_PATTERN = /^\s+/;
|
||||
const PUNCTUATION_PATTERN = /^[.,]/;
|
||||
const OTHER_PATTERN = /\s/;
|
||||
const WORD_CHAR_PATTERN = /[a-zA-Z0-9_]/;
|
||||
// prettier-ignore
|
||||
const PUNCTUATION_PATTERN = /^[.,/!?;:'"""''\u2018\u2019\u201c\u201d`~@#$%^&*=_|]/;
|
||||
|
||||
/** All characters that can serve as an opening quote in a prompt function argument. */
|
||||
const OPEN_QUOTE_CHARS = new Set(["'", '"', '\u2018', '\u201c']);
|
||||
|
||||
/** Map from opening curly quote to the matching closing curly quote. Straight quotes match themselves. */
|
||||
const CLOSE_QUOTE_MAP: Record<string, string> = {
|
||||
"'": "'",
|
||||
'"': '"',
|
||||
'\u2018': '\u2019', // ' → '
|
||||
'\u201c': '\u201d', // " → "
|
||||
};
|
||||
|
||||
// #region Token Helpers
|
||||
|
||||
/** Get the string value of a token, if it has one. */
|
||||
function tokenValue(t: Token | undefined): string | undefined {
|
||||
if (!t) {
|
||||
return undefined;
|
||||
}
|
||||
if ('value' in t) {
|
||||
return String(t.value);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Check if a token is a punct token with a specific value. */
|
||||
function isPunctValue(t: Token | undefined, value: string): boolean {
|
||||
return t?.type === 'punct' && tokenValue(t) === value;
|
||||
}
|
||||
|
||||
// #region Tokenizer
|
||||
|
||||
/**
|
||||
* Convert a prompt string into an AST.
|
||||
* Convert a prompt string into a token stream.
|
||||
* @param prompt string
|
||||
* @returns ASTNode[]
|
||||
* @returns Token[]
|
||||
*/
|
||||
export function tokenize(prompt: string): Token[] {
|
||||
if (!prompt) {
|
||||
@@ -52,7 +97,7 @@ export function tokenize(prompt: string): Token[] {
|
||||
}
|
||||
|
||||
const len = prompt.length;
|
||||
let tokens: Token[] = [];
|
||||
const tokens: Token[] = [];
|
||||
let i = 0;
|
||||
|
||||
while (i < len) {
|
||||
@@ -69,7 +114,7 @@ export function tokenize(prompt: string): Token[] {
|
||||
tokenizeEmbedding(char, i) ||
|
||||
tokenizeWord(prompt, i) ||
|
||||
tokenizePunctuation(char, i) ||
|
||||
tokenizeOther(char, i);
|
||||
tokenizeFallback(char, i);
|
||||
|
||||
if (result) {
|
||||
if (result.token) {
|
||||
@@ -168,15 +213,15 @@ function tokenizeWord(prompt: string, i: number): TokenizeResult {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (/[a-zA-Z0-9_]/.test(char)) {
|
||||
if (WORD_CHAR_PATTERN.test(char)) {
|
||||
let j = i;
|
||||
while (j < prompt.length && /[a-zA-Z0-9_]/.test(prompt[j]!)) {
|
||||
while (j < prompt.length && WORD_CHAR_PATTERN.test(prompt[j]!)) {
|
||||
j++;
|
||||
}
|
||||
const word = prompt.slice(i, j);
|
||||
|
||||
// Check for weight immediately after word (e.g., "Lorem+", "consectetur-")
|
||||
const weightMatch = prompt.slice(j).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/);
|
||||
const weightMatch = prompt.slice(j).match(WEIGHT_PATTERN);
|
||||
if (weightMatch && weightMatch[0]) {
|
||||
const weightEnd = j + weightMatch[0].length;
|
||||
return {
|
||||
@@ -210,17 +255,20 @@ function tokenizeEmbedding(char: string, i: number): TokenizeResult {
|
||||
return null;
|
||||
}
|
||||
|
||||
function tokenizeOther(char: string, i: number): TokenizeResult {
|
||||
// Any other single character punctuation
|
||||
if (OTHER_PATTERN.test(char)) {
|
||||
return {
|
||||
token: { type: 'punct', value: char, start: i, end: i + 1 },
|
||||
nextIndex: i + 1,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
/**
|
||||
* Fallback tokenizer for characters not matched by any other tokenizer.
|
||||
* Emits them as word tokens so they are preserved in the AST rather than silently dropped.
|
||||
* This handles non-Latin Unicode text (CJK, emoji, etc.) and any other unrecognized characters.
|
||||
*/
|
||||
function tokenizeFallback(char: string, i: number): TokenizeResult {
|
||||
return {
|
||||
token: { type: 'word', value: char, start: i, end: i + 1 },
|
||||
nextIndex: i + 1,
|
||||
};
|
||||
}
|
||||
|
||||
// #region Parser
|
||||
|
||||
/**
|
||||
* Convert tokens into an AST.
|
||||
* @param tokens Token[]
|
||||
@@ -233,10 +281,373 @@ export function parseTokens(tokens: Token[]): ASTNode[] {
|
||||
return tokens[pos];
|
||||
}
|
||||
|
||||
function peekAt(offset: number): Token | undefined {
|
||||
return tokens[pos + offset];
|
||||
}
|
||||
|
||||
function consume(): Token | undefined {
|
||||
return tokens[pos++];
|
||||
}
|
||||
|
||||
/**
|
||||
* Quick lookahead check: does the current lparen (already consumed) start a quoted prompt function?
|
||||
* A quoted prompt function looks like ('...', '...').method(...)
|
||||
* We check if the first non-whitespace token after lparen is a quote character.
|
||||
*/
|
||||
function isQuotedPromptFunctionAhead(): boolean {
|
||||
let p = 0;
|
||||
while (peekAt(p)?.type === 'whitespace') {
|
||||
p++;
|
||||
}
|
||||
const t = peekAt(p);
|
||||
return t?.type === 'punct' && OPEN_QUOTE_CHARS.has(tokenValue(t)!);
|
||||
}
|
||||
|
||||
/**
|
||||
* Lookahead check: does the current lparen (already consumed) start an unquoted prompt function?
|
||||
* An unquoted prompt function looks like (arg1, arg2).method(...) where args are not quoted.
|
||||
* We scan forward looking for a comma at the same nesting depth, then rparen followed by .word(
|
||||
*/
|
||||
function isUnquotedPromptFunctionAhead(): boolean {
|
||||
let p = 0;
|
||||
let depth = 0;
|
||||
let hasComma = false;
|
||||
|
||||
// Scan forward through tokens to find the matching rparen
|
||||
while (peekAt(p)) {
|
||||
const t = peekAt(p)!;
|
||||
|
||||
if (t.type === 'lparen') {
|
||||
depth++;
|
||||
} else if (t.type === 'rparen') {
|
||||
if (depth === 0) {
|
||||
// Found matching rparen — now check for .methodName( pattern
|
||||
// (possibly with whitespace between ) and .)
|
||||
if (!hasComma) {
|
||||
return false; // No comma means it's just a regular group
|
||||
}
|
||||
let next = p + 1;
|
||||
while (peekAt(next)?.type === 'whitespace') {
|
||||
next++;
|
||||
}
|
||||
return (
|
||||
isPunctValue(peekAt(next), '.') && peekAt(next + 1)?.type === 'word' && peekAt(next + 2)?.type === 'lparen'
|
||||
);
|
||||
}
|
||||
depth--;
|
||||
} else if (isPunctValue(t, ',') && depth === 0) {
|
||||
hasComma = true;
|
||||
}
|
||||
|
||||
p++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the `.methodName(params)` suffix that follows the closing rparen of a prompt function.
|
||||
* Assumes whitespace has already been skipped. Returns null and restores pos if the pattern
|
||||
* doesn't match.
|
||||
*/
|
||||
function tryParseMethodTail(savedPos: number): { name: string; functionParams: string; endPos: number } | null {
|
||||
// Skip whitespace between ) and .methodName (allows newlines)
|
||||
while (peek()?.type === 'whitespace') {
|
||||
consume();
|
||||
}
|
||||
|
||||
// Expect .methodName(params)
|
||||
if (!isPunctValue(peek(), '.')) {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
consume(); // consume dot
|
||||
|
||||
if (peek()?.type !== 'word') {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
const methodName = tokenValue(consume())!;
|
||||
|
||||
// Expect opening paren for method call
|
||||
if (peek()?.type !== 'lparen') {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
consume(); // consume method open paren
|
||||
|
||||
// Collect method params until closing rparen
|
||||
let functionParams = '';
|
||||
while (pos < tokens.length) {
|
||||
const t = peek()!;
|
||||
if (t.type === 'rparen') {
|
||||
break;
|
||||
}
|
||||
const tok = consume()!;
|
||||
const v = tokenValue(tok);
|
||||
if (v !== undefined) {
|
||||
functionParams += v;
|
||||
}
|
||||
}
|
||||
|
||||
// Expect closing rparen for method call
|
||||
if (peek()?.type !== 'rparen') {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
const methodCloseParen = consume()!; // consume method close paren
|
||||
|
||||
return { name: methodName, functionParams, endPos: methodCloseParen.end };
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to parse a prompt function starting after the opening lparen.
|
||||
* Returns the PromptFunctionNode if successful, or null if the pattern doesn't match
|
||||
* (in which case `pos` is restored to `savedPos`).
|
||||
*/
|
||||
function tryParsePromptFunction(lparenToken: Token & { type: 'lparen' }, savedPos: number): ASTNode | null {
|
||||
const args: PromptFunctionArg[] = [];
|
||||
let openQuoteChar: string | null = null;
|
||||
let closeQuoteChar: string | null = null;
|
||||
let pendingSeparator: string | undefined;
|
||||
|
||||
while (pos < tokens.length) {
|
||||
// Skip whitespace before arg or closing paren
|
||||
while (peek()?.type === 'whitespace') {
|
||||
consume();
|
||||
}
|
||||
|
||||
// Check for rparen (end of prompt function args)
|
||||
if (peek()?.type === 'rparen') {
|
||||
break;
|
||||
}
|
||||
|
||||
// Expect comma separator between args
|
||||
if (args.length > 0) {
|
||||
if (isPunctValue(peek(), ',')) {
|
||||
consume();
|
||||
let sep = '';
|
||||
while (peek()?.type === 'whitespace') {
|
||||
const sepToken = consume()!;
|
||||
const sepValue = tokenValue(sepToken);
|
||||
if (sepValue !== undefined) {
|
||||
sep += sepValue;
|
||||
}
|
||||
}
|
||||
pendingSeparator = sep;
|
||||
} else {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// Expect opening quote
|
||||
const openQuoteTok = peek();
|
||||
if (!openQuoteTok || openQuoteTok.type !== 'punct') {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
const thisOpenQuote = tokenValue(openQuoteTok)!;
|
||||
if (!OPEN_QUOTE_CHARS.has(thisOpenQuote)) {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
|
||||
const thisCloseQuote = CLOSE_QUOTE_MAP[thisOpenQuote]!;
|
||||
if (openQuoteChar === null) {
|
||||
openQuoteChar = thisOpenQuote;
|
||||
closeQuoteChar = thisCloseQuote;
|
||||
} else if (thisOpenQuote !== openQuoteChar) {
|
||||
// Mismatched quote style between args
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
|
||||
consume(); // consume opening quote
|
||||
const contentStart = openQuoteTok.end;
|
||||
|
||||
// Collect tokens until closing quote
|
||||
const argTokens: Token[] = [];
|
||||
let contentEnd = contentStart;
|
||||
while (pos < tokens.length) {
|
||||
const t = peek();
|
||||
if (isPunctValue(t, closeQuoteChar!)) {
|
||||
contentEnd = t!.start;
|
||||
break;
|
||||
}
|
||||
const consumed = consume()!;
|
||||
argTokens.push(consumed);
|
||||
contentEnd = consumed.end;
|
||||
}
|
||||
|
||||
// Expect closing quote
|
||||
if (!isPunctValue(peek(), closeQuoteChar!)) {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
consume(); // consume closing quote
|
||||
|
||||
// Parse sub-tokens as AST
|
||||
const argNodes = parseTokens(argTokens);
|
||||
|
||||
args.push({
|
||||
nodes: argNodes,
|
||||
quote: openQuoteChar,
|
||||
contentRange: { start: contentStart, end: contentEnd },
|
||||
separator: pendingSeparator,
|
||||
});
|
||||
pendingSeparator = undefined;
|
||||
}
|
||||
|
||||
if (args.length === 0) {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
|
||||
// Expect rparen
|
||||
if (peek()?.type !== 'rparen') {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
consume(); // consume rparen
|
||||
|
||||
// Parse .methodName(params) suffix
|
||||
const methodTail = tryParseMethodTail(savedPos);
|
||||
if (!methodTail) {
|
||||
return null; // pos already restored by tryParseMethodTail
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'prompt_function',
|
||||
name: methodTail.name,
|
||||
promptArgs: args,
|
||||
functionParams: methodTail.functionParams,
|
||||
range: { start: lparenToken.start, end: methodTail.endPos },
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to parse an unquoted prompt function starting after the opening lparen.
|
||||
* Unquoted prompt functions look like (arg1 words, arg2 words).method(params)
|
||||
* where arguments are separated by commas without quotes.
|
||||
* Returns the PromptFunctionNode if successful, or null if the pattern doesn't match
|
||||
* (in which case `pos` is restored to `savedPos`).
|
||||
*/
|
||||
function tryParseUnquotedPromptFunction(lparenToken: Token & { type: 'lparen' }, savedPos: number): ASTNode | null {
|
||||
const args: PromptFunctionArg[] = [];
|
||||
let pendingSeparator: string | undefined;
|
||||
|
||||
while (pos < tokens.length) {
|
||||
// Check for rparen (end of prompt function args)
|
||||
if (peek()?.type === 'rparen') {
|
||||
break;
|
||||
}
|
||||
|
||||
// Expect comma separator between args (consume the comma)
|
||||
if (args.length > 0) {
|
||||
if (isPunctValue(peek(), ',')) {
|
||||
consume(); // consume comma
|
||||
let sep = '';
|
||||
while (peek()?.type === 'whitespace') {
|
||||
const sepToken = consume()!;
|
||||
const sepValue = tokenValue(sepToken);
|
||||
if (sepValue !== undefined) {
|
||||
sep += sepValue;
|
||||
}
|
||||
}
|
||||
pendingSeparator = sep;
|
||||
} else {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tokens until comma or rparen (at nesting depth 0)
|
||||
const argTokens: Token[] = [];
|
||||
let contentStart: number | null = null;
|
||||
let contentEnd: number | null = null;
|
||||
let depth = 0;
|
||||
|
||||
while (pos < tokens.length) {
|
||||
const t = peek()!;
|
||||
|
||||
if (t.type === 'lparen') {
|
||||
depth++;
|
||||
} else if (t.type === 'rparen') {
|
||||
if (depth === 0) {
|
||||
break; // End of all args
|
||||
}
|
||||
depth--;
|
||||
} else if (isPunctValue(t, ',') && depth === 0) {
|
||||
break; // End of this arg
|
||||
}
|
||||
|
||||
if (contentStart === null) {
|
||||
contentStart = t.start;
|
||||
}
|
||||
const consumed = consume()!;
|
||||
argTokens.push(consumed);
|
||||
contentEnd = consumed.end;
|
||||
}
|
||||
|
||||
if (argTokens.length === 0) {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
|
||||
// Trim leading/trailing whitespace tokens from the arg content
|
||||
let firstNonWs = 0;
|
||||
while (firstNonWs < argTokens.length && argTokens[firstNonWs]!.type === 'whitespace') {
|
||||
firstNonWs++;
|
||||
}
|
||||
let lastNonWs = argTokens.length - 1;
|
||||
while (lastNonWs >= 0 && argTokens[lastNonWs]!.type === 'whitespace') {
|
||||
lastNonWs--;
|
||||
}
|
||||
|
||||
const trimmedArgTokens = argTokens.slice(firstNonWs, lastNonWs + 1);
|
||||
const trimmedStart = trimmedArgTokens.length > 0 ? trimmedArgTokens[0]!.start : contentStart!;
|
||||
const trimmedEnd = trimmedArgTokens.length > 0 ? trimmedArgTokens[trimmedArgTokens.length - 1]!.end : contentEnd!;
|
||||
|
||||
// Parse sub-tokens as AST
|
||||
const argNodes = parseTokens(trimmedArgTokens);
|
||||
|
||||
args.push({
|
||||
nodes: argNodes,
|
||||
quote: '', // Unquoted
|
||||
contentRange: { start: trimmedStart, end: trimmedEnd },
|
||||
separator: pendingSeparator,
|
||||
});
|
||||
pendingSeparator = undefined;
|
||||
}
|
||||
|
||||
if (args.length < 2) {
|
||||
// An unquoted prompt function must have at least 2 args (otherwise it's a regular group)
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
|
||||
// Expect rparen
|
||||
if (peek()?.type !== 'rparen') {
|
||||
pos = savedPos;
|
||||
return null;
|
||||
}
|
||||
consume(); // consume rparen
|
||||
|
||||
// Parse .methodName(params) suffix
|
||||
const methodTail = tryParseMethodTail(savedPos);
|
||||
if (!methodTail) {
|
||||
return null; // pos already restored by tryParseMethodTail
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'prompt_function',
|
||||
name: methodTail.name,
|
||||
promptArgs: args,
|
||||
functionParams: methodTail.functionParams,
|
||||
range: { start: lparenToken.start, end: methodTail.endPos },
|
||||
};
|
||||
}
|
||||
|
||||
function parseGroup(): ASTNode[] {
|
||||
const nodes: ASTNode[] = [];
|
||||
|
||||
@@ -254,6 +665,30 @@ export function parseTokens(tokens: Token[]): ASTNode[] {
|
||||
}
|
||||
case 'lparen': {
|
||||
const lparen = consume() as Token & { type: 'lparen' };
|
||||
|
||||
// Try to parse as a quoted prompt function first
|
||||
if (isQuotedPromptFunctionAhead()) {
|
||||
const savedPos = pos;
|
||||
const pfResult = tryParsePromptFunction(lparen, savedPos);
|
||||
if (pfResult) {
|
||||
nodes.push(pfResult);
|
||||
break;
|
||||
}
|
||||
// pos was restored by tryParsePromptFunction on failure
|
||||
}
|
||||
|
||||
// Try to parse as an unquoted prompt function
|
||||
if (isUnquotedPromptFunctionAhead()) {
|
||||
const savedPos = pos;
|
||||
const pfResult = tryParseUnquotedPromptFunction(lparen, savedPos);
|
||||
if (pfResult) {
|
||||
nodes.push(pfResult);
|
||||
break;
|
||||
}
|
||||
// pos was restored by tryParseUnquotedPromptFunction on failure
|
||||
}
|
||||
|
||||
// Regular group parsing
|
||||
const groupChildren = parseGroup();
|
||||
|
||||
let attention: Attention | undefined;
|
||||
@@ -283,10 +718,10 @@ export function parseTokens(tokens: Token[]): ASTNode[] {
|
||||
let end = lembed.end;
|
||||
while (peek() && peek()!.type !== 'rembed') {
|
||||
const embedToken = consume()!;
|
||||
embedValue +=
|
||||
embedToken.type === 'word' || embedToken.type === 'punct' || embedToken.type === 'whitespace'
|
||||
? embedToken.value
|
||||
: '';
|
||||
const v = tokenValue(embedToken);
|
||||
if (v !== undefined) {
|
||||
embedValue += v;
|
||||
}
|
||||
end = embedToken.end;
|
||||
}
|
||||
if (peek()?.type === 'rembed') {
|
||||
@@ -341,47 +776,131 @@ export function parseTokens(tokens: Token[]): ASTNode[] {
|
||||
return parseGroup();
|
||||
}
|
||||
|
||||
// #region Serialization
|
||||
|
||||
/**
|
||||
* Visitor callbacks for AST serialization. All callbacks are optional.
|
||||
* Called during traversal to allow tracking node positions in the output string.
|
||||
*/
|
||||
type SerializeVisitor = {
|
||||
/** Called after a node has been fully serialized, with its start and end positions in the output. */
|
||||
onNode?: (node: ASTNode, start: number, end: number) => void;
|
||||
};
|
||||
|
||||
/** Mutable buffer used by serializeCore so all recursive calls share the same position tracking. */
|
||||
type SerializeBuffer = { prompt: string };
|
||||
|
||||
/**
|
||||
* Shared serialization core. Converts an AST back into a prompt string,
|
||||
* optionally calling visitor hooks for position tracking.
|
||||
*
|
||||
* Uses a shared mutable buffer so that node positions reported via
|
||||
* `visitor.onNode` are always absolute offsets in the final output string,
|
||||
* even for nodes nested inside groups or prompt function args.
|
||||
*/
|
||||
function serializeCore(ast: ASTNode[], visitor: SerializeVisitor | undefined, buf: SerializeBuffer): void {
|
||||
for (const node of ast) {
|
||||
const nodeStart = buf.prompt.length;
|
||||
|
||||
switch (node.type) {
|
||||
case 'punct':
|
||||
case 'whitespace': {
|
||||
buf.prompt += node.value;
|
||||
break;
|
||||
}
|
||||
case 'escaped_paren': {
|
||||
buf.prompt += `\\${node.value}`;
|
||||
break;
|
||||
}
|
||||
case 'word': {
|
||||
buf.prompt += node.text;
|
||||
if (node.attention) {
|
||||
buf.prompt += String(node.attention);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'group': {
|
||||
buf.prompt += '(';
|
||||
serializeCore(node.children, visitor, buf);
|
||||
buf.prompt += ')';
|
||||
if (node.attention) {
|
||||
buf.prompt += String(node.attention);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'embedding': {
|
||||
buf.prompt += `<${node.value}>`;
|
||||
break;
|
||||
}
|
||||
case 'prompt_function': {
|
||||
buf.prompt += '(';
|
||||
for (let i = 0; i < node.promptArgs.length; i++) {
|
||||
if (i > 0) {
|
||||
const sep = node.promptArgs[i]!.separator ?? ' ';
|
||||
buf.prompt += `,${sep}`;
|
||||
}
|
||||
const arg = node.promptArgs[i]!;
|
||||
buf.prompt += arg.quote;
|
||||
serializeCore(arg.nodes, visitor, buf);
|
||||
buf.prompt += CLOSE_QUOTE_MAP[arg.quote] ?? arg.quote;
|
||||
}
|
||||
buf.prompt += ').';
|
||||
buf.prompt += node.name;
|
||||
buf.prompt += '(';
|
||||
buf.prompt += node.functionParams;
|
||||
buf.prompt += ')';
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
visitor?.onNode?.(node, nodeStart, buf.prompt.length);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert an AST back into a prompt string.
|
||||
* @param ast ASTNode[]
|
||||
* @returns string
|
||||
*/
|
||||
export function serialize(ast: ASTNode[]): string {
|
||||
let prompt = '';
|
||||
const buf: SerializeBuffer = { prompt: '' };
|
||||
serializeCore(ast, undefined, buf);
|
||||
return buf.prompt;
|
||||
}
|
||||
|
||||
for (const node of ast) {
|
||||
switch (node.type) {
|
||||
case 'punct':
|
||||
case 'whitespace': {
|
||||
prompt += node.value;
|
||||
break;
|
||||
}
|
||||
case 'escaped_paren': {
|
||||
prompt += `\\${node.value}`;
|
||||
break;
|
||||
}
|
||||
case 'word': {
|
||||
prompt += node.text;
|
||||
if (node.attention) {
|
||||
prompt += String(node.attention);
|
||||
/**
|
||||
* Serialize an AST to a prompt string while simultaneously computing the
|
||||
* selection range from `isSelection` flags on nodes.
|
||||
*
|
||||
* This is more reliable than separate serialize + selection computation because
|
||||
* the position tracking is guaranteed to match the serialized output.
|
||||
*/
|
||||
export function serializeWithSelection(ast: ASTNode[]): {
|
||||
prompt: string;
|
||||
selectionStart: number;
|
||||
selectionEnd: number;
|
||||
} {
|
||||
let selStart = Infinity;
|
||||
let selEnd = -1;
|
||||
|
||||
const buf: SerializeBuffer = { prompt: '' };
|
||||
serializeCore(
|
||||
ast,
|
||||
{
|
||||
onNode(node, start, end) {
|
||||
if (node.isSelection) {
|
||||
selStart = Math.min(selStart, start);
|
||||
selEnd = Math.max(selEnd, end);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'group': {
|
||||
prompt += '(';
|
||||
prompt += serialize(node.children);
|
||||
prompt += ')';
|
||||
if (node.attention) {
|
||||
prompt += String(node.attention);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'embedding': {
|
||||
prompt += `<${node.value}>`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
buf
|
||||
);
|
||||
|
||||
if (selStart === Infinity) {
|
||||
selStart = 0;
|
||||
selEnd = buf.prompt.length;
|
||||
}
|
||||
|
||||
return prompt;
|
||||
return { prompt: buf.prompt, selectionStart: selStart, selectionEnd: selEnd };
|
||||
}
|
||||
|
||||
@@ -2,170 +2,706 @@ import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { adjustPromptAttention } from './promptAttention';
|
||||
|
||||
/**
|
||||
* Helper: select by substring match within the prompt.
|
||||
* If `selected` is a string, finds it in the prompt and uses its position.
|
||||
* If `selected` is a [start, end] tuple, uses those positions directly.
|
||||
*/
|
||||
function adj(
|
||||
prompt: string,
|
||||
selected: string | [number, number],
|
||||
direction: 'increment' | 'decrement',
|
||||
prefersNumericWeights = false
|
||||
) {
|
||||
const [start, end] =
|
||||
typeof selected === 'string' ? [prompt.indexOf(selected), prompt.indexOf(selected) + selected.length] : selected;
|
||||
return adjustPromptAttention(prompt, start, end, direction, prefersNumericWeights);
|
||||
}
|
||||
|
||||
/** Helper that calls adj with prefersNumericWeights=true */
|
||||
function adjNumeric(prompt: string, selected: string | [number, number], direction: 'increment' | 'decrement') {
|
||||
return adj(prompt, selected, direction, true);
|
||||
}
|
||||
|
||||
describe('adjustPromptAttention', () => {
|
||||
describe('cross-boundary selection', () => {
|
||||
it('should split group and apply attention when selection spans from inside group to outside (increment)', () => {
|
||||
const prompt = '(a b)+ c';
|
||||
const result = adjustPromptAttention(prompt, 3, 8, 'increment');
|
||||
|
||||
expect(result.prompt).toBe('(a b+ c)+');
|
||||
});
|
||||
|
||||
it('should split group and apply attention when selection spans from inside group to outside (decrement)', () => {
|
||||
const prompt = '(a b)+ c';
|
||||
const result = adjustPromptAttention(prompt, 3, 8, 'decrement');
|
||||
|
||||
expect(result.prompt).toBe('a+ b c-');
|
||||
});
|
||||
|
||||
it('should split group when selection starts before group and ends inside (increment)', () => {
|
||||
const prompt = 'a (b c)+';
|
||||
const result = adjustPromptAttention(prompt, 0, 4, 'increment');
|
||||
|
||||
expect(result.prompt).toBe('(a b+ c)+');
|
||||
});
|
||||
|
||||
it('should split group when selection starts before group and ends inside (decrement)', () => {
|
||||
const prompt = 'a (b c)+';
|
||||
const result = adjustPromptAttention(prompt, 0, 4, 'decrement');
|
||||
|
||||
expect(result.prompt).toBe('a- b c+');
|
||||
});
|
||||
|
||||
it('should handle nested groups with cross-boundary selection (increment)', () => {
|
||||
const prompt = '((a b)+)+ c';
|
||||
const result = adjustPromptAttention(prompt, 2, 11, 'increment');
|
||||
|
||||
expect(result.prompt).toBe('((a b)++ c)+');
|
||||
});
|
||||
|
||||
it('should handle nested groups with cross-boundary selection (decrement)', () => {
|
||||
const prompt = '((a b)+)+ c';
|
||||
const result = adjustPromptAttention(prompt, 2, 11, 'decrement');
|
||||
|
||||
expect(result.prompt).toBe('(a b)+ c-');
|
||||
});
|
||||
|
||||
it('should handle selection spanning multiple groups (increment)', () => {
|
||||
const prompt = '(a)+ (b)+';
|
||||
const result = adjustPromptAttention(prompt, 0, 9, 'increment');
|
||||
|
||||
expect(result.prompt).toBe('(a b)++');
|
||||
});
|
||||
|
||||
it('should handle selection spanning multiple groups (decrement)', () => {
|
||||
const prompt = '(a)+ (b)+';
|
||||
const result = adjustPromptAttention(prompt, 0, 9, 'decrement');
|
||||
|
||||
expect(result.prompt).toBe('a b');
|
||||
});
|
||||
|
||||
it('should split negative group correctly (decrement on negative group)', () => {
|
||||
const prompt = '(a b)- c';
|
||||
const result = adjustPromptAttention(prompt, 3, 8, 'decrement');
|
||||
|
||||
expect(result.prompt).toBe('(a b- c)-');
|
||||
});
|
||||
|
||||
it('should split negative group correctly (increment on negative group)', () => {
|
||||
const prompt = '(a b)- c';
|
||||
const result = adjustPromptAttention(prompt, 3, 8, 'increment');
|
||||
|
||||
expect(result.prompt).toBe('a- b c+');
|
||||
});
|
||||
|
||||
it('should handle multiple non-selected items in group', () => {
|
||||
const prompt = '(a b c)+ d';
|
||||
const result = adjustPromptAttention(prompt, 5, 10, 'decrement');
|
||||
|
||||
expect(result.prompt).toBe('(a b)+ c d-');
|
||||
});
|
||||
|
||||
it('should handle word with existing attention in group when crossing boundary', () => {
|
||||
const prompt = 'c (d- e)+';
|
||||
const result = adjustPromptAttention(prompt, 0, 5, 'increment');
|
||||
|
||||
expect(result.prompt).toBe('c+ d e+');
|
||||
});
|
||||
|
||||
it('should handle complex multi-group case', () => {
|
||||
const prompt = '(a+ b)+ c (d- e)+';
|
||||
const result = adjustPromptAttention(prompt, 8, 14, 'increment');
|
||||
|
||||
expect(result.prompt).toBe('(a+ b c)+ d e+');
|
||||
});
|
||||
});
|
||||
// Basic Attention
|
||||
|
||||
describe('single word', () => {
|
||||
it('should add + when incrementing word without attention', () => {
|
||||
const prompt = 'hello world';
|
||||
const result = adjustPromptAttention(prompt, 0, 5, 'increment');
|
||||
|
||||
expect(result.prompt).toBe('hello+ world');
|
||||
});
|
||||
|
||||
it('should add - when decrementing word without attention', () => {
|
||||
const prompt = 'hello world';
|
||||
const result = adjustPromptAttention(prompt, 0, 5, 'decrement');
|
||||
|
||||
expect(result.prompt).toBe('hello- world');
|
||||
it.each([
|
||||
['hello world', 'hello', 'increment', 'hello+ world'],
|
||||
['hello world', 'hello', 'decrement', 'hello- world'],
|
||||
['hello+ world', 'hello+', 'increment', 'hello++ world'],
|
||||
['hello+ world', 'hello+', 'decrement', 'hello world'],
|
||||
['hello- world', 'hello-', 'decrement', 'hello-- world'],
|
||||
['hello- world', 'hello-', 'increment', 'hello world'],
|
||||
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
|
||||
expect(adj(prompt, selected, direction).prompt).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('existing group', () => {
|
||||
it('should adjust group attention when cursor is at group boundary', () => {
|
||||
const prompt = '(hello world)+';
|
||||
const result = adjustPromptAttention(prompt, 13, 14, 'increment');
|
||||
describe('multiple words', () => {
|
||||
it.each([
|
||||
['hello world', [0, 11] as [number, number], 'increment', '(hello world)+'],
|
||||
['hello world', [0, 11] as [number, number], 'decrement', '(hello world)-'],
|
||||
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
|
||||
expect(adj(prompt, selected, direction).prompt).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.prompt).toBe('(hello world)++');
|
||||
describe('cursor at word-punctuation boundary', () => {
|
||||
it('should select word, not punctuation, when cursor is between word and comma', () => {
|
||||
// "one|, two" — cursor at position 3, between "one" (0-3) and "," (3-4)
|
||||
expect(adj('one, two', [3, 3], 'increment').prompt).toBe('one+, two');
|
||||
});
|
||||
|
||||
it('should select word, not punctuation, when cursor is between word and period', () => {
|
||||
expect(adj('one. two', [3, 3], 'increment').prompt).toBe('one+. two');
|
||||
});
|
||||
|
||||
it('should select word when cursor is at start of word after punctuation', () => {
|
||||
// "one, |two" — cursor at position 5, between " " (4-5) and "two" (5-8)
|
||||
expect(adj('one, two', [5, 5], 'increment').prompt).toBe('one, two+');
|
||||
});
|
||||
|
||||
it('should still select punctuation when cursor is only touching punctuation', () => {
|
||||
// Cursor in the middle of a run of punctuation with no adjacent word
|
||||
// e.g. "one ,, two" cursor at position 5 — between "," (4-5) and "," (5-6)
|
||||
// Both neighbors are punct, so no word to prefer — should still work
|
||||
const result = adj('one ,, two', [5, 5], 'increment');
|
||||
expect(result).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
// Existing Groups
|
||||
|
||||
describe('existing groups', () => {
|
||||
it('should increment group when cursor is at group boundary', () => {
|
||||
expect(adj('(hello world)+', [13, 14], 'increment').prompt).toBe('(hello world)++');
|
||||
});
|
||||
|
||||
it('should remove group when attention becomes neutral', () => {
|
||||
const prompt = '(hello world)+';
|
||||
const result = adjustPromptAttention(prompt, 0, 14, 'decrement');
|
||||
expect(adj('(hello world)+', [0, 14], 'decrement').prompt).toBe('hello world');
|
||||
});
|
||||
|
||||
expect(result.prompt).toBe('hello world');
|
||||
it('should increment inner word within group', () => {
|
||||
const result = adj('(a b)+', [1, 2], 'increment');
|
||||
expect(result.prompt).toBe('(a+ b)+');
|
||||
});
|
||||
});
|
||||
|
||||
describe('multiple words without group', () => {
|
||||
it('should create new group with + when incrementing multiple words', () => {
|
||||
const prompt = 'hello world';
|
||||
const result = adjustPromptAttention(prompt, 0, 11, 'increment');
|
||||
// Cross-Boundary Selection
|
||||
|
||||
expect(result.prompt).toBe('(hello world)+');
|
||||
});
|
||||
|
||||
it('should create new group with - when decrementing multiple words', () => {
|
||||
const prompt = 'hello world';
|
||||
const result = adjustPromptAttention(prompt, 0, 11, 'decrement');
|
||||
|
||||
expect(result.prompt).toBe('(hello world)-');
|
||||
describe('cross-boundary selection', () => {
|
||||
it.each([
|
||||
// Selection from inside group to outside
|
||||
['(a b)+ c', [3, 8], 'increment', '(a b+ c)+'],
|
||||
['(a b)+ c', [3, 8], 'decrement', 'a+ b c-'],
|
||||
// Selection from outside to inside group
|
||||
['a (b c)+', [0, 4], 'increment', '(a b+ c)+'],
|
||||
['a (b c)+', [0, 4], 'decrement', 'a- b c+'],
|
||||
// Nested groups
|
||||
['((a b)+)+ c', [2, 11], 'increment', '((a b)++ c)+'],
|
||||
['((a b)+)+ c', [2, 11], 'decrement', '(a b)+ c-'],
|
||||
// Spanning multiple groups
|
||||
['(a)+ (b)+', [0, 9], 'increment', '(a b)++'],
|
||||
['(a)+ (b)+', [0, 9], 'decrement', 'a b'],
|
||||
// Negative groups
|
||||
['(a b)- c', [3, 8], 'decrement', '(a b- c)-'],
|
||||
['(a b)- c', [3, 8], 'increment', 'a- b c+'],
|
||||
// Multiple non-selected items in group
|
||||
['(a b c)+ d', [5, 10], 'decrement', '(a b)+ c d-'],
|
||||
// Word with existing attention crossing boundary
|
||||
['c (d- e)+', [0, 5], 'increment', 'c+ d e+'],
|
||||
// Complex multi-group
|
||||
['(a+ b)+ c (d- e)+', [8, 14], 'increment', '(a+ b c)+ d e+'],
|
||||
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
|
||||
expect(adj(prompt, selected as string | [number, number], direction).prompt).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
// Selection Preservation
|
||||
|
||||
describe('selection preservation', () => {
|
||||
it('should preserve selection when incrementing single word', () => {
|
||||
const prompt = 'hello world';
|
||||
const result = adjustPromptAttention(prompt, 0, 5, 'increment');
|
||||
it('should track selection when incrementing single word', () => {
|
||||
const result = adj('hello world', 'hello', 'increment');
|
||||
expect(result.prompt).toBe('hello+ world');
|
||||
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('hello+');
|
||||
});
|
||||
|
||||
it('should preserve selection when incrementing group', () => {
|
||||
const prompt = '(hello world)+';
|
||||
const result = adjustPromptAttention(prompt, 0, 14, 'increment');
|
||||
it('should track selection when incrementing full group', () => {
|
||||
const result = adj('(hello world)+', [0, 14], 'increment');
|
||||
expect(result.prompt).toBe('(hello world)++');
|
||||
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('(hello world)++');
|
||||
});
|
||||
|
||||
it('should preserve selection when splitting group', () => {
|
||||
const prompt = '(a b)+';
|
||||
const result = adjustPromptAttention(prompt, 1, 2, 'increment'); // Select 'a' (index 1 to 2)
|
||||
// 'a' becomes 1.21, 'b' stays 1.1
|
||||
// Result: (a+ b)+ which is equivalent to a++ b+
|
||||
it('should track selection when splitting group', () => {
|
||||
const result = adj('(a b)+', [1, 2], 'increment');
|
||||
expect(result.prompt).toBe('(a+ b)+');
|
||||
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('a+');
|
||||
});
|
||||
});
|
||||
|
||||
// Numeric Attention Weights
|
||||
|
||||
describe('numeric attention weights', () => {
|
||||
it.each([
|
||||
// Increment / decrement numeric weights with additive step
|
||||
['(masterpiece)1.3', [0, 16], 'increment', '(masterpiece)1.4'],
|
||||
['(masterpiece)1.3', [0, 16], 'decrement', '(masterpiece)1.2'],
|
||||
['(high detail)1.2', [0, 16], 'increment', '(high detail)1.3'],
|
||||
['(sunny midday light)1.15', [0, 24], 'increment', '(sunny midday light)1.25'],
|
||||
['(sunny midday light)1.15', [0, 24], 'decrement', '(sunny midday light)1.05'],
|
||||
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
|
||||
expect(adj(prompt, selected as [number, number], direction).prompt).toBe(expected);
|
||||
});
|
||||
|
||||
it('should preserve non-selected numeric weights when adjusting elsewhere', () => {
|
||||
const prompt = '(masterpiece)1.3, best quality';
|
||||
const result = adj(prompt, 'best quality', 'increment');
|
||||
expect(result.prompt).toContain('(masterpiece)1.3');
|
||||
expect(result.prompt).not.toContain('masterpiece1.3');
|
||||
});
|
||||
|
||||
it('should not produce floating point garbage', () => {
|
||||
const prompt = '(high detail)1.2, oil painting';
|
||||
const result = adj(prompt, 'oil painting', 'increment');
|
||||
expect(result.prompt).toContain('(high detail)1.2');
|
||||
expect(result.prompt).not.toMatch(/1\.19999/);
|
||||
expect(result.prompt).not.toMatch(/1\.20000/);
|
||||
});
|
||||
|
||||
it('should preserve numeric weight 1.15 without corruption', () => {
|
||||
const prompt = '(sunny midday light)1.15, landscape';
|
||||
const result = adj(prompt, 'landscape', 'increment');
|
||||
expect(result.prompt).toContain('(sunny midday light)1.15');
|
||||
expect(result.prompt).not.toMatch(/1\.15005/);
|
||||
});
|
||||
|
||||
it('should normalize numeric 1.1 weight to + syntax', () => {
|
||||
const prompt = '(lush rolling hills)1.1, landscape';
|
||||
const result = adj(prompt, 'landscape', 'increment');
|
||||
expect(result.prompt).toMatch(/\(lush rolling hills\)(\+|1\.1)/);
|
||||
});
|
||||
|
||||
it('should handle the full complex prompt without corrupting non-selected weights', () => {
|
||||
const prompt =
|
||||
'(masterpiece)1.3, best quality, (high detail)1.2, oil painting, (sunny midday light)1.15, an old stone castle standing on a hill, medieval architecture, weathered stone walls, (lush rolling hills)1.1, expansive landscape, clear blue sky';
|
||||
const result = adj(prompt, 'clear blue sky', 'increment');
|
||||
|
||||
expect(result.prompt).toContain('(masterpiece)1.3');
|
||||
expect(result.prompt).toContain('(high detail)1.2');
|
||||
expect(result.prompt).toContain('(sunny midday light)1.15');
|
||||
expect(result.prompt).toContain('(clear blue sky)+');
|
||||
expect(result.prompt).not.toMatch(/\d\.\d{5,}/);
|
||||
});
|
||||
});
|
||||
|
||||
// Prompt Functions
|
||||
|
||||
describe('prompt functions', () => {
|
||||
describe('within a single argument', () => {
|
||||
it.each([
|
||||
// Single word inside an arg
|
||||
["('hello world', 'other').and()", 'hello', 'increment', "('hello+ world', 'other').and()"],
|
||||
["('hello world', 'other').and()", 'hello', 'decrement', "('hello- world', 'other').and()"],
|
||||
// Multiple words in second arg
|
||||
["('a', 'hello world').or()", 'hello world', 'increment', "('a', '(hello world)+').or()"],
|
||||
["('a', 'hello world').or()", 'hello world', 'decrement', "('a', '(hello world)-').or()"],
|
||||
// Single word in .blend()
|
||||
["('one two', 'three four').blend(0.7, 0.3)", 'two', 'increment', "('one two+', 'three four').blend(0.7, 0.3)"],
|
||||
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
|
||||
expect(adj(prompt, selected, direction).prompt).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('across argument separator', () => {
|
||||
it('should adjust both args simultaneously when selection spans separator (increment)', () => {
|
||||
const prompt = "('one two', 'three four').and()";
|
||||
// Select across the separator: "two', 'three"
|
||||
const start = prompt.indexOf('two');
|
||||
const end = prompt.indexOf('three') + 'three'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe("('one two+', 'three+ four').and()");
|
||||
});
|
||||
|
||||
it('should adjust both args simultaneously when selection spans separator (decrement)', () => {
|
||||
const prompt = "('one two', 'three four').and()";
|
||||
const start = prompt.indexOf('two');
|
||||
const end = prompt.indexOf('three') + 'three'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'decrement');
|
||||
expect(result.prompt).toBe("('one two-', 'three- four').and()");
|
||||
});
|
||||
|
||||
it('should adjust across separator for .or()', () => {
|
||||
const prompt = "('alpha beta', 'gamma delta').or()";
|
||||
const start = prompt.indexOf('beta');
|
||||
const end = prompt.indexOf('gamma') + 'gamma'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe("('alpha beta+', 'gamma+ delta').or()");
|
||||
});
|
||||
|
||||
it('should adjust across separator for .blend() preserving params', () => {
|
||||
const prompt = "('one two', 'three four').blend(0.7, 0.3)";
|
||||
const start = prompt.indexOf('two');
|
||||
const end = prompt.indexOf('three') + 'three'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe("('one two+', 'three+ four').blend(0.7, 0.3)");
|
||||
});
|
||||
|
||||
it('should handle repeated increment across separator', () => {
|
||||
const prompt = "('one two+', 'three+ four').and()";
|
||||
const start = prompt.indexOf('two');
|
||||
const end = prompt.indexOf('three') + 'three'.length;
|
||||
// "two+" is at the boundary, "three+" is at the boundary
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe("('one two++', 'three++ four').and()");
|
||||
});
|
||||
});
|
||||
|
||||
describe('whole function selected', () => {
|
||||
it('should increment all content in all args when whole function is selected', () => {
|
||||
const prompt = "('one', 'two').and()";
|
||||
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
|
||||
expect(result.prompt).toBe("('one+', 'two+').and()");
|
||||
});
|
||||
|
||||
it('should decrement all content in all args', () => {
|
||||
const prompt = "('one', 'two').and()";
|
||||
const result = adjustPromptAttention(prompt, 0, prompt.length, 'decrement');
|
||||
expect(result.prompt).toBe("('one-', 'two-').and()");
|
||||
});
|
||||
|
||||
it('should increment all args of .blend() preserving params', () => {
|
||||
const prompt = "('one', 'two').blend(0.7, 0.3)";
|
||||
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
|
||||
expect(result.prompt).toBe("('one+', 'two+').blend(0.7, 0.3)");
|
||||
});
|
||||
});
|
||||
|
||||
describe('prompt function embedded in larger prompt', () => {
|
||||
it('should adjust only the targeted region outside the function', () => {
|
||||
const prompt = "some text, ('a', 'b').and(), more text";
|
||||
const result = adj(prompt, 'some', 'increment');
|
||||
expect(result.prompt).toContain('some+');
|
||||
expect(result.prompt).toContain("('a', 'b').and()");
|
||||
});
|
||||
|
||||
it('should adjust only the targeted region inside the function', () => {
|
||||
const prompt = "prefix ('alpha beta', 'gamma').and() suffix";
|
||||
const result = adj(prompt, 'alpha', 'increment');
|
||||
expect(result.prompt).toContain("'alpha+ beta'");
|
||||
expect(result.prompt).toContain('prefix');
|
||||
expect(result.prompt).toContain('suffix');
|
||||
});
|
||||
|
||||
it('should adjust text outside and inside function when selection spans boundary', () => {
|
||||
const prompt = "text ('one two', 'three').and()";
|
||||
// Select from 'text' through 'one'
|
||||
const start = prompt.indexOf('text');
|
||||
const end = prompt.indexOf('one') + 'one'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toContain('text+');
|
||||
expect(result.prompt).toContain("'one+ two'");
|
||||
});
|
||||
});
|
||||
|
||||
describe('prompt function with existing attention inside args', () => {
|
||||
it('should further increment already-weighted word inside arg', () => {
|
||||
const prompt = "('hello+', 'world').and()";
|
||||
// Select hello+ (the word with its weight marker)
|
||||
const result = adj(prompt, 'hello+', 'increment');
|
||||
expect(result.prompt).toBe("('hello++', 'world').and()");
|
||||
});
|
||||
|
||||
it('should cancel attention to neutral inside arg', () => {
|
||||
const prompt = "('hello+', 'world').and()";
|
||||
const result = adj(prompt, 'hello+', 'decrement');
|
||||
expect(result.prompt).toBe("('hello', 'world').and()");
|
||||
});
|
||||
|
||||
it('should handle group attention inside arg', () => {
|
||||
const prompt = "('(a b)+', 'c').and()";
|
||||
// Select everything in first arg
|
||||
const start = prompt.indexOf('(a b)+');
|
||||
const end = start + '(a b)+'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe("('(a b)++', 'c').and()");
|
||||
});
|
||||
});
|
||||
|
||||
describe('three-arg prompt functions', () => {
|
||||
it('should adjust a word in one arg of a three-arg blend', () => {
|
||||
const prompt = "('a', 'b', 'c').blend(0.5, 0.3, 0.2)";
|
||||
const result = adj(prompt, 'b', 'increment');
|
||||
expect(result.prompt).toBe("('a', 'b+', 'c').blend(0.5, 0.3, 0.2)");
|
||||
});
|
||||
|
||||
it('should adjust across two separators in a three-arg blend', () => {
|
||||
const prompt = "('aa bb', 'cc dd', 'ee ff').blend(0.5, 0.3, 0.2)";
|
||||
// Select from bb through ee
|
||||
const start = prompt.indexOf('bb');
|
||||
const end = prompt.indexOf('ee') + 'ee'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe("('aa bb+', '(cc dd)+', 'ee+ ff').blend(0.5, 0.3, 0.2)");
|
||||
});
|
||||
});
|
||||
|
||||
describe('unquoted prompt functions', () => {
|
||||
it('should increment a word in unquoted .and()', () => {
|
||||
const prompt = '(one, two).and()';
|
||||
const result = adj(prompt, 'one', 'increment');
|
||||
expect(result.prompt).toBe('(one+, two).and()');
|
||||
});
|
||||
|
||||
it('should decrement a word in unquoted .and()', () => {
|
||||
const prompt = '(one, two).and()';
|
||||
const result = adj(prompt, 'one', 'decrement');
|
||||
expect(result.prompt).toBe('(one-, two).and()');
|
||||
});
|
||||
|
||||
it('should increment a word in unquoted multi-word arg', () => {
|
||||
const prompt = '(hello world, foo bar).and()';
|
||||
const result = adj(prompt, 'hello', 'increment');
|
||||
expect(result.prompt).toBe('(hello+ world, foo bar).and()');
|
||||
});
|
||||
|
||||
it('should increment all args when whole unquoted function is selected', () => {
|
||||
const prompt = '(one, two).and()';
|
||||
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
|
||||
expect(result.prompt).toBe('(one+, two+).and()');
|
||||
});
|
||||
|
||||
it('should preserve unquoted prompt function when adjusting text outside', () => {
|
||||
const prompt = 'prefix (a, b).and() suffix';
|
||||
const result = adj(prompt, 'prefix', 'increment');
|
||||
expect(result.prompt).toContain('(a, b).and()');
|
||||
expect(result.prompt).toContain('prefix+');
|
||||
});
|
||||
|
||||
it('should handle unquoted .blend() with params', () => {
|
||||
const prompt = '(one two, three four).blend(0.7, 0.3)';
|
||||
const result = adj(prompt, 'one', 'increment');
|
||||
expect(result.prompt).toBe('(one+ two, three four).blend(0.7, 0.3)');
|
||||
});
|
||||
|
||||
it('should adjust across separator in unquoted prompt function', () => {
|
||||
const prompt = '(one two, three four).and()';
|
||||
const start = prompt.indexOf('two');
|
||||
const end = prompt.indexOf('three') + 'three'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe('(one two+, three+ four).and()');
|
||||
});
|
||||
});
|
||||
|
||||
describe('curly-quoted prompt functions', () => {
|
||||
it('should increment a word inside curly double-quoted arg', () => {
|
||||
const prompt = '(\u201chello world\u201d, \u201cother\u201d).and()';
|
||||
const result = adj(prompt, 'hello', 'increment');
|
||||
expect(result.prompt).toBe('(\u201chello+ world\u201d, \u201cother\u201d).and()');
|
||||
});
|
||||
|
||||
it('should decrement a word inside curly double-quoted arg', () => {
|
||||
const prompt = '(\u201chello world\u201d, \u201cother\u201d).and()';
|
||||
const result = adj(prompt, 'hello', 'decrement');
|
||||
expect(result.prompt).toBe('(\u201chello- world\u201d, \u201cother\u201d).and()');
|
||||
});
|
||||
|
||||
it('should increment a word inside curly single-quoted arg', () => {
|
||||
const prompt = '(\u2018hello world\u2019, \u2018other\u2019).and()';
|
||||
const result = adj(prompt, 'hello', 'increment');
|
||||
expect(result.prompt).toBe('(\u2018hello+ world\u2019, \u2018other\u2019).and()');
|
||||
});
|
||||
|
||||
it('should increment all args when whole curly-quoted function is selected', () => {
|
||||
const prompt = '(\u201cone\u201d, \u201ctwo\u201d).and()';
|
||||
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
|
||||
expect(result.prompt).toBe('(\u201cone+\u201d, \u201ctwo+\u201d).and()');
|
||||
});
|
||||
|
||||
it('should adjust across separator in curly double-quoted prompt function', () => {
|
||||
const prompt = '(\u201cone two\u201d, \u201cthree four\u201d).and()';
|
||||
const start = prompt.indexOf('two');
|
||||
const end = prompt.indexOf('three') + 'three'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe('(\u201cone two+\u201d, \u201cthree+ four\u201d).and()');
|
||||
});
|
||||
|
||||
it('should preserve curly-quoted function when adjusting text outside', () => {
|
||||
const prompt = 'prefix (\u201ca\u201d, \u201cb\u201d).and() suffix';
|
||||
const result = adj(prompt, 'prefix', 'increment');
|
||||
expect(result.prompt).toContain('(\u201ca\u201d, \u201cb\u201d).and()');
|
||||
expect(result.prompt).toContain('prefix+');
|
||||
});
|
||||
|
||||
it('should handle curly-quoted .blend() with params', () => {
|
||||
const prompt = '(\u201cone two\u201d, \u201cthree four\u201d).blend(0.7, 0.3)';
|
||||
const result = adj(prompt, 'one', 'increment');
|
||||
expect(result.prompt).toBe('(\u201cone+ two\u201d, \u201cthree four\u201d).blend(0.7, 0.3)');
|
||||
});
|
||||
});
|
||||
|
||||
describe('newline before .method()', () => {
|
||||
it('should increment a word in quoted prompt function with newline before .method()', () => {
|
||||
const prompt = "('hello world', 'other')\n.and()";
|
||||
const result = adj(prompt, 'hello', 'increment');
|
||||
// Newline is normalized away in output
|
||||
expect(result.prompt).toBe("('hello+ world', 'other').and()");
|
||||
});
|
||||
|
||||
it('should increment a word in curly-quoted prompt function with newline before .method()', () => {
|
||||
const prompt = '(\u201chello world\u201d, \u201cother\u201d)\n.and()';
|
||||
const result = adj(prompt, 'hello', 'increment');
|
||||
expect(result.prompt).toBe('(\u201chello+ world\u201d, \u201cother\u201d).and()');
|
||||
});
|
||||
|
||||
it('should increment a word in unquoted prompt function with newline before .method()', () => {
|
||||
const prompt = '(hello, other)\n.and()';
|
||||
const result = adj(prompt, 'hello', 'increment');
|
||||
expect(result.prompt).toBe('(hello+, other).and()');
|
||||
});
|
||||
});
|
||||
|
||||
describe('paragraph separators between args', () => {
|
||||
it('should preserve newlines between quoted args when adjusting', () => {
|
||||
const prompt = "('chunk 1\n\nline',\n 'chunk 2').and()";
|
||||
const result = adj(prompt, 'chunk', 'increment');
|
||||
expect(result.prompt).toBe("('chunk+ 1\n\nline',\n 'chunk 2').and()");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Selection Preservation with Prompt Functions
|
||||
|
||||
describe('selection preservation with prompt functions', () => {
|
||||
it('should track selection for single word inside prompt function arg', () => {
|
||||
const prompt = "('hello world', 'other').and()";
|
||||
const result = adj(prompt, 'hello', 'increment');
|
||||
expect(result.prompt).toBe("('hello+ world', 'other').and()");
|
||||
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('hello+');
|
||||
});
|
||||
|
||||
it('should track selection spanning across prompt function separator', () => {
|
||||
const prompt = "('one two', 'three four').and()";
|
||||
const start = prompt.indexOf('two');
|
||||
const end = prompt.indexOf('three') + 'three'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment');
|
||||
expect(result.prompt).toBe("('one two+', 'three+ four').and()");
|
||||
// Selection should span from 'two+' through 'three+' (including structural chars between)
|
||||
const sel = result.prompt.slice(result.selectionStart, result.selectionEnd);
|
||||
expect(sel).toContain('two+');
|
||||
expect(sel).toContain('three+');
|
||||
});
|
||||
});
|
||||
|
||||
// Edge Cases
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should return prompt unchanged when no selection overlap', () => {
|
||||
const prompt = 'hello world';
|
||||
const result = adjustPromptAttention(prompt, 5, 5, 'increment');
|
||||
// Cursor at the boundary between hello and space — should still find a terminal
|
||||
expect(result.prompt).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle empty prompt', () => {
|
||||
const result = adjustPromptAttention('', 0, 0, 'increment');
|
||||
expect(result.prompt).toBe('');
|
||||
});
|
||||
|
||||
it('should not modify prompt function structure when cursor is on structural char', () => {
|
||||
const prompt = "('a', 'b').and()";
|
||||
// Cursor on the dot between ) and and
|
||||
const dotPos = prompt.indexOf('.and');
|
||||
const result = adjustPromptAttention(prompt, dotPos, dotPos, 'increment');
|
||||
// Should either not change or only affect content, not break the structure
|
||||
expect(result.prompt).toContain('.and()');
|
||||
});
|
||||
});
|
||||
|
||||
// Numeric Weight Preference
|
||||
|
||||
describe('prefersNumericWeights', () => {
|
||||
describe('single word (no existing attention)', () => {
|
||||
it.each([
|
||||
['hello world', 'hello', 'increment', '(hello)1.1 world'],
|
||||
['hello world', 'hello', 'decrement', '(hello)0.9 world'],
|
||||
['hello world', 'world', 'increment', 'hello (world)1.1'],
|
||||
['hello world', 'world', 'decrement', 'hello (world)0.9'],
|
||||
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
|
||||
expect(adjNumeric(prompt, selected, direction).prompt).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('successive numeric adjustments', () => {
|
||||
it('should use additive step on second increment', () => {
|
||||
const result = adjNumeric('(hello)1.1 world', '(hello)1.1', 'increment');
|
||||
expect(result.prompt).toBe('(hello)1.2 world');
|
||||
});
|
||||
|
||||
it('should use additive step on second decrement', () => {
|
||||
const result = adjNumeric('(hello)0.9 world', '(hello)0.9', 'decrement');
|
||||
expect(result.prompt).toBe('(hello)0.8 world');
|
||||
});
|
||||
|
||||
it('should return to neutral from 1.1 on decrement', () => {
|
||||
const result = adjNumeric('(hello)1.1 world', '(hello)1.1', 'decrement');
|
||||
expect(result.prompt).toBe('hello world');
|
||||
});
|
||||
});
|
||||
|
||||
describe('does not convert existing +/- attention on unselected terminals', () => {
|
||||
it('should preserve +/- on unselected word when adjusting another', () => {
|
||||
const result = adjNumeric('hello+ world', 'world', 'increment');
|
||||
expect(result.prompt).toContain('hello+');
|
||||
expect(result.prompt).toContain('(world)1.1');
|
||||
});
|
||||
|
||||
it('should preserve - on unselected word', () => {
|
||||
const result = adjNumeric('hello- world', 'world', 'decrement');
|
||||
expect(result.prompt).toContain('hello-');
|
||||
expect(result.prompt).toContain('(world)0.9');
|
||||
});
|
||||
});
|
||||
|
||||
describe('existing +/- attention on selected terminals', () => {
|
||||
it('should increment existing + word with multiplicative step (respects existing style)', () => {
|
||||
const result = adjNumeric('hello+ world', 'hello+', 'increment');
|
||||
// The terminal already has explicit +/- attention, so it keeps that style
|
||||
expect(result.prompt).toBe('hello++ world');
|
||||
});
|
||||
|
||||
it('should decrement existing + word to neutral', () => {
|
||||
const result = adjNumeric('hello+ world', 'hello+', 'decrement');
|
||||
expect(result.prompt).toBe('hello world');
|
||||
});
|
||||
});
|
||||
|
||||
describe('existing numeric attention on selected terminals', () => {
|
||||
it('should increment existing numeric weight additively', () => {
|
||||
const result = adjNumeric('(detail)1.3 world', '(detail)1.3', 'increment');
|
||||
expect(result.prompt).toBe('(detail)1.4 world');
|
||||
});
|
||||
|
||||
it('should decrement existing numeric weight additively', () => {
|
||||
const result = adjNumeric('(detail)1.3 world', '(detail)1.3', 'decrement');
|
||||
expect(result.prompt).toBe('(detail)1.2 world');
|
||||
});
|
||||
});
|
||||
|
||||
describe('multiple words selected', () => {
|
||||
it('should wrap multiple words in numeric group on increment', () => {
|
||||
const result = adjNumeric('hello world', [0, 11], 'increment');
|
||||
expect(result.prompt).toBe('(hello world)1.1');
|
||||
});
|
||||
|
||||
it('should wrap multiple words in numeric group on decrement', () => {
|
||||
const result = adjNumeric('hello world', [0, 11], 'decrement');
|
||||
expect(result.prompt).toBe('(hello world)0.9');
|
||||
});
|
||||
});
|
||||
|
||||
describe('inside prompt functions', () => {
|
||||
it('should use numeric format inside prompt function arg', () => {
|
||||
const prompt = "('hello world', 'other').and()";
|
||||
const result = adjNumeric(prompt, 'hello', 'increment');
|
||||
expect(result.prompt).toBe("('(hello)1.1 world', 'other').and()");
|
||||
});
|
||||
|
||||
it('should use numeric format across prompt function separator', () => {
|
||||
const prompt = "('one two', 'three four').and()";
|
||||
const start = prompt.indexOf('two');
|
||||
const end = prompt.indexOf('three') + 'three'.length;
|
||||
const result = adjustPromptAttention(prompt, start, end, 'increment', true);
|
||||
expect(result.prompt).toBe("('one (two)1.1', '(three)1.1 four').and()");
|
||||
});
|
||||
});
|
||||
|
||||
describe('group splitting inside prompt function args', () => {
|
||||
it('should correctly split weighted group when decrementing a single word inside it', () => {
|
||||
const prompt =
|
||||
'("high detail, (cinematic lighting)1.25, soft volumetric light, (sharp focus)+, professional photography", "a young woman with balanced natural proportions, medium length brown hair, neutral expression, casual modern clothing", "subtle rim light, shallow depth of field, natural skin texture, clean background").and()';
|
||||
const result = adj(prompt, 'lighting', 'decrement');
|
||||
// "lighting" gets decremented from 1.25 → 1.25/1.1 ≈ 1.1364
|
||||
// "cinematic" stays at 1.25
|
||||
// The key thing: no space should be lost/misplaced
|
||||
expect(result.prompt).toContain('(cinematic)1.25');
|
||||
expect(result.prompt).toContain('lighting)');
|
||||
// Verify there's a space between the cinematic group and lighting group
|
||||
const cinIdx = result.prompt.indexOf('(cinematic)1.25');
|
||||
const afterCinematic = result.prompt.substring(
|
||||
cinIdx + '(cinematic)1.25'.length,
|
||||
cinIdx + '(cinematic)1.25'.length + 2
|
||||
);
|
||||
expect(afterCinematic).toMatch(/^ /); // Should start with a space
|
||||
});
|
||||
|
||||
it('should rejoin groups when incrementing back to the same weight', () => {
|
||||
const prompt =
|
||||
'("high detail, (cinematic lighting)1.25, soft volumetric light, (sharp focus)+, professional photography", "a young woman with balanced natural proportions, medium length brown hair, neutral expression, casual modern clothing", "subtle rim light, shallow depth of field, natural skin texture, clean background").and()';
|
||||
// Decrement "lighting" to split the group
|
||||
const step1 = adj(prompt, 'lighting', 'decrement');
|
||||
expect(step1.prompt).toContain('(cinematic)1.25');
|
||||
// Now increment "lighting" back — should rejoin into (cinematic lighting)1.25
|
||||
const step2 = adj(step1.prompt, 'lighting', 'increment');
|
||||
expect(step2.prompt).toContain('(cinematic lighting)1.25');
|
||||
});
|
||||
});
|
||||
|
||||
describe('numeric group whitespace trimming', () => {
|
||||
it('should not capture trailing whitespace inside numeric weighted groups', () => {
|
||||
// (foo bar)1.3 → decrement "bar" → (foo)1.3 (bar)X, with space between
|
||||
const result = adj('(foo bar)1.3', 'bar', 'decrement');
|
||||
expect(result.prompt).toContain('(foo)1.3');
|
||||
// Space should be outside the group, not inside
|
||||
expect(result.prompt).not.toContain('(foo )');
|
||||
expect(result.prompt).toMatch(/\(foo\)1\.3 /);
|
||||
});
|
||||
|
||||
it('should not capture leading whitespace inside numeric weighted groups', () => {
|
||||
// (foo bar)1.3 → decrement "foo" → (foo)X (bar)1.3, with space between
|
||||
const result = adj('(foo bar)1.3', 'foo', 'decrement');
|
||||
expect(result.prompt).toContain('(bar)1.3');
|
||||
// Space should be outside the group, not inside
|
||||
expect(result.prompt).not.toContain('( bar)');
|
||||
expect(result.prompt).toMatch(/ \(bar\)1\.3/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('numeric group conjoining', () => {
|
||||
it('should merge adjacent same-weight numeric groups back together', () => {
|
||||
// Two separate groups with same weight should conjoin into one
|
||||
const result = adj('(foo)1.25 (bar)1.25', [0, 19], 'increment');
|
||||
// Both words get the same increment, so they should stay in one group
|
||||
expect(result.prompt).not.toContain(') (');
|
||||
});
|
||||
|
||||
it('should merge adjacent same-weight groups when incrementing to match', () => {
|
||||
// Start with (foo bar)1.3, decrement "bar", then increment it back
|
||||
const step1 = adj('(foo bar)1.3', 'bar', 'decrement');
|
||||
// Now increment "bar" back — it should rejoin into a single group
|
||||
const step2 = adj(step1.prompt, 'bar', 'increment');
|
||||
expect(step2.prompt).toBe('(foo bar)1.3');
|
||||
});
|
||||
|
||||
it('should merge inside prompt function args', () => {
|
||||
const prompt = '("(cinematic)1.25 (lighting)1.25", "other").and()';
|
||||
const start = prompt.indexOf('cinematic');
|
||||
const end = prompt.indexOf('lighting') + 'lighting'.length;
|
||||
const result = adj(prompt, [start, end], 'increment');
|
||||
// Both get incremented to same weight, should be one group
|
||||
expect(result.prompt).not.toMatch(/\)\d[.\d]* \(/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('without prefersNumericWeights (default behavior unchanged)', () => {
|
||||
it('should still use +/- syntax by default', () => {
|
||||
expect(adj('hello world', 'hello', 'increment').prompt).toBe('hello+ world');
|
||||
expect(adj('hello world', 'hello', 'decrement').prompt).toBe('hello- world');
|
||||
});
|
||||
|
||||
it('should still use +/- for multiple words by default', () => {
|
||||
expect(adj('hello world', [0, 11], 'increment').prompt).toBe('(hello world)+');
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,152 +1,56 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
import { type ASTNode, type Attention, parseTokens, serialize, tokenize } from './promptAST';
|
||||
import {
|
||||
type ASTNode,
|
||||
type Attention,
|
||||
parseTokens,
|
||||
type PromptFunctionArg,
|
||||
serializeWithSelection,
|
||||
tokenize,
|
||||
} from './promptAST';
|
||||
|
||||
const log = logger('events');
|
||||
const log = logger('generation');
|
||||
|
||||
type AttentionDirection = 'increment' | 'decrement';
|
||||
type AdjustmentResult = { prompt: string; selectionStart: number; selectionEnd: number };
|
||||
|
||||
const ATTENTION_STEP = 1.1;
|
||||
const NUMERIC_ATTENTION_STEP = 0.1;
|
||||
|
||||
/** Tolerance for floating-point weight comparisons. */
|
||||
const WEIGHT_TOLERANCE = 0.001;
|
||||
|
||||
/** Tolerance for checking if a weight is a power of ATTENTION_STEP. */
|
||||
const STEP_COUNT_TOLERANCE = 0.005;
|
||||
|
||||
// #region Weight Helpers
|
||||
|
||||
/**
|
||||
* Adjusts the attention of the prompt at the current cursor/selection position.
|
||||
* Check if a weight is approximately ATTENTION_STEP^n for some integer n.
|
||||
* Returns n if so, or null if the weight is not a power of ATTENTION_STEP.
|
||||
*/
|
||||
export function adjustPromptAttention(
|
||||
prompt: string,
|
||||
selectionStart: number,
|
||||
selectionEnd: number,
|
||||
direction: AttentionDirection
|
||||
): AdjustmentResult {
|
||||
try {
|
||||
const tokens = tokenize(prompt);
|
||||
const ast = parseTokens(tokens);
|
||||
const terminals = flattenAST(ast);
|
||||
|
||||
let selectedTerminals = terminals.filter((t) => {
|
||||
const isSelected =
|
||||
(t.range.start < selectionEnd && t.range.end > selectionStart) ||
|
||||
(selectionStart === selectionEnd && t.range.start <= selectionStart && t.range.end >= selectionStart);
|
||||
|
||||
if (!isSelected) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (t.parentRange) {
|
||||
const parentContainsSelection = t.parentRange.start <= selectionStart && t.parentRange.end >= selectionEnd;
|
||||
const selectionCoversParent = selectionStart <= t.parentRange.start && selectionEnd >= t.parentRange.end;
|
||||
|
||||
if (!parentContainsSelection && !selectionCoversParent) {
|
||||
// Partial overlap.
|
||||
if (t.hasExplicitAttention) {
|
||||
return false; // Don't modify explicit weight in partial group
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
for (const t of selectedTerminals) {
|
||||
t.isSelected = true;
|
||||
}
|
||||
|
||||
if (selectedTerminals.length === 0) {
|
||||
const selectedGroup = findSelectedGroup(ast, selectionStart, selectionEnd);
|
||||
if (selectedGroup) {
|
||||
selectedTerminals = terminals.filter(
|
||||
(t) => t.range.start >= selectedGroup.range.start && t.range.end <= selectedGroup.range.end
|
||||
);
|
||||
for (const t of selectedTerminals) {
|
||||
t.isSelected = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (selectedTerminals.length === 0) {
|
||||
return { prompt, selectionStart, selectionEnd };
|
||||
}
|
||||
|
||||
for (const terminal of selectedTerminals) {
|
||||
if (direction === 'increment') {
|
||||
terminal.weight *= ATTENTION_STEP;
|
||||
} else {
|
||||
terminal.weight /= ATTENTION_STEP;
|
||||
}
|
||||
}
|
||||
|
||||
const newAST = groupTerminals(terminals);
|
||||
const newPrompt = serialize(newAST);
|
||||
const newSelection = calculateSelectionRange(newAST);
|
||||
|
||||
return {
|
||||
prompt: newPrompt,
|
||||
selectionStart: newSelection.start,
|
||||
selectionEnd: newSelection.end,
|
||||
};
|
||||
} catch (e) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
log.error({ error: serializeError(e) as any }, 'Failed to adjust prompt attention');
|
||||
return { prompt, selectionStart, selectionEnd };
|
||||
function getAttentionStepCount(weight: number): number | null {
|
||||
if (weight <= 0) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
type Terminal = {
|
||||
text: string;
|
||||
type: ASTNode['type'];
|
||||
weight: number;
|
||||
range: { start: number; end: number };
|
||||
hasExplicitAttention: boolean;
|
||||
parentRange?: { start: number; end: number };
|
||||
isSelected: boolean;
|
||||
};
|
||||
|
||||
function flattenAST(ast: ASTNode[], currentWeight = 1.0, parentRange?: { start: number; end: number }): Terminal[] {
|
||||
let terminals: Terminal[] = [];
|
||||
|
||||
for (const node of ast) {
|
||||
let nodeWeight = currentWeight;
|
||||
if ('attention' in node && node.attention) {
|
||||
nodeWeight *= parseAttention(node.attention);
|
||||
}
|
||||
|
||||
if (node.type === 'group') {
|
||||
terminals.push(...flattenAST(node.children, nodeWeight, node.range));
|
||||
} else {
|
||||
terminals.push({
|
||||
text: node.type === 'word' ? node.text : node.value,
|
||||
type: node.type,
|
||||
weight: nodeWeight,
|
||||
range: node.range,
|
||||
hasExplicitAttention: 'attention' in node && !!node.attention,
|
||||
parentRange: parentRange,
|
||||
isSelected: false,
|
||||
});
|
||||
}
|
||||
if (Math.abs(weight - 1.0) < WEIGHT_TOLERANCE) {
|
||||
return 0;
|
||||
}
|
||||
return terminals;
|
||||
}
|
||||
|
||||
function findSelectedGroup(nodes: ASTNode[], start: number, end: number): ASTNode | null {
|
||||
for (const node of nodes) {
|
||||
if (node.type === 'group') {
|
||||
const foundInChildren = findSelectedGroup(node.children, start, end);
|
||||
if (foundInChildren) {
|
||||
return foundInChildren;
|
||||
}
|
||||
|
||||
if (rangesOverlap(node.range, { start, end })) {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
const n = Math.round(Math.log(weight) / Math.log(ATTENTION_STEP));
|
||||
if (n === 0) {
|
||||
return null;
|
||||
}
|
||||
const expected = Math.pow(ATTENTION_STEP, n);
|
||||
if (Math.abs(expected - weight) < STEP_COUNT_TOLERANCE) {
|
||||
return n;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function rangesOverlap(a: { start: number; end: number }, b: { start: number; end: number }) {
|
||||
return a.start < b.end && a.end > b.start;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert an Attention value ('+', '--', 1.2, etc.) into a numeric multiplier.
|
||||
*/
|
||||
function parseAttention(attention: Attention): number {
|
||||
if (typeof attention === 'number') {
|
||||
return attention;
|
||||
@@ -161,83 +65,435 @@ function parseAttention(attention: Attention): number {
|
||||
return isNaN(num) ? 1.0 : num;
|
||||
}
|
||||
|
||||
function calculateSelectionRange(nodes: ASTNode[]): { start: number; end: number } {
|
||||
let selectionStart = Infinity;
|
||||
let selectionEnd = -1;
|
||||
let currentPos = 0;
|
||||
/**
|
||||
* Combine an existing attention value with an additional '+' or '-' level.
|
||||
* Handles cancellation: e.g. '++' + '-' → '+', '+' + '-' → undefined (neutral).
|
||||
*/
|
||||
function addAttention(current: Attention | undefined, added: '+' | '-'): Attention | undefined {
|
||||
if (!current) {
|
||||
return added;
|
||||
}
|
||||
if (typeof current === 'number') {
|
||||
if (added === '+') {
|
||||
return Number((current * ATTENTION_STEP).toFixed(4));
|
||||
}
|
||||
return Number((current / ATTENTION_STEP).toFixed(4));
|
||||
}
|
||||
// Check if the added direction cancels the current one
|
||||
const isCancel = (current.startsWith('+') && added === '-') || (current.startsWith('-') && added === '+');
|
||||
if (isCancel) {
|
||||
const res = current.substring(1);
|
||||
return res === '' ? undefined : res;
|
||||
}
|
||||
return `${current}${added}`;
|
||||
}
|
||||
|
||||
function traverse(nodes: ASTNode[]) {
|
||||
for (const node of nodes) {
|
||||
if (node.isSelection) {
|
||||
const len = serialize([node]).length;
|
||||
selectionStart = Math.min(selectionStart, currentPos);
|
||||
selectionEnd = Math.max(selectionEnd, currentPos + len);
|
||||
currentPos += len;
|
||||
} else {
|
||||
if (node.type === 'group') {
|
||||
// Group is not fully selected, but children might be.
|
||||
// Group structure: "(" + children + ")" + attention
|
||||
currentPos += 1; // '('
|
||||
traverse(node.children);
|
||||
currentPos += 1; // ')'
|
||||
if (node.attention) {
|
||||
currentPos += String(node.attention).length;
|
||||
// #region Terminal Type
|
||||
|
||||
type Terminal = {
|
||||
text: string;
|
||||
type: ASTNode['type'];
|
||||
weight: number;
|
||||
range: { start: number; end: number };
|
||||
hasExplicitAttention: boolean;
|
||||
hasNumericAttention: boolean;
|
||||
parentRange?: { start: number; end: number };
|
||||
isSelected: boolean;
|
||||
};
|
||||
|
||||
// #region Main Entry Point
|
||||
|
||||
/**
|
||||
* Adjusts the attention of the prompt at the current cursor/selection position.
|
||||
* Supports regular prompts and prompt functions (.and(), .or(), .blend()).
|
||||
*
|
||||
* When a selection spans across a prompt function's argument separator, each
|
||||
* affected argument is adjusted independently and simultaneously.
|
||||
*/
|
||||
export function adjustPromptAttention(
|
||||
prompt: string,
|
||||
selectionStart: number,
|
||||
selectionEnd: number,
|
||||
direction: AttentionDirection,
|
||||
prefersNumericWeights = false
|
||||
): AdjustmentResult {
|
||||
try {
|
||||
const tokens = tokenize(prompt);
|
||||
const ast = parseTokens(tokens);
|
||||
|
||||
const regions = extractRegions(ast);
|
||||
const processedNodes: ASTNode[] = [];
|
||||
let anyModified = false;
|
||||
|
||||
for (const region of regions) {
|
||||
if (region.type === 'normal') {
|
||||
const clipped = clipSelection(selectionStart, selectionEnd, region.range);
|
||||
if (clipped) {
|
||||
const result = adjustRegionNodes(region.nodes, clipped.start, clipped.end, direction, prefersNumericWeights);
|
||||
if (result.modified) {
|
||||
anyModified = true;
|
||||
}
|
||||
processedNodes.push(...result.nodes);
|
||||
} else {
|
||||
// Leaf node not selected.
|
||||
const len = serialize([node]).length;
|
||||
currentPos += len;
|
||||
processedNodes.push(...region.nodes);
|
||||
}
|
||||
} else {
|
||||
// prompt_function region
|
||||
const pfNode = region.node;
|
||||
const clipped = clipSelection(selectionStart, selectionEnd, pfNode.range);
|
||||
if (clipped) {
|
||||
const result = adjustPromptFunctionNode(pfNode, clipped.start, clipped.end, direction, prefersNumericWeights);
|
||||
if (result.modified) {
|
||||
anyModified = true;
|
||||
}
|
||||
processedNodes.push(result.node);
|
||||
} else {
|
||||
processedNodes.push(pfNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
traverse(nodes);
|
||||
if (!anyModified) {
|
||||
return { prompt, selectionStart, selectionEnd };
|
||||
}
|
||||
|
||||
if (selectionStart === Infinity) {
|
||||
return { start: 0, end: serialize(nodes).length };
|
||||
return serializeWithSelection(processedNodes);
|
||||
} catch (e) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
log.error({ error: serializeError(e) as any }, 'Failed to adjust prompt attention');
|
||||
return { prompt, selectionStart, selectionEnd };
|
||||
}
|
||||
return { start: selectionStart, end: selectionEnd };
|
||||
}
|
||||
|
||||
// #region Region Extraction
|
||||
|
||||
type Region =
|
||||
| { type: 'normal'; nodes: ASTNode[]; range: { start: number; end: number } }
|
||||
| { type: 'prompt_function'; node: ASTNode & { type: 'prompt_function' } };
|
||||
|
||||
/**
|
||||
* Split the top-level AST into contiguous "normal" regions and prompt function regions.
|
||||
* This allows us to process prompt function arguments independently.
|
||||
*/
|
||||
function extractRegions(ast: ASTNode[]): Region[] {
|
||||
const regions: Region[] = [];
|
||||
let currentNormal: ASTNode[] = [];
|
||||
|
||||
const flushNormal = () => {
|
||||
if (currentNormal.length > 0) {
|
||||
const first = currentNormal[0]!;
|
||||
const last = currentNormal[currentNormal.length - 1]!;
|
||||
regions.push({
|
||||
type: 'normal',
|
||||
nodes: currentNormal,
|
||||
range: { start: first.range.start, end: last.range.end },
|
||||
});
|
||||
currentNormal = [];
|
||||
}
|
||||
};
|
||||
|
||||
for (const node of ast) {
|
||||
if (node.type === 'prompt_function') {
|
||||
flushNormal();
|
||||
regions.push({ type: 'prompt_function', node });
|
||||
} else {
|
||||
currentNormal.push(node);
|
||||
}
|
||||
}
|
||||
flushNormal();
|
||||
|
||||
return regions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clip a selection range to a target range. Returns null if there is no overlap.
|
||||
* For cursor positions (start === end), checks containment including boundaries.
|
||||
*/
|
||||
function clipSelection(
|
||||
selStart: number,
|
||||
selEnd: number,
|
||||
range: { start: number; end: number }
|
||||
): { start: number; end: number } | null {
|
||||
if (selStart === selEnd) {
|
||||
// Cursor position: check if within range (inclusive of boundaries)
|
||||
if (selStart >= range.start && selStart <= range.end) {
|
||||
return { start: selStart, end: selEnd };
|
||||
}
|
||||
return null;
|
||||
}
|
||||
const clippedStart = Math.max(selStart, range.start);
|
||||
const clippedEnd = Math.min(selEnd, range.end);
|
||||
if (clippedStart >= clippedEnd) {
|
||||
return null;
|
||||
}
|
||||
return { start: clippedStart, end: clippedEnd };
|
||||
}
|
||||
|
||||
// #region Prompt Function Handling
|
||||
|
||||
/**
|
||||
* Adjust attention within a prompt function node by processing each argument
|
||||
* whose content range overlaps the selection independently.
|
||||
* Returns the (possibly updated) node and whether any modification was made.
|
||||
*/
|
||||
function adjustPromptFunctionNode(
|
||||
pf: ASTNode & { type: 'prompt_function' },
|
||||
selStart: number,
|
||||
selEnd: number,
|
||||
direction: AttentionDirection,
|
||||
prefersNumericWeights = false
|
||||
): { node: ASTNode & { type: 'prompt_function' }; modified: boolean } {
|
||||
let modified = false;
|
||||
const newArgs: PromptFunctionArg[] = pf.promptArgs.map((arg) => {
|
||||
const clipped = clipSelection(selStart, selEnd, arg.contentRange);
|
||||
if (clipped) {
|
||||
const result = adjustRegionNodes(arg.nodes, clipped.start, clipped.end, direction, prefersNumericWeights);
|
||||
if (result.modified) {
|
||||
modified = true;
|
||||
return { ...arg, nodes: result.nodes };
|
||||
}
|
||||
}
|
||||
return arg;
|
||||
});
|
||||
|
||||
if (!modified) {
|
||||
return { node: pf, modified: false };
|
||||
}
|
||||
|
||||
return { node: { ...pf, promptArgs: newArgs }, modified: true };
|
||||
}
|
||||
|
||||
// #region Core Attention Adjustment
|
||||
|
||||
/**
|
||||
* Adjust attention for a set of AST nodes (a "region") given a selection range.
|
||||
* This is the core flatten → select → adjust → regroup pipeline.
|
||||
* Returns the adjusted nodes and whether any modification was made.
|
||||
*/
|
||||
function adjustRegionNodes(
|
||||
nodes: ASTNode[],
|
||||
selStart: number,
|
||||
selEnd: number,
|
||||
direction: AttentionDirection,
|
||||
prefersNumericWeights = false
|
||||
): { nodes: ASTNode[]; modified: boolean } {
|
||||
const terminals = flattenAST(nodes);
|
||||
|
||||
let selectedTerminals = selectTerminals(terminals, selStart, selEnd);
|
||||
|
||||
// Fallback: if no terminals were selected, try to find an overlapping group
|
||||
if (selectedTerminals.length === 0) {
|
||||
const group = findSelectedGroup(nodes, selStart, selEnd);
|
||||
if (group) {
|
||||
selectedTerminals = terminals.filter((t) => t.range.start >= group.range.start && t.range.end <= group.range.end);
|
||||
}
|
||||
}
|
||||
|
||||
if (selectedTerminals.length === 0) {
|
||||
return { nodes, modified: false };
|
||||
}
|
||||
|
||||
for (const t of selectedTerminals) {
|
||||
t.isSelected = true;
|
||||
// When the user prefers numeric weights and the terminal doesn't already
|
||||
// have explicit attention, mark it as numeric so adjustWeights uses
|
||||
// additive steps and groupTerminals emits numeric syntax.
|
||||
if (prefersNumericWeights && !t.hasExplicitAttention) {
|
||||
t.hasNumericAttention = true;
|
||||
}
|
||||
}
|
||||
|
||||
adjustWeights(selectedTerminals, direction);
|
||||
|
||||
return { nodes: groupTerminals(terminals), modified: true };
|
||||
}
|
||||
|
||||
// #region Flatten AST to Terminals
|
||||
|
||||
/**
|
||||
* Flatten an AST into a flat list of terminals, computing the effective weight
|
||||
* of each terminal by accumulating attention from ancestor groups.
|
||||
*/
|
||||
function flattenAST(
|
||||
ast: ASTNode[],
|
||||
currentWeight = 1.0,
|
||||
parentRange?: { start: number; end: number },
|
||||
numericAttention = false
|
||||
): Terminal[] {
|
||||
const terminals: Terminal[] = [];
|
||||
|
||||
for (const node of ast) {
|
||||
let nodeWeight = currentWeight;
|
||||
let nodeNumericAttention = numericAttention;
|
||||
if ((node.type === 'word' || node.type === 'group') && node.attention) {
|
||||
nodeWeight *= parseAttention(node.attention);
|
||||
nodeNumericAttention = typeof node.attention === 'number';
|
||||
}
|
||||
|
||||
if (node.type === 'group') {
|
||||
terminals.push(...flattenAST(node.children, nodeWeight, node.range, nodeNumericAttention));
|
||||
} else if (node.type === 'prompt_function') {
|
||||
// Prompt functions should not appear inside regions being flattened;
|
||||
// they are handled at the region level. If one somehow appears, skip it.
|
||||
continue;
|
||||
} else {
|
||||
terminals.push({
|
||||
text: node.type === 'word' ? node.text : node.value,
|
||||
type: node.type,
|
||||
weight: nodeWeight,
|
||||
range: node.range,
|
||||
hasExplicitAttention: node.type === 'word' && !!node.attention,
|
||||
hasNumericAttention: nodeNumericAttention,
|
||||
parentRange,
|
||||
isSelected: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
return terminals;
|
||||
}
|
||||
|
||||
// #region Terminal Selection
|
||||
|
||||
/**
|
||||
* Find terminals that overlap the selection range and should be affected
|
||||
* by the attention adjustment. Handles partial group overlap carefully:
|
||||
* terminals with explicit attention inside partially-overlapping groups
|
||||
* are excluded to avoid corrupting explicit weights.
|
||||
*
|
||||
* When the cursor is at a boundary between two tokens (e.g. "word|,"),
|
||||
* both tokens technically overlap the cursor position. In this case we
|
||||
* prefer word/embedding terminals over punctuation/whitespace so that
|
||||
* adjusting attention at a word boundary doesn't accidentally include
|
||||
* adjacent punctuation.
|
||||
*/
|
||||
function selectTerminals(terminals: Terminal[], selStart: number, selEnd: number): Terminal[] {
|
||||
const result = terminals.filter((t) => {
|
||||
const isOverlapping =
|
||||
(t.range.start < selEnd && t.range.end > selStart) ||
|
||||
(selStart === selEnd && t.range.start <= selStart && t.range.end >= selStart);
|
||||
|
||||
if (!isOverlapping) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (t.parentRange) {
|
||||
const parentContainsSelection = t.parentRange.start <= selStart && t.parentRange.end >= selEnd;
|
||||
const selectionCoversParent = selStart <= t.parentRange.start && selEnd >= t.parentRange.end;
|
||||
|
||||
if (!parentContainsSelection && !selectionCoversParent) {
|
||||
// Partial overlap between selection and parent group
|
||||
if (t.hasExplicitAttention) {
|
||||
return false; // Don't modify explicit weight in partially-overlapping group
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
// When the cursor is at a token boundary (no selection range), multiple tokens
|
||||
// can match. Prefer word/embedding terminals over punctuation/whitespace.
|
||||
if (selStart === selEnd && result.length > 1) {
|
||||
const contentTerminals = result.filter((t) => t.type === 'word' || t.type === 'embedding');
|
||||
if (contentTerminals.length > 0) {
|
||||
return contentTerminals;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// #region Weight Adjustment
|
||||
|
||||
/**
|
||||
* Apply weight changes to the selected terminals based on direction.
|
||||
* Numeric weights use additive steps; +/- syntax uses multiplicative steps.
|
||||
* All results are rounded to 4 decimal places to prevent floating-point drift.
|
||||
*/
|
||||
function adjustWeights(terminals: Terminal[], direction: AttentionDirection): void {
|
||||
for (const terminal of terminals) {
|
||||
if (terminal.hasNumericAttention) {
|
||||
// Additive step for explicit numeric weights (e.g. 1.1 → 1.2)
|
||||
if (direction === 'increment') {
|
||||
terminal.weight = Number((terminal.weight + NUMERIC_ATTENTION_STEP).toFixed(4));
|
||||
} else {
|
||||
terminal.weight = Number((terminal.weight - NUMERIC_ATTENTION_STEP).toFixed(4));
|
||||
}
|
||||
} else {
|
||||
// Multiplicative step for +/- syntax weights, rounded to prevent drift
|
||||
if (direction === 'increment') {
|
||||
terminal.weight = Number((terminal.weight * ATTENTION_STEP).toFixed(4));
|
||||
} else {
|
||||
terminal.weight = Number((terminal.weight / ATTENTION_STEP).toFixed(4));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #region Find Selected Group (fallback)
|
||||
|
||||
/**
|
||||
* When no terminals directly overlap the selection (e.g. cursor is on a group
|
||||
* boundary character), find the innermost group that overlaps the selection.
|
||||
*/
|
||||
function findSelectedGroup(nodes: ASTNode[], start: number, end: number): ASTNode | null {
|
||||
for (const node of nodes) {
|
||||
if (node.type === 'group') {
|
||||
const foundInChildren = findSelectedGroup(node.children, start, end);
|
||||
if (foundInChildren) {
|
||||
return foundInChildren;
|
||||
}
|
||||
if (node.range.start < end && node.range.end > start) {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// #region Regroup Terminals into AST
|
||||
|
||||
/**
|
||||
* Reconstruct an AST from a flat list of terminals with adjusted weights.
|
||||
* Groups consecutive terminals with compatible weights using +/- or numeric syntax.
|
||||
*
|
||||
* Note: Reconstructed group nodes use `range: { start: 0, end: 0 }` as a sentinel
|
||||
* value since the original source positions are no longer meaningful after regrouping.
|
||||
* These nodes are only used for serialization output, never for source-position lookups.
|
||||
*/
|
||||
function groupTerminals(terminals: Terminal[]): ASTNode[] {
|
||||
if (terminals.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
/** Sentinel range for reconstructed nodes whose original positions are not applicable. */
|
||||
const NO_RANGE = { start: 0, end: 0 };
|
||||
|
||||
const nodes: ASTNode[] = [];
|
||||
let i = 0;
|
||||
|
||||
while (i < terminals.length) {
|
||||
const t = terminals[i]!;
|
||||
const weight = t.weight;
|
||||
const stepCount = getAttentionStepCount(weight);
|
||||
|
||||
const findRunEnd = (predicate: (w: number) => boolean) => {
|
||||
let j = i;
|
||||
while (j < terminals.length) {
|
||||
const next = terminals[j]!;
|
||||
if (predicate(next.weight)) {
|
||||
j++;
|
||||
} else if (next.type === 'whitespace') {
|
||||
let k = j + 1;
|
||||
while (k < terminals.length && terminals[k]!.type === 'whitespace') {
|
||||
k++;
|
||||
}
|
||||
if (k < terminals.length && predicate(terminals[k]!.weight)) {
|
||||
j = k;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
// ── +/- attention (weight is a non-zero power of ATTENTION_STEP) ──
|
||||
// Skip this branch if the terminal prefers numeric format to avoid an
|
||||
// infinite loop (predicate would reject it, findRunEnd returns i, i never advances).
|
||||
if (stepCount !== null && stepCount !== 0 && !t.hasNumericAttention) {
|
||||
const isPositive = stepCount > 0;
|
||||
const sign: '+' | '-' = isPositive ? '+' : '-';
|
||||
const predicate = (t: Terminal): boolean => {
|
||||
if (t.hasNumericAttention) {
|
||||
return false; // Numeric-preference terminals should not join +/- runs
|
||||
}
|
||||
}
|
||||
return j;
|
||||
};
|
||||
const sc = getAttentionStepCount(t.weight);
|
||||
return sc !== null && (isPositive ? sc > 0 : sc < 0);
|
||||
};
|
||||
const factor = isPositive ? ATTENTION_STEP : 1 / ATTENTION_STEP;
|
||||
|
||||
// Check for + (>= 1.1)
|
||||
if (weight >= ATTENTION_STEP - 0.001) {
|
||||
const j = findRunEnd((w) => w >= ATTENTION_STEP - 0.001);
|
||||
const j = findRunEnd(terminals, i, predicate);
|
||||
|
||||
// Trim whitespace from the content run boundaries
|
||||
let runStart = i;
|
||||
let runEnd = j;
|
||||
while (runStart < runEnd && terminals[runStart]!.type === 'whitespace') {
|
||||
@@ -247,28 +503,31 @@ function groupTerminals(terminals: Terminal[]): ASTNode[] {
|
||||
runEnd--;
|
||||
}
|
||||
|
||||
// Emit leading whitespace as standalone nodes
|
||||
for (let k = i; k < runStart; k++) {
|
||||
nodes.push(createNodeFromTerminal(terminals[k]!));
|
||||
}
|
||||
|
||||
if (runStart < runEnd) {
|
||||
const slice = terminals.slice(runStart, runEnd).map((t) => ({ ...t, weight: t.weight / ATTENTION_STEP }));
|
||||
// Factor out one level of attention and recurse
|
||||
const slice = terminals.slice(runStart, runEnd).map((t) => ({ ...t, weight: t.weight / factor }));
|
||||
const children = groupTerminals(slice);
|
||||
const isSelection = slice.every((t) => t.isSelected);
|
||||
|
||||
if (children.length === 1) {
|
||||
const child = children[0]!;
|
||||
if (child.type === 'word' || child.type === 'group') {
|
||||
const newAttention = addAttention(child.attention, '+');
|
||||
nodes.push({ ...child, attention: newAttention });
|
||||
const newAttention = addAttention(child.attention, sign);
|
||||
nodes.push({ ...child, attention: newAttention, isSelection: isSelection || undefined });
|
||||
} else {
|
||||
nodes.push({ type: 'group', children, attention: '+', range: { start: 0, end: 0 }, isSelection });
|
||||
nodes.push({ type: 'group', children, attention: sign, range: NO_RANGE, isSelection });
|
||||
}
|
||||
} else {
|
||||
nodes.push({ type: 'group', children, attention: '+', range: { start: 0, end: 0 }, isSelection });
|
||||
nodes.push({ type: 'group', children, attention: sign, range: NO_RANGE, isSelection });
|
||||
}
|
||||
}
|
||||
|
||||
// Emit trailing whitespace as standalone nodes
|
||||
for (let k = runEnd; k < j; k++) {
|
||||
nodes.push(createNodeFromTerminal(terminals[k]!));
|
||||
}
|
||||
@@ -277,126 +536,103 @@ function groupTerminals(terminals: Terminal[]): ASTNode[] {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for - (<= 0.909)
|
||||
if (weight <= 1 / ATTENTION_STEP + 0.001) {
|
||||
const j = findRunEnd((w) => w <= 1 / ATTENTION_STEP + 0.001);
|
||||
|
||||
let runStart = i;
|
||||
let runEnd = j;
|
||||
while (runStart < runEnd && terminals[runStart]!.type === 'whitespace') {
|
||||
runStart++;
|
||||
}
|
||||
while (runEnd > runStart && terminals[runEnd - 1]!.type === 'whitespace') {
|
||||
runEnd--;
|
||||
}
|
||||
|
||||
for (let k = i; k < runStart; k++) {
|
||||
nodes.push(createNodeFromTerminal(terminals[k]!));
|
||||
}
|
||||
|
||||
if (runStart < runEnd) {
|
||||
const slice = terminals.slice(runStart, runEnd).map((t) => ({ ...t, weight: t.weight * ATTENTION_STEP }));
|
||||
const children = groupTerminals(slice);
|
||||
const isSelection = slice.every((t) => t.isSelected);
|
||||
|
||||
if (children.length === 1) {
|
||||
const child = children[0]!;
|
||||
if (child.type === 'word' || child.type === 'group') {
|
||||
const newAttention = addAttention(child.attention, '-');
|
||||
nodes.push({ ...child, attention: newAttention });
|
||||
} else {
|
||||
nodes.push({ type: 'group', children, attention: '-', range: { start: 0, end: 0 }, isSelection });
|
||||
}
|
||||
} else {
|
||||
nodes.push({ type: 'group', children, attention: '-', range: { start: 0, end: 0 }, isSelection });
|
||||
}
|
||||
}
|
||||
|
||||
for (let k = runEnd; k < j; k++) {
|
||||
nodes.push(createNodeFromTerminal(terminals[k]!));
|
||||
}
|
||||
|
||||
i = j;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Residual or 1.0
|
||||
if (Math.abs(weight - 1.0) < 0.001) {
|
||||
// ── Neutral weight (≈ 1.0) ──
|
||||
if (Math.abs(weight - 1.0) < WEIGHT_TOLERANCE) {
|
||||
nodes.push(createNodeFromTerminal(t));
|
||||
i++;
|
||||
} else {
|
||||
let j = i;
|
||||
while (j < terminals.length && Math.abs(terminals[j]!.weight - weight) < 0.001) {
|
||||
j++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// ── Numeric weight (not a power of ATTENTION_STEP) ──
|
||||
{
|
||||
const j = findRunEnd(terminals, i, (t) => Math.abs(t.weight - weight) < WEIGHT_TOLERANCE);
|
||||
|
||||
// Trim whitespace from the content run boundaries (same as +/- branch)
|
||||
let runStart = i;
|
||||
let runEnd = j;
|
||||
while (runStart < runEnd && terminals[runStart]!.type === 'whitespace') {
|
||||
runStart++;
|
||||
}
|
||||
while (runEnd > runStart && terminals[runEnd - 1]!.type === 'whitespace') {
|
||||
runEnd--;
|
||||
}
|
||||
|
||||
const groupTerminalsSlice = terminals.slice(i, j).map((t) => ({ ...t, weight: 1.0 }));
|
||||
const children = groupTerminals(groupTerminalsSlice);
|
||||
const isSelection = groupTerminalsSlice.every((t) => t.isSelected);
|
||||
|
||||
const weightStr = Number(weight.toFixed(4));
|
||||
|
||||
if (children.length === 1) {
|
||||
const child = children[0]!;
|
||||
if (child.type === 'word' || child.type === 'group') {
|
||||
nodes.push({ ...child, attention: weightStr });
|
||||
} else {
|
||||
nodes.push({ type: 'group', children, attention: weightStr, range: { start: 0, end: 0 }, isSelection });
|
||||
}
|
||||
} else {
|
||||
nodes.push({ type: 'group', children, attention: weightStr, range: { start: 0, end: 0 }, isSelection });
|
||||
// Emit leading whitespace as standalone nodes
|
||||
for (let k = i; k < runStart; k++) {
|
||||
nodes.push(createNodeFromTerminal(terminals[k]!));
|
||||
}
|
||||
|
||||
if (runStart < runEnd) {
|
||||
const groupSlice = terminals.slice(runStart, runEnd).map((t) => ({ ...t, weight: 1.0 }));
|
||||
const children = groupTerminals(groupSlice);
|
||||
const isSelection = groupSlice.every((t) => t.isSelected);
|
||||
const weightNum = Number(weight.toFixed(4));
|
||||
|
||||
nodes.push({ type: 'group', children, attention: weightNum, range: NO_RANGE, isSelection });
|
||||
}
|
||||
|
||||
// Emit trailing whitespace as standalone nodes
|
||||
for (let k = runEnd; k < j; k++) {
|
||||
nodes.push(createNodeFromTerminal(terminals[k]!));
|
||||
}
|
||||
|
||||
i = j;
|
||||
}
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
function createNodeFromTerminal(t: Terminal): ASTNode {
|
||||
if (t.type === 'word') {
|
||||
return { type: 'word', text: t.text, range: t.range, isSelection: t.isSelected };
|
||||
/**
|
||||
* Find the end of a "run" of terminals whose weights satisfy a predicate.
|
||||
* Whitespace terminals are included if the next non-whitespace terminal also satisfies the predicate.
|
||||
* Note: The returned index may point to a whitespace token that is NOT included in the run;
|
||||
* the caller is responsible for trimming trailing whitespace from the run boundaries.
|
||||
*/
|
||||
function findRunEnd(terminals: Terminal[], start: number, predicate: (t: Terminal) => boolean): number {
|
||||
let j = start;
|
||||
while (j < terminals.length) {
|
||||
const next = terminals[j]!;
|
||||
if (predicate(next)) {
|
||||
j++;
|
||||
} else if (next.type === 'whitespace') {
|
||||
// Look ahead past consecutive whitespace
|
||||
let k = j + 1;
|
||||
while (k < terminals.length && terminals[k]!.type === 'whitespace') {
|
||||
k++;
|
||||
}
|
||||
if (k < terminals.length && predicate(terminals[k]!)) {
|
||||
j = k;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (t.type === 'whitespace') {
|
||||
return { type: 'whitespace', value: t.text, range: t.range, isSelection: t.isSelected };
|
||||
}
|
||||
if (t.type === 'punct') {
|
||||
return { type: 'punct', value: t.text, range: t.range, isSelection: t.isSelected };
|
||||
}
|
||||
if (t.type === 'embedding') {
|
||||
return { type: 'embedding', value: t.text, range: t.range, isSelection: t.isSelected };
|
||||
}
|
||||
if (t.type === 'escaped_paren') {
|
||||
return { type: 'escaped_paren', value: t.text as '(' | ')', range: t.range, isSelection: t.isSelected };
|
||||
}
|
||||
return { type: 'word', text: t.text, range: t.range, isSelection: t.isSelected };
|
||||
return j;
|
||||
}
|
||||
|
||||
function addAttention(current: Attention | undefined, added: string): Attention | undefined {
|
||||
if (!current) {
|
||||
return added;
|
||||
/**
|
||||
* Convert a Terminal back into a leaf ASTNode.
|
||||
*/
|
||||
function createNodeFromTerminal(t: Terminal): ASTNode {
|
||||
switch (t.type) {
|
||||
case 'word':
|
||||
return { type: 'word', text: t.text, range: t.range, isSelection: t.isSelected || undefined };
|
||||
case 'whitespace':
|
||||
return { type: 'whitespace', value: t.text, range: t.range, isSelection: t.isSelected || undefined };
|
||||
case 'punct':
|
||||
return { type: 'punct', value: t.text, range: t.range, isSelection: t.isSelected || undefined };
|
||||
case 'embedding':
|
||||
return { type: 'embedding', value: t.text, range: t.range, isSelection: t.isSelected || undefined };
|
||||
case 'escaped_paren':
|
||||
return {
|
||||
type: 'escaped_paren',
|
||||
value: t.text as '(' | ')',
|
||||
range: t.range,
|
||||
isSelection: t.isSelected || undefined,
|
||||
};
|
||||
default:
|
||||
return { type: 'word', text: t.text, range: t.range, isSelection: t.isSelected || undefined };
|
||||
}
|
||||
if (typeof current === 'number') {
|
||||
if (added === '+') {
|
||||
return current * ATTENTION_STEP;
|
||||
}
|
||||
if (added === '-') {
|
||||
return current / ATTENTION_STEP;
|
||||
}
|
||||
return current;
|
||||
}
|
||||
if (added === '+') {
|
||||
if (current.startsWith('-')) {
|
||||
const res = current.substring(1);
|
||||
return res === '' ? undefined : res;
|
||||
}
|
||||
return `${current}+`;
|
||||
}
|
||||
if (added === '-') {
|
||||
if (current.startsWith('+')) {
|
||||
const res = current.substring(1);
|
||||
return res === '' ? undefined : res;
|
||||
}
|
||||
return `${current}-`;
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Center,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
Heading,
|
||||
Input,
|
||||
Spinner,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import type { ChangeEvent, FormEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useGetSetupStatusQuery, useSetupMutation } from 'services/api/endpoints/auth';
|
||||
|
||||
const validatePasswordStrength = (
|
||||
password: string,
|
||||
t: (key: string) => string
|
||||
): { isValid: boolean; message: string } => {
|
||||
if (password.length < 8) {
|
||||
return { isValid: false, message: t('auth.setup.passwordTooShort') };
|
||||
}
|
||||
|
||||
const hasUpper = /[A-Z]/.test(password);
|
||||
const hasLower = /[a-z]/.test(password);
|
||||
const hasDigit = /\d/.test(password);
|
||||
|
||||
if (!hasUpper || !hasLower || !hasDigit) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: t('auth.setup.passwordMissingRequirements'),
|
||||
};
|
||||
}
|
||||
|
||||
return { isValid: true, message: '' };
|
||||
};
|
||||
|
||||
export const AdministratorSetup = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const [email, setEmail] = useState('');
|
||||
const [displayName, setDisplayName] = useState('');
|
||||
const [password, setPassword] = useState('');
|
||||
const [confirmPassword, setConfirmPassword] = useState('');
|
||||
const [setup, { isLoading, error }] = useSetupMutation();
|
||||
const { data: setupStatus, isLoading: isLoadingSetup } = useGetSetupStatusQuery();
|
||||
|
||||
// Redirect to app if multiuser mode is disabled
|
||||
useEffect(() => {
|
||||
if (!isLoadingSetup && setupStatus && !setupStatus.multiuser_enabled) {
|
||||
navigate('/app', { replace: true });
|
||||
}
|
||||
}, [setupStatus, isLoadingSetup, navigate]);
|
||||
|
||||
const passwordValidation = validatePasswordStrength(password, t);
|
||||
const passwordsMatch = password === confirmPassword;
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (e: FormEvent) => {
|
||||
e.preventDefault();
|
||||
|
||||
if (!passwordValidation.isValid) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!passwordsMatch) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await setup({ email, display_name: displayName, password }).unwrap();
|
||||
if (result.success) {
|
||||
// Auto-login after setup - need to call login API
|
||||
// For now, just redirect to login page
|
||||
window.location.href = '/login';
|
||||
}
|
||||
} catch {
|
||||
// Error is handled by RTK Query and displayed via error state
|
||||
}
|
||||
},
|
||||
[email, displayName, password, passwordValidation.isValid, passwordsMatch, setup]
|
||||
);
|
||||
|
||||
const handleEmailChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setEmail(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleDisplayNameChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setDisplayName(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handlePasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setPassword(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleConfirmPasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setConfirmPassword(e.target.value);
|
||||
}, []);
|
||||
|
||||
const errorMessage = error
|
||||
? 'data' in error && typeof error.data === 'object' && error.data && 'detail' in error.data
|
||||
? String(error.data.detail)
|
||||
: t('auth.setup.setupFailed')
|
||||
: null;
|
||||
|
||||
// Show loading spinner while checking setup status or redirecting
|
||||
if (isLoadingSetup || (setupStatus && !setupStatus.multiuser_enabled)) {
|
||||
return (
|
||||
<Center w="100dvw" h="100dvh">
|
||||
<Spinner size="xl" />
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Center w="100dvw" h="100dvh" bg="base.900">
|
||||
<Box w="full" maxW="600px" p={8} borderRadius="lg" bg="base.800" boxShadow="dark-lg">
|
||||
<form onSubmit={handleSubmit}>
|
||||
<VStack spacing={6} align="stretch">
|
||||
<VStack spacing={2}>
|
||||
<Heading size="lg" textAlign="center">
|
||||
{t('auth.setup.title')}
|
||||
</Heading>
|
||||
<Text fontSize="sm" color="base.400" textAlign="center">
|
||||
{t('auth.setup.subtitle')}
|
||||
</Text>
|
||||
</VStack>
|
||||
|
||||
<FormControl isRequired>
|
||||
<Grid templateColumns="140px 1fr" gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" pt={2} mb={0}>
|
||||
{t('auth.setup.email')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input
|
||||
type="email"
|
||||
value={email}
|
||||
onChange={handleEmailChange}
|
||||
placeholder={t('auth.setup.emailPlaceholder')}
|
||||
autoComplete="email"
|
||||
autoFocus
|
||||
/>
|
||||
<FormHelperText mt={1}>{t('auth.setup.emailHelper')}</FormHelperText>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
<FormControl isRequired>
|
||||
<Grid templateColumns="140px 1fr" gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" pt={2} mb={0}>
|
||||
{t('auth.setup.displayName')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input
|
||||
type="text"
|
||||
value={displayName}
|
||||
onChange={handleDisplayNameChange}
|
||||
placeholder={t('auth.setup.displayNamePlaceholder')}
|
||||
/>
|
||||
<FormHelperText mt={1}>{t('auth.setup.displayNameHelper')}</FormHelperText>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
<FormControl isRequired isInvalid={password.length > 0 && !passwordValidation.isValid}>
|
||||
<Grid templateColumns="140px 1fr" gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" pt={2} mb={0}>
|
||||
{t('auth.setup.password')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={handlePasswordChange}
|
||||
placeholder={t('auth.setup.passwordPlaceholder')}
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
{password.length > 0 && !passwordValidation.isValid && (
|
||||
<FormErrorMessage>{passwordValidation.message}</FormErrorMessage>
|
||||
)}
|
||||
{password.length === 0 && <FormHelperText mt={1}>{t('auth.setup.passwordHelper')}</FormHelperText>}
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
<FormControl isRequired isInvalid={confirmPassword.length > 0 && !passwordsMatch}>
|
||||
<Grid templateColumns="140px 1fr" gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" pt={2} mb={0}>
|
||||
{t('auth.setup.confirmPassword')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input
|
||||
type="password"
|
||||
value={confirmPassword}
|
||||
onChange={handleConfirmPasswordChange}
|
||||
placeholder={t('auth.setup.confirmPasswordPlaceholder')}
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
{confirmPassword.length > 0 && !passwordsMatch && (
|
||||
<FormErrorMessage>{t('auth.setup.passwordsDoNotMatch')}</FormErrorMessage>
|
||||
)}
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
isLoading={isLoading}
|
||||
loadingText={t('auth.setup.creatingAccount')}
|
||||
colorScheme="invokeBlue"
|
||||
size="lg"
|
||||
w="full"
|
||||
isDisabled={!passwordValidation.isValid || !passwordsMatch}
|
||||
>
|
||||
{t('auth.setup.createAccount')}
|
||||
</Button>
|
||||
|
||||
{errorMessage && (
|
||||
<Flex p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm" justifyContent="center">
|
||||
<Text fontWeight="semibold">{errorMessage}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</VStack>
|
||||
</form>
|
||||
</Box>
|
||||
</Center>
|
||||
);
|
||||
});
|
||||
|
||||
AdministratorSetup.displayName = 'AdministratorSetup';
|
||||
168
invokeai/frontend/web/src/features/auth/components/LoginPage.tsx
Normal file
168
invokeai/frontend/web/src/features/auth/components/LoginPage.tsx
Normal file
@@ -0,0 +1,168 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Center,
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormLabel,
|
||||
Heading,
|
||||
Input,
|
||||
Spinner,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setCredentials } from 'features/auth/store/authSlice';
|
||||
import type { ChangeEvent, FormEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useGetSetupStatusQuery, useLoginMutation } from 'services/api/endpoints/auth';
|
||||
|
||||
export const LoginPage = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const [email, setEmail] = useState('');
|
||||
const [password, setPassword] = useState('');
|
||||
const [rememberMe, setRememberMe] = useState(true);
|
||||
const [login, { isLoading, error }] = useLoginMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: setupStatus, isLoading: isLoadingSetup } = useGetSetupStatusQuery();
|
||||
|
||||
// Redirect to app if multiuser mode is disabled
|
||||
useEffect(() => {
|
||||
if (!isLoadingSetup && setupStatus && !setupStatus.multiuser_enabled) {
|
||||
navigate('/app', { replace: true });
|
||||
}
|
||||
}, [setupStatus, isLoadingSetup, navigate]);
|
||||
|
||||
// Redirect to setup page if setup is required
|
||||
useEffect(() => {
|
||||
if (!isLoadingSetup && setupStatus?.setup_required) {
|
||||
navigate('/setup', { replace: true });
|
||||
}
|
||||
}, [setupStatus, isLoadingSetup, navigate]);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (e: FormEvent) => {
|
||||
e.preventDefault();
|
||||
try {
|
||||
const result = await login({ email, password, remember_me: rememberMe }).unwrap();
|
||||
// Map the UserDTO from API to our User type
|
||||
const user = {
|
||||
user_id: result.user.user_id,
|
||||
email: result.user.email,
|
||||
display_name: result.user.display_name || null,
|
||||
is_admin: result.user.is_admin || false,
|
||||
is_active: result.user.is_active || true,
|
||||
};
|
||||
dispatch(setCredentials({ token: result.token, user }));
|
||||
// Force a page reload to ensure all user-specific state is loaded from server
|
||||
// This is important for multiuser isolation to prevent state leakage
|
||||
window.location.href = '/app';
|
||||
} catch {
|
||||
// Error is handled by RTK Query and displayed via error state
|
||||
}
|
||||
},
|
||||
[email, password, rememberMe, login, dispatch]
|
||||
);
|
||||
|
||||
const handleEmailChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setEmail(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handlePasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setPassword(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleRememberMeChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setRememberMe(e.target.checked);
|
||||
}, []);
|
||||
|
||||
const errorMessage = error
|
||||
? 'data' in error && typeof error.data === 'object' && error.data && 'detail' in error.data
|
||||
? String(error.data.detail)
|
||||
: t('auth.login.loginFailed')
|
||||
: null;
|
||||
|
||||
// Show loading spinner while checking setup status or redirecting
|
||||
if (isLoadingSetup || (setupStatus && !setupStatus.multiuser_enabled)) {
|
||||
return (
|
||||
<Center w="100dvw" h="100dvh">
|
||||
<Spinner size="xl" />
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
// Show loading spinner if setup is required (redirecting to setup)
|
||||
if (setupStatus?.setup_required) {
|
||||
return (
|
||||
<Center w="100dvw" h="100dvh">
|
||||
<Spinner size="xl" />
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Center w="100dvw" h="100dvh" bg="base.900">
|
||||
<Box w="full" maxW="400px" p={8} borderRadius="lg" bg="base.800" boxShadow="dark-lg">
|
||||
<form onSubmit={handleSubmit}>
|
||||
<VStack spacing={6} align="stretch">
|
||||
<Heading size="lg" textAlign="center">
|
||||
{t('auth.login.title')}
|
||||
</Heading>
|
||||
|
||||
<FormControl isRequired isInvalid={!!errorMessage}>
|
||||
<FormLabel>{t('auth.login.email')}</FormLabel>
|
||||
<Input
|
||||
type="email"
|
||||
value={email}
|
||||
onChange={handleEmailChange}
|
||||
placeholder={t('auth.login.emailPlaceholder')}
|
||||
autoComplete="email"
|
||||
autoFocus
|
||||
/>
|
||||
</FormControl>
|
||||
|
||||
<FormControl isRequired isInvalid={!!errorMessage}>
|
||||
<FormLabel>{t('auth.login.password')}</FormLabel>
|
||||
<Input
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={handlePasswordChange}
|
||||
placeholder={t('auth.login.passwordPlaceholder')}
|
||||
autoComplete="current-password"
|
||||
/>
|
||||
{errorMessage && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
|
||||
<Checkbox isChecked={rememberMe} onChange={handleRememberMeChange}>
|
||||
{t('auth.login.rememberMe')}
|
||||
</Checkbox>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
isLoading={isLoading}
|
||||
loadingText={t('auth.login.signingIn')}
|
||||
colorScheme="invokeBlue"
|
||||
size="lg"
|
||||
w="full"
|
||||
>
|
||||
{t('auth.login.signIn')}
|
||||
</Button>
|
||||
|
||||
{errorMessage && (
|
||||
<Flex p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm" justifyContent="center">
|
||||
<Text fontWeight="semibold">{errorMessage}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</VStack>
|
||||
</form>
|
||||
</Box>
|
||||
</Center>
|
||||
);
|
||||
});
|
||||
|
||||
LoginPage.displayName = 'LoginPage';
|
||||
@@ -0,0 +1,100 @@
|
||||
import { Center, Spinner } from '@invoke-ai/ui-library';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { logout, setCredentials } from 'features/auth/store/authSlice';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useEffect } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useGetCurrentUserQuery, useGetSetupStatusQuery } from 'services/api/endpoints/auth';
|
||||
|
||||
interface ProtectedRouteProps {
|
||||
requireAdmin?: boolean;
|
||||
}
|
||||
|
||||
export const ProtectedRoute = memo(({ children, requireAdmin = false }: PropsWithChildren<ProtectedRouteProps>) => {
|
||||
const isAuthenticated = useAppSelector((state: RootState) => state.auth?.isAuthenticated || false);
|
||||
const token = useAppSelector((state: RootState) => state.auth?.token);
|
||||
const user = useAppSelector((state: RootState) => state.auth?.user);
|
||||
const navigate = useNavigate();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
// Check if multiuser mode is enabled
|
||||
const { data: setupStatus } = useGetSetupStatusQuery();
|
||||
const multiuserEnabled = setupStatus?.multiuser_enabled ?? true; // Default to true for safety
|
||||
|
||||
// Only fetch user if we have a token but no user data, and multiuser mode is enabled
|
||||
const shouldFetchUser = multiuserEnabled && isAuthenticated && token && !user;
|
||||
const {
|
||||
data: currentUser,
|
||||
isLoading: isLoadingUser,
|
||||
error: userError,
|
||||
} = useGetCurrentUserQuery(undefined, {
|
||||
skip: !shouldFetchUser,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
// If we have a token but fetching user failed, token is invalid - logout
|
||||
if (userError && isAuthenticated) {
|
||||
dispatch(logout());
|
||||
navigate('/login', { replace: true });
|
||||
}
|
||||
}, [userError, isAuthenticated, dispatch, navigate]);
|
||||
|
||||
useEffect(() => {
|
||||
// If we successfully fetched user data, update auth state
|
||||
if (currentUser && token && !user) {
|
||||
const userObj = {
|
||||
user_id: currentUser.user_id,
|
||||
email: currentUser.email,
|
||||
display_name: currentUser.display_name || null,
|
||||
is_admin: currentUser.is_admin || false,
|
||||
is_active: currentUser.is_active || true,
|
||||
};
|
||||
dispatch(setCredentials({ token, user: userObj }));
|
||||
}
|
||||
}, [currentUser, token, user, dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
// If multiuser is disabled, allow access without authentication
|
||||
if (!multiuserEnabled) {
|
||||
// Clear any persisted auth state when switching to single-user mode
|
||||
if (isAuthenticated) {
|
||||
dispatch(logout());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// In multiuser mode, check authentication
|
||||
if (!isLoadingUser && !isAuthenticated) {
|
||||
navigate('/login', { replace: true });
|
||||
} else if (!isLoadingUser && isAuthenticated && user && requireAdmin && !user.is_admin) {
|
||||
navigate('/', { replace: true });
|
||||
}
|
||||
}, [isAuthenticated, isLoadingUser, requireAdmin, user, navigate, multiuserEnabled, dispatch]);
|
||||
|
||||
// In single-user mode, always allow access
|
||||
if (!multiuserEnabled) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
// Show loading while fetching user data
|
||||
if (isLoadingUser || (isAuthenticated && !user)) {
|
||||
return (
|
||||
<Center w="100dvw" h="100dvh">
|
||||
<Spinner size="xl" />
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAuthenticated) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (requireAdmin && !user?.is_admin) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <>{children}</>;
|
||||
});
|
||||
|
||||
ProtectedRoute.displayName = 'ProtectedRoute';
|
||||
@@ -0,0 +1,640 @@
|
||||
import {
|
||||
Badge,
|
||||
Box,
|
||||
Button,
|
||||
Center,
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
Heading,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Spinner,
|
||||
Switch,
|
||||
Table,
|
||||
Tbody,
|
||||
Td,
|
||||
Text,
|
||||
Th,
|
||||
Thead,
|
||||
Tooltip,
|
||||
Tr,
|
||||
useDisclosure,
|
||||
VStack,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCurrentUser } from 'features/auth/store/authSlice';
|
||||
import type { ChangeEvent, FormEvent } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
PiArrowLeftBold,
|
||||
PiEyeBold,
|
||||
PiEyeSlashBold,
|
||||
PiLightningFill,
|
||||
PiPencilBold,
|
||||
PiPlusBold,
|
||||
PiTrashBold,
|
||||
} from 'react-icons/pi';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import type { UserDTO } from 'services/api/endpoints/auth';
|
||||
import {
|
||||
useCreateUserMutation,
|
||||
useDeleteUserMutation,
|
||||
useLazyGeneratePasswordQuery,
|
||||
useListUsersQuery,
|
||||
useUpdateUserMutation,
|
||||
} from 'services/api/endpoints/auth';
|
||||
|
||||
const validatePasswordStrength = (
|
||||
password: string,
|
||||
t: (key: string) => string
|
||||
): { isValid: boolean; message: string } => {
|
||||
if (password.length === 0) {
|
||||
return { isValid: true, message: '' };
|
||||
}
|
||||
if (password.length < 8) {
|
||||
return { isValid: false, message: t('auth.setup.passwordTooShort') };
|
||||
}
|
||||
const hasUpper = /[A-Z]/.test(password);
|
||||
const hasLower = /[a-z]/.test(password);
|
||||
const hasDigit = /\d/.test(password);
|
||||
if (!hasUpper || !hasLower || !hasDigit) {
|
||||
return { isValid: false, message: t('auth.setup.passwordMissingRequirements') };
|
||||
}
|
||||
return { isValid: true, message: '' };
|
||||
};
|
||||
|
||||
const FORM_GRID_COLUMNS = '120px 1fr';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Create / Edit user modal
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type UserFormModalProps = {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
/** When provided, the modal operates in "edit" mode for the given user */
|
||||
editUser?: UserDTO | null;
|
||||
};
|
||||
|
||||
const UserFormModal = memo(({ isOpen, onClose, editUser }: UserFormModalProps) => {
|
||||
const { t } = useTranslation();
|
||||
const isEdit = !!editUser;
|
||||
|
||||
const [email, setEmail] = useState(editUser?.email ?? '');
|
||||
const [displayName, setDisplayName] = useState(editUser?.display_name ?? '');
|
||||
const [password, setPassword] = useState('');
|
||||
const [isAdmin, setIsAdmin] = useState(editUser?.is_admin ?? false);
|
||||
const [showPassword, setShowPassword] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const [createUser, { isLoading: isCreating }] = useCreateUserMutation();
|
||||
const [updateUser, { isLoading: isUpdating }] = useUpdateUserMutation();
|
||||
const [triggerGeneratePassword] = useLazyGeneratePasswordQuery();
|
||||
|
||||
const isLoading = isCreating || isUpdating;
|
||||
const passwordValidation = validatePasswordStrength(password, t);
|
||||
|
||||
const handleGeneratePassword = useCallback(async () => {
|
||||
try {
|
||||
const result = await triggerGeneratePassword().unwrap();
|
||||
setPassword(result.password);
|
||||
setShowPassword(true);
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}, [triggerGeneratePassword]);
|
||||
|
||||
const toggleShowPassword = useCallback(() => {
|
||||
setShowPassword((v) => !v);
|
||||
}, []);
|
||||
|
||||
const handleEmailChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setEmail(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleDisplayNameChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setDisplayName(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handlePasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setPassword(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleIsAdminChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setIsAdmin(e.target.checked);
|
||||
}, []);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (e: FormEvent) => {
|
||||
e.preventDefault();
|
||||
setError(null);
|
||||
|
||||
if (!isEdit && (!password || !passwordValidation.isValid)) {
|
||||
return;
|
||||
}
|
||||
if (isEdit && password && !passwordValidation.isValid) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
if (isEdit && editUser) {
|
||||
const updateData: Parameters<typeof updateUser>[0]['data'] = {
|
||||
display_name: displayName || null,
|
||||
is_admin: isAdmin,
|
||||
};
|
||||
if (password) {
|
||||
updateData.password = password;
|
||||
}
|
||||
await updateUser({
|
||||
userId: editUser.user_id,
|
||||
data: updateData,
|
||||
}).unwrap();
|
||||
} else {
|
||||
await createUser({
|
||||
email,
|
||||
display_name: displayName || null,
|
||||
password,
|
||||
is_admin: isAdmin,
|
||||
}).unwrap();
|
||||
}
|
||||
onClose();
|
||||
} catch (err) {
|
||||
const detail =
|
||||
err && typeof err === 'object' && 'data' in err && typeof (err as { data: unknown }).data === 'object'
|
||||
? ((err as { data: { detail?: string } }).data?.detail ?? t('auth.userManagement.saveFailed'))
|
||||
: t('auth.userManagement.saveFailed');
|
||||
setError(detail);
|
||||
}
|
||||
},
|
||||
[
|
||||
isEdit,
|
||||
editUser,
|
||||
email,
|
||||
displayName,
|
||||
password,
|
||||
isAdmin,
|
||||
passwordValidation.isValid,
|
||||
createUser,
|
||||
updateUser,
|
||||
onClose,
|
||||
t,
|
||||
]
|
||||
);
|
||||
|
||||
// Reset local state when modal closes
|
||||
const handleClose = useCallback(() => {
|
||||
setEmail(editUser?.email ?? '');
|
||||
setDisplayName(editUser?.display_name ?? '');
|
||||
setPassword('');
|
||||
setIsAdmin(editUser?.is_admin ?? false);
|
||||
setShowPassword(false);
|
||||
setError(null);
|
||||
onClose();
|
||||
}, [editUser, onClose]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={isOpen} onClose={handleClose} isCentered size="md">
|
||||
<ModalOverlay />
|
||||
<ModalContent bg="base.800">
|
||||
<form onSubmit={handleSubmit}>
|
||||
<ModalHeader>{isEdit ? t('auth.userManagement.editUser') : t('auth.userManagement.createUser')}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<VStack spacing={4}>
|
||||
{!isEdit && (
|
||||
<FormControl isRequired>
|
||||
<Grid templateColumns={FORM_GRID_COLUMNS} gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" mb={0} pt={2}>
|
||||
{t('auth.userManagement.email')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input
|
||||
type="email"
|
||||
value={email}
|
||||
onChange={handleEmailChange}
|
||||
placeholder={t('auth.userManagement.emailPlaceholder')}
|
||||
autoComplete="off"
|
||||
/>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl>
|
||||
<Grid templateColumns={FORM_GRID_COLUMNS} gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" mb={0} pt={2}>
|
||||
{t('auth.userManagement.displayName')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input
|
||||
type="text"
|
||||
value={displayName}
|
||||
onChange={handleDisplayNameChange}
|
||||
placeholder={t('auth.userManagement.displayNamePlaceholder')}
|
||||
/>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
<FormControl isInvalid={password.length > 0 && !passwordValidation.isValid} isRequired={!isEdit}>
|
||||
<Grid templateColumns={FORM_GRID_COLUMNS} gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" mb={0} pt={2}>
|
||||
{isEdit ? t('auth.userManagement.newPassword') : t('auth.userManagement.password')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<InputGroup>
|
||||
<Input
|
||||
type={showPassword ? 'text' : 'password'}
|
||||
value={password}
|
||||
onChange={handlePasswordChange}
|
||||
placeholder={
|
||||
isEdit
|
||||
? t('auth.userManagement.newPasswordPlaceholder')
|
||||
: t('auth.userManagement.passwordPlaceholder')
|
||||
}
|
||||
autoComplete="new-password"
|
||||
pr="4.5rem"
|
||||
/>
|
||||
<InputRightElement w="4.5rem">
|
||||
<Tooltip
|
||||
label={
|
||||
showPassword ? t('auth.userManagement.hidePassword') : t('auth.userManagement.showPassword')
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
aria-label={
|
||||
showPassword
|
||||
? t('auth.userManagement.hidePassword')
|
||||
: t('auth.userManagement.showPassword')
|
||||
}
|
||||
icon={showPassword ? <PiEyeSlashBold /> : <PiEyeBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={toggleShowPassword}
|
||||
tabIndex={-1}
|
||||
/>
|
||||
</Tooltip>
|
||||
</InputRightElement>
|
||||
</InputGroup>
|
||||
{password.length > 0 && !passwordValidation.isValid && (
|
||||
<FormErrorMessage>{passwordValidation.message}</FormErrorMessage>
|
||||
)}
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
<Grid templateColumns={FORM_GRID_COLUMNS} gap={4} w="full">
|
||||
<GridItem />
|
||||
<GridItem>
|
||||
<Button size="sm" variant="ghost" onClick={handleGeneratePassword} leftIcon={<PiLightningFill />}>
|
||||
{t('auth.userManagement.generatePassword')}
|
||||
</Button>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
|
||||
<FormControl display="flex" alignItems="center">
|
||||
<FormLabel mb={0}>{t('auth.userManagement.isAdmin')}</FormLabel>
|
||||
<Checkbox isChecked={isAdmin} onChange={handleIsAdminChange} />
|
||||
</FormControl>
|
||||
|
||||
{error && (
|
||||
<Flex p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm" w="full">
|
||||
<Text fontWeight="semibold">{error}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</VStack>
|
||||
</ModalBody>
|
||||
<ModalFooter gap={2}>
|
||||
<Button variant="ghost" onClick={handleClose}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
type="submit"
|
||||
colorScheme="invokeBlue"
|
||||
isLoading={isLoading}
|
||||
isDisabled={!isEdit && (!password || !passwordValidation.isValid)}
|
||||
>
|
||||
{isEdit ? t('common.save') : t('auth.userManagement.createUser')}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</form>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
UserFormModal.displayName = 'UserFormModal';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Delete confirmation modal
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type DeleteUserModalProps = {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
user: UserDTO | null;
|
||||
};
|
||||
|
||||
const DeleteUserModal = memo(({ isOpen, onClose, user }: DeleteUserModalProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [deleteUser, { isLoading }] = useDeleteUserMutation();
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const handleDelete = useCallback(async () => {
|
||||
if (!user) {
|
||||
return;
|
||||
}
|
||||
setError(null);
|
||||
try {
|
||||
await deleteUser(user.user_id).unwrap();
|
||||
onClose();
|
||||
} catch (err) {
|
||||
const detail =
|
||||
err && typeof err === 'object' && 'data' in err && typeof (err as { data: unknown }).data === 'object'
|
||||
? ((err as { data: { detail?: string } }).data?.detail ?? t('auth.userManagement.deleteFailed'))
|
||||
: t('auth.userManagement.deleteFailed');
|
||||
setError(detail);
|
||||
}
|
||||
}, [user, deleteUser, onClose, t]);
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
setError(null);
|
||||
onClose();
|
||||
}, [onClose]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={isOpen} onClose={handleClose} isCentered size="sm">
|
||||
<ModalOverlay />
|
||||
<ModalContent bg="base.800">
|
||||
<ModalHeader>{t('auth.userManagement.deleteUser')}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Text>
|
||||
{t('auth.userManagement.deleteConfirm', {
|
||||
name: user?.display_name ?? user?.email ?? '',
|
||||
})}
|
||||
</Text>
|
||||
{error && (
|
||||
<Flex mt={3} p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm">
|
||||
<Text fontWeight="semibold">{error}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</ModalBody>
|
||||
<ModalFooter gap={2}>
|
||||
<Button variant="ghost" onClick={handleClose}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button colorScheme="error" isLoading={isLoading} onClick={handleDelete}>
|
||||
{t('common.delete')}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
DeleteUserModal.displayName = 'DeleteUserModal';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Inline active/inactive toggle
|
||||
// Wrapping the Switch in a Box lets the Tooltip track mouse-enter/leave
|
||||
// correctly; without it the tooltip may not dismiss on mouse-out.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const UserStatusToggle = memo(({ user, isCurrentUser }: { user: UserDTO; isCurrentUser: boolean }) => {
|
||||
const { t } = useTranslation();
|
||||
const [updateUser, { isLoading }] = useUpdateUserMutation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
async (e: ChangeEvent<HTMLInputElement>) => {
|
||||
await updateUser({ userId: user.user_id, data: { is_active: e.target.checked } })
|
||||
.unwrap()
|
||||
.catch(() => null);
|
||||
},
|
||||
[user.user_id, updateUser]
|
||||
);
|
||||
|
||||
const tooltipLabel = isCurrentUser
|
||||
? t('auth.userManagement.cannotDeactivateSelf')
|
||||
: user.is_active
|
||||
? t('auth.userManagement.deactivate')
|
||||
: t('auth.userManagement.activate');
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltipLabel}>
|
||||
<Box as="span" display="inline-flex">
|
||||
<Switch isChecked={user.is_active} onChange={handleChange} isDisabled={isLoading || isCurrentUser} size="sm" />
|
||||
</Box>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
UserStatusToggle.displayName = 'UserStatusToggle';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const UserManagement = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const currentUser = useAppSelector(selectCurrentUser);
|
||||
const navigate = useNavigate();
|
||||
const { data: users, isLoading, error } = useListUsersQuery();
|
||||
|
||||
const createModal = useDisclosure();
|
||||
const editModal = useDisclosure();
|
||||
const deleteModal = useDisclosure();
|
||||
|
||||
const [selectedUser, setSelectedUser] = useState<UserDTO | null>(null);
|
||||
|
||||
const handleBack = useCallback(() => {
|
||||
navigate(-1);
|
||||
}, [navigate]);
|
||||
|
||||
const handleEdit = useCallback(
|
||||
(user: UserDTO) => {
|
||||
setSelectedUser(user);
|
||||
editModal.onOpen();
|
||||
},
|
||||
[editModal]
|
||||
);
|
||||
|
||||
const handleDelete = useCallback(
|
||||
(user: UserDTO) => {
|
||||
setSelectedUser(user);
|
||||
deleteModal.onOpen();
|
||||
},
|
||||
[deleteModal]
|
||||
);
|
||||
|
||||
const handleEditClose = useCallback(() => {
|
||||
editModal.onClose();
|
||||
setSelectedUser(null);
|
||||
}, [editModal]);
|
||||
|
||||
const handleDeleteClose = useCallback(() => {
|
||||
deleteModal.onClose();
|
||||
setSelectedUser(null);
|
||||
}, [deleteModal]);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Center py={12}>
|
||||
<Spinner size="xl" />
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<Center py={12}>
|
||||
<Text color="error.400">{t('auth.userManagement.loadFailed')}</Text>
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={6}>
|
||||
<Flex justify="space-between" align="center" mb={6}>
|
||||
<Flex align="center" gap={4}>
|
||||
<Button leftIcon={<PiArrowLeftBold />} variant="outline" size="sm" onClick={handleBack}>
|
||||
{t('auth.userManagement.back')}
|
||||
</Button>
|
||||
<Heading size="md">{t('auth.userManagement.title')}</Heading>
|
||||
</Flex>
|
||||
<Button leftIcon={<PiPlusBold />} colorScheme="invokeBlue" size="sm" onClick={createModal.onOpen}>
|
||||
{t('auth.userManagement.createUser')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Box overflowX="auto">
|
||||
<Table variant="simple" size="sm">
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>{t('auth.userManagement.email')}</Th>
|
||||
<Th>{t('auth.userManagement.displayName')}</Th>
|
||||
<Th>{t('auth.userManagement.role')}</Th>
|
||||
<Th>{t('auth.userManagement.status')}</Th>
|
||||
<Th>{t('auth.userManagement.actions')}</Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{(users ?? []).map((user) => (
|
||||
<UserRow
|
||||
key={user.user_id}
|
||||
user={user}
|
||||
isCurrentUser={user.user_id === currentUser?.user_id}
|
||||
onEdit={handleEdit}
|
||||
onDelete={handleDelete}
|
||||
/>
|
||||
))}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</Box>
|
||||
|
||||
{/* Create user modal */}
|
||||
<UserFormModal isOpen={createModal.isOpen} onClose={createModal.onClose} />
|
||||
|
||||
{/* Edit user modal */}
|
||||
<UserFormModal isOpen={editModal.isOpen} onClose={handleEditClose} editUser={selectedUser} />
|
||||
|
||||
{/* Delete confirmation modal */}
|
||||
<DeleteUserModal isOpen={deleteModal.isOpen} onClose={handleDeleteClose} user={selectedUser} />
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
UserManagement.displayName = 'UserManagement';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// User table row
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type UserRowProps = {
|
||||
user: UserDTO;
|
||||
isCurrentUser: boolean;
|
||||
onEdit: (user: UserDTO) => void;
|
||||
onDelete: (user: UserDTO) => void;
|
||||
};
|
||||
|
||||
const UserRow = memo(({ user, isCurrentUser, onEdit, onDelete }: UserRowProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleEdit = useCallback(() => {
|
||||
onEdit(user);
|
||||
}, [user, onEdit]);
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
onDelete(user);
|
||||
}, [user, onDelete]);
|
||||
|
||||
return (
|
||||
<Tr>
|
||||
<Td>
|
||||
<Text fontSize="sm">{user.email}</Text>
|
||||
{isCurrentUser && (
|
||||
<Badge colorScheme="invokeBlue" size="xs" ml={1}>
|
||||
{t('auth.userManagement.you')}
|
||||
</Badge>
|
||||
)}
|
||||
</Td>
|
||||
<Td>
|
||||
<Text fontSize="sm">{user.display_name ?? '—'}</Text>
|
||||
</Td>
|
||||
<Td>
|
||||
{user.is_admin ? (
|
||||
<Badge colorScheme="invokeYellow">{t('auth.admin')}</Badge>
|
||||
) : (
|
||||
<Badge colorScheme="base">{t('auth.userManagement.user')}</Badge>
|
||||
)}
|
||||
</Td>
|
||||
<Td>
|
||||
<UserStatusToggle user={user} isCurrentUser={isCurrentUser} />
|
||||
</Td>
|
||||
<Td>
|
||||
<Flex gap={1}>
|
||||
<Tooltip label={t('auth.userManagement.editUser')}>
|
||||
<IconButton
|
||||
aria-label={t('auth.userManagement.editUser')}
|
||||
icon={<PiPencilBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={handleEdit}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
label={isCurrentUser ? t('auth.userManagement.cannotDeleteSelf') : t('auth.userManagement.deleteUser')}
|
||||
>
|
||||
<IconButton
|
||||
aria-label={t('auth.userManagement.deleteUser')}
|
||||
icon={<PiTrashBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
colorScheme="error"
|
||||
isDisabled={isCurrentUser}
|
||||
onClick={handleDelete}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</Td>
|
||||
</Tr>
|
||||
);
|
||||
});
|
||||
UserRow.displayName = 'UserRow';
|
||||
@@ -0,0 +1,87 @@
|
||||
import { Badge, Flex, IconButton, Menu, MenuButton, MenuItem, MenuList, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { logout, selectCurrentUser } from 'features/auth/store/authSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearBold, PiSignOutBold, PiUserBold, PiUsersBold } from 'react-icons/pi';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useLogoutMutation } from 'services/api/endpoints/auth';
|
||||
|
||||
export const UserMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const user = useAppSelector(selectCurrentUser);
|
||||
const dispatch = useAppDispatch();
|
||||
const navigate = useNavigate();
|
||||
const [logoutMutation] = useLogoutMutation();
|
||||
|
||||
const handleLogout = useCallback(() => {
|
||||
// Call backend logout endpoint
|
||||
logoutMutation()
|
||||
.unwrap()
|
||||
.catch(() => {
|
||||
// Ignore errors - we'll log out locally anyway
|
||||
})
|
||||
.finally(() => {
|
||||
// Clear local state regardless of backend response
|
||||
dispatch(logout());
|
||||
navigate('/login');
|
||||
});
|
||||
}, [dispatch, navigate, logoutMutation]);
|
||||
|
||||
const handleProfile = useCallback(() => {
|
||||
navigate('/profile');
|
||||
}, [navigate]);
|
||||
|
||||
const handleUserManagement = useCallback(() => {
|
||||
navigate('/admin/users');
|
||||
}, [navigate]);
|
||||
|
||||
if (!user) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<Tooltip label={t('auth.userMenu')}>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
aria-label={t('auth.userMenu')}
|
||||
icon={<PiUserBold />}
|
||||
variant="link"
|
||||
minW={8}
|
||||
w={8}
|
||||
h={8}
|
||||
borderRadius="base"
|
||||
/>
|
||||
</Tooltip>
|
||||
<MenuList>
|
||||
<Flex px={3} py={2} flexDir="column" gap={1}>
|
||||
<Text fontSize="sm" fontWeight="semibold" noOfLines={1}>
|
||||
{user.display_name || user.email}
|
||||
</Text>
|
||||
<Text fontSize="xs" color="base.500" noOfLines={1}>
|
||||
{user.email}
|
||||
</Text>
|
||||
{user.is_admin && (
|
||||
<Badge colorScheme="invokeYellow" size="sm" alignSelf="flex-start" mt={1}>
|
||||
{t('auth.admin')}
|
||||
</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
<MenuItem icon={<PiGearBold />} onClick={handleProfile}>
|
||||
{t('auth.profile.menuItem')}
|
||||
</MenuItem>
|
||||
{user.is_admin && (
|
||||
<MenuItem icon={<PiUsersBold />} onClick={handleUserManagement}>
|
||||
{t('auth.userManagement.menuItem')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem icon={<PiSignOutBold />} onClick={handleLogout}>
|
||||
{t('auth.logout')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
UserMenu.displayName = 'UserMenu';
|
||||
@@ -0,0 +1,390 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Center,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
Heading,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
Spinner,
|
||||
Text,
|
||||
Tooltip,
|
||||
VStack,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectAuthToken, selectCurrentUser, setCredentials } from 'features/auth/store/authSlice';
|
||||
import type { ChangeEvent, FormEvent } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiEyeBold, PiEyeSlashBold, PiLightningFill } from 'react-icons/pi';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useLazyGeneratePasswordQuery, useUpdateCurrentUserMutation } from 'services/api/endpoints/auth';
|
||||
|
||||
const validatePasswordStrength = (
|
||||
password: string,
|
||||
t: (key: string) => string
|
||||
): { isValid: boolean; message: string } => {
|
||||
if (password.length === 0) {
|
||||
return { isValid: true, message: '' };
|
||||
}
|
||||
if (password.length < 8) {
|
||||
return { isValid: false, message: t('auth.setup.passwordTooShort') };
|
||||
}
|
||||
const hasUpper = /[A-Z]/.test(password);
|
||||
const hasLower = /[a-z]/.test(password);
|
||||
const hasDigit = /\d/.test(password);
|
||||
if (!hasUpper || !hasLower || !hasDigit) {
|
||||
return { isValid: false, message: t('auth.setup.passwordMissingRequirements') };
|
||||
}
|
||||
return { isValid: true, message: '' };
|
||||
};
|
||||
|
||||
const PASSWORD_GRID_COLUMNS = '180px 1fr';
|
||||
|
||||
export const UserProfile = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const currentUser = useAppSelector(selectCurrentUser);
|
||||
const currentToken = useAppSelector(selectAuthToken);
|
||||
const dispatch = useAppDispatch();
|
||||
const navigate = useNavigate();
|
||||
|
||||
const [displayName, setDisplayName] = useState(currentUser?.display_name ?? '');
|
||||
const [currentPassword, setCurrentPassword] = useState('');
|
||||
const [newPassword, setNewPassword] = useState('');
|
||||
const [confirmPassword, setConfirmPassword] = useState('');
|
||||
const [showCurrentPassword, setShowCurrentPassword] = useState(false);
|
||||
const [showNewPassword, setShowNewPassword] = useState(false);
|
||||
const [showConfirmPassword, setShowConfirmPassword] = useState(false);
|
||||
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||
|
||||
const [updateCurrentUser, { isLoading }] = useUpdateCurrentUserMutation();
|
||||
const [triggerGeneratePassword] = useLazyGeneratePasswordQuery();
|
||||
|
||||
const newPasswordValidation = validatePasswordStrength(newPassword, t);
|
||||
|
||||
const isPasswordChangeAttempted = newPassword.length > 0 || currentPassword.length > 0;
|
||||
const passwordsMatch = newPassword.length > 0 && newPassword === confirmPassword;
|
||||
const isPasswordChangeValid =
|
||||
!isPasswordChangeAttempted || (currentPassword.length > 0 && newPasswordValidation.isValid && passwordsMatch);
|
||||
|
||||
const handleCancel = useCallback(() => {
|
||||
navigate(-1);
|
||||
}, [navigate]);
|
||||
|
||||
const handleGeneratePassword = useCallback(async () => {
|
||||
try {
|
||||
const result = await triggerGeneratePassword().unwrap();
|
||||
setNewPassword(result.password);
|
||||
setConfirmPassword(result.password);
|
||||
setShowNewPassword(true);
|
||||
setShowConfirmPassword(true);
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}, [triggerGeneratePassword]);
|
||||
|
||||
const toggleShowCurrentPassword = useCallback(() => {
|
||||
setShowCurrentPassword((v) => !v);
|
||||
}, []);
|
||||
|
||||
const toggleShowNewPassword = useCallback(() => {
|
||||
setShowNewPassword((v) => !v);
|
||||
}, []);
|
||||
|
||||
const toggleShowConfirmPassword = useCallback(() => {
|
||||
setShowConfirmPassword((v) => !v);
|
||||
}, []);
|
||||
|
||||
const handleDisplayNameChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setDisplayName(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleCurrentPasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setCurrentPassword(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleNewPasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNewPassword(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleConfirmPasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setConfirmPassword(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (e: FormEvent) => {
|
||||
e.preventDefault();
|
||||
setErrorMessage(null);
|
||||
|
||||
if (!isPasswordChangeValid) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const updatePayload: Parameters<typeof updateCurrentUser>[0] = {
|
||||
display_name: displayName || null,
|
||||
};
|
||||
if (newPassword) {
|
||||
updatePayload.current_password = currentPassword;
|
||||
updatePayload.new_password = newPassword;
|
||||
}
|
||||
const updatedUser = await updateCurrentUser(updatePayload).unwrap();
|
||||
|
||||
// Refresh the stored user info so the header reflects the new display name
|
||||
if (currentToken) {
|
||||
dispatch(
|
||||
setCredentials({
|
||||
token: currentToken,
|
||||
user: {
|
||||
user_id: updatedUser.user_id,
|
||||
email: updatedUser.email,
|
||||
display_name: updatedUser.display_name ?? null,
|
||||
is_admin: updatedUser.is_admin ?? false,
|
||||
is_active: updatedUser.is_active ?? true,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Navigate back after successful save
|
||||
navigate(-1);
|
||||
} catch (err) {
|
||||
const detail =
|
||||
err && typeof err === 'object' && 'data' in err && typeof (err as { data: unknown }).data === 'object'
|
||||
? ((err as { data: { detail?: string } }).data?.detail ?? t('auth.profile.saveFailed'))
|
||||
: t('auth.profile.saveFailed');
|
||||
setErrorMessage(detail);
|
||||
}
|
||||
},
|
||||
[
|
||||
displayName,
|
||||
currentPassword,
|
||||
newPassword,
|
||||
isPasswordChangeValid,
|
||||
updateCurrentUser,
|
||||
currentToken,
|
||||
dispatch,
|
||||
navigate,
|
||||
t,
|
||||
]
|
||||
);
|
||||
|
||||
if (!currentUser) {
|
||||
return (
|
||||
<Center py={12}>
|
||||
<Spinner size="xl" />
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={6} maxW="480px">
|
||||
<Heading size="md" mb={6}>
|
||||
{t('auth.profile.title')}
|
||||
</Heading>
|
||||
|
||||
<form onSubmit={handleSubmit}>
|
||||
<VStack spacing={5} align="stretch">
|
||||
{/* Email (read-only) */}
|
||||
<FormControl>
|
||||
<FormLabel>{t('auth.profile.email')}</FormLabel>
|
||||
<Input type="email" value={currentUser.email} isReadOnly opacity={0.6} />
|
||||
<FormHelperText>{t('auth.profile.emailReadOnly')}</FormHelperText>
|
||||
</FormControl>
|
||||
|
||||
{/* Display name */}
|
||||
<FormControl>
|
||||
<FormLabel>{t('auth.profile.displayName')}</FormLabel>
|
||||
<Input
|
||||
type="text"
|
||||
value={displayName}
|
||||
onChange={handleDisplayNameChange}
|
||||
placeholder={t('auth.profile.displayNamePlaceholder')}
|
||||
/>
|
||||
</FormControl>
|
||||
|
||||
<Box borderTop="1px solid" borderColor="base.600" pt={4}>
|
||||
<Text fontSize="sm" fontWeight="semibold" mb={4} color="base.300">
|
||||
{t('auth.profile.changePassword')}
|
||||
</Text>
|
||||
|
||||
{/* Current password */}
|
||||
<FormControl mb={4} isRequired={newPassword.length > 0}>
|
||||
<Grid templateColumns={PASSWORD_GRID_COLUMNS} gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" mb={0} pt={2}>
|
||||
{t('auth.profile.currentPassword')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<InputGroup>
|
||||
<Input
|
||||
type={showCurrentPassword ? 'text' : 'password'}
|
||||
value={currentPassword}
|
||||
onChange={handleCurrentPasswordChange}
|
||||
placeholder={t('auth.profile.currentPasswordPlaceholder')}
|
||||
autoComplete="current-password"
|
||||
pr="3rem"
|
||||
/>
|
||||
<InputRightElement>
|
||||
<Tooltip
|
||||
label={
|
||||
showCurrentPassword
|
||||
? t('auth.userManagement.hidePassword')
|
||||
: t('auth.userManagement.showPassword')
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
aria-label={
|
||||
showCurrentPassword
|
||||
? t('auth.userManagement.hidePassword')
|
||||
: t('auth.userManagement.showPassword')
|
||||
}
|
||||
icon={showCurrentPassword ? <PiEyeSlashBold /> : <PiEyeBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={toggleShowCurrentPassword}
|
||||
tabIndex={-1}
|
||||
/>
|
||||
</Tooltip>
|
||||
</InputRightElement>
|
||||
</InputGroup>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
{/* New password */}
|
||||
<FormControl isInvalid={newPassword.length > 0 && !newPasswordValidation.isValid} mb={4}>
|
||||
<Grid templateColumns={PASSWORD_GRID_COLUMNS} gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" mb={0} pt={2}>
|
||||
{t('auth.profile.newPassword')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<InputGroup>
|
||||
<Input
|
||||
type={showNewPassword ? 'text' : 'password'}
|
||||
value={newPassword}
|
||||
onChange={handleNewPasswordChange}
|
||||
placeholder={t('auth.profile.newPasswordPlaceholder')}
|
||||
autoComplete="new-password"
|
||||
pr="3rem"
|
||||
/>
|
||||
<InputRightElement>
|
||||
<Tooltip
|
||||
label={
|
||||
showNewPassword
|
||||
? t('auth.userManagement.hidePassword')
|
||||
: t('auth.userManagement.showPassword')
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
aria-label={
|
||||
showNewPassword
|
||||
? t('auth.userManagement.hidePassword')
|
||||
: t('auth.userManagement.showPassword')
|
||||
}
|
||||
icon={showNewPassword ? <PiEyeSlashBold /> : <PiEyeBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={toggleShowNewPassword}
|
||||
tabIndex={-1}
|
||||
/>
|
||||
</Tooltip>
|
||||
</InputRightElement>
|
||||
</InputGroup>
|
||||
{newPassword.length > 0 && !newPasswordValidation.isValid && (
|
||||
<FormErrorMessage>{newPasswordValidation.message}</FormErrorMessage>
|
||||
)}
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
{/* Confirm new password */}
|
||||
<FormControl isInvalid={confirmPassword.length > 0 && !passwordsMatch} mb={4}>
|
||||
<Grid templateColumns={PASSWORD_GRID_COLUMNS} gap={4} alignItems="start">
|
||||
<GridItem>
|
||||
<FormLabel textAlign="right" mb={0} pt={2}>
|
||||
{t('auth.profile.confirmPassword')}
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<InputGroup>
|
||||
<Input
|
||||
type={showConfirmPassword ? 'text' : 'password'}
|
||||
value={confirmPassword}
|
||||
onChange={handleConfirmPasswordChange}
|
||||
placeholder={t('auth.profile.confirmPasswordPlaceholder')}
|
||||
autoComplete="new-password"
|
||||
pr="3rem"
|
||||
/>
|
||||
<InputRightElement>
|
||||
<Tooltip
|
||||
label={
|
||||
showConfirmPassword
|
||||
? t('auth.userManagement.hidePassword')
|
||||
: t('auth.userManagement.showPassword')
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
aria-label={
|
||||
showConfirmPassword
|
||||
? t('auth.userManagement.hidePassword')
|
||||
: t('auth.userManagement.showPassword')
|
||||
}
|
||||
icon={showConfirmPassword ? <PiEyeSlashBold /> : <PiEyeBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={toggleShowConfirmPassword}
|
||||
tabIndex={-1}
|
||||
/>
|
||||
</Tooltip>
|
||||
</InputRightElement>
|
||||
</InputGroup>
|
||||
{confirmPassword.length > 0 && !passwordsMatch && (
|
||||
<FormErrorMessage>{t('auth.profile.passwordsDoNotMatch')}</FormErrorMessage>
|
||||
)}
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</FormControl>
|
||||
|
||||
{/* Generate password button – aligned with the input column */}
|
||||
<Grid templateColumns={PASSWORD_GRID_COLUMNS} gap={4}>
|
||||
<GridItem />
|
||||
<GridItem>
|
||||
<Button size="sm" variant="ghost" onClick={handleGeneratePassword} leftIcon={<PiLightningFill />}>
|
||||
{t('auth.userManagement.generatePassword')}
|
||||
</Button>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
</Box>
|
||||
|
||||
{errorMessage && (
|
||||
<Flex p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm">
|
||||
<Text fontWeight="semibold">{errorMessage}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
<Flex gap={3}>
|
||||
<Button variant="ghost" onClick={handleCancel}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button type="submit" colorScheme="invokeBlue" isLoading={isLoading} isDisabled={!isPasswordChangeValid}>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</VStack>
|
||||
</form>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
UserProfile.displayName = 'UserProfile';
|
||||
83
invokeai/frontend/web/src/features/auth/store/authSlice.ts
Normal file
83
invokeai/frontend/web/src/features/auth/store/authSlice.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zUser = z.object({
|
||||
user_id: z.string(),
|
||||
email: z.string(),
|
||||
display_name: z.string().nullable(),
|
||||
is_admin: z.boolean(),
|
||||
is_active: z.boolean(),
|
||||
});
|
||||
|
||||
const zAuthState = z.object({
|
||||
isAuthenticated: z.boolean(),
|
||||
token: z.string().nullable(),
|
||||
user: zUser.nullable(),
|
||||
isLoading: z.boolean(),
|
||||
});
|
||||
|
||||
type User = z.infer<typeof zUser>;
|
||||
type AuthState = z.infer<typeof zAuthState>;
|
||||
|
||||
// Helper to safely access localStorage (not available in test environment)
|
||||
const getStoredAuthToken = (): string | null => {
|
||||
if (typeof window !== 'undefined' && window.localStorage) {
|
||||
return localStorage.getItem('auth_token');
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const initialState: AuthState = {
|
||||
isAuthenticated: !!getStoredAuthToken(),
|
||||
token: getStoredAuthToken(),
|
||||
user: null,
|
||||
isLoading: false,
|
||||
};
|
||||
|
||||
const getInitialAuthState = (): AuthState => initialState;
|
||||
|
||||
const authSlice = createSlice({
|
||||
name: 'auth',
|
||||
initialState,
|
||||
reducers: {
|
||||
setCredentials: (state, action: PayloadAction<{ token: string; user: User }>) => {
|
||||
state.token = action.payload.token;
|
||||
state.user = action.payload.user;
|
||||
state.isAuthenticated = true;
|
||||
if (typeof window !== 'undefined' && window.localStorage) {
|
||||
localStorage.setItem('auth_token', action.payload.token);
|
||||
}
|
||||
},
|
||||
logout: (state) => {
|
||||
state.token = null;
|
||||
state.user = null;
|
||||
state.isAuthenticated = false;
|
||||
if (typeof window !== 'undefined' && window.localStorage) {
|
||||
localStorage.removeItem('auth_token');
|
||||
}
|
||||
},
|
||||
setLoading: (state, action: PayloadAction<boolean>) => {
|
||||
state.isLoading = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { setCredentials, logout, setLoading } = authSlice.actions;
|
||||
|
||||
export const authSliceConfig: SliceConfig<typeof authSlice> = {
|
||||
slice: authSlice,
|
||||
schema: zAuthState,
|
||||
getInitialState: getInitialAuthState,
|
||||
persistConfig: {
|
||||
migrate: () => getInitialAuthState(),
|
||||
// Don't persist auth state - token is stored in localStorage
|
||||
persistDenylist: ['isAuthenticated', 'token', 'user', 'isLoading'],
|
||||
},
|
||||
};
|
||||
|
||||
export const selectIsAuthenticated = (state: { auth: AuthState }) => state.auth.isAuthenticated;
|
||||
export const selectCurrentUser = (state: { auth: AuthState }) => state.auth.user;
|
||||
export const selectAuthToken = (state: { auth: AuthState }) => state.auth.token;
|
||||
export const selectIsAuthLoading = (state: { auth: AuthState }) => state.auth.isLoading;
|
||||
@@ -96,16 +96,14 @@ const FontSelect = () => {
|
||||
<Text fontSize="sm" lineHeight="1" whiteSpace="nowrap">
|
||||
{t('controlLayers.text.font', { defaultValue: 'Font' })}
|
||||
</Text>
|
||||
<Tooltip label={t('controlLayers.text.font', { defaultValue: 'Font' })}>
|
||||
<Combobox
|
||||
size="sm"
|
||||
variant="outline"
|
||||
isSearchable={false}
|
||||
options={options}
|
||||
value={selectedOption}
|
||||
onChange={handleFontChange}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Combobox
|
||||
size="sm"
|
||||
variant="outline"
|
||||
isSearchable={false}
|
||||
options={options}
|
||||
value={selectedOption}
|
||||
onChange={handleFontChange}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -6,6 +6,7 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import { clamp } from 'es-toolkit/compat';
|
||||
import { logout } from 'features/auth/store/authSlice';
|
||||
import type { AspectRatioID, InfillMethod, ParamsState, RgbaColor } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
ASPECT_RATIO_MAP,
|
||||
@@ -428,6 +429,12 @@ const slice = createSlice({
|
||||
},
|
||||
paramsReset: (state) => resetState(state),
|
||||
},
|
||||
extraReducers(builder) {
|
||||
// Reset params state on logout to prevent user data leakage when switching users
|
||||
builder.addCase(logout, () => {
|
||||
return getInitialParamsState();
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const applyClipSkip = (state: { clipSkip: number }, model: ParameterModel | null, clipSkip: number) => {
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCurrentUser } from 'features/auth/store/authSlice';
|
||||
import type { AddImageToBoardDndTargetData } from 'features/dnd/dnd';
|
||||
import { addImageToBoardDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
@@ -36,6 +37,7 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const autoAssignBoardOnClick = useAppSelector(selectAutoAssignBoardOnClick);
|
||||
const selectedBoardId = useAppSelector(selectSelectedBoardId);
|
||||
const currentUser = useAppSelector(selectCurrentUser);
|
||||
const onClick = useCallback(() => {
|
||||
if (selectedBoardId !== board.board_id) {
|
||||
dispatch(boardIdSelected({ boardId: board.board_id }));
|
||||
@@ -58,6 +60,8 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
|
||||
[board]
|
||||
);
|
||||
|
||||
const showOwner = currentUser?.is_admin && board.owner_username;
|
||||
|
||||
return (
|
||||
<Box position="relative" w="full" h={12}>
|
||||
<BoardContextMenu board={board}>
|
||||
@@ -85,8 +89,13 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
|
||||
h="full"
|
||||
>
|
||||
<CoverImage board={board} />
|
||||
<Flex flex={1}>
|
||||
<Flex flex={1} direction="column" minW={0}>
|
||||
<BoardEditableTitle board={board} isSelected={isSelected} />
|
||||
{showOwner && (
|
||||
<Text fontSize="xs" color="base.500" noOfLines={1}>
|
||||
{board.owner_username}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
{autoAddBoardId === board.board_id && <AutoAddBadge />}
|
||||
{board.archived && <Icon as={PiArchiveBold} fill="base.300" />}
|
||||
|
||||
@@ -13,6 +13,9 @@ import {
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { MutableRefObject } from 'react';
|
||||
import React, { memo, useCallback, useEffect, useMemo, useRef } from 'react';
|
||||
import type {
|
||||
@@ -80,22 +83,41 @@ const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageNam
|
||||
return `${JSON.stringify(queryArgs)}-${imageName ?? index}`;
|
||||
};
|
||||
|
||||
const canHandleGridArrowNavigation = (
|
||||
activeTab: ReturnType<typeof selectActiveTab>,
|
||||
focusedRegion: ReturnType<typeof getFocusedRegion>
|
||||
) => {
|
||||
if (navigationApi.isViewerArrowNavigationMode(activeTab)) {
|
||||
// When gallery is not effectively available, viewer hotkeys own left/right navigation.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (focusedRegion === 'gallery' || focusedRegion === 'viewer') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Fallback for tab-switch edge case: allow nav when viewer dock tab is active before first click.
|
||||
return navigationApi.isDockviewPanelActive(activeTab, VIEWER_PANEL_ID);
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles keyboard navigation for the gallery.
|
||||
*/
|
||||
const useKeyboardNavigation = (
|
||||
imageNames: string[],
|
||||
navigationImageNames: string[],
|
||||
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
|
||||
rootRef: React.RefObject<HTMLDivElement>
|
||||
) => {
|
||||
const { dispatch, getState } = useAppStore();
|
||||
const activeTab = useAppSelector(selectActiveTab);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent) => {
|
||||
if (getFocusedRegion() !== 'gallery') {
|
||||
// Only handle keyboard navigation when the gallery is focused
|
||||
const focusedRegion = getFocusedRegion();
|
||||
if (!canHandleGridArrowNavigation(activeTab, focusedRegion)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only handle arrow keys
|
||||
if (!['ArrowUp', 'ArrowDown', 'ArrowLeft', 'ArrowRight'].includes(event.key)) {
|
||||
return;
|
||||
@@ -112,7 +134,7 @@ const useKeyboardNavigation = (
|
||||
return;
|
||||
}
|
||||
|
||||
if (imageNames.length === 0) {
|
||||
if (navigationImageNames.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -132,7 +154,7 @@ const useKeyboardNavigation = (
|
||||
(selectImageToCompare(state) ?? selectLastSelectedItem(state))
|
||||
: selectLastSelectedItem(state);
|
||||
|
||||
const currentIndex = getItemIndex(imageName ?? null, imageNames);
|
||||
const currentIndex = getItemIndex(imageName ?? null, navigationImageNames);
|
||||
|
||||
let newIndex = currentIndex;
|
||||
|
||||
@@ -146,7 +168,7 @@ const useKeyboardNavigation = (
|
||||
}
|
||||
break;
|
||||
case 'ArrowRight':
|
||||
if (currentIndex < imageNames.length - 1) {
|
||||
if (currentIndex < navigationImageNames.length - 1) {
|
||||
newIndex = currentIndex + 1;
|
||||
// } else {
|
||||
// // Wrap to first image
|
||||
@@ -163,16 +185,16 @@ const useKeyboardNavigation = (
|
||||
break;
|
||||
case 'ArrowDown':
|
||||
// If no images below, stay on current image
|
||||
if (currentIndex >= imageNames.length - imagesPerRow) {
|
||||
if (currentIndex >= navigationImageNames.length - imagesPerRow) {
|
||||
newIndex = currentIndex;
|
||||
} else {
|
||||
newIndex = Math.min(imageNames.length - 1, currentIndex + imagesPerRow);
|
||||
newIndex = Math.min(navigationImageNames.length - 1, currentIndex + imagesPerRow);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < imageNames.length) {
|
||||
const newImageName = imageNames[newIndex];
|
||||
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < navigationImageNames.length) {
|
||||
const newImageName = navigationImageNames[newIndex];
|
||||
if (newImageName) {
|
||||
if (event.altKey) {
|
||||
dispatch(imageToCompareChanged(newImageName));
|
||||
@@ -182,7 +204,7 @@ const useKeyboardNavigation = (
|
||||
}
|
||||
}
|
||||
},
|
||||
[rootRef, virtuosoRef, imageNames, getState, dispatch]
|
||||
[activeTab, rootRef, virtuosoRef, navigationImageNames, getState, dispatch]
|
||||
);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -316,13 +338,14 @@ const useStarImageHotkey = () => {
|
||||
|
||||
type GalleryImageGridContentProps = {
|
||||
imageNames: string[];
|
||||
navigationImageNames?: string[];
|
||||
isLoading: boolean;
|
||||
queryArgs: ListImageNamesQueryArgs;
|
||||
rootRef?: React.RefObject<HTMLDivElement>;
|
||||
};
|
||||
|
||||
export const GalleryImageGridContent = memo(
|
||||
({ imageNames, isLoading, queryArgs, rootRef: rootRefProp }: GalleryImageGridContentProps) => {
|
||||
({ imageNames, navigationImageNames, isLoading, queryArgs, rootRef: rootRefProp }: GalleryImageGridContentProps) => {
|
||||
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
|
||||
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
|
||||
const internalRootRef = useRef<HTMLDivElement>(null);
|
||||
@@ -336,7 +359,7 @@ export const GalleryImageGridContent = memo(
|
||||
|
||||
useStarImageHotkey();
|
||||
useKeepSelectedImageInView(imageNames, virtuosoRef, rootRef, rangeRef);
|
||||
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
|
||||
useKeyboardNavigation(navigationImageNames ?? imageNames, virtuosoRef, rootRef);
|
||||
const scrollerRef = useScrollableGallery(rootRef);
|
||||
|
||||
/*
|
||||
|
||||
@@ -181,6 +181,7 @@ export const GalleryImageGridPaged = memo(() => {
|
||||
<Flex w="full" h="full">
|
||||
<GalleryImageGridContent
|
||||
imageNames={pageImageNames}
|
||||
navigationImageNames={imageNames}
|
||||
isLoading={false}
|
||||
queryArgs={queryArgs}
|
||||
rootRef={gridRootRef}
|
||||
|
||||
@@ -5,10 +5,18 @@ import { CanvasAlertsInvocationProgress } from 'features/controlLayers/component
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
|
||||
import NextPrevItemButtons from 'features/gallery/components/NextPrevItemButtons';
|
||||
import { selectShouldShowItemDetails, selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors';
|
||||
import { useNextPrevItemNavigation } from 'features/gallery/components/useNextPrevItemNavigation';
|
||||
import { selectLastSelectedItem } from 'features/gallery/store/gallerySelectors';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import {
|
||||
selectActiveTab,
|
||||
selectShouldShowItemDetails,
|
||||
selectShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import type { AnimationProps } from 'framer-motion';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { memo, useCallback, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { useImageViewerContext } from './context';
|
||||
@@ -17,11 +25,56 @@ import { ProgressImage } from './ProgressImage2';
|
||||
import { ProgressIndicator } from './ProgressIndicator2';
|
||||
|
||||
export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => {
|
||||
const activeTab = useAppSelector(selectActiveTab);
|
||||
const selectedImageName = useAppSelector(selectLastSelectedItem);
|
||||
const shouldShowItemDetails = useAppSelector(selectShouldShowItemDetails);
|
||||
const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer);
|
||||
const { goToPreviousImage, goToNextImage, isFetching } = useNextPrevItemNavigation();
|
||||
const { onLoadImage, $progressEvent, $progressImage } = useImageViewerContext();
|
||||
const progressEvent = useStore($progressEvent);
|
||||
const progressImage = useStore($progressImage);
|
||||
const [imageToRender, setImageToRender] = useState<ImageDTO | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedImageName) {
|
||||
setImageToRender(null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!imageDTO || imageToRender?.image_name === imageDTO.image_name) {
|
||||
return;
|
||||
}
|
||||
|
||||
let canceled = false;
|
||||
|
||||
const onReady = () => {
|
||||
if (canceled) {
|
||||
return;
|
||||
}
|
||||
setImageToRender(imageDTO);
|
||||
};
|
||||
|
||||
if (typeof window === 'undefined') {
|
||||
onReady();
|
||||
return;
|
||||
}
|
||||
|
||||
const preloader = new window.Image();
|
||||
|
||||
preloader.onload = onReady;
|
||||
preloader.onerror = onReady;
|
||||
preloader.src = imageDTO.image_url;
|
||||
|
||||
if (preloader.complete) {
|
||||
onReady();
|
||||
}
|
||||
|
||||
return () => {
|
||||
canceled = true;
|
||||
preloader.onload = null;
|
||||
preloader.onerror = null;
|
||||
};
|
||||
}, [imageDTO, imageToRender?.image_name, selectedImageName]);
|
||||
|
||||
// Show and hide the next/prev buttons on mouse move
|
||||
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
|
||||
@@ -36,6 +89,50 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
|
||||
}, 500);
|
||||
}, []);
|
||||
|
||||
const handleViewerArrowNavigation = useCallback(
|
||||
(event: KeyboardEvent, navigate: () => void) => {
|
||||
if (!navigationApi.isViewerArrowNavigationMode(activeTab) || !imageToRender || isFetching) {
|
||||
return;
|
||||
}
|
||||
if (event.target instanceof HTMLInputElement || event.target instanceof HTMLTextAreaElement) {
|
||||
return;
|
||||
}
|
||||
event.preventDefault();
|
||||
navigate();
|
||||
},
|
||||
[activeTab, imageToRender, isFetching]
|
||||
);
|
||||
|
||||
const onHotkeyPrevImage = useCallback(
|
||||
(event: KeyboardEvent) => {
|
||||
handleViewerArrowNavigation(event, goToPreviousImage);
|
||||
},
|
||||
[goToPreviousImage, handleViewerArrowNavigation]
|
||||
);
|
||||
|
||||
const onHotkeyNextImage = useCallback(
|
||||
(event: KeyboardEvent) => {
|
||||
handleViewerArrowNavigation(event, goToNextImage);
|
||||
},
|
||||
[goToNextImage, handleViewerArrowNavigation]
|
||||
);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'galleryNavLeft',
|
||||
category: 'gallery',
|
||||
callback: onHotkeyPrevImage,
|
||||
options: { preventDefault: true },
|
||||
dependencies: [onHotkeyPrevImage],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'galleryNavRight',
|
||||
category: 'gallery',
|
||||
callback: onHotkeyNextImage,
|
||||
options: { preventDefault: true },
|
||||
dependencies: [onHotkeyNextImage],
|
||||
});
|
||||
|
||||
const withProgress = shouldShowProgressInViewer && progressImage !== null;
|
||||
|
||||
return (
|
||||
@@ -48,19 +145,12 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
|
||||
justifyContent="center"
|
||||
position="relative"
|
||||
>
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
key={imageDTO.image_name}
|
||||
w="full"
|
||||
h="full"
|
||||
position="absolute"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<DndImage imageDTO={imageDTO} onLoad={onLoadImage} borderRadius="base" />
|
||||
{imageToRender && (
|
||||
<Flex w="full" h="full" position="absolute" alignItems="center" justifyContent="center">
|
||||
<DndImage imageDTO={imageToRender} onLoad={onLoadImage} borderRadius="base" />
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && <NoContentForViewer />}
|
||||
{!imageToRender && <NoContentForViewer />}
|
||||
{withProgress && (
|
||||
<Flex w="full" h="full" position="absolute" alignItems="center" justifyContent="center" bg="base.900">
|
||||
<ProgressImage progressImage={progressImage} />
|
||||
@@ -72,13 +162,13 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
|
||||
<Flex flexDir="column" gap={2} position="absolute" top={0} insetInlineStart={0} alignItems="flex-start">
|
||||
<CanvasAlertsInvocationProgress />
|
||||
</Flex>
|
||||
{shouldShowItemDetails && imageDTO && !withProgress && (
|
||||
{shouldShowItemDetails && imageToRender && !withProgress && (
|
||||
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
|
||||
<ImageMetadataViewer image={imageDTO} />
|
||||
<ImageMetadataViewer image={imageToRender} />
|
||||
</Box>
|
||||
)}
|
||||
<AnimatePresence>
|
||||
{shouldShowNextPrevButtons && imageDTO && (
|
||||
{shouldShowNextPrevButtons && imageToRender && (
|
||||
<Box
|
||||
as={motion.div}
|
||||
key="nextPrevButtons"
|
||||
|
||||
@@ -1,51 +1,28 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Box, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { clamp } from 'es-toolkit/compat';
|
||||
import { selectLastSelectedItem } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, type MouseEvent, type PointerEvent } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
import { useGalleryImageNames } from './use-gallery-image-names';
|
||||
import { useNextPrevItemNavigation } from './useNextPrevItemNavigation';
|
||||
|
||||
const ARROW_SIZE = 48;
|
||||
|
||||
const preventButtonFocusOnPointerDown = (event: PointerEvent<HTMLButtonElement>) => {
|
||||
event.preventDefault();
|
||||
};
|
||||
|
||||
const preventButtonFocusOnMouseDown = (event: MouseEvent<HTMLButtonElement>) => {
|
||||
event.preventDefault();
|
||||
};
|
||||
|
||||
const blurButtonOnPointerUp = (event: PointerEvent<HTMLButtonElement>) => {
|
||||
event.currentTarget.blur();
|
||||
};
|
||||
|
||||
const NextPrevItemButtons = ({ inset = 8 }: { inset?: ChakraProps['insetInlineStart' | 'insetInlineEnd'] }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const lastSelectedItem = useAppSelector(selectLastSelectedItem);
|
||||
const { imageNames, isFetching } = useGalleryImageNames();
|
||||
|
||||
const isOnFirstItem = useMemo(
|
||||
() => (lastSelectedItem ? imageNames.at(0) === lastSelectedItem : false),
|
||||
[imageNames, lastSelectedItem]
|
||||
);
|
||||
const isOnLastItem = useMemo(
|
||||
() => (lastSelectedItem ? imageNames.at(-1) === lastSelectedItem : false),
|
||||
[imageNames, lastSelectedItem]
|
||||
);
|
||||
|
||||
const onClickLeftArrow = useCallback(() => {
|
||||
const targetIndex = lastSelectedItem ? imageNames.findIndex((n) => n === lastSelectedItem) - 1 : 0;
|
||||
const clampedIndex = clamp(targetIndex, 0, imageNames.length - 1);
|
||||
const n = imageNames.at(clampedIndex);
|
||||
if (!n) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageSelected(n));
|
||||
}, [dispatch, imageNames, lastSelectedItem]);
|
||||
|
||||
const onClickRightArrow = useCallback(() => {
|
||||
const targetIndex = lastSelectedItem ? imageNames.findIndex((n) => n === lastSelectedItem) + 1 : 0;
|
||||
const clampedIndex = clamp(targetIndex, 0, imageNames.length - 1);
|
||||
const n = imageNames.at(clampedIndex);
|
||||
if (!n) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageSelected(n));
|
||||
}, [dispatch, imageNames, lastSelectedItem]);
|
||||
const { goToPreviousImage, goToNextImage, isOnFirstItem, isOnLastItem, isFetching } = useNextPrevItemNavigation();
|
||||
|
||||
return (
|
||||
<Box pos="relative" h="full" w="full">
|
||||
@@ -62,7 +39,10 @@ const NextPrevItemButtons = ({ inset = 8 }: { inset?: ChakraProps['insetInlineSt
|
||||
minH={0}
|
||||
w={`${ARROW_SIZE}px`}
|
||||
h={`${ARROW_SIZE}px`}
|
||||
onClick={onClickLeftArrow}
|
||||
onClick={goToPreviousImage}
|
||||
onPointerDown={preventButtonFocusOnPointerDown}
|
||||
onMouseDown={preventButtonFocusOnMouseDown}
|
||||
onPointerUp={blurButtonOnPointerUp}
|
||||
isDisabled={isFetching}
|
||||
color="base.100"
|
||||
pointerEvents="auto"
|
||||
@@ -82,7 +62,10 @@ const NextPrevItemButtons = ({ inset = 8 }: { inset?: ChakraProps['insetInlineSt
|
||||
minH={0}
|
||||
w={`${ARROW_SIZE}px`}
|
||||
h={`${ARROW_SIZE}px`}
|
||||
onClick={onClickRightArrow}
|
||||
onClick={goToNextImage}
|
||||
onPointerDown={preventButtonFocusOnPointerDown}
|
||||
onMouseDown={preventButtonFocusOnMouseDown}
|
||||
onPointerUp={blurButtonOnPointerUp}
|
||||
isDisabled={isFetching}
|
||||
color="base.100"
|
||||
pointerEvents="auto"
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { clamp } from 'es-toolkit/compat';
|
||||
import { selectLastSelectedItem } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
import { useGalleryImageNames } from './use-gallery-image-names';
|
||||
|
||||
export const useNextPrevItemNavigation = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const lastSelectedItem = useAppSelector(selectLastSelectedItem);
|
||||
const { imageNames, isFetching } = useGalleryImageNames();
|
||||
|
||||
const currentIndex = useMemo(
|
||||
() => (lastSelectedItem ? imageNames.findIndex((n) => n === lastSelectedItem) : -1),
|
||||
[imageNames, lastSelectedItem]
|
||||
);
|
||||
const isOnFirstItem = currentIndex === 0;
|
||||
const isOnLastItem = currentIndex >= 0 && currentIndex === imageNames.length - 1;
|
||||
|
||||
const navigateBy = useCallback(
|
||||
(delta: number) => {
|
||||
const maxIndex = imageNames.length - 1;
|
||||
if (maxIndex < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const targetIndex = currentIndex >= 0 ? clamp(currentIndex + delta, 0, maxIndex) : 0;
|
||||
const imageName = imageNames[targetIndex];
|
||||
if (!imageName) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageSelected(imageName));
|
||||
},
|
||||
[currentIndex, dispatch, imageNames]
|
||||
);
|
||||
|
||||
const goToPreviousImage = useCallback(() => {
|
||||
navigateBy(-1);
|
||||
}, [navigateBy]);
|
||||
|
||||
const goToNextImage = useCallback(() => {
|
||||
navigateBy(1);
|
||||
}, [navigateBy]);
|
||||
|
||||
return { goToPreviousImage, goToNextImage, isOnFirstItem, isOnLastItem, isFetching };
|
||||
};
|
||||
@@ -3,6 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject, uniq } from 'es-toolkit';
|
||||
import { logout } from 'features/auth/store/authSlice';
|
||||
import type { BoardRecordOrderBy } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@@ -142,6 +143,14 @@ const slice = createSlice({
|
||||
state.boardsListOrderDir = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
// Clear board-related state on logout to prevent stale data when switching users
|
||||
builder.addCase(logout, (state) => {
|
||||
state.selectedBoardId = 'none';
|
||||
state.autoAddBoardId = 'none';
|
||||
state.boardSearchText = '';
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
@@ -182,6 +191,6 @@ export const gallerySliceConfig: SliceConfig<typeof slice> = {
|
||||
}
|
||||
return zGalleryState.parse(state);
|
||||
},
|
||||
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
|
||||
persistDenylist: ['selection', 'galleryView', 'imageToCompare'],
|
||||
},
|
||||
};
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCurrentUser } from 'features/auth/store/authSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
|
||||
|
||||
/**
|
||||
* Hook to determine if model manager features should be enabled for the current user.
|
||||
*
|
||||
* Returns true if:
|
||||
* - Multiuser mode is disabled (single-user mode = always admin)
|
||||
* - Multiuser mode is enabled AND user is an admin
|
||||
*
|
||||
* Returns false if:
|
||||
* - Multiuser mode is enabled AND user is not an admin
|
||||
*/
|
||||
export const useIsModelManagerEnabled = (): boolean => {
|
||||
const user = useAppSelector(selectCurrentUser);
|
||||
const { data: setupStatus } = useGetSetupStatusQuery();
|
||||
|
||||
return useMemo(() => {
|
||||
// If multiuser is disabled, treat as admin (single-user mode)
|
||||
if (setupStatus && !setupStatus.multiuser_enabled) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If multiuser is enabled, check if user is admin
|
||||
return user?.is_admin ?? false;
|
||||
}, [setupStatus, user]);
|
||||
};
|
||||
@@ -1,4 +1,6 @@
|
||||
import { Button, Text, useToast } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsAuthenticated } from 'features/auth/store/authSlice';
|
||||
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
@@ -12,8 +14,14 @@ export const useStarterModelsToast = () => {
|
||||
const [didToast, setDidToast] = useState(false);
|
||||
const [mainModels, { data }] = useMainModels();
|
||||
const toast = useToast();
|
||||
const isAuthenticated = useAppSelector(selectIsAuthenticated);
|
||||
|
||||
useEffect(() => {
|
||||
// Only show the toast if the user is authenticated
|
||||
if (!isAuthenticated) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (toast.isActive(TOAST_ID)) {
|
||||
if (mainModels.length === 0) {
|
||||
return;
|
||||
@@ -32,7 +40,7 @@ export const useStarterModelsToast = () => {
|
||||
onCloseComplete: () => setDidToast(true),
|
||||
});
|
||||
}
|
||||
}, [data, didToast, mainModels.length, t, toast]);
|
||||
}, [data, didToast, isAuthenticated, mainModels.length, t, toast]);
|
||||
};
|
||||
|
||||
const ToastDescription = () => {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsModelManagerEnabled } from 'features/modelManagerV2/hooks/useIsModelManagerEnabled';
|
||||
import { selectSelectedModelKey, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -23,6 +24,7 @@ const modelManagerSx: SystemStyleObject = {
|
||||
export const ModelManager = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const canManageModels = useIsModelManagerEnabled();
|
||||
const handleClickAddModel = useCallback(() => {
|
||||
dispatch(setSelectedModelKey(null));
|
||||
}, [dispatch]);
|
||||
@@ -36,7 +38,7 @@ export const ModelManager = memo(() => {
|
||||
</Heading>
|
||||
<Flex gap={2}>
|
||||
<SyncModelsButton />
|
||||
{!!selectedModelKey && (
|
||||
{!!selectedModelKey && canManageModels && (
|
||||
<Button size="sm" colorScheme="invokeYellow" leftIcon={<PiPlusBold />} onClick={handleClickAddModel}>
|
||||
{t('modelManager.addModels')}
|
||||
</Button>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user