Merge branch 'main' into feat/model-manager-queue-redesign

This commit is contained in:
Josh Corbett
2026-03-09 18:17:28 -06:00
committed by GitHub
162 changed files with 18404 additions and 1276 deletions

View File

@@ -16,20 +16,20 @@ help:
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "wheel Build the wheel for the current version"
@echo "frontend-prettier Format the frontend using lint:prettier"
@echo "wheel Build the wheel for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
@echo "docs Serve the mkdocs site with live reload"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
ruff check . --fix
ruff format .
cd invokeai && uv tool run ruff@0.11.2 format
# Runs ruff, fixing all errors it can fix and formatting
ruff-unsafe:
ruff check . --fix --unsafe-fixes
ruff format .
ruff format
# Runs mypy, using the config in pyproject.toml
mypy:
@@ -64,6 +64,13 @@ frontend-dev:
frontend-typegen:
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
frontend-lint:
cd invokeai/frontend/web/src && \
pnpm lint:tsc && \
pnpm lint:dpdm && \
pnpm lint:eslint --fix && \
pnpm lint:prettier --write
# Tag the release
wheel:
cd scripts && ./build_wheel.sh
@@ -79,4 +86,4 @@ openapi:
# Serve the mkdocs site w/ live reload
.PHONY: docs
docs:
mkdocs serve
mkdocs serve

View File

@@ -52,7 +52,7 @@ The Unified Canvas is a fully integrated canvas implementation with support for
### Workflows & Nodes
Invoke offers a fully featured workflow management solution, enabling users to combine the power of node-based workflows with the easy of a UI. This allows for customizable generation pipelines to be developed and shared by users looking to create specific workflows to support their production use-cases.
Invoke offers a fully featured workflow management solution, enabling users to combine the power of node-based workflows with the ease of a UI. This allows for customizable generation pipelines to be developed and shared by users looking to create specific workflows to support their production use-cases.
### Board & Gallery Management

View File

@@ -0,0 +1,169 @@
# User Isolation Implementation Summary
This document describes the implementation of user isolation features in the InvokeAI session queue and processing system to address issues identified in the enhancement request.
## Issues Addressed
### 1. Cross-User Image/Preview Visibility
**Problem:** When two users are logged in simultaneously and one initiates a generation, the generation preview shows up in both users' browsers and the generated image gets saved to both users' image boards.
**Solution:** Implemented socket-level event filtering based on user authentication:
#### Backend Changes (`invokeai/app/api/sockets.py`):
- Added socket authentication middleware in `_handle_connect()` method
- Extracts JWT token from socket auth data or HTTP headers
- Verifies token using existing `verify_token()` function
- Stores `user_id` and `is_admin` in socket session for later use
- Modified `_handle_queue_event()` to filter events by user:
- For `QueueItemEventBase` events, only emit to:
- The user who owns the queue item (`user_id` matches)
- Admin users (`is_admin` is True)
- For general queue events, emit to all subscribers
#### Event System Changes (`invokeai/app/services/events/events_common.py`):
- Added `user_id` field to `QueueItemEventBase` class
- Updated all event builders to include `user_id` from queue items:
- `InvocationStartedEvent.build()`
- `InvocationProgressEvent.build()`
- `InvocationCompleteEvent.build()`
- `InvocationErrorEvent.build()`
- `QueueItemStatusChangedEvent.build()`
### 2. Batch Field Values Privacy
**Problem:** Users can see batch field values from generation processes launched by other users.
**Solution:** Implemented field value sanitization at the API level:
#### API Router Changes (`invokeai/app/api/routers/session_queue.py`):
- Created `sanitize_queue_item_for_user()` helper function
- Clears `field_values` for non-admin users viewing other users' items
- Admins and item owners can see all field values
- Updated endpoints to require authentication and sanitize responses:
- `list_all_queue_items()` - Added `CurrentUser` dependency
- `get_queue_items_by_item_ids()` - Added `CurrentUser` dependency
- `get_queue_item()` - Added `CurrentUser` dependency
### 3. Queue Updates Across Browser Windows
**Problem:** When the job queue tab is open in multiple browsers and a generation is begun in one browser window, the queue does not update in the other window.
**Status:** This issue is likely resolved by the socket authentication and event filtering changes. The existing socket subscription mechanism (`subscribe_queue` event) already supports multiple connections per user. Testing is required to confirm this works correctly with the new authentication flow.
### 4. User Information Display
**Problem:** Queue table lacks user identification, making it difficult to know who launched which job.
**Solution:** Added user information to queue items and UI:
#### Database Layer (`invokeai/app/services/session_queue/session_queue_sqlite.py`):
- Updated SQL queries to JOIN with `users` table
- Modified methods to fetch user information:
- `get_queue_item()` - Now selects `display_name` and `email` from users table
- `dequeue()` - Includes user info
- `get_next()` - Includes user info
- `get_current()` - Includes user info
- `list_all_queue_items()` - Includes user info
#### Data Model Changes (`invokeai/app/services/session_queue/session_queue_common.py`):
- Added optional fields to `SessionQueueItem`:
- `user_display_name: Optional[str]` - Display name from users table
- `user_email: Optional[str]` - Email from users table
- Note: `user_id` field already existed from Migration 25
#### Frontend UI Changes:
- **Constants** (`constants.ts`): Added `user: '8rem'` column width
- **Header** (`QueueListHeader.tsx`): Added "User" column header
- **Item Component** (`QueueItemComponent.tsx`):
- Added logic to display user information (display_name → email → user_id)
- Added user column to queue item row
- Added tooltip with full username on hover
- Added "Hidden for privacy" message when field_values are null for non-owned items
- **Localization** (`en.json`): Added translations:
- `"user": "User"`
- `"fieldValuesHidden": "Hidden for privacy"`
## Security Considerations
### Token Verification
- Tokens are verified using the existing `verify_token()` function from `invokeai.app.services.auth.token_service`
- Invalid or missing tokens default to "system" user with non-admin privileges
- Socket connections without valid tokens are still accepted for backward compatibility but have limited access
### Data Privacy
- Field values are only visible to:
- The user who created the queue item
- Admin users
- Non-admin users viewing other users' queue items see "Hidden for privacy" instead of field values
### Admin Privileges
- Admin users can see all queue events and field values across all users
- Admin status is determined from the JWT token's `is_admin` field
## Migration Notes
No database migration is required. The changes leverage:
- Existing `user_id` column in `session_queue` table (added in Migration 25)
- Existing `users` table (added in Migration 25)
- SQL LEFT JOINs to fetch user information (gracefully handles missing user records)
## Testing Requirements
### Backend Testing
1. **Socket Authentication:**
- Verify valid tokens are accepted and user context is stored
- Verify invalid tokens default to system user
- Verify expired tokens are rejected
2. **Event Filtering:**
- User A should only receive events for their own queue items
- Admin users should receive all events
- Non-admin users should not receive events from other users
3. **Field Value Sanitization:**
- Non-admin users should see null field_values for other users' items
- Admins should see all field values
- Users should see their own field values
### Frontend Testing
1. **UI Display:**
- User column should display in queue list
- Display name should be shown when available
- Email should be shown as fallback when display name is missing
- User ID should be shown when both display name and email are missing
- Tooltip should show full username on hover
2. **Field Values Display:**
- "Hidden for privacy" message should appear when viewing other users' items
- Own items should show field values normally
3. **Multi-Browser Testing:**
- Open queue tab in two browsers with different users
- Start generation in one browser
- Verify other browser doesn't see the preview/progress
- Verify admin user can see all generations
### Integration Testing
1. Multi-user scenarios with simultaneous generations
2. Queue updates across multiple browser windows
3. Admin vs. non-admin privilege differentiation
4. Socket reconnection handling
## Known Limitations
1. **TypeScript Types:**
- The OpenAPI schema needs to be regenerated to include new fields
- Run: `cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen`
2. **Backward Compatibility:**
- System user ("system") entries will not have display name or email
- Existing queue items from before Migration 25 will have user_id="system"
3. **Socket.IO Session Storage:**
- Socket.IO's in-memory session storage may not persist across server restarts
- Consider implementing persistent session storage if needed for production
## Future Enhancements
1. Add user filtering to queue list (show only my items vs. all items)
2. Add permission system for queue management operations (cancel, retry, delete)
3. Implement queue item ownership transfer for administrative purposes
4. Add audit logging for queue operations with user attribution
5. Consider implementing user-specific queue limits or quotas

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@@ -18,7 +18,7 @@ If youd like to add a Node, please see our [nodes contribution guide](../node
Helping support other users in [Discord](https://discord.gg/ZmtBAhwWhy) and on Github are valuable forms of contribution that we greatly appreciate.
We receive many issues and requests for help from users. We're limited in bandwidth relative to our the user base, so providing answers to questions or helping identify causes of issues is very helpful. By doing this, you enable us to spend time on the highest priority work.
We receive many issues and requests for help from users. We're limited in bandwidth relative to our user base, so providing answers to questions or helping identify causes of issues is very helpful. By doing this, you enable us to spend time on the highest priority work.
## Documentation

View File

@@ -6,7 +6,9 @@ Invoke runs on Windows 10+, macOS 14+ and Linux (Ubuntu 20.04+ is well-tested).
Hardware requirements vary significantly depending on model and image output size.
The requirements below are rough guidelines for best performance. GPUs with less VRAM typically still work, if a bit slower. Follow the [Low-VRAM mode guide](./features/low-vram.md) to optimize performance.
The requirements below are rough guidelines for best performance. GPUs
with less VRAM typically still work, if a bit slower. Follow the
[Low-VRAM mode guide](../features/low-vram.md) to optimize performance.
- All Apple Silicon (M1, M2, etc) Macs work, but 16GB+ memory is recommended.
- AMD GPUs are supported on Linux only. The VRAM requirements are the same as Nvidia GPUs.

View File

@@ -0,0 +1,876 @@
# InvokeAI Multi-User Administrator Guide
## Overview
This guide is for administrators managing a multi-user InvokeAI installation. It covers initial setup, user management, security best practices, and troubleshooting.
## Prerequisites
Before enabling multi-user support, ensure you have:
- InvokeAI installed and running
- Access to the server filesystem (for initial setup)
- Understanding of your deployment environment
- Backup of your existing data (recommended)
## Initial Setup
### Activating Multiuser Mode
To put InvokeAI into multiuser mode, you will need to add the option
`multiuser: true` to its configuration file. This file is located at
`INVOKEAI_ROOT/invokeai.yaml` With the InvokeAI backend halted, add
the new configuration option to the end of the file with a text editor
so that it looks like this:
```yaml
# Internal metadata - do not edit:
schema_version: 4.0.2
# Enable/disable multi-user mode
multiuser: true
```
Then restart the InvokeAI server backend from the command line or
using the launcher.
!!! note "Reverting to single-user mode"
If at any time you wish to revert to single-user mode, simply comment
out the `multiuser` line, or change "true" to "false". Then
restart the server. Because of the way that browsers cache pages,
users with open InvokeAI sessions may need to force-refresh their
browsers.
### First Administrator Account
When InvokeAI starts for the first time in multi-user mode, you'll see the **Administrator Setup** dialog.
**Setup Steps:**
1. **Email Address**: Enter a valid email address (this becomes your username)
* Example: `admin@example.com` or `admin@localhost` for testing
* Must be a valid email format
* Cannot be changed later without database access
2. **Display Name**: Enter a friendly name
* Example: "System Administrator" or your real name
* Can be changed later in your profile
* Visible to other users in shared contexts
3. **Password**: Create a strong administrator password
* **Minimum requirements:**
* At least 8 characters long
* Contains uppercase letters (A-Z)
* Contains lowercase letters (a-z)
* Contains numbers (0-9)
* **Recommended:**
* Use 12+ characters
* Include special characters (!@#$%^&*)
* Use a password manager to generate and store
* Don't reuse passwords from other services
4. **Confirm Password**: Re-enter the password
5. Click **Create Administrator Account**
!!! warning "Important"
Store these credentials securely! The
first administrator account can reset
the password to something new, but cannot
retrieve a lost one.
### Configuration
InvokeAI can run in single-user or multi-user mode, controlled by the `multiuser` configuration option in `invokeai.yaml`:
```yaml
# Enable/disable multi-user mode
multiuser: true # Enable multi-user mode (requires authentication)
# multiuser: false # Single-user mode (no authentication required)
# If the multiuser option is absent, single-user mode is used
# Database configuration
use_memory_db: false # Use persistent database
db_path: databases/invokeai.db # Database location
# Session configuration (multi-user mode only)
jwt_secret_key: "your-secret-key-here" # Auto-generated if not specified
jwt_token_expiry_hours: 24 # Default session timeout
jwt_remember_me_days: 7 # "Remember me" duration
```
**Single-User Mode** (`multiuser: false` or option absent):
- No authentication required
- All functionality enabled by default
- All boards and images visible in unified view
- Ideal for personal use or trusted environments
**Multi-User Mode** (`multiuser: true`):
- Authentication required for access
- User isolation for boards, images, and workflows
- Role-based permissions enforced
- Ideal for shared servers or team environments
!!! warning "Mode Switching Behavior"
**Switching to Single-User Mode:** If boards or images were created in multi-user mode, they will all be combined into a single unified view when switching to single-user mode.
**Switching to Multi-User Mode:** Legacy boards and images created under single-user mode will be owned by an internal user named "system." Only the Administrator will have access to these legacy assets. A utility to migrate these legacy assets to another user will be part of a future release.
### Migration from Single-User
When upgrading from a single-user installation or switching modes:
1. **Automatic Migration**: The database will automatically migrate to multi-user schema when multi-user mode is first enabled
2. **Legacy Data Ownership**: Existing data (boards, images, workflows) created in single-user mode is assigned to an internal user named "system"
3. **Administrator Access**: Only administrators will have access to legacy "system"-owned assets when in multi-user mode
4. **No Data Loss**: All existing content is preserved
**Migration Process:**
```bash
# Backup your database first
cp databases/invokeai.db databases/invokeai.db.backup
# Enable multi-user mode in invokeai.yaml
# multiuser: true
# Start InvokeAI (migration happens automatically)
invokeai-web
# Complete the administrator setup dialog
# Legacy data will be owned by "system" user
```
!!! note "Legacy Asset Migration"
A utility to migrate legacy "system"-owned assets to specific user accounts will be available in a future release. Until then, administrators can access and manage all legacy content.
## User Management
### Creating Users
**Via Web Interface (Coming Soon):**
!!! info "Web UI for User Management"
A web-based user interface that allows administrators to manage users is coming in a future release. Until then, use the command-line scripts described below.
**Via Command Line Scripts:**
InvokeAI provides several command-line scripts in the `scripts/` directory for user management:
**useradd.py** - Add a new user:
```bash
# Interactive mode (prompts for details)
python scripts/useradd.py
# Create a regular user
python scripts/useradd.py \
--email user@example.com \
--password TempPass123 \
--name "User Name"
# Create an administrator
python scripts/useradd.py \
--email admin@example.com \
--password AdminPass123 \
--name "Admin Name" \
--admin
```
**userlist.py** - List all users:
```bash
# List all users
python scripts/userlist.py
# Show detailed information
python scripts/userlist.py --verbose
```
**usermod.py** - Modify an existing user:
```bash
# Change display name
python scripts/usermod.py --email user@example.com --name "New Name"
# Promote to administrator
python scripts/usermod.py --email user@example.com --admin
# Demote from administrator
python scripts/usermod.py --email user@example.com --no-admin
# Deactivate account
python scripts/usermod.py --email user@example.com --deactivate
# Reactivate account
python scripts/usermod.py --email user@example.com --activate
# Change password
python scripts/usermod.py --email user@example.com --password NewPassword123
```
**userdel.py** - Delete a user:
```bash
# Delete a user (prompts for confirmation)
python scripts/userdel.py --email user@example.com
# Delete without confirmation
python scripts/userdel.py --email user@example.com --force
```
!!! tip "Script Usage"
Run any script with `--help` to see all available options:
```bash
python scripts/useradd.py --help
```
!!! warning "Command Line Management"
- These scripts directly modify the database
- Always backup your database before making changes
- Changes take effect immediately (users may need to log in again)
- Deleting a user permanently removes all their content
### Editing Users
**Via Command Line:**
Use `usermod.py` as described above to modify user properties.
!!! warning "Last Administrator"
You cannot remove admin privileges from the last remaining administrator account.
### Resetting User Passwords
**Via Web Interface (Coming Soon):**
Web-based password reset functionality for administrators is coming in a future release.
**Via Command Line:**
```bash
# Reset a user's password
python scripts/usermod.py --email user@example.com --password NewTempPassword123
```
**Security Note:** Never send passwords via email or unsecured channels. Use secure communication methods.
### Deactivating Users
**Via Command Line:**
```bash
# Deactivate a user account
python scripts/usermod.py --email user@example.com --deactivate
# Reactivate a user account
python scripts/usermod.py --email user@example.com --activate
```
**Effects:**
- User cannot log in when deactivated
- Existing sessions are immediately invalidated
- User's data is preserved
- Can be reactivated at any time
### Deleting Users
**Via Command Line:**
```bash
# Delete a user (prompts for confirmation)
python scripts/userdel.py --email user@example.com
# Delete without confirmation prompt
python scripts/userdel.py --email user@example.com --force
```
**Important:**
- ⚠️ This action is **permanent**
- User's boards, images, and workflows are deleted
- Cannot be undone
- Consider deactivating instead of deleting
!!! warning "Data Loss"
Deleting a user permanently removes all their content. Back up the database first if recovery might be needed.
### Viewing User Activity
**Queue Management:**
1. Navigate to **Admin** → **Queue Overview**
2. View all users' active and pending generations
3. Filter by user
4. Cancel stuck or problematic tasks
**User Statistics:**
- Number of boards created
- Number of images generated
- Storage usage (if enabled)
- Last login time
## Model Management
As an administrator, you have full access to model management.
### Adding Models
**Via Model Manager UI:**
1. Go to **Models** tab
2. Click **Add Model**
3. Choose installation method:
- **From URL**: Provide HuggingFace repo or download URL
- **From Local Path**: Scan local directories
- **Import**: Import model from filesystem
**Supported Model Types:**
- Main models (Stable Diffusion, SDXL, FLUX)
- LoRA models
- ControlNet models
- VAE models
- Textual Inversions
- IP-Adapters
### Configuring Models
**Model Settings:**
- Display name
- Description
- Default generation settings (CFG, steps, scheduler)
- Variant selection (fp16/fp32)
- Model thumbnail image
**Default Settings:**
Set default parameters that users will start with:
1. Select a model
2. Go to **Default Settings** tab
3. Configure:
- CFG Scale
- Steps
- Scheduler
- VAE selection
4. Save settings
### Removing Models
1. Go to **Models** tab
2. Select model(s) to remove
3. Click **Delete**
4. Confirm deletion
!!! warning "Impact"
Removing a model affects all users who may be using it in workflows or saved settings.
## Shared Boards
Shared boards enable collaboration between users while maintaining control.
!!! note "Future Feature"
Board sharing will be implemented in a future release.
### Creating Shared Boards
1. Log in as administrator
2. Create a new board (or use existing board)
3. Right-click the board → **Share Board**
4. Add users and set permissions
5. Click **Save Sharing Settings**
### Permission Levels
| Level | View | Add Images | Edit/Delete | Manage Sharing |
|-------|------|------------|-------------|----------------|
| **Read** | ✅ | ❌ | ❌ | ❌ |
| **Write** | ✅ | ✅ | ✅ | ❌ |
| **Admin** | ✅ | ✅ | ✅ | ✅ |
**Permission Recommendations:**
- **Read**: For viewers who should see but not modify content
- **Write**: For active collaborators who add and organize images
- **Admin**: For trusted users who help manage the shared board
### Managing Shared Boards
**Add Users to Shared Board:**
1. Right-click shared board → **Manage Sharing**
2. Click **Add User**
3. Select user from dropdown
4. Choose permission level
5. Save changes
**Remove Users from Shared Board:**
1. Right-click shared board → **Manage Sharing**
2. Find user in list
3. Click **Remove**
4. Confirm removal
**Change User Permissions:**
1. Right-click shared board → **Manage Sharing**
2. Find user in list
3. Change permission dropdown
4. Save changes
### Shared Board Best Practices
- Give meaningful names to shared boards
- Document the board's purpose in the description
- Assign minimum necessary permissions
- Regularly audit access lists
- Remove users who no longer need access
## Security
### Password Policies
**Enforced Requirements:**
- Minimum 8 characters
- Must contain uppercase letters
- Must contain lowercase letters
- Must contain numbers
**Recommended Policies:**
- Require 12+ character passwords
- Include special characters
- Implement password rotation every 90 days
- Prevent password reuse
- Use multi-factor authentication (when available)
### Session Management
**Session Security and Token Management:**
This system uses stateless JWT tokens with HMAC signatures to
identify users after they provide their initial credentials. The
tokens will persist for 24 hours by default, or for 7 days if the user
clicks the "Remember me" checkbox at login. Expired tokens are
automatically rejected and the user will have to log in again.
At the client side, tokens are stored in browser localStorage. Logging
out clears them. No server-side session storage is required.
The tokens include the user's ID, email, and admin status, along with
an HMAC signature.
### Secret Key Management
**Important:** The JWT secret key must be kept confidential.
To generate tokens, each InvokeAI instance has a distinct secret JWT key that must be
kept confidential. The key is stored in the `app_settings` table of
the InvokeAI database with in a field value named `jwt_secret`.
The secret key is automatically generated during database creation or
migration. If you wish to change the key, you may generate a
replacement using either of these commands:
```bash
# Python
python -c "import secrets; print(secrets.token_urlsafe(32))"
# OpenSSL
openssl rand -base64 32
```
Then cut and paste the printed secret into this Sqlite3 command:
```bash
sqlite3 INVOKE_ROOT/databases/invokeai.db 'update app_settings set value="THE_SECRET" where key="jwt_secret"'
```
(replace INVOKE_ROOT with your InvokeAI root directory and THE_SECRET
with the new secret).
After this, restart the server. All logged in users will be logged out
and will need to provide their usernames and passwords again.
### Hosting a Shared InvokeAI Instance
The multiuser feature allows you to run an InvokeAI backend that can
be accessed by your friends and family across your home network. It is
also possible to host a backend that is accessible over the Internet.
By default, InvokeAI runs on `localhost`, IP address `127.0.0.1`,
which is only accessible to browsers running on the same machine as
the backend. To make the backend accessible to any machine on your
home or work LAN, add the line `host: 0.0.0.0` to the InvokeAI
configuration file, usually stored at `INVOKE_ROOT/invokeai.yaml`.
Here is a minimal example.
```yaml
# Internal metadata - do not edit:
schema_version: 4.0.2
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/configuration/:
multiuser: true
host: 0.0.0.0
```
After relaunching the backend you will be able to reach the server
from other machines on the LAN using the server machine's IP address
or hostname and port 9090.
#### Connecting to the Internet
!!! warning "Use at your own risk"
The InvokeAI team has done its best to make the software free of
exploitable bugs, but the software has not undergone a rigorous security
audit or intrusion testing. Use at your own risk
It is also possible to create a (semi) public server accessible from
the Internet. The details of how to do this depend very much on your
home or corporate router/firewall system and are beyond the scope of
this document.
If you expose InvokeAI to the Internet, there are a number of
precautions to take. Here is a brief list of recommended network
security practices.
**HTTPS Configuration:**
For internet deployments, always use HTTPS:
```yaml
# Use a reverse proxy like nginx or Traefik
# Example nginx configuration:
server {
listen 443 ssl http2;
server_name invoke.example.com;
ssl_certificate /path/to/cert.pem;
ssl_certificate_key /path/to/key.pem;
location / {
proxy_pass http://localhost:9090;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# WebSocket support
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
}
}
```
**Firewall Rules:**
It is best to restrict access to trusted networks and remote IP
addresses, or use a VPN to connect to your home network. Rate limit
connections to InvokeAI's authentication endpoint
`http://your.host:9090/login`.
**Backup and Recovery:**
It is a good idea to periodically backup your InvokeAI database,
images, and possibly models in the event of unauthorized use of a
publicly-accessible server.
**Manual Backup:**
```bash
# Stop InvokeAI
# Copy database file
cd INVOKE_ROOT
cp databases/invokeai.db databases/invokeai.db.$(date +%Y%m%d)
# Or create compressed backup
tar -czf invokeai_backup_$(date +%Y%m%d).tar.gz databases/
```
**Automated Backup Script:**
```bash
#!/bin/bash
# backup_invokeai.sh
INVOKE_ROOT="/path/to/invoke_root"
BACKUP_DIR="/path/to/backups"
DB_PATH="$INVOKE_ROOT/databases/invokeai.db"
DATE=$(date +%Y%m%d_%H%M%S)
# Create backup directory
mkdir -p "$BACKUP_DIR"
# Copy database
cp "$DB_PATH" "$BACKUP_DIR/invokeai_$DATE.db"
# Keep only last 30 days
find "$BACKUP_DIR" -name "invokeai_*.db" -mtime +30 -delete
echo "Backup completed: invokeai_$DATE.db"
```
**Schedule with cron:**
```bash
# Edit crontab
crontab -e
# Add daily backup at 2 AM
0 2 * * * /path/to/backup_invokeai.sh
```
```bash
# Stop InvokeAI
# Replace current database with backup
cd INVOKE_ROOT
cp databases/invokeai.db databases/invokeai.db.old # Save current
cp databases/invokeai_backup.db databases/invokeai.db
# Restart InvokeAI
invokeai-web
```
**Disaster Recover - Complete System Backup:**
Include these directories/files:
- `databases/` - All database files
- `models/` - Installed models (if locally stored)
- `outputs/` - Generated images
- `invokeai.yaml` - Configuration file
- Any custom scripts or modifications
**Recovery Process:**
1. Install InvokeAI on new system
2. Restore configuration file
3. Restore database directory
4. Restore models and outputs
5. Verify file permissions
6. Start InvokeAI and test
## Troubleshooting
### User Cannot Login
**Symptom:** User reports unable to log in
**Diagnosis:**
1. Verify account exists and is active
```bash
sqlite3 databases/invokeai.db "SELECT * FROM users WHERE email = 'user@example.com';"
```
2. Check password (have user try resetting)
3. Verify account is active (`is_active = 1`)
4. Check for account lockout (if implemented)
**Solutions:**
- Reset user password
- Reactivate disabled account
- Verify email address is correct
- Check system logs for auth errors
### Database Locked Errors
**Symptom:** "Database is locked" errors
**Causes:**
- Concurrent write operations
- Long-running transactions
- Backup process accessing database
- File system issues
**Solutions:**
```bash
# Check for locks
fuser databases/invokeai.db
# Increase timeout (in config)
# Or switch to WAL mode:
sqlite3 databases/invokeai.db "PRAGMA journal_mode=WAL;"
```
### Forgotten Admin Password
**Recovery Process:**
1. Stop InvokeAI
2. Direct database access:
```bash
sqlite3 databases/invokeai.db
```
3. Reset admin password (requires password hash):
```sql
-- Generate hash first using Python:
-- from passlib.context import CryptContext
-- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
-- print(pwd_context.hash("NewPassword123"))
UPDATE users
SET password_hash = '$2b$12$...'
WHERE email = 'admin@example.com';
```
4. Restart InvokeAI
**Alternative:** Remove `jwt_secret_key` from config to trigger setup wizard (will create new admin).
### Performance Issues
**Symptom:** Slow generation or UI
**Diagnosis:**
1. Check active generation count
2. Review resource usage (CPU/GPU/RAM)
3. Check database size and performance
4. Review network latency
**Solutions:**
- Limit concurrent generations
- Increase hardware resources
- Optimize database (`VACUUM`, `ANALYZE`)
- Add indexes for slow queries
- Consider load balancing
### Migration Failures
**Symptom:** Database migration fails on upgrade
**Prevention:**
- Always backup before upgrading
- Test migration on copy of database
- Review migration logs
**Recovery:**
```bash
# Restore backup
cp databases/invokeai.db.backup databases/invokeai.db
# Try migration again with verbose logging
invokeai-web --log-level DEBUG
```
## Configuration Reference
### Complete Configuration Example for a Public Site
```yaml
# invokeai.yaml - Multi-user configuration
# Internal metadata - do not edit:
schema_version: 4.0.2
# Put user settings here
multiuser: true
# Server
host: "0.0.0.0"
port: 9090
# Performance
enable_partial_loading: true
precision: float16
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
hashing_algorithm: blake3_multi
```
## Frequently Asked Questions
### How many users can InvokeAI support?
The backend will support dozens of concurrent users. However, because
the image generation queue is single-threaded, image generation tasks
are processed on a first-come, first-serve basis. This means that a
user may have to wait for all the other users' image generation jobs
to complete before their generation job starts to execute.
A future version of InvokeAI may support concurrent execution on
systems with multiple GPUs/graphics cards.
### Can I integrate with existing authentication systems?
OAuth2/OpenID Connect support is planned for a future release. Currently, InvokeAI uses its own authentication system.
### How do I audit user actions?
Full audit logging is planned for a future release. Currently, you can:
- Monitor the generation queue
- Review database changes
- Check application logs
### Can users have different model access?
Not in the current release. All users can view and use all installed models. Per-user model access is a possible enhancement.
### How do I handle user data when they leave?
Best practice:
1. Deactivate the account first
2. Transfer ownership of shared boards
3. After transition period, delete the account
4. Or keep the account deactivated for audit purposes
### What's the licensing impact of multi-user mode?
InvokeAI remains under its existing license. Multi-user mode does not change licensing terms.
## Getting Help
### Support Resources
- **Documentation**: [InvokeAI Docs](https://invoke-ai.github.io/InvokeAI/)
- **Discord**: [Join Community](https://discord.gg/ZmtBAhwWhy)
- **GitHub Issues**: [Report Problems](https://github.com/invoke-ai/InvokeAI/issues)
- **User Guide**: [For Users](user_guide.md)
- **API Guide**: [For Developers](api_guide.md)
### Reporting Issues
When reporting administrator issues, include:
- InvokeAI version
- Operating system and version
- Database size and user count
- Relevant log excerpts
- Steps to reproduce
- Expected vs actual behavior
## Additional Resources
- [User Guide](user_guide.md) - For end users
- [API Guide](api_guide.md) - For API consumers
- [Multiuser Specification](specification.md) - Technical details
---
**Need additional assistance?** Visit the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy) or file an issue on [GitHub](https://github.com/invoke-ai/InvokeAI/issues).

1224
docs/multiuser/api_guide.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,870 @@
# InvokeAI Multi-User Support - Detailed Specification
## 1. Executive Summary
This document provides a comprehensive specification for adding multi-user support to InvokeAI. The feature will enable a single InvokeAI instance to support multiple isolated users, each with their own generation settings, image boards, and workflows, while maintaining administrative controls for model management and system configuration.
## 2. Overview
### 2.1 Goals
- Enable multiple users to share a single InvokeAI instance
- Provide user isolation for personal content (boards, images, workflows, settings)
- Maintain centralized model management by administrators
- Support shared boards for collaboration
- Provide secure authentication and authorization
- Minimize impact on existing single-user installations
### 2.2 Non-Goals
- Real-time collaboration features (multiple users editing same workflow simultaneously)
- Advanced team management features (in initial release)
- Migration of existing multi-user enterprise edition data
- Support for external identity providers (in initial release, can be added later)
## 3. User Roles and Permissions
### 3.1 Administrator Role
**Capabilities:**
- Full access to all InvokeAI features
- Model management (add, delete, configure models)
- User management (create, edit, delete users)
- View and manage all users' queue sessions
- Access system configuration
- Create and manage shared boards
- Grant/revoke administrative privileges to other users
**Restrictions:**
- Cannot delete their own account if they are the last administrator
- Cannot revoke their own admin privileges if they are the last administrator
### 3.2 Regular User Role
**Capabilities:**
- Create, edit, and delete their own image boards
- Upload and manage their own assets
- Use all image generation tools (linear, canvas, upscale, workflow tabs)
- Create, edit, save, and load workflows
- Access public/shared workflows
- View and manage their own queue sessions
- Adjust personal UI preferences (theme, hotkeys, etc.)
- Access shared boards (read/write based on permissions)
- **View model configurations** (read-only access to model manager)
- **View model details, default settings, and metadata**
**Restrictions:**
- Cannot add, delete, or edit models
- **Can view but cannot modify model manager settings** (read-only access)
- Cannot reidentify, convert, or update model paths
- Cannot upload or change model thumbnail images
- Cannot save changes to model default settings
- Cannot perform bulk delete operations on models
- Cannot view or modify other users' boards, images, or workflows
- Cannot cancel or modify other users' queue sessions
- Cannot access system configuration
- Cannot manage users or permissions
### 3.3 Future Role Considerations
- **Viewer Role**: Read-only access (future enhancement)
- **Team/Group-based Permissions**: Organizational hierarchy (future enhancement)
## 4. Authentication System
### 4.1 Authentication Method
- **Primary Method**: Username and password authentication with secure password hashing
- **Password Hashing**: Use bcrypt or Argon2 for password storage
- **Session Management**: JWT tokens or secure session cookies
- **Token Expiration**: Configurable session timeout (default: 7 days for "remember me", 24 hours otherwise)
### 4.2 Initial Administrator Setup
**First-time Launch Flow:**
1. Application detects no administrator account exists
2. Displays mandatory setup dialog (cannot be skipped)
3. Prompts for:
- Administrator username (email format recommended)
- Administrator display name
- Strong password (minimum requirements enforced)
- Password confirmation
4. Stores hashed credentials in configuration
5. Creates administrator account in database
6. Proceeds to normal login screen
**Reset Capability:**
- Administrators can be reset by manually editing the config file
- Requires access to server filesystem (intentional security measure)
- Database maintains user records; config file contains root admin credentials
### 4.3 Password Requirements
- Minimum 8 characters
- At least one uppercase letter
- At least one lowercase letter
- At least one number
- At least one special character (optional but recommended)
- Not in common password list
### 4.4 Login Flow
1. User navigates to InvokeAI URL
2. If not authenticated, redirect to login page
3. User enters username/email and password
4. Optional "Remember me" checkbox for extended session
5. Backend validates credentials
6. On success: Generate session token, redirect to application
7. On failure: Display error, allow retry with rate limiting (prevent brute force)
### 4.5 Logout Flow
- User clicks logout button
- Frontend clears session token
- Backend invalidates session (if using server-side sessions)
- Redirect to login page
### 4.6 Future Authentication Enhancements
- OAuth2/OpenID Connect support
- Two-factor authentication (2FA)
- SSO integration
- API key authentication for programmatic access
## 5. User Management
### 5.1 User Creation (Administrator)
**Flow:**
1. Administrator navigates to user management interface
2. Clicks "Add User" button
3. Enters user information:
- Email address (required, used as username)
- Display name (optional, defaults to email)
- Role (User or Administrator)
- Initial password or "Send invitation email"
4. System validates email uniqueness
5. System creates user account
6. If invitation mode:
- Generate one-time secure token
- Send email with setup link
- Link expires after 7 days
7. If direct password mode:
- Administrator provides initial password
- User must change on first login
**Invitation Email Flow:**
1. User receives email with unique link
2. Link contains secure token
3. User clicks link, redirected to setup page
4. User enters desired password
5. Token validated and consumed (single-use)
6. Account activated
7. User redirected to login page
### 5.2 User Profile Management
**User Self-Service:**
- Update display name
- Change password (requires current password)
- Update email address (requires verification)
- Manage UI preferences
- View account creation date and last login
**Administrator Actions:**
- Edit user information (name, email)
- Reset user password (generates reset link)
- Toggle administrator privileges
- Assign to groups (future feature)
- Suspend/unsuspend account
- Delete account (with data retention options)
### 5.3 Password Reset Flow
**User-Initiated (Future Enhancement):**
1. User clicks "Forgot Password" on login page
2. Enters email address
3. System sends password reset link (if email exists)
4. User clicks link, enters new password
5. Password updated, user can login
**Administrator-Initiated:**
1. Administrator selects user
2. Clicks "Send Password Reset"
3. System generates reset token and link
4. Email sent to user
5. User follows same flow as user-initiated reset
## 6. Data Model and Database Schema
### 6.1 New Tables
#### 6.1.1 users
```sql
CREATE TABLE users (
user_id TEXT NOT NULL PRIMARY KEY,
email TEXT NOT NULL UNIQUE,
display_name TEXT,
password_hash TEXT NOT NULL,
is_admin BOOLEAN NOT NULL DEFAULT FALSE,
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
last_login_at DATETIME
);
CREATE INDEX idx_users_email ON users(email);
CREATE INDEX idx_users_is_admin ON users(is_admin);
CREATE INDEX idx_users_is_active ON users(is_active);
```
#### 6.1.2 user_sessions
```sql
CREATE TABLE user_sessions (
session_id TEXT NOT NULL PRIMARY KEY,
user_id TEXT NOT NULL,
token_hash TEXT NOT NULL,
expires_at DATETIME NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
last_activity_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
user_agent TEXT,
ip_address TEXT,
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
CREATE INDEX idx_user_sessions_user_id ON user_sessions(user_id);
CREATE INDEX idx_user_sessions_expires_at ON user_sessions(expires_at);
CREATE INDEX idx_user_sessions_token_hash ON user_sessions(token_hash);
```
#### 6.1.3 user_invitations
```sql
CREATE TABLE user_invitations (
invitation_id TEXT NOT NULL PRIMARY KEY,
email TEXT NOT NULL,
token_hash TEXT NOT NULL,
invited_by_user_id TEXT NOT NULL,
expires_at DATETIME NOT NULL,
used_at DATETIME,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
FOREIGN KEY (invited_by_user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
CREATE INDEX idx_user_invitations_email ON user_invitations(email);
CREATE INDEX idx_user_invitations_token_hash ON user_invitations(token_hash);
CREATE INDEX idx_user_invitations_expires_at ON user_invitations(expires_at);
```
#### 6.1.4 shared_boards
```sql
CREATE TABLE shared_boards (
board_id TEXT NOT NULL,
user_id TEXT NOT NULL,
permission TEXT NOT NULL CHECK(permission IN ('read', 'write', 'admin')),
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
PRIMARY KEY (board_id, user_id),
FOREIGN KEY (board_id) REFERENCES boards(board_id) ON DELETE CASCADE,
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
CREATE INDEX idx_shared_boards_user_id ON shared_boards(user_id);
CREATE INDEX idx_shared_boards_board_id ON shared_boards(board_id);
```
### 6.2 Modified Tables
#### 6.2.1 boards
```sql
-- Add columns:
ALTER TABLE boards ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
ALTER TABLE boards ADD COLUMN is_shared BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE boards ADD COLUMN created_by_user_id TEXT;
-- Add foreign key (requires recreation in SQLite):
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
FOREIGN KEY (created_by_user_id) REFERENCES users(user_id) ON DELETE SET NULL
-- Add indices:
CREATE INDEX idx_boards_user_id ON boards(user_id);
CREATE INDEX idx_boards_is_shared ON boards(is_shared);
```
#### 6.2.2 images
```sql
-- Add column:
ALTER TABLE images ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
-- Add foreign key:
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
-- Add index:
CREATE INDEX idx_images_user_id ON images(user_id);
```
#### 6.2.3 workflows
```sql
-- Add columns:
ALTER TABLE workflows ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
ALTER TABLE workflows ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;
-- Add foreign key:
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
-- Add indices:
CREATE INDEX idx_workflows_user_id ON workflows(user_id);
CREATE INDEX idx_workflows_is_public ON workflows(is_public);
```
#### 6.2.4 session_queue
```sql
-- Add column:
ALTER TABLE session_queue ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
-- Add foreign key:
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
-- Add index:
CREATE INDEX idx_session_queue_user_id ON session_queue(user_id);
```
#### 6.2.5 style_presets
```sql
-- Add columns:
ALTER TABLE style_presets ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
ALTER TABLE style_presets ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;
-- Add foreign key:
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
-- Add indices:
CREATE INDEX idx_style_presets_user_id ON style_presets(user_id);
CREATE INDEX idx_style_presets_is_public ON style_presets(is_public);
```
### 6.3 Migration Strategy
1. Create new user tables (users, user_sessions, user_invitations, shared_boards)
2. Create default 'system' user for backward compatibility
3. Update existing data to reference 'system' user
4. Add foreign key constraints
5. Version as database migration (e.g., migration_25.py)
### 6.4 Migration for Existing Installations
- Single-user installations: Prompt to create admin account on first launch after update
- Existing data migration: Administrator can specify an arbitrary user account to hold legacy data (can be the admin account or a separate user)
- System provides UI during migration to choose destination user for existing data
## 7. API Endpoints
### 7.1 Authentication Endpoints
#### POST /api/v1/auth/setup
- Initialize first administrator account
- Only works if no admin exists
- Body: `{ email, display_name, password }`
- Response: `{ success, user }`
#### POST /api/v1/auth/login
- Authenticate user
- Body: `{ email, password, remember_me? }`
- Response: `{ token, user, expires_at }`
#### POST /api/v1/auth/logout
- Invalidate current session
- Headers: `Authorization: Bearer <token>`
- Response: `{ success }`
#### GET /api/v1/auth/me
- Get current user information
- Headers: `Authorization: Bearer <token>`
- Response: `{ user }`
#### POST /api/v1/auth/change-password
- Change current user's password
- Body: `{ current_password, new_password }`
- Headers: `Authorization: Bearer <token>`
- Response: `{ success }`
### 7.2 User Management Endpoints (Admin Only)
#### GET /api/v1/users
- List all users (paginated)
- Query params: `offset`, `limit`, `search`, `role_filter`
- Response: `{ users[], total, offset, limit }`
#### POST /api/v1/users
- Create new user
- Body: `{ email, display_name, is_admin, send_invitation?, initial_password? }`
- Response: `{ user, invitation_link? }`
#### GET /api/v1/users/{user_id}
- Get user details
- Response: `{ user }`
#### PATCH /api/v1/users/{user_id}
- Update user
- Body: `{ display_name?, is_admin?, is_active? }`
- Response: `{ user }`
#### DELETE /api/v1/users/{user_id}
- Delete user
- Query params: `delete_data` (true/false)
- Response: `{ success }`
#### POST /api/v1/users/{user_id}/reset-password
- Send password reset email
- Response: `{ success, reset_link }`
### 7.3 Shared Boards Endpoints
#### POST /api/v1/boards/{board_id}/share
- Share board with users
- Body: `{ user_ids[], permission: 'read' | 'write' | 'admin' }`
- Response: `{ success, shared_with[] }`
#### GET /api/v1/boards/{board_id}/shares
- Get board sharing information
- Response: `{ shares[] }`
#### DELETE /api/v1/boards/{board_id}/share/{user_id}
- Remove board sharing
- Response: `{ success }`
### 7.4 Modified Endpoints
All existing endpoints will be modified to:
1. Require authentication (except setup/login)
2. Filter data by current user (unless admin viewing all)
3. Enforce permissions (e.g., model management requires admin)
4. Include user context in operations
Example modifications:
- `GET /api/v1/boards` → Returns only user's boards + shared boards
- `POST /api/v1/session/queue` → Associates queue item with current user
- `GET /api/v1/queue` → Returns all items for admin, only user's items for regular users
## 8. Frontend Changes
### 8.1 New Components
#### LoginPage
- Email/password form
- "Remember me" checkbox
- Login button
- Forgot password link (future)
- Branding and welcome message
#### AdministratorSetup
- Modal dialog (cannot be dismissed)
- Administrator account creation form
- Password strength indicator
- Terms/welcome message
#### UserManagementPage (Admin only)
- User list table
- Add user button
- User actions (edit, delete, reset password)
- Search and filter
- Role toggle
#### UserProfilePage
- Display user information
- Change password form
- UI preferences
- Account details
#### BoardSharingDialog
- User picker/search
- Permission selector
- Share button
- Current shares list
### 8.2 Modified Components
#### App Root
- Add authentication check
- Redirect to login if not authenticated
- Handle session expiration
- Add global error boundary for auth errors
#### Navigation/Header
- Add user menu with logout
- Display current user name
- Admin indicator badge
#### ModelManagerTab
- Hide/disable for non-admin users
- Show "Admin only" message
#### QueuePanel
- Filter by current user (for non-admin)
- Show all with user indicators (for admin)
- Disable actions on other users' items (for non-admin)
#### BoardsPanel
- Show personal boards section
- Show shared boards section
- Add sharing controls to board actions
### 8.3 State Management
New Redux slices/zustand stores:
- `authSlice`: Current user, authentication status, token
- `usersSlice`: User list for admin interface
- `sharingSlice`: Board sharing state
Updated slices:
- `boardsSlice`: Include shared boards, ownership info
- `queueSlice`: Include user filtering
- `workflowsSlice`: Include public/private status
## 9. Configuration
### 9.1 New Config Options
Add to `InvokeAIAppConfig`:
```python
# Authentication
auth_enabled: bool = True # Enable/disable multi-user auth
session_expiry_hours: int = 24 # Default session expiration
session_expiry_hours_remember: int = 168 # "Remember me" expiration (7 days)
password_min_length: int = 8 # Minimum password length
require_strong_passwords: bool = True # Enforce password complexity
# Session tracking
enable_server_side_sessions: bool = False # Optional server-side session tracking
# Audit logging
audit_log_auth_events: bool = True # Log authentication events
audit_log_admin_actions: bool = True # Log administrative actions
# Email (optional - for invitations and password reset)
email_enabled: bool = False
smtp_host: str = ""
smtp_port: int = 587
smtp_username: str = ""
smtp_password: str = ""
smtp_from_address: str = ""
smtp_from_name: str = "InvokeAI"
# Initial admin (stored as hash)
admin_email: Optional[str] = None
admin_password_hash: Optional[str] = None
```
### 9.2 Backward Compatibility
- If `auth_enabled = False`, system runs in legacy single-user mode
- All data belongs to implicit "system" user
- No authentication required
- Smooth upgrade path for existing installations
## 10. Security Considerations
### 10.1 Password Security
- Never store passwords in plain text
- Use bcrypt or Argon2id for password hashing
- Implement proper salt generation
- Enforce password complexity requirements
- Implement rate limiting on login attempts
- Consider password breach checking (Have I Been Pwned API)
### 10.2 Session Security
- Use cryptographically secure random tokens
- Implement token rotation
- Set appropriate cookie flags (HttpOnly, Secure, SameSite)
- Implement session timeout and renewal
- Invalidate sessions on logout
- Clean up expired sessions periodically
### 10.3 Authorization
- Always verify user identity from session token (never trust client)
- Check permissions on every API call
- Implement principle of least privilege
- Validate user ownership of resources before operations
- Implement proper error messages (avoid information leakage)
### 10.4 Data Isolation
- Strict separation of user data in database queries
- Prevent SQL injection via parameterized queries
- Validate all user inputs
- Implement proper access control checks
- Audit trail for sensitive operations
### 10.5 API Security
- Implement rate limiting on sensitive endpoints
- Use HTTPS in production (enforce via config)
- Implement CSRF protection
- Validate and sanitize all inputs
- Implement proper CORS configuration
- Add security headers (CSP, X-Frame-Options, etc.)
### 10.6 Deployment Security
- Document secure deployment practices
- Recommend reverse proxy configuration (nginx, Apache)
- Provide example configurations for HTTPS
- Document firewall requirements
- Recommend network isolation strategies
## 11. Email Integration (Optional)
**Note**: Email/SMTP configuration is optional. Many administrators will not have ready access to an outgoing SMTP server. When email is not configured, the system provides fallback mechanisms by displaying setup links directly in the admin UI.
### 11.1 Email Templates
#### User Invitation
```
Subject: You've been invited to InvokeAI
Hello,
You've been invited to join InvokeAI by [Administrator Name].
Click the link below to set up your account:
[Setup Link]
This link expires in 7 days.
---
InvokeAI
```
#### Password Reset
```
Subject: Reset your InvokeAI password
Hello [User Name],
A password reset was requested for your account.
Click the link below to reset your password:
[Reset Link]
This link expires in 24 hours.
If you didn't request this, please ignore this email.
---
InvokeAI
```
### 11.2 Email Service
- Support SMTP configuration
- Use secure connection (TLS)
- Handle email failures gracefully
- Implement email queue for reliability
- Log email activities (without sensitive data)
- Provide fallback for no-email deployments (show links in admin UI)
## 12. Testing Requirements
### 12.1 Unit Tests
- Authentication service (password hashing, validation)
- Authorization checks
- Token generation and validation
- User management operations
- Shared board permissions
- Data isolation queries
### 12.2 Integration Tests
- Complete authentication flows
- User creation and invitation
- Password reset flow
- Multi-user data isolation
- Shared board access
- Session management
- Admin operations
### 12.3 Security Tests
- SQL injection prevention
- XSS prevention
- CSRF protection
- Session hijacking prevention
- Brute force protection
- Authorization bypass attempts
### 12.4 Performance Tests
- Authentication overhead
- Query performance with user filters
- Concurrent user sessions
- Database scalability with many users
## 13. Documentation Requirements
### 13.1 User Documentation
- Getting started with multi-user InvokeAI
- Login and account management
- Using shared boards
- Understanding permissions
- Troubleshooting authentication issues
### 13.2 Administrator Documentation
- Setting up multi-user InvokeAI
- User management guide
- Creating and managing shared boards
- Email configuration
- Security best practices
- Backup and restore with user data
### 13.3 Developer Documentation
- Authentication architecture
- API authentication requirements
- Adding new multi-user features
- Database schema changes
- Testing multi-user features
### 13.4 Migration Documentation
- Upgrading from single-user to multi-user
- Data migration strategies
- Rollback procedures
- Common issues and solutions
## 14. Future Enhancements
### 14.1 Phase 2 Features
- **OAuth2/OpenID Connect integration** (deferred from initial release to keep scope manageable)
- Two-factor authentication
- API keys for programmatic access
- Enhanced team/group management
- Advanced permission system (roles and capabilities)
### 14.2 Phase 3 Features
- SSO integration (SAML, LDAP)
- User quotas and limits
- Resource usage tracking
- Advanced collaboration features
- Workflow template library with permissions
- Model access controls per user/group
## 15. Success Metrics
### 15.1 Functionality Metrics
- Successful user authentication rate
- Zero unauthorized data access incidents
- All tests passing (unit, integration, security)
- API response time within acceptable limits
### 15.2 Usability Metrics
- User setup completion time < 2 minutes
- Login time < 2 seconds
- Clear error messages for all auth failures
- Positive user feedback on multi-user features
### 15.3 Security Metrics
- No critical security vulnerabilities identified
- CodeQL scan passes
- Penetration testing completed
- Security best practices followed
## 16. Risks and Mitigations
### 16.1 Technical Risks
| Risk | Impact | Probability | Mitigation |
|------|--------|-------------|------------|
| Performance degradation with user filtering | Medium | Low | Index optimization, query caching |
| Database migration failures | High | Low | Thorough testing, rollback procedures |
| Session management complexity | Medium | Medium | Use proven libraries (PyJWT), extensive testing |
| Auth bypass vulnerabilities | High | Low | Security review, penetration testing |
### 16.2 UX Risks
| Risk | Impact | Probability | Mitigation |
|------|--------|-------------|------------|
| Confusion in migration for existing users | Medium | High | Clear documentation, migration wizard |
| Friction from additional login step | Low | High | Remember me option, long session timeout |
| Complexity of admin interface | Medium | Medium | Intuitive UI design, user testing |
### 16.3 Operational Risks
| Risk | Impact | Probability | Mitigation |
|------|--------|-------------|------------|
| Email delivery failures | Low | Medium | Show links in UI, document manual methods |
| Lost admin password | High | Low | Document recovery procedure, config reset |
| User data conflicts in migration | Medium | Low | Data validation, backup requirements |
## 17. Implementation Phases
### Phase 1: Foundation (Weeks 1-2)
- Database schema design and migration
- Basic authentication service
- Password hashing and validation
- Session management
### Phase 2: Backend API (Weeks 3-4)
- Authentication endpoints
- User management endpoints
- Authorization middleware
- Update existing endpoints with auth
### Phase 3: Frontend Auth (Weeks 5-6)
- Login page and flow
- Administrator setup
- Session management
- Auth state management
### Phase 4: Multi-tenancy (Weeks 7-9)
- User isolation in all services
- Shared boards implementation
- Queue permission filtering
- Workflow public/private
### Phase 5: Admin Interface (Weeks 10-11)
- User management UI
- Board sharing UI
- Admin-specific features
- User profile page
### Phase 6: Testing & Polish (Weeks 12-13)
- Comprehensive testing
- Security audit
- Performance optimization
- Documentation
- Bug fixes
### Phase 7: Beta & Release (Week 14+)
- Beta testing with selected users
- Feedback incorporation
- Final testing
- Release preparation
- Documentation finalization
## 18. Acceptance Criteria
- [ ] Administrator can set up initial account on first launch
- [ ] Users can log in with email and password
- [ ] Users can change their password
- [ ] Administrators can create, edit, and delete users
- [ ] User data is properly isolated (boards, images, workflows)
- [ ] Shared boards work correctly with permissions
- [ ] Non-admin users cannot access model management
- [ ] Queue filtering works correctly for users and admins
- [ ] Session management works correctly (expiry, renewal, logout)
- [ ] All security tests pass
- [ ] API documentation is updated
- [ ] User and admin documentation is complete
- [ ] Migration from single-user works smoothly
- [ ] Performance is acceptable with multiple concurrent users
- [ ] Backward compatibility mode works (auth disabled)
## 19. Design Decisions
The following design decisions have been approved for implementation:
1. **OAuth2 Priority**: OAuth2/OpenID Connect integration will be a **future enhancement**. The initial release will focus on username/password authentication to keep scope manageable.
2. **Email Requirement**: Email/SMTP configuration is **optional**. Many administrators will not have ready access to an outgoing SMTP server. The system will provide fallback mechanisms (showing setup links directly in the admin UI) when email is not configured.
3. **Data Migration**: During migration from single-user to multi-user mode, the administrator will be given the **option to specify an arbitrary user account** to hold legacy data. The admin account can be used for this purpose if the administrator wishes.
4. **API Compatibility**: Authentication will be **required on all APIs**, but authentication will not be required if multi-user support is disabled (backward compatibility mode with `auth_enabled: false`).
5. **Session Storage**: The system will use **JWT tokens with optional server-side session tracking**. This provides scalability while allowing administrators to enable server-side tracking if needed.
6. **Audit Logging**: The system will **log authentication events and admin actions**. This provides accountability and security monitoring for critical operations.
## 20. Conclusion
This specification provides a comprehensive blueprint for implementing multi-user support in InvokeAI. The design prioritizes:
- **Security**: Proper authentication, authorization, and data isolation
- **Usability**: Intuitive UI, smooth migration, minimal friction
- **Scalability**: Efficient database design, performant queries
- **Maintainability**: Clean architecture, comprehensive testing
- **Flexibility**: Future enhancement paths, optional features
The phased implementation approach allows for iterative development and testing, while the detailed specifications ensure all stakeholders have clear expectations of the final system.

View File

@@ -0,0 +1,399 @@
# InvokeAI Multi-User Guide
## Overview
Multi-User mode is a recent feature (introduced in version 6.12), which allows multiple individuals to share a single InvokeAI server while keeping their work separate and organized. Each user has their own username and login password, images, assets, image boards, customization settings and workflows.
Two types of users are recognized:
* A user with **Administrator** status can add, remove and modify other users, and can install models. They also have the ability to view the full session queue and pause or kill other users' jobs.
* **Non-administrator** users can modify their own profile but not others. They also do not have the ability to install or configure models, but must ask an Administrator to do this task.
Multiple users can be granted Administrator status.
***
## Getting Started
To activate Multi-User mode, open the `INVOKEAI_ROOT/invokeai.yaml` configuration file in a text editor. Add this line anywhere in the file:
```yaml
multiuser: true
```
You may also wish to make InvokeAI available to other machines on your local LAN. Add an additional line to `invokeai.yaml`:
```yaml
host: 0.0.0.0
```
Restart the server. It will now be in multi-user mode. If you enabled
the `host` option, other users on your home or office LAN will be able
to reach it by browsing to the IP address of the machine the backend
is running on (`http://host-ip-address:9090`).
!!! tip "Do not expose InvokeAI to the internet"
It is not recommended to expose the InvokeAI host to the internet
due to security concerns.
### Initial Setup (First Time in Multi-User Mode)
If you're the first person to access a fresh InvokeAI installation in multi-user mode, you'll see the **Administrator Setup** dialog:
![Administrator Setup Screen](../../assets/multiuser/admin-setup.png)
Now
1. Enter your email address (this will be your login name)
2. Create a display name (this will be the name other users see)
3. Choose a strong password that meets the requirements:
- At least 8 characters long
- Contains uppercase letters
- Contains lowercase letters
- Contains numbers
4. Confirm your password
5. Click **Create Administrator Account**
You'll now be taken to a login screen and can enter the credentials
you just created.
### Adding and Modifying Users
If you are logged in as Administrator, you can add additional users. Click on the small "person silhouette" icon at the bottom left of the main Invoke screen and select "User Management:"
![Administrator Menu](../../assets/multiuser/admin-add-user-1.png)
This will take you to the User Management screen...
![User Management screen](../../assets/multiuser/admin-add-user-2.png)
...where you can click "Create User" to add a new user.
![Add User Screen](../../assets/multiuser/admin-add-user-3.png)
The User Management screen also allows you to:
1. Temporarily change a user's status to Inactive, preventing them from logging in to Invoke.
2. Edit a user (by clicking on the pencil icon) to change the user's display name or password.
3. Permanently delete a user.
4. Grant a user Administrator privileges.
### Command-line User Management Scripts
Administrators can also use a series of command-line scripts to add, modify, or delete users. If you use the launcher, click the ">" icon to enter the command-line interface. Otherwise, if you are a native command-line user, activate the InvokeAI environment from your terminal.
The commands are named:
* **invoke-useradd** -- add a user
* **invoke-usermod** -- modify a user
* **invoke-userdel** -- delete a user
* **invoke-userlist** -- list all users
Pass the `--help` argument to get the usage of each script. For example:
```bash
> invoke-useradd --help
usage: invoke-useradd [-h] [--root ROOT] [--email EMAIL] [--password PASSWORD] [--name NAME] [--admin]
Add a user to the InvokeAI database
options:
-h, --help show this help message and exit
--root ROOT, -r ROOT Path to the InvokeAI root directory. If omitted, the root is resolved in this order: the $INVOKEAI_ROOT environment
variable, the active virtual environment's parent directory, or $HOME/invokeai.
--email EMAIL, -e EMAIL
User email address
--password PASSWORD, -p PASSWORD
User password
--name NAME, -n NAME User display name (optional)
--admin, -a Make user an administrator
If no arguments are provided, the script will run in interactive mode.
```
***
## Logging in as a Non-Administrative User
If you are a registered user on the system, enter your email address and password to log in. The Administrator will be able to provide you with the values to use:
![Login Screen](../../assets/multiuser/user-login-1.png)
As an unprivileged user you can do pretty much anything that's allowed under single-user mode -- generating images, using LoRAs, creating and running workflows, creating image boards -- but you are restricted against installing new models, changing low-level server settings, or interfering with other users. More information on user roles is given below.
### Changing your Profile
To change your display name or profile, click on the person silhouette icon at the bottom left of the screen and choose "My Profile". This will take you to a screen that lets you change these values. At this time you can change your display name but not your login ID (ordinarily your contact email address).
***
## Understanding User Roles
In single-user mode, you have access to all features without restrictions. In multi-user mode, InvokeAI has two user roles:
### Regular User
As a regular user, you can:
- ✅ Create and manage your own image boards
- ✅ Generate images using all AI tools (Linear, Canvas, Upscale, Workflows)
- ✅ Create, save, and load your own workflows
- ✅ View your own generation queue
- ✅ Customize your UI preferences (theme, hotkeys, etc.)
- ✅ View available models (read-only access to Model Manager)
- ✅ Access shared boards (based on permissions granted to you) (FUTURE FEATURE)
- ✅ Access workflows marked as public (FUTURE FEATURE)
You cannot:
- ❌ Add, delete, or modify models
- ❌ View or modify other users' boards, images, or workflows
- ❌ Manage user accounts
- ❌ Access system configuration
- ❌ View or cancel other users' generation tasks
!!! tip "The generation queue"
When two or more users are accessing InvokeAI at the same time,
their image generation jobs will be placed on the session queue on
a first-come, first-serve basis. This means that you will have to
wait for other users' image rendering jobs to complete before
yours will start.
When another user's job is running, you will see the image
generation progress bar and a queue badge that reads `X/Y`, where
"X" is the number of jobs you have queued and "Y" is the total
number of jobs queued, including your own and others.
You can also pull up the Queue tab in order to see where your job
is in relationship to other queued tasks.
### Administrator
Administrators have all regular user capabilities, plus:
- ✅ Full model management (add, delete, configure models)
- ✅ Create and manage user accounts
- ✅ View and manage all users' generation queues
- ✅ Create and manage shared boards (FUTURE FEATURE)
- ✅ Access system configuration
- ✅ Grant or revoke admin privileges
***
## Working with Your Content in Multi-User Mode
### Image Boards
In multi-user model, Image Boards work as before. Each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards are private: you cannot see a board owned by a different user.
!!! tip "Shared Boards"
InvokeAI 6.13 will add support for creating public boards that are accessible to all users.
The Administrator can see all users Image Boards and their contents.
### Going From Multi-User to Single-User mode
If an InvokeAI instance was in multiuser mode and then restarted in single user mode (by setting `multiuser: false` in the configuration file), all users' boards will be consolidated in one place. Any images that were in "Uncategorized" will be merged together into a single Uncategorized board. If, at a later date, the server is restarted in multi-user mode, the boards and images will be separated and restored to their owners.
### Workflows
In the current released version (6.12) workflows are always shared among users. Any workflow that you create will be visible to other users and vice-versa, and there is no protection against one user modifying another user's workflow.
!!! tip "Private and Shared Workflows"
InvokeAI 6.13 will provide the ability to create private and shared workflows. A private workflow can only be viewed by the user who created it. At any time, however, the user can designate the workflow *shared*, in which case it can be opened on a read-only basis by all logged-in users.
### The Generation Queue
The queue shows your pending and running generation tasks.
**Queue Features:**
- View your current and completed generations
- Cancel pending tasks
- Re-run previous generations
- Monitor progress in real-time
**Queue Isolation:**
- You will see your own queue items, as well as the items generated by
either users, but the generation parameters (e.g. prompts) for other
users' are hidden for privacy reasons.
- Administrators can view all queues for troubleshooting
- Your generations won't interfere with other users' tasks
***
## Customizing Your Experience
### Personal Preferences
Your UI preferences are saved to your account and are restored when you log in:
- **Theme**: Choose between light and dark modes
- **Hotkeys**: Customize keyboard shortcuts
- **Canvas Settings**: Default zoom, grid visibility, etc.
- **Generation Defaults**: Default values for width, height, steps, etc.
These settings are stored per-user and won't affect other users.
***
## Troubleshooting
### Cannot Log In
**Issue:** Login fails with "Incorrect email or password"
**Solutions:**
- Verify you're entering the correct email address
- Check that Caps Lock is off
- Try typing the password slowly to avoid mistakes
- Contact your administrator if you've forgotten your password
**Issue:** Login fails with "Account is disabled"
**Solution:** Contact your administrator to reactivate your account
### Session Expired
**Issue:** You're suddenly logged out and see "Session expired"
**Explanation:** Sessions expire after 24 hours (or 7 days with "remember me")
**Solution:** Simply log in again with your credentials
### Cannot Access Features
**Issue:** Features like Model Manager show "Admin privileges required"
**Explanation:** Some features are restricted to administrators
**Solution:**
- For model viewing: You can view but not modify models
- For user management: Contact an administrator
- For system configuration: Contact an administrator
### Missing Boards or Images
**Issue:** Boards or images you created are not visible
**Possible Causes:**
1. **Filter Applied:** Check if a filter is hiding content
2. **Wrong User:** Ensure you're logged in with the correct account
3. **Archived Board:** Check the "Show Archived" option
**Solution:**
- Clear any active filters
- Verify you're logged in as the right user
- Check archived items
### Slow Performance
**Issue:** Generation or UI feels slower than expected
**Possible Causes:**
- Other users generating images simultaneously
- Server resource limits
- Network latency
**Solutions:**
- Check the queue to see if others are generating
- Wait for current generations to complete
- Contact administrator if persistent
### Generation Stuck in Queue
**Issue:** Your generation is queued but not starting
**Possible Causes:**
- Server is processing other users' generations
- Server resources are fully utilized
- Technical issue with the server
**Solutions:**
- Wait for your turn in the queue
- Check if your generation is paused
- Contact administrator if stuck for extended period
***
## Frequently Asked Questions
### Can other users see my images?
No, unless you add them to a shared board (FUTURE FEATURE). All your personal boards and images are private.
### Can I share my workflows with others?
Not directly. Ask your administrator to mark workflows as public if you want to share them.
### How long do sessions last?
- 24 hours by default
- 7 days if you check "Remember me" during login
### Can I use the API with multi-user mode?
Yes, but you'll need to authenticate with a JWT token. See the [API Guide](api_guide.md) for details.
### What happens if I forget my password?
Contact your administrator. They can reset your password for you.
### Can I have multiple sessions?
Yes, you can log in from multiple devices or browsers simultaneously. All sessions will use the same account and see the same content.
### Why can't I see the Model Manager "Add Models" tab?
Regular users can see the Models tab but with read-only access. Check that you're logged in and try refreshing the page.
### How do I know if I'm an administrator?
Administrators see an "Admin" badge next to their name in the top-right corner and have access to additional features like User Management.
### Can I request admin privileges?
Yes, ask your current administrator to grant you admin
privileges. Admin privileges will give you the ability to see all
other user's boards and images, as well as to add models and change
various server-wide settings.
## Getting Help
### Support Channels
- **Administrator:** Contact your system administrator for account issues
- **Documentation:** Check the [FAQ](../faq.md) for common issues
- **Community:** Join the [Discord](https://discord.gg/ZmtBAhwWhy) for help
- **Bug Reports:** File issues on [GitHub](https://github.com/invoke-ai/InvokeAI/issues)
### Reporting Issues
When reporting an issue, include:
- Your role (regular user or administrator)
- What you were trying to do
- What happened instead
- Any error messages you saw
- Your browser and operating system
## Additional Resources
- [Administrator Guide](admin_guide.md) - For administrators managing users and the system
- [API Guide](api_guide.md) - For developers using the InvokeAI API
- [Multiuser Specification](specification.md) - Technical details about the feature
- [InvokeAI Documentation](../index.md) - Main documentation hub
---
**Need more help?** Contact your administrator or visit the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).

View File

@@ -0,0 +1,166 @@
"""FastAPI dependencies for authentication."""
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.auth.token_service import TokenData, verify_token
from invokeai.backend.util.logging import logging
logger = logging.getLogger(__name__)
# HTTP Bearer token security scheme
security = HTTPBearer(auto_error=False)
async def get_current_user(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
) -> TokenData:
"""Get current authenticated user from Bearer token.
Note: This function accesses ApiDependencies.invoker.services.users directly,
which is the established pattern in this codebase. The ApiDependencies.invoker
is initialized in the FastAPI lifespan context before any requests are handled.
Args:
credentials: The HTTP authorization credentials containing the Bearer token
Returns:
TokenData containing user information from the token
Raises:
HTTPException: If token is missing, invalid, or expired (401 Unauthorized)
"""
if credentials is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
token_data = verify_token(token)
if token_data is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
# Verify user still exists and is active
user_service = ApiDependencies.invoker.services.users
user = user_service.get(token_data.user_id)
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account is inactive or does not exist",
headers={"WWW-Authenticate": "Bearer"},
)
return token_data
async def get_current_user_or_default(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
) -> TokenData:
"""Get current authenticated user from Bearer token, or return a default system user if not authenticated.
This dependency is useful for endpoints that should work in both single-user and multiuser modes.
When multiuser mode is disabled (default), this always returns a system user with admin privileges,
allowing unrestricted access to all operations.
When multiuser mode is enabled, authentication is required and this function validates the token,
returning authenticated user data or raising 401 Unauthorized if no valid credentials are provided.
Args:
credentials: The HTTP authorization credentials containing the Bearer token
Returns:
TokenData containing user information from the token, or system user in single-user mode
Raises:
HTTPException: 401 Unauthorized if in multiuser mode and credentials are missing, invalid, or user is inactive
"""
# Get configuration to check if multiuser is enabled
config = ApiDependencies.invoker.services.configuration
# In single-user mode (multiuser=False), always return system user with admin privileges
if not config.multiuser:
return TokenData(user_id="system", email="system@system.invokeai", is_admin=True)
# Multiuser mode is enabled - validate credentials
if credentials is None:
# In multiuser mode, authentication is required
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
token = credentials.credentials
token_data = verify_token(token)
if token_data is None:
# Invalid token in multiuser mode - reject
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token")
# Verify user still exists and is active
user_service = ApiDependencies.invoker.services.users
user = user_service.get(token_data.user_id)
if user is None or not user.is_active:
# User doesn't exist or is inactive in multiuser mode - reject
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
return token_data
async def require_admin(
current_user: Annotated[TokenData, Depends(get_current_user)],
) -> TokenData:
"""Require admin role for the current user.
Args:
current_user: The current authenticated user's token data
Returns:
The token data if user is an admin
Raises:
HTTPException: If user does not have admin privileges (403 Forbidden)
"""
if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
return current_user
async def require_admin_or_default(
current_user: Annotated[TokenData, Depends(get_current_user_or_default)],
) -> TokenData:
"""Require admin role for the current user, or return default system admin in single-user mode.
This dependency is useful for admin-only endpoints that should work in both single-user and multiuser modes.
When multiuser mode is disabled (default), this always returns a system user with admin privileges.
When multiuser mode is enabled, this validates that the authenticated user has admin privileges.
Args:
current_user: The current authenticated user's token data (or default system user)
Returns:
The token data if user is an admin (or system user in single-user mode)
Raises:
HTTPException: If user does not have admin privileges (403 Forbidden) in multiuser mode
"""
if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
return current_user
# Type aliases for convenient use in route dependencies
CurrentUser = Annotated[TokenData, Depends(get_current_user)]
CurrentUserOrDefault = Annotated[TokenData, Depends(get_current_user_or_default)]
AdminUser = Annotated[TokenData, Depends(require_admin)]
AdminUserOrDefault = Annotated[TokenData, Depends(require_admin_or_default)]

View File

@@ -5,6 +5,8 @@ from logging import Logger
import torch
from invokeai.app.services.app_settings import AppSettingsService
from invokeai.app.services.auth.token_service import set_jwt_secret
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_images.board_images_default import BoardImagesService
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
@@ -40,6 +42,7 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.users.users_default import UserService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
@@ -101,6 +104,12 @@ class ApiDependencies:
db = init_db(config=config, logger=logger, image_files=image_files)
# Initialize JWT secret from database
app_settings = AppSettingsService(db=db)
jwt_secret = app_settings.get_jwt_secret()
set_jwt_secret(jwt_secret)
logger.info("JWT secret loaded from database")
configuration = config
logger = logger
@@ -155,6 +164,7 @@ class ApiDependencies:
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
users = UserService(db=db)
services = InvocationServices(
board_image_records=board_image_records,
@@ -186,6 +196,7 @@ class ApiDependencies:
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
users=users,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -1,7 +1,9 @@
from typing import Any
from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.staticfiles import StaticFiles
from starlette.types import Scope
class NoCacheStaticFiles(StaticFiles):
@@ -12,6 +14,10 @@ class NoCacheStaticFiles(StaticFiles):
Static files include the javascript bundles, fonts, locales, and some images. Generated
images are not included, as they are served by a router.
This class also implements proper SPA (Single Page Application) routing by serving index.html
for any routes that don't match static files, enabling client-side routing to work correctly
in production builds.
"""
def __init__(self, *args: Any, **kwargs: Any):
@@ -26,3 +32,19 @@ class NoCacheStaticFiles(StaticFiles):
resp.headers.setdefault("Pragma", self.pragma)
resp.headers.setdefault("Expires", self.expires)
return resp
async def get_response(self, path: str, scope: Scope) -> Response:
"""
Override get_response to implement SPA routing.
When a file is not found and html mode is enabled, serve index.html instead of raising a 404.
This allows client-side routing to work correctly in SPAs.
"""
try:
return await super().get_response(path, scope)
except HTTPException as exc:
# If the file is not found (404) and html mode is enabled, serve index.html
# This allows client-side routing to handle the path
if exc.status_code == 404 and self.html:
return await super().get_response("index.html", scope)
raise

View File

@@ -0,0 +1,514 @@
"""Authentication endpoints."""
import secrets
import string
from datetime import timedelta
from typing import Annotated
from fastapi import APIRouter, Body, HTTPException, Path, status
from pydantic import BaseModel, Field, field_validator
from invokeai.app.api.auth_dependencies import AdminUser, CurrentUser
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.auth.token_service import TokenData, create_access_token
from invokeai.app.services.users.users_common import (
UserCreateRequest,
UserDTO,
UserUpdateRequest,
validate_email_with_special_domains,
)
auth_router = APIRouter(prefix="/v1/auth", tags=["authentication"])
# Token expiration constants (in days)
TOKEN_EXPIRATION_NORMAL = 1 # 1 day for normal login
TOKEN_EXPIRATION_REMEMBER_ME = 7 # 7 days for "remember me" login
class LoginRequest(BaseModel):
"""Request body for user login."""
email: str = Field(description="User email address")
password: str = Field(description="User password")
remember_me: bool = Field(default=False, description="Whether to extend session duration")
@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email address, allowing special-use domains."""
return validate_email_with_special_domains(v)
class LoginResponse(BaseModel):
"""Response from successful login."""
token: str = Field(description="JWT access token")
user: UserDTO = Field(description="User information")
expires_in: int = Field(description="Token expiration time in seconds")
class SetupRequest(BaseModel):
"""Request body for initial admin setup."""
email: str = Field(description="Admin email address")
display_name: str | None = Field(default=None, description="Admin display name")
password: str = Field(description="Admin password")
@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email address, allowing special-use domains."""
return validate_email_with_special_domains(v)
class SetupResponse(BaseModel):
"""Response from successful admin setup."""
success: bool = Field(description="Whether setup was successful")
user: UserDTO = Field(description="Created admin user information")
class LogoutResponse(BaseModel):
"""Response from logout."""
success: bool = Field(description="Whether logout was successful")
class SetupStatusResponse(BaseModel):
"""Response for setup status check."""
setup_required: bool = Field(description="Whether initial setup is required")
multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled")
@auth_router.get("/status", response_model=SetupStatusResponse)
async def get_setup_status() -> SetupStatusResponse:
"""Check if initial administrator setup is required.
Returns:
SetupStatusResponse indicating whether setup is needed and multiuser mode status
"""
config = ApiDependencies.invoker.services.configuration
# If multiuser is disabled, setup is never required
if not config.multiuser:
return SetupStatusResponse(setup_required=False, multiuser_enabled=False)
# In multiuser mode, check if an admin exists
user_service = ApiDependencies.invoker.services.users
setup_required = not user_service.has_admin()
return SetupStatusResponse(setup_required=setup_required, multiuser_enabled=True)
@auth_router.post("/login", response_model=LoginResponse)
async def login(
request: Annotated[LoginRequest, Body(description="Login credentials")],
) -> LoginResponse:
"""Authenticate user and return access token.
Args:
request: Login credentials (email and password)
Returns:
LoginResponse containing JWT token and user information
Raises:
HTTPException: 401 if credentials are invalid or user is inactive
HTTPException: 403 if multiuser mode is disabled
"""
config = ApiDependencies.invoker.services.configuration
# Check if multiuser is enabled
if not config.multiuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Multiuser mode is disabled. Authentication is not required in single-user mode.",
)
user_service = ApiDependencies.invoker.services.users
user = user_service.authenticate(request.email, request.password)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User account is disabled")
# Create token with appropriate expiration
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME if request.remember_me else TOKEN_EXPIRATION_NORMAL)
token_data = TokenData(
user_id=user.user_id,
email=user.email,
is_admin=user.is_admin,
)
token = create_access_token(token_data, expires_delta)
return LoginResponse(
token=token,
user=user,
expires_in=int(expires_delta.total_seconds()),
)
@auth_router.post("/logout", response_model=LogoutResponse)
async def logout(
current_user: CurrentUser,
) -> LogoutResponse:
"""Logout current user.
Currently a no-op since we use stateless JWT tokens. For token invalidation in
future implementations, consider:
- Token blacklist: Store invalidated tokens in Redis/database with expiration
- Token versioning: Add version field to user record, increment on logout
- Short-lived tokens: Use refresh token pattern with token rotation
- Session storage: Track active sessions server-side for revocation
Args:
current_user: The authenticated user (validates token)
Returns:
LogoutResponse indicating success
"""
# TODO: Implement token invalidation when server-side session management is added
# For now, this is a no-op since we use stateless JWT tokens
return LogoutResponse(success=True)
@auth_router.get("/me", response_model=UserDTO)
async def get_current_user_info(
current_user: CurrentUser,
) -> UserDTO:
"""Get current authenticated user's information.
Args:
current_user: The authenticated user's token data
Returns:
UserDTO containing user information
Raises:
HTTPException: 404 if user is not found (should not happen normally)
"""
user_service = ApiDependencies.invoker.services.users
user = user_service.get(current_user.user_id)
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return user
@auth_router.post("/setup", response_model=SetupResponse)
async def setup_admin(
request: Annotated[SetupRequest, Body(description="Admin account details")],
) -> SetupResponse:
"""Set up initial administrator account.
This endpoint can only be called once, when no admin user exists. It creates
the first admin user for the system.
Args:
request: Admin account details (email, display_name, password)
Returns:
SetupResponse containing the created admin user
Raises:
HTTPException: 400 if admin already exists or password is weak
HTTPException: 403 if multiuser mode is disabled
"""
config = ApiDependencies.invoker.services.configuration
# Check if multiuser is enabled
if not config.multiuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Multiuser mode is disabled. Admin setup is not required in single-user mode.",
)
user_service = ApiDependencies.invoker.services.users
# Check if any admin exists
if user_service.has_admin():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Administrator account already configured",
)
# Create admin user - this will validate password strength
try:
user_data = UserCreateRequest(
email=request.email,
display_name=request.display_name,
password=request.password,
is_admin=True,
)
user = user_service.create_admin(user_data)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
return SetupResponse(success=True, user=user)
# ---------------------------------------------------------------------------
# User management models
# ---------------------------------------------------------------------------
_PASSWORD_ALPHABET = string.ascii_letters + string.digits + string.punctuation
class AdminUserCreateRequest(BaseModel):
"""Request body for admin to create a new user."""
email: str = Field(description="User email address")
display_name: str | None = Field(default=None, description="Display name")
password: str = Field(description="User password")
is_admin: bool = Field(default=False, description="Whether user should have admin privileges")
@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email address, allowing special-use domains."""
return validate_email_with_special_domains(v)
class AdminUserUpdateRequest(BaseModel):
"""Request body for admin to update any user."""
display_name: str | None = Field(default=None, description="Display name")
password: str | None = Field(default=None, description="New password")
is_admin: bool | None = Field(default=None, description="Whether user should have admin privileges")
is_active: bool | None = Field(default=None, description="Whether user account should be active")
class UserProfileUpdateRequest(BaseModel):
"""Request body for a user to update their own profile."""
display_name: str | None = Field(default=None, description="New display name")
current_password: str | None = Field(default=None, description="Current password (required when changing password)")
new_password: str | None = Field(default=None, description="New password")
class GeneratePasswordResponse(BaseModel):
"""Response containing a generated password."""
password: str = Field(description="Generated strong password")
# ---------------------------------------------------------------------------
# User management endpoints
# ---------------------------------------------------------------------------
@auth_router.get("/generate-password", response_model=GeneratePasswordResponse)
async def generate_password(
current_user: CurrentUser,
) -> GeneratePasswordResponse:
"""Generate a strong random password.
Returns a cryptographically secure random password of 16 characters
containing uppercase, lowercase, digits, and punctuation.
"""
# Ensure the generated password always meets strength requirements:
# at least one uppercase, one lowercase, one digit, one special char.
while True:
password = "".join(secrets.choice(_PASSWORD_ALPHABET) for _ in range(16))
if (
any(c.isupper() for c in password)
and any(c.islower() for c in password)
and any(c.isdigit() for c in password)
):
return GeneratePasswordResponse(password=password)
@auth_router.get("/users", response_model=list[UserDTO])
async def list_users(
current_user: AdminUser,
) -> list[UserDTO]:
"""List all users. Requires admin privileges.
The internal 'system' user (created for backward compatibility) is excluded
from the results since it cannot be managed through this interface.
Returns:
List of all real users (system user excluded)
"""
user_service = ApiDependencies.invoker.services.users
return [u for u in user_service.list_users() if u.user_id != "system"]
@auth_router.post("/users", response_model=UserDTO, status_code=status.HTTP_201_CREATED)
async def create_user(
request: Annotated[AdminUserCreateRequest, Body(description="New user details")],
current_user: AdminUser,
) -> UserDTO:
"""Create a new user. Requires admin privileges.
Args:
request: New user details
Returns:
The created user
Raises:
HTTPException: 400 if email already exists or password is weak
"""
user_service = ApiDependencies.invoker.services.users
try:
user_data = UserCreateRequest(
email=request.email,
display_name=request.display_name,
password=request.password,
is_admin=request.is_admin,
)
return user_service.create(user_data)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
@auth_router.get("/users/{user_id}", response_model=UserDTO)
async def get_user(
user_id: Annotated[str, Path(description="User ID")],
current_user: AdminUser,
) -> UserDTO:
"""Get a user by ID. Requires admin privileges.
Args:
user_id: The user ID
Returns:
The user
Raises:
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
user = user_service.get(user_id)
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return user
@auth_router.patch("/users/{user_id}", response_model=UserDTO)
async def update_user(
user_id: Annotated[str, Path(description="User ID")],
request: Annotated[AdminUserUpdateRequest, Body(description="User fields to update")],
current_user: AdminUser,
) -> UserDTO:
"""Update a user. Requires admin privileges.
Args:
user_id: The user ID
request: Fields to update
Returns:
The updated user
Raises:
HTTPException: 400 if password is weak
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
try:
changes = UserUpdateRequest(
display_name=request.display_name,
password=request.password,
is_admin=request.is_admin,
is_active=request.is_active,
)
return user_service.update(user_id, changes)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
@auth_router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(
user_id: Annotated[str, Path(description="User ID")],
current_user: AdminUser,
) -> None:
"""Delete a user. Requires admin privileges.
Admins can delete any user including other admins, but cannot delete the last
remaining admin.
Args:
user_id: The user ID
Raises:
HTTPException: 400 if attempting to delete the last admin
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
user = user_service.get(user_id)
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
# Prevent deleting the last active admin
if user.is_admin and user.is_active and user_service.count_admins() <= 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete the last administrator",
)
try:
user_service.delete(user_id)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
@auth_router.patch("/me", response_model=UserDTO)
async def update_current_user(
request: Annotated[UserProfileUpdateRequest, Body(description="Profile fields to update")],
current_user: CurrentUser,
) -> UserDTO:
"""Update the current user's own profile.
To change the password, both ``current_password`` and ``new_password`` must
be provided. The current password is verified before the change is applied.
Args:
request: Profile fields to update
current_user: The authenticated user
Returns:
The updated user
Raises:
HTTPException: 400 if current password is incorrect or new password is weak
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
# Verify current password when attempting a password change
if request.new_password is not None:
if not request.current_password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is required to set a new password",
)
# Re-authenticate to verify the current password
user = user_service.get(current_user.user_id)
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
authenticated = user_service.authenticate(user.email, request.current_password)
if authenticated is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect",
)
try:
changes = UserUpdateRequest(
display_name=request.display_name,
password=request.new_password,
)
return user_service.update(current_user.user_id, changes)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e

View File

@@ -4,6 +4,7 @@ from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.boards.boards_common import BoardDTO
@@ -32,11 +33,12 @@ class DeleteBoardResult(BaseModel):
response_model=BoardDTO,
)
async def create_board(
current_user: CurrentUserOrDefault,
board_name: str = Query(description="The name of the board to create", max_length=300),
) -> BoardDTO:
"""Creates a board"""
"""Creates a board for the current user"""
try:
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
result = ApiDependencies.invoker.services.boards.create(board_name=board_name, user_id=current_user.user_id)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to create board")
@@ -44,16 +46,21 @@ async def create_board(
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
async def get_board(
current_user: CurrentUserOrDefault,
board_id: str = Path(description="The id of board to get"),
) -> BoardDTO:
"""Gets a board"""
"""Gets a board (user must have access to it)"""
try:
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
return result
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if not current_user.is_admin and result.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to access this board")
return result
@boards_router.patch(
"/{board_id}",
@@ -67,10 +74,19 @@ async def get_board(
response_model=BoardDTO,
)
async def update_board(
current_user: CurrentUserOrDefault,
board_id: str = Path(description="The id of board to update"),
changes: BoardChanges = Body(description="The changes to apply to the board"),
) -> BoardDTO:
"""Updates a board"""
"""Updates a board (user must have access to it)"""
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if not current_user.is_admin and board.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this board")
try:
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
return result
@@ -80,10 +96,19 @@ async def update_board(
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
async def delete_board(
current_user: CurrentUserOrDefault,
board_id: str = Path(description="The id of board to delete"),
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
) -> DeleteBoardResult:
"""Deletes a board"""
"""Deletes a board (user must have access to it)"""
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if not current_user.is_admin and board.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to delete this board")
try:
if include_images is True:
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
@@ -120,6 +145,7 @@ async def delete_board(
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
)
async def list_boards(
current_user: CurrentUserOrDefault,
order_by: BoardRecordOrderBy = Query(default=BoardRecordOrderBy.CreatedAt, description="The attribute to order by"),
direction: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The direction to order by"),
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
@@ -127,11 +153,15 @@ async def list_boards(
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
"""Gets a list of boards for the current user, including shared boards. Admin users see all boards."""
if all:
return ApiDependencies.invoker.services.boards.get_all(order_by, direction, include_archived)
return ApiDependencies.invoker.services.boards.get_all(
current_user.user_id, current_user.is_admin, order_by, direction, include_archived
)
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(order_by, direction, offset, limit, include_archived)
return ApiDependencies.invoker.services.boards.get_many(
current_user.user_id, current_user.is_admin, order_by, direction, offset, limit, include_archived
)
else:
raise HTTPException(
status_code=400,
@@ -145,12 +175,22 @@ async def list_boards(
response_model=list[str],
)
async def list_all_board_image_names(
current_user: CurrentUserOrDefault,
board_id: str = Path(description="The id of the board or 'none' for uncategorized images"),
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
) -> list[str]:
"""Gets a list of images for a board"""
if board_id != "none":
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if not current_user.is_admin and board.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to access this board")
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id,
categories,

View File

@@ -1,6 +1,7 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.util.logging import logging
@@ -13,15 +14,16 @@ client_state_router = APIRouter(prefix="/v1/client_state", tags=["client_state"]
response_model=str | None,
)
async def get_client_state_by_key(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
key: str = Query(..., description="Key to get"),
) -> str | None:
"""Gets the client state"""
"""Gets the client state for the current user (or system user if not authenticated)"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(current_user.user_id, key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
raise HTTPException(status_code=500, detail="Error getting client state")
@client_state_router.post(
@@ -30,13 +32,14 @@ async def get_client_state_by_key(
response_model=str,
)
async def set_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
key: str = Query(..., description="Key to set"),
value: str = Body(..., description="Stringified value to set"),
) -> str:
"""Sets the client state"""
"""Sets the client state for the current user (or system user if not authenticated)"""
try:
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(current_user.user_id, key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@@ -48,11 +51,12 @@ async def set_client_state(
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
) -> None:
"""Deletes the client state"""
"""Deletes the client state for the current user (or system user if not authenticated)"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
ApiDependencies.invoker.services.client_state_persistence.delete(current_user.user_id)
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -9,6 +9,7 @@ from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field, model_validator
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
from invokeai.app.invocations.fields import MetadataField
@@ -61,6 +62,7 @@ class ResizeToDimensions(BaseModel):
response_model=ImageDTO,
)
async def upload_image(
current_user: CurrentUserOrDefault,
file: UploadFile,
request: Request,
response: Response,
@@ -80,7 +82,7 @@ async def upload_image(
embed=True,
),
) -> ImageDTO:
"""Uploads an image"""
"""Uploads an image for the current user"""
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -133,6 +135,7 @@ async def upload_image(
workflow=extracted_metadata.invokeai_workflow,
graph=extracted_metadata.invokeai_graph,
is_intermediate=is_intermediate,
user_id=current_user.user_id,
)
response.status_code = 201
@@ -373,6 +376,7 @@ async def get_image_urls(
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_image_dtos(
current_user: CurrentUserOrDefault,
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
@@ -386,10 +390,19 @@ async def list_image_dtos(
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs"""
"""Gets a list of image DTOs for the current user"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
offset,
limit,
starred_first,
order_dir,
image_origin,
categories,
is_intermediate,
board_id,
search_term,
current_user.user_id,
)
return image_dtos
@@ -567,6 +580,7 @@ async def get_bulk_download_item(
@images_router.get("/names", operation_id="get_image_names")
async def get_image_names(
current_user: CurrentUserOrDefault,
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
@@ -589,6 +603,8 @@ async def get_image_names(
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
user_id=current_user.user_id,
is_admin=current_user.is_admin,
)
return result
except Exception:

View File

@@ -19,6 +19,7 @@ from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
@@ -229,6 +230,7 @@ async def get_model_record(
)
async def reidentify_model(
key: Annotated[str, Path(description="Key of the model to reidentify.")],
current_admin: AdminUserOrDefault,
) -> AnyModelConfig:
"""Attempt to reidentify a model by re-probing its weights file."""
try:
@@ -244,11 +246,13 @@ async def reidentify_model(
raise InvalidModelException("Unable to identify model format")
# Retain user-editable fields from the original config
result.config.path = config.path
result.config.key = config.key
result.config.name = config.name
result.config.description = config.description
result.config.cover_image = config.cover_image
result.config.trigger_phrases = config.trigger_phrases
if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"):
result.config.trigger_phrases = config.trigger_phrases
result.config.source = config.source
result.config.source_type = config.source_type
@@ -364,6 +368,7 @@ async def get_hugging_face_models(
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
current_admin: AdminUserOrDefault,
) -> AnyModelConfig:
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
@@ -426,6 +431,7 @@ async def get_model_image(
async def update_model_image(
key: Annotated[str, Path(description="Unique key of model")],
image: UploadFile,
current_admin: AdminUserOrDefault,
) -> None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -459,6 +465,7 @@ async def update_model_image(
status_code=204,
)
async def delete_model(
current_admin: AdminUserOrDefault,
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""
@@ -501,6 +508,7 @@ class BulkDeleteModelsResponse(BaseModel):
status_code=200,
)
async def bulk_delete_models(
current_admin: AdminUserOrDefault,
request: BulkDeleteModelsRequest = Body(description="List of model keys to delete"),
) -> BulkDeleteModelsResponse:
"""
@@ -542,6 +550,7 @@ async def bulk_delete_models(
status_code=204,
)
async def delete_model_image(
current_admin: AdminUserOrDefault,
key: str = Path(description="Unique key of model image to remove from model_images directory."),
) -> None:
logger = ApiDependencies.invoker.services.logger
@@ -567,6 +576,7 @@ async def delete_model_image(
status_code=201,
)
async def install_model(
current_admin: AdminUserOrDefault,
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
@@ -637,6 +647,7 @@ async def install_model(
response_class=HTMLResponse,
)
async def install_hugging_face_model(
current_admin: AdminUserOrDefault,
source: str = Query(description="HuggingFace repo_id to install"),
) -> HTMLResponse:
"""Install a Hugging Face model using a string identifier."""
@@ -809,7 +820,10 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
},
status_code=201,
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
async def cancel_model_install_job(
current_admin: AdminUserOrDefault,
id: int = Path(description="Model install job ID"),
) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -910,7 +924,7 @@ async def restart_model_install_file(
400: {"description": "Bad request"},
},
)
async def prune_model_install_jobs() -> Response:
async def prune_model_install_jobs(current_admin: AdminUserOrDefault) -> Response:
"""Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
return Response(status_code=204)
@@ -930,6 +944,7 @@ async def prune_model_install_jobs() -> Response:
},
)
async def convert_model(
current_admin: AdminUserOrDefault,
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
) -> AnyModelConfig:
"""
@@ -1111,7 +1126,7 @@ async def get_stats() -> Optional[CacheStats]:
operation_id="empty_model_cache",
status_code=200,
)
async def empty_model_cache() -> None:
async def empty_model_cache(current_admin: AdminUserOrDefault) -> None:
"""Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped."""
# Request 1000GB of room in order to force the cache to drop all models.
ApiDependencies.invoker.services.logger.info("Emptying model cache.")
@@ -1128,11 +1143,11 @@ class HFTokenHelper:
@classmethod
def get_status(cls) -> HFTokenStatus:
try:
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
# Valid token!
return HFTokenStatus.VALID
# No token set
return HFTokenStatus.INVALID
token = huggingface_hub.get_token()
if not token:
return HFTokenStatus.INVALID
huggingface_hub.whoami(token=token)
return HFTokenStatus.VALID
except Exception:
return HFTokenStatus.UNKNOWN
@@ -1161,6 +1176,7 @@ async def get_hf_login_status() -> HFTokenStatus:
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
async def do_hf_login(
current_admin: AdminUserOrDefault,
token: str = Body(description="Hugging Face token to use for login", embed=True),
) -> HFTokenStatus:
HFTokenHelper.set_token(token)
@@ -1173,7 +1189,7 @@ async def do_hf_login(
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
async def reset_hf_token() -> HFTokenStatus:
async def reset_hf_token(current_admin: AdminUserOrDefault) -> HFTokenStatus:
return HFTokenHelper.reset_token()

View File

@@ -4,6 +4,7 @@ from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel
from invokeai.app.api.auth_dependencies import AdminUserOrDefault, CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
@@ -24,6 +25,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemNotFoundError,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
@@ -36,6 +38,40 @@ class SessionQueueAndProcessorStatus(BaseModel):
processor: SessionProcessorStatus
def sanitize_queue_item_for_user(
queue_item: SessionQueueItem, current_user_id: str, is_admin: bool
) -> SessionQueueItem:
"""Sanitize queue item for non-admin users viewing other users' items.
For non-admin users viewing queue items belonging to other users,
the field_values, session graph, and workflow should be hidden/cleared to protect privacy.
Args:
queue_item: The queue item to sanitize
current_user_id: The ID of the current user viewing the item
is_admin: Whether the current user is an admin
Returns:
The sanitized queue item (sensitive fields cleared if necessary)
"""
# Admins and item owners can see everything
if is_admin or queue_item.user_id == current_user_id:
return queue_item
# For non-admins viewing other users' items, clear sensitive fields
# Create a shallow copy to avoid mutating the original
sanitized_item = queue_item.model_copy(deep=False)
sanitized_item.field_values = None
sanitized_item.workflow = None
# Clear the session graph by replacing it with an empty graph execution state
# This prevents information leakage through the generation graph
sanitized_item.session = GraphExecutionState(
id=queue_item.session.id,
graph=Graph(),
)
return sanitized_item
@session_queue_router.post(
"/{queue_id}/enqueue_batch",
operation_id="enqueue_batch",
@@ -44,14 +80,15 @@ class SessionQueueAndProcessorStatus(BaseModel):
},
)
async def enqueue_batch(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
batch: Batch = Body(description="Batch to process"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
"""Processes a batch and enqueues the output graphs for execution for the current user."""
try:
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
queue_id=queue_id, batch=batch, prepend=prepend, user_id=current_user.user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
@@ -65,15 +102,18 @@ async def enqueue_batch(
},
)
async def list_all_queue_items(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
items = ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
# Sanitize items for non-admin users
return [sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) for item in items]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
@@ -102,6 +142,7 @@ async def get_queue_item_ids(
responses={200: {"model": list[SessionQueueItem]}},
)
async def get_queue_items_by_item_ids(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
item_ids: list[int] = Body(
embed=True, description="Object containing list of queue item ids to fetch queue items for"
@@ -118,7 +159,9 @@ async def get_queue_items_by_item_ids(
queue_item = session_queue_service.get_queue_item(item_id=item_id)
if queue_item.queue_id != queue_id: # Auth protection for items from other queues
continue
queue_items.append(queue_item)
# Sanitize item for non-admin users
sanitized_item = sanitize_queue_item_for_user(queue_item, current_user.user_id, current_user.is_admin)
queue_items.append(sanitized_item)
except Exception:
# Skip missing queue items - they may have been deleted between item id fetch and queue item fetch
continue
@@ -134,9 +177,10 @@ async def get_queue_items_by_item_ids(
responses={200: {"model": SessionProcessorStatus}},
)
async def resume(
current_user: AdminUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
"""Resumes session processor. Admin only."""
try:
return ApiDependencies.invoker.services.session_processor.resume()
except Exception as e:
@@ -148,10 +192,11 @@ async def resume(
operation_id="pause",
responses={200: {"model": SessionProcessorStatus}},
)
async def Pause(
async def pause(
current_user: AdminUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
"""Pauses session processor. Admin only."""
try:
return ApiDependencies.invoker.services.session_processor.pause()
except Exception as e:
@@ -164,11 +209,16 @@ async def Pause(
responses={200: {"model": CancelAllExceptCurrentResult}},
)
async def cancel_all_except_current(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
"""Immediately cancels all queue items except in-processing items. Non-admin users can only cancel their own items."""
try:
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
# Admin users can cancel all items, non-admin users can only cancel their own
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(
queue_id=queue_id, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
@@ -179,11 +229,16 @@ async def cancel_all_except_current(
responses={200: {"model": DeleteAllExceptCurrentResult}},
)
async def delete_all_except_current(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> DeleteAllExceptCurrentResult:
"""Immediately deletes all queue items except in-processing items"""
"""Immediately deletes all queue items except in-processing items. Non-admin users can only delete their own items."""
try:
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
# Admin users can delete all items, non-admin users can only delete their own
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(
queue_id=queue_id, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
@@ -194,13 +249,16 @@ async def delete_all_except_current(
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_batch_ids(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
"""Immediately cancels all queue items from the given batch ids. Non-admin users can only cancel their own items."""
try:
# Admin users can cancel all items, non-admin users can only cancel their own
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
queue_id=queue_id, batch_ids=batch_ids
queue_id=queue_id, batch_ids=batch_ids, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
@@ -212,13 +270,16 @@ async def cancel_by_batch_ids(
responses={200: {"model": CancelByDestinationResult}},
)
async def cancel_by_destination(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
destination: str = Query(description="The destination to cancel all queue items for"),
) -> CancelByDestinationResult:
"""Immediately cancels all queue items with the given origin"""
"""Immediately cancels all queue items with the given destination. Non-admin users can only cancel their own items."""
try:
# Admin users can cancel all items, non-admin users can only cancel their own
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
queue_id=queue_id, destination=destination, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
@@ -230,12 +291,28 @@ async def cancel_by_destination(
responses={200: {"model": RetryItemsResult}},
)
async def retry_items_by_id(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
"""Retries the given queue items. Users can only retry their own items unless they are an admin."""
try:
# Check authorization: user must own all items or be an admin
if not current_user.is_admin:
for item_id in item_ids:
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
if queue_item.user_id != current_user.user_id:
raise HTTPException(
status_code=403, detail=f"You do not have permission to retry queue item {item_id}"
)
except SessionQueueItemNotFoundError:
# Skip items that don't exist - they will be handled by retry_items_by_id
continue
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
@@ -248,15 +325,25 @@ async def retry_items_by_id(
},
)
async def clear(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
"""Clears the queue entirely. Admin users clear all items; non-admin users only clear their own items. If there's a currently-executing item, users can only cancel it if they own it or are an admin."""
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
# Check authorization for canceling the current item
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
raise HTTPException(
status_code=403, detail="You do not have permission to cancel the currently executing queue item"
)
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
# Admin users can clear all items, non-admin users can only clear their own
user_id = None if current_user.is_admin else current_user.user_id
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id, user_id=user_id)
return clear_result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
@@ -269,11 +356,14 @@ async def clear(
},
)
async def prune(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
"""Prunes all completed or errored queue items. Non-admin users can only prune their own items."""
try:
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
# Admin users can prune all items, non-admin users can only prune their own
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.prune(queue_id, user_id=user_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
@@ -320,11 +410,12 @@ async def get_next_queue_item(
},
)
async def get_queue_status(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
@@ -358,6 +449,7 @@ async def get_batch_status(
response_model_exclude_none=True,
)
async def get_queue_item(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
@@ -366,7 +458,8 @@ async def get_queue_item(
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id=item_id)
if queue_item.queue_id != queue_id:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
return queue_item
# Sanitize item for non-admin users
return sanitize_queue_item_for_user(queue_item, current_user.user_id, current_user.is_admin)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
@@ -378,12 +471,24 @@ async def get_queue_item(
operation_id="delete_queue_item",
)
async def delete_queue_item(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to delete"),
) -> None:
"""Deletes a queue item"""
"""Deletes a queue item. Users can only delete their own items unless they are an admin."""
try:
# Get the queue item to check ownership
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
# Check authorization: user must own the item or be an admin
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="You do not have permission to delete this queue item")
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
@@ -396,14 +501,24 @@ async def delete_queue_item(
},
)
async def cancel_queue_item(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
"""Cancels a queue item. Users can only cancel their own items unless they are an admin."""
try:
# Get the queue item to check ownership
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
# Check authorization: user must own the item or be an admin
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="You do not have permission to cancel this queue item")
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
@@ -432,13 +547,16 @@ async def counts_by_destination(
responses={200: {"model": DeleteByDestinationResult}},
)
async def delete_by_destination(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to query"),
destination: str = Path(description="The destination to query"),
) -> DeleteByDestinationResult:
"""Deletes all items with the given destination"""
"""Deletes all items with the given destination. Non-admin users can only delete their own items."""
try:
# Admin users can delete all items, non-admin users can only delete their own
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
queue_id=queue_id, destination=destination, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")

View File

@@ -6,6 +6,7 @@ from fastapi import FastAPI
from pydantic import BaseModel
from socketio import ASGIApp, AsyncServer
from invokeai.app.services.auth.token_service import verify_token
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
@@ -38,6 +39,9 @@ from invokeai.app.services.events.events_common import (
RecallParametersUpdatedEvent,
register_events,
)
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
class QueueSubscriptionEvent(BaseModel):
@@ -96,6 +100,13 @@ class SocketIO:
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self._app)
# Track user information for each socket connection
self._socket_users: dict[str, dict[str, Any]] = {}
# Set up authentication middleware
self._sio.on("connect", handler=self._handle_connect)
self._sio.on("disconnect", handler=self._handle_disconnect)
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
@@ -105,8 +116,83 @@ class SocketIO:
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> bool:
"""Handle socket connection and authenticate the user.
Returns True to accept the connection, False to reject it.
Stores user_id in the internal socket users dict for later use.
"""
# Extract token from auth data or headers
token = None
if auth and isinstance(auth, dict):
token = auth.get("token")
if not token and environ:
# Try to get token from headers
headers = environ.get("HTTP_AUTHORIZATION", "")
if headers.startswith("Bearer "):
token = headers[7:]
# Verify the token
if token:
token_data = verify_token(token)
if token_data:
# Store user_id and is_admin in socket users dict
self._socket_users[sid] = {
"user_id": token_data.user_id,
"is_admin": token_data.is_admin,
}
logger.info(
f"Socket {sid} connected with user_id: {token_data.user_id}, is_admin: {token_data.is_admin}"
)
return True
# If no valid token, store system user for backward compatibility
self._socket_users[sid] = {
"user_id": "system",
"is_admin": False,
}
logger.debug(f"Socket {sid} connected as system user (no valid token)")
return True
async def _handle_disconnect(self, sid: str) -> None:
"""Handle socket disconnection and cleanup user info."""
if sid in self._socket_users:
del self._socket_users[sid]
logger.debug(f"Socket {sid} disconnected and cleaned up")
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
"""Handle queue subscription and add socket to both queue and user-specific rooms."""
queue_id = QueueSubscriptionEvent(**data).queue_id
# Check if we have user info for this socket
if sid not in self._socket_users:
logger.warning(
f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event"
)
# Store as system user temporarily - real auth should happen in connect
self._socket_users[sid] = {
"user_id": "system",
"is_admin": False,
}
user_id = self._socket_users[sid]["user_id"]
is_admin = self._socket_users[sid]["is_admin"]
# Add socket to the queue room
await self._sio.enter_room(sid, queue_id)
# Also add socket to a user-specific room for event filtering
user_room = f"user:{user_id}"
await self._sio.enter_room(sid, user_room)
# If admin, also add to admin room to receive all events
if is_admin:
await self._sio.enter_room(sid, "admin")
logger.debug(
f"Socket {sid} (user_id: {user_id}, is_admin: {is_admin}) subscribed to queue {queue_id} and user room {user_room}"
)
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
@@ -118,7 +204,62 @@ class SocketIO:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
"""Handle queue events with user isolation.
Invocation events (progress, started, complete) are private - only emit to owner and admins.
Queue item status events are public - emit to all users (field values hidden via API).
Other queue events emit to all subscribers.
IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase
inherits from QueueItemEventBase. The order of isinstance checks matters!
"""
try:
event_name, event_data = event
# Import here to avoid circular dependency
from invokeai.app.services.events.events_common import InvocationEventBase, QueueItemEventBase
# Check InvocationEventBase FIRST (before QueueItemEventBase) since it's a subclass
# Invocation events (progress, started, complete, error) are private to owner + admins
if isinstance(event_data, InvocationEventBase) and hasattr(event_data, "user_id"):
user_room = f"user:{event_data.user_id}"
# Emit to the user's room
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
# Also emit to admin room so admins can see all events, but strip image preview data
# from InvocationProgressEvent to prevent admins from seeing other users' image content
if isinstance(event_data, InvocationProgressEvent):
admin_event_data = event_data.model_copy(update={"image": None})
await self._sio.emit(event=event_name, data=admin_event_data.model_dump(mode="json"), room="admin")
else:
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room")
# Queue item status events are visible to all users (field values masked via API)
# This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above)
elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"):
# Emit to all subscribers in the queue
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
)
logger.info(
f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}"
)
else:
# For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
)
logger.info(
f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}"
)
except Exception as e:
# Log any unhandled exceptions in event handling to prevent silent failures
logger.error(f"Error handling queue event {event[0]}: {e}", exc_info=True)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))

View File

@@ -17,6 +17,7 @@ from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.api.routers import (
app_info,
auth,
board_images,
boards,
client_state,
@@ -122,6 +123,8 @@ app.add_middleware(GZipMiddleware, minimum_size=1000)
# Include all routers
# Authentication router should be first so it's registered before protected routes
app.include_router(auth.auth_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")

View File

@@ -45,7 +45,7 @@ KLEIN_MAX_SEQ_LEN = 512
title="Prompt - Flux2 Klein",
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
category="conditioning",
version="1.1.0",
version="1.1.1",
classification=Classification.Prototype,
)
class Flux2KleinTextEncoderInvocation(BaseInvocation):
@@ -73,140 +73,111 @@ class Flux2KleinTextEncoderInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
qwen3_embeds, pooled_embeds = self._encode_prompt(context)
# Open the exitstack here to lock models for the duration of the node
with ExitStack() as exit_stack:
# Pass the locked stack down to the helper function
qwen3_embeds, pooled_embeds = self._encode_prompt(context, exit_stack)
# Use FLUXConditioningInfo for compatibility with existing Flux denoiser
# t5_embeds -> qwen3 stacked embeddings
# clip_embeds -> pooled qwen3 embedding
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput(
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
# The models are still locked while we save the data
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput(
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode prompt using Qwen3 text encoder with Klein-style layer extraction.
This matches the diffusers Flux2KleinPipeline._get_qwen3_prompt_embeds() exactly.
Returns:
Tuple of (stacked_embeddings, pooled_embedding):
- stacked_embeddings: Hidden states from layers (9, 18, 27) stacked together.
Shape: (1, seq_len, hidden_size * 3)
- pooled_embedding: Pooled representation for global conditioning.
Shape: (1, hidden_size)
"""
def _encode_prompt(self, context: InvocationContext, exit_stack: ExitStack) -> Tuple[torch.Tensor, torch.Tensor]:
prompt = self.prompt
# Reordered loading to prevent the annoying cache drop issue
# This prevents it from being evicted while we look up the tokenizer
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
# Now it is safe to load and lock the tokenizer
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
with ExitStack() as exit_stack:
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
device = text_encoder.device
# you can now define the device, as the text_encoder exists here
device = text_encoder.device
# Apply LoRA models
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_T5_PREFIX,
dtype=lora_dtype,
cached_weights=cached_weights,
)
)
# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_T5_PREFIX, # Reuse T5 prefix for Qwen3 LoRAs
dtype=lora_dtype,
cached_weights=cached_weights,
)
context.util.signal_progress("Running Qwen3 text encoder (Klein)")
if not isinstance(text_encoder, PreTrainedModel):
raise TypeError(
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
"The Qwen3 encoder model may be corrupted or incompatible."
)
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise TypeError(
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
"The Qwen3 tokenizer may be corrupted or incompatible."
)
context.util.signal_progress("Running Qwen3 text encoder (Klein)")
messages = [{"role": "user", "content": prompt}]
if not isinstance(text_encoder, PreTrainedModel):
raise TypeError(
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
"The Qwen3 encoder model may be corrupted or incompatible."
)
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise TypeError(
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
"The Qwen3 tokenizer may be corrupted or incompatible."
)
text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
# Format messages exactly like diffusers Flux2KleinPipeline:
# - Only user message, NO system message
# - add_generation_prompt=True (adds assistant prefix)
# - enable_thinking=False
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_seq_len,
)
# Step 1: Apply chat template to get formatted text (tokenize=False)
text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
messages,
tokenize=False,
add_generation_prompt=True, # Adds assistant prefix like diffusers
enable_thinking=False, # Disable thinking mode
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Forward pass through the model
outputs = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
raise RuntimeError(
"Text encoder did not return hidden_states. "
"Ensure output_hidden_states=True is supported by this model."
)
num_hidden_layers = len(outputs.hidden_states)
# Step 2: Tokenize the formatted text
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_seq_len,
)
hidden_states_list = []
for layer_idx in KLEIN_EXTRACTION_LAYERS:
if layer_idx >= num_hidden_layers:
layer_idx = num_hidden_layers - 1
hidden_states_list.append(outputs.hidden_states[layer_idx])
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
out = torch.stack(hidden_states_list, dim=1)
out = out.to(dtype=text_encoder.dtype, device=device)
# Move to device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
# Forward pass through the model - matching diffusers exactly
# Explicitly move inputs to the same device as the text_encoder
outputs = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Validate hidden_states output
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
raise RuntimeError(
"Text encoder did not return hidden_states. "
"Ensure output_hidden_states=True is supported by this model."
)
num_hidden_layers = len(outputs.hidden_states)
# Extract and stack hidden states - EXACTLY like diffusers:
# out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
# prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
hidden_states_list = []
for layer_idx in KLEIN_EXTRACTION_LAYERS:
if layer_idx >= num_hidden_layers:
layer_idx = num_hidden_layers - 1
hidden_states_list.append(outputs.hidden_states[layer_idx])
# Stack along dim=1, then permute and reshape - exactly like diffusers
out = torch.stack(hidden_states_list, dim=1)
out = out.to(dtype=text_encoder.dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
# Create pooled embedding for global conditioning
# Use mean pooling over the sequence (excluding padding)
# This serves a similar role to CLIP's pooled output in standard FLUX
last_hidden_state = outputs.hidden_states[-1] # Use last layer for pooling
# Expand mask to match hidden state dimensions
expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
pooled_embeds = sum_embeds / num_tokens
last_hidden_state = outputs.hidden_states[-1]
expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
pooled_embeds = sum_embeds / num_tokens
return prompt_embeds, pooled_embeds

View File

@@ -9,6 +9,11 @@ def get_app():
def run_app() -> None:
"""The main entrypoint for the app."""
import asyncio
import sys
import threading
import traceback
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
# Parse the CLI arguments before doing anything else, which ensures CLI args correctly override settings from other
@@ -100,4 +105,41 @@ def run_app() -> None:
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)
loop.run_until_complete(server.serve())
try:
loop.run_until_complete(server.serve())
except KeyboardInterrupt:
logger.info("InvokeAI shutting down...")
# Gracefully shut down services (e.g. model download and install managers) so that any
# active work is completed or cleanly cancelled before the process exits.
from invokeai.app.api.dependencies import ApiDependencies
ApiDependencies.shutdown()
# Cancel any pending asyncio tasks (e.g. socket.io ping tasks) so that loop.close() does
# not emit "Task was destroyed but it is pending!" warnings for each one.
pending = [t for t in asyncio.all_tasks(loop) if not t.done()]
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
# Shut down the asyncio default thread executor. asyncio.to_thread() (used e.g. in the
# session queue for SQLite operations during generation) creates non-daemon threads via the
# event loop's default ThreadPoolExecutor. Without this call those threads remain alive and
# cause threading._shutdown() to hang indefinitely after the process's main code finishes.
loop.run_until_complete(loop.shutdown_default_executor())
loop.close()
# After graceful shutdown, log any non-daemon threads that are still alive. These are the
# threads that will cause Python's threading._shutdown() to block, preventing the process
# from exiting cleanly. This helps identify threads that need to be fixed or joined.
frames = sys._current_frames()
for thread in threading.enumerate():
if thread.daemon or thread is threading.main_thread():
continue
frame = frames.get(thread.ident)
stack = "".join(traceback.format_stack(frame)) if frame else "(no frame available)"
logger.warning(
f"Non-daemon thread still alive after shutdown: {thread.name!r} "
f"(ident={thread.ident})\nStack trace:\n{stack}"
)

View File

@@ -0,0 +1,5 @@
"""App settings service exports."""
from invokeai.app.services.app_settings.app_settings_service import AppSettingsService
__all__ = ["AppSettingsService"]

View File

@@ -0,0 +1,74 @@
"""Service for managing application-level settings stored in the database."""
from typing import Optional
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class AppSettingsService:
"""Service for accessing application-level settings from the database.
This service provides a simple key-value store for application-level configuration
that needs to be persisted across restarts, such as JWT secrets.
"""
def __init__(self, db: SqliteDatabase) -> None:
"""Initialize the app settings service.
Args:
db: The SQLite database instance
"""
self._db = db
def get(self, key: str) -> Optional[str]:
"""Get a setting value by key.
Args:
key: The setting key
Returns:
The setting value if found, None otherwise
"""
try:
with self._db.transaction() as cursor:
cursor.execute("SELECT value FROM app_settings WHERE key = ?;", (key,))
row = cursor.fetchone()
return row[0] if row else None
except Exception:
return None
def set(self, key: str, value: str) -> None:
"""Set a setting value.
Args:
key: The setting key
value: The setting value
"""
with self._db.transaction() as cursor:
cursor.execute(
"""
INSERT INTO app_settings (key, value)
VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET
value = excluded.value,
updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW');
""",
(key, value),
)
def get_jwt_secret(self) -> str:
"""Get the JWT secret key from the database.
Returns:
The JWT secret key
Raises:
RuntimeError: If the JWT secret is not found in the database
"""
secret = self.get("jwt_secret")
if secret is None:
raise RuntimeError(
"JWT secret not found in database. This should have been created during database migration. "
"Please ensure database migrations have been run successfully."
)
return secret

View File

@@ -0,0 +1 @@
"""Authentication service module."""

View File

@@ -0,0 +1,86 @@
"""Password hashing and validation utilities."""
from typing import cast
from passlib.context import CryptContext
# Configure bcrypt context - set truncate_error=False to allow passwords >72 bytes
# without raising an error. They will be automatically truncated by bcrypt to 72 bytes.
pwd_context = CryptContext(
schemes=["bcrypt"],
deprecated="auto",
bcrypt__truncate_error=False,
)
def hash_password(password: str) -> str:
"""Hash a password using bcrypt.
bcrypt has a maximum password length of 72 bytes. Longer passwords
are automatically truncated to comply with this limit.
Args:
password: The plain text password to hash
Returns:
The hashed password
"""
# bcrypt has a 72 byte limit - encode and truncate if necessary
password_bytes = password.encode("utf-8")
if len(password_bytes) > 72:
# Truncate to 72 bytes and decode back, dropping incomplete UTF-8 sequences
password = password_bytes[:72].decode("utf-8", errors="ignore")
return cast(str, pwd_context.hash(password))
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash.
bcrypt has a maximum password length of 72 bytes. Longer passwords
are automatically truncated to match hash_password behavior.
Args:
plain_password: The plain text password to verify
hashed_password: The hashed password to verify against
Returns:
True if the password matches the hash, False otherwise
"""
try:
# bcrypt has a 72 byte limit - encode and truncate if necessary to match hash_password
password_bytes = plain_password.encode("utf-8")
if len(password_bytes) > 72:
# Truncate to 72 bytes and decode back, dropping incomplete UTF-8 sequences
plain_password = password_bytes[:72].decode("utf-8", errors="ignore")
return cast(bool, pwd_context.verify(plain_password, hashed_password))
except Exception:
# Invalid hash format or other error - return False
return False
def validate_password_strength(password: str) -> tuple[bool, str]:
"""Validate password meets minimum security requirements.
Password requirements:
- At least 8 characters long
- Contains at least one uppercase letter
- Contains at least one lowercase letter
- Contains at least one digit
Args:
password: The password to validate
Returns:
A tuple of (is_valid, error_message). If valid, error_message is empty.
"""
if len(password) < 8:
return False, "Password must be at least 8 characters long"
has_upper = any(c.isupper() for c in password)
has_lower = any(c.islower() for c in password)
has_digit = any(c.isdigit() for c in password)
if not (has_upper and has_lower and has_digit):
return False, "Password must contain uppercase, lowercase, and numbers"
return True, ""

View File

@@ -0,0 +1,105 @@
"""JWT token generation and validation."""
from datetime import datetime, timedelta, timezone
from typing import cast
from jose import JWTError, jwt
from pydantic import BaseModel
ALGORITHM = "HS256"
DEFAULT_EXPIRATION_HOURS = 24
# Module-level variable to store the JWT secret. This is set during application initialization
# by calling set_jwt_secret(). The secret is loaded from the database where it is stored
# securely after being generated during database migration.
_jwt_secret: str | None = None
class TokenData(BaseModel):
"""Data stored in JWT token."""
user_id: str
email: str
is_admin: bool
def set_jwt_secret(secret: str) -> None:
"""Set the JWT secret key for token signing and verification.
This should be called once during application initialization with the secret
loaded from the database.
Args:
secret: The JWT secret key
"""
global _jwt_secret
_jwt_secret = secret
def get_jwt_secret() -> str:
"""Get the JWT secret key.
Returns:
The JWT secret key
Raises:
RuntimeError: If the secret has not been initialized
"""
if _jwt_secret is None:
raise RuntimeError("JWT secret has not been initialized. Call set_jwt_secret() during application startup.")
return _jwt_secret
def create_access_token(data: TokenData, expires_delta: timedelta | None = None) -> str:
"""Create a JWT access token.
Args:
data: The token data to encode
expires_delta: Optional expiration time delta. Defaults to 24 hours.
Returns:
The encoded JWT token
"""
to_encode = data.model_dump()
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(hours=DEFAULT_EXPIRATION_HOURS))
to_encode.update({"exp": expire})
return cast(str, jwt.encode(to_encode, get_jwt_secret(), algorithm=ALGORITHM))
def verify_token(token: str) -> TokenData | None:
"""Verify and decode a JWT token.
Args:
token: The JWT token to verify
Returns:
TokenData if valid, None if invalid or expired
"""
try:
# python-jose 3.5.0 has a bug where exp verification doesn't work properly
# We need to manually check expiration, but MUST verify signature first
# to prevent accepting tokens with valid payloads but invalid signatures
# First, verify the signature - this will raise JWTError if signature is invalid
# Note: python-jose won't reject expired tokens here due to the bug
payload = jwt.decode(
token,
get_jwt_secret(),
algorithms=[ALGORITHM],
)
# Now manually check expiration (because python-jose 3.5.0 doesn't do this properly)
if "exp" in payload:
exp_timestamp = payload["exp"]
current_timestamp = datetime.now(timezone.utc).timestamp()
if current_timestamp >= exp_timestamp:
# Token is expired
return None
return TokenData(**payload)
except JWTError:
# Token is invalid (bad signature, malformed, etc.)
return None
except Exception:
# Catch any other exceptions (e.g., Pydantic validation errors)
return None

View File

@@ -17,8 +17,9 @@ class BoardRecordStorageBase(ABC):
def save(
self,
board_name: str,
user_id: str,
) -> BoardRecord:
"""Saves a board record."""
"""Saves a board record for a specific user."""
pass
@abstractmethod
@@ -41,18 +42,25 @@ class BoardRecordStorageBase(ABC):
@abstractmethod
def get_many(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
"""Gets many board records for a specific user, including shared boards. Admin users see all boards."""
pass
@abstractmethod
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
include_archived: bool = False,
) -> list[BoardRecord]:
"""Gets all board records."""
"""Gets all board records for a specific user, including shared boards. Admin users see all boards."""
pass

View File

@@ -16,6 +16,8 @@ class BoardRecord(BaseModelExcludeNull):
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
user_id: str = Field(description="The user ID of the board owner.")
"""The user ID of the board owner."""
created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
@@ -35,6 +37,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
# Default to 'system' for backwards compatibility with boards created before multiuser support
user_id = board_dict.get("user_id", "system")
cover_image_name = board_dict.get("cover_image_name", "unknown")
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
@@ -44,6 +48,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
return BoardRecord(
board_id=board_id,
board_name=board_name,
user_id=user_id,
cover_image_name=cover_image_name,
created_at=created_at,
updated_at=updated_at,

View File

@@ -38,16 +38,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
def save(
self,
board_name: str,
user_id: str,
) -> BoardRecord:
with self._db.transaction() as cursor:
try:
board_id = uuid_string()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
INSERT OR IGNORE INTO boards (board_id, board_name, user_id)
VALUES (?, ?, ?);
""",
(board_id, board_name),
(board_id, board_name, user_id),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
@@ -121,6 +122,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
def get_many(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
@@ -128,74 +131,147 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
# Build base query - admins see all boards, regular users see owned, shared, or public boards
if is_admin:
base_query = """
SELECT DISTINCT boards.*
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
# Determine archived filter condition
archived_filter = "WHERE 1=1" if include_archived else "WHERE boards.archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
else:
base_query = """
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "AND boards.archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (user_id, user_id, limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
# Determine count query - admins count all boards, regular users count accessible boards
if is_admin:
if include_archived:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
else:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
WHERE archived = 0;
WHERE boards.archived = 0;
"""
cursor.execute(count_query)
else:
if include_archived:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1);
"""
else:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
AND boards.archived = 0;
"""
# Execute count query
cursor.execute(count_query)
# Execute count query
cursor.execute(count_query, (user_id, user_id))
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
include_archived: bool = False,
) -> list[BoardRecord]:
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
# Build query - admins see all boards, regular users see owned, shared, or public boards
if is_admin:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT DISTINCT boards.*
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
ORDER BY LOWER(boards.board_name) {direction}
"""
else:
base_query = """
SELECT *
else:
base_query = """
SELECT DISTINCT boards.*
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "WHERE 1=1" if include_archived else "WHERE boards.archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
cursor.execute(final_query)
else:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
{archived_filter}
ORDER BY LOWER(boards.board_name) {direction}
"""
else:
base_query = """
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "AND boards.archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query, (user_id, user_id))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

View File

@@ -13,8 +13,9 @@ class BoardServiceABC(ABC):
def create(
self,
board_name: str,
user_id: str,
) -> BoardDTO:
"""Creates a board."""
"""Creates a board for a specific user."""
pass
@abstractmethod
@@ -45,18 +46,25 @@ class BoardServiceABC(ABC):
@abstractmethod
def get_many(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
"""Gets many boards for a specific user, including shared boards. Admin users see all boards."""
pass
@abstractmethod
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
include_archived: bool = False,
) -> list[BoardDTO]:
"""Gets all boards."""
"""Gets all boards for a specific user, including shared boards. Admin users see all boards."""
pass

View File

@@ -14,10 +14,16 @@ class BoardDTO(BoardRecord):
"""The number of images in the board."""
asset_count: int = Field(description="The number of assets in the board.")
"""The number of assets in the board."""
owner_username: Optional[str] = Field(default=None, description="The username of the board owner (for admin view).")
"""The username of the board owner (for admin view)."""
def board_record_to_dto(
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int, asset_count: int
board_record: BoardRecord,
cover_image_name: Optional[str],
image_count: int,
asset_count: int,
owner_username: Optional[str] = None,
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
@@ -25,4 +31,5 @@ def board_record_to_dto(
cover_image_name=cover_image_name,
image_count=image_count,
asset_count=asset_count,
owner_username=owner_username,
)

View File

@@ -15,8 +15,9 @@ class BoardService(BoardServiceABC):
def create(
self,
board_name: str,
user_id: str,
) -> BoardDTO:
board_record = self.__invoker.services.board_records.save(board_name)
board_record = self.__invoker.services.board_records.save(board_name, user_id)
return board_record_to_dto(board_record, None, 0, 0)
def get_dto(self, board_id: str) -> BoardDTO:
@@ -51,6 +52,8 @@ class BoardService(BoardServiceABC):
def get_many(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
@@ -58,7 +61,7 @@ class BoardService(BoardServiceABC):
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(
order_by, direction, offset, limit, include_archived
user_id, is_admin, order_by, direction, offset, limit, include_archived
)
board_dtos = []
for r in board_records.items:
@@ -70,14 +73,29 @@ class BoardService(BoardServiceABC):
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
asset_count = self.__invoker.services.board_image_records.get_asset_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count))
# For admin users, include owner username
owner_username = None
if is_admin:
owner = self.__invoker.services.users.get(r.user_id)
if owner:
owner_username = owner.display_name or owner.email
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count, owner_username))
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
include_archived: bool = False,
) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(order_by, direction, include_archived)
board_records = self.__invoker.services.board_records.get_all(
user_id, is_admin, order_by, direction, include_archived
)
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
@@ -88,6 +106,14 @@ class BoardService(BoardServiceABC):
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
asset_count = self.__invoker.services.board_image_records.get_asset_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count))
# For admin users, include owner username
owner_username = None
if is_admin:
owner = self.__invoker.services.users.get(r.user_id)
if owner:
owner_username = owner.display_name or owner.email
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count, owner_username))
return board_dtos

View File

@@ -4,15 +4,16 @@ from abc import ABC, abstractmethod
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
This class defines the interface for persisting client data per user.
"""
@abstractmethod
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
def set_by_key(self, user_id: str, key: str, value: str) -> str:
"""
Set a key-value pair for the client.
Args:
user_id (str): The user ID to set state for.
key (str): The key to set.
value (str): The value to set for the key.
@@ -22,11 +23,12 @@ class ClientStatePersistenceABC(ABC):
pass
@abstractmethod
def get_by_key(self, queue_id: str, key: str) -> str | None:
def get_by_key(self, user_id: str, key: str) -> str | None:
"""
Get the value for a specific key of the client.
Args:
user_id (str): The user ID to get state for.
key (str): The key to retrieve the value for.
Returns:
@@ -35,8 +37,11 @@ class ClientStatePersistenceABC(ABC):
pass
@abstractmethod
def delete(self, queue_id: str) -> None:
def delete(self, user_id: str) -> None:
"""
Delete all client state.
Delete all client state for a user.
Args:
user_id (str): The user ID to delete state for.
"""
pass

View File

@@ -1,5 +1,3 @@
import json
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -7,59 +5,51 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
SQLite implementation for client state persistence.
This class stores client state data per user to prevent data leakage between users.
"""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def _get(self) -> dict[str, str] | None:
def set_by_key(self, user_id: str, key: str, value: str) -> str:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
state = self._get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
INSERT INTO client_state (user_id, key, value)
VALUES (?, ?, ?)
ON CONFLICT(user_id, key) DO UPDATE
SET value = excluded.value;
""",
(json.dumps(state),),
(user_id, key, value),
)
return value
def get_by_key(self, queue_id: str, key: str) -> str | None:
state = self._get()
if state is None:
return None
return state.get(key, None)
def delete(self, queue_id: str) -> None:
def get_by_key(self, user_id: str, key: str) -> str | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
SELECT value FROM client_state
WHERE user_id = ? AND key = ?
""",
(user_id, key),
)
row = cursor.fetchone()
if row is None:
return None
return row[0]
def delete(self, user_id: str) -> None:
with self._db.transaction() as cursor:
cursor.execute(
"""
DELETE FROM client_state
WHERE user_id = ?
""",
(user_id,),
)

View File

@@ -110,6 +110,7 @@ class InvokeAIAppConfig(BaseSettings):
scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
multiuser: Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.
"""
_root: Optional[Path] = PrivateAttr(default=None)
@@ -203,6 +204,9 @@ class InvokeAIAppConfig(BaseSettings):
unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.")
allow_unknown_models: bool = Field(default=True, description="Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.")
# MULTIUSER
multiuser: bool = Field(default=False, description="Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.")
# fmt: on
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)

View File

@@ -1,4 +1,4 @@
# Copyright (c) 2023, Lincoln D. Stein
# Copyright (c) 2023,2026 Lincoln D. Stein
"""Implementation of multithreaded download queue for invokeai."""
import os
@@ -60,7 +60,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
"""
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {}
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
self._download_part2parent: Dict[int, MultiFileDownloadJob] = {}
self._mfd_pending: Dict[int, list[DownloadJob]] = {}
self._mfd_active: Dict[int, DownloadJob] = {}
self._next_job_id = 0
@@ -88,7 +88,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
"""Stop the download worker threads."""
with self._lock:
if not self._worker_pool:
raise Exception("Attempt to stop the download service before it was started")
return
self._accept_download_requests = False # reject attempts to add new jobs to queue
queued_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.WAITING]
active_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.RUNNING]
@@ -118,7 +118,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
job.id = self._next_id()
if job.id == -1:
job.id = self._next_id()
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
@@ -197,12 +198,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
dest=path,
access_token=access_token or self._lookup_access_token(url),
)
job.id = self._next_id() # pre-assign ID so _download_part2parent can be keyed by ID
if part.size and part.size > 0:
job.total_bytes = part.size
job.expected_total_bytes = part.size
job.canonical_url = str(url)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
self._download_part2parent[job.id] = mfdj
if submit_job:
self.submit_multifile_download(mfdj)
return mfdj
@@ -327,7 +329,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
finally:
job.job_ended = get_iso_timestamp()
self._job_terminated_event.set() # signal a change to terminal state
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
self._download_part2parent.pop(job.id, None) # if this is a subpart of a multipart job, remove it
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
@@ -386,18 +388,23 @@ class DownloadQueueService(DownloadQueueServiceBase):
if len(candidates) == 1:
inferred = candidates[0].with_name(candidates[0].name.removesuffix(".downloading"))
job.download_path = inferred
resume_from = candidates[0].stat().st_size
job.bytes = resume_from
self._logger.debug(
f"Resume check (dir): inferred in-progress file path={candidates[0]} size={resume_from} bytes"
)
if resume_from > 0:
if job.etag:
header["If-Range"] = job.etag
elif job.last_modified:
header["If-Range"] = job.last_modified
header["Range"] = f"bytes={resume_from}-"
open_mode = "ab"
try:
resume_from = candidates[0].stat().st_size
except FileNotFoundError:
# The .downloading file was renamed/deleted between glob and stat (race condition); skip resume.
job.download_path = None
else:
job.bytes = resume_from
self._logger.debug(
f"Resume check (dir): inferred in-progress file path={candidates[0]} size={resume_from} bytes"
)
if resume_from > 0:
if job.etag:
header["If-Range"] = job.etag
elif job.last_modified:
header["If-Range"] = job.last_modified
header["Range"] = f"bytes={resume_from}-"
open_mode = "ab"
else:
self._logger.debug(
"Resume check (dir): no prior download_path available; cannot resume from disk "
@@ -622,7 +629,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._event_bus.emit_download_cancelled(job)
# if multifile download, then signal the parent
if parent_job := self._download_part2parent.get(job.source, None):
if parent_job := self._download_part2parent.get(job.id, None):
if not parent_job.in_terminal_state:
parent_job.status = DownloadJobStatus.CANCELLED
self._execute_cb(parent_job, "on_cancelled")
@@ -639,7 +646,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
if self._event_bus:
self._event_bus.emit_download_paused(job)
if parent_job := self._download_part2parent.get(job.source, None):
if parent_job := self._download_part2parent.get(job.id, None):
if not parent_job.in_terminal_state:
parent_job.status = DownloadJobStatus.PAUSED
self._execute_cb(parent_job, "on_cancelled")
@@ -669,7 +676,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _mfd_started(self, download_job: DownloadJob) -> None:
self._logger.info(f"File download started: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
mf_job = self._download_part2parent[download_job.id]
if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING
@@ -682,7 +689,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _mfd_progress(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
mf_job = self._download_part2parent[download_job.id]
if mf_job.cancelled:
for part in mf_job.download_parts:
self.cancel_job(part)
@@ -696,7 +703,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
submit_next = False
mf_job: Optional[MultiFileDownloadJob] = None
with self._lock:
mf_job = self._download_part2parent[download_job.source]
mf_job = self._download_part2parent[download_job.id]
self._mfd_active.pop(mf_job.id, None)
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.bytes = sum(x.bytes for x in mf_job.download_parts)
@@ -715,7 +722,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
mf_job = self._download_part2parent[download_job.id]
assert mf_job is not None
self._mfd_active.pop(mf_job.id, None)
@@ -735,7 +742,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
mf_job = self._download_part2parent[download_job.id]
assert mf_job is not None
self._mfd_active.pop(mf_job.id, None)
if not mf_job.in_terminal_state:
@@ -748,7 +755,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
for s in [x for x in mf_job.download_parts if x.running]:
self.cancel_job(s)
self._download_part2parent.pop(download_job.source)
self._mfd_pending.pop(mf_job.id, None)
self._job_terminated_event.set()

View File

@@ -91,6 +91,7 @@ class QueueItemEventBase(QueueEventBase):
batch_id: str = Field(description="The ID of the queue batch")
origin: str | None = Field(default=None, description="The origin of the queue item")
destination: str | None = Field(default=None, description="The destination of the queue item")
user_id: str = Field(default="system", description="The ID of the user who created the queue item")
class InvocationEventBase(QueueItemEventBase):
@@ -117,6 +118,7 @@ class InvocationStartedEvent(InvocationEventBase):
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
user_id=queue_item.user_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -152,6 +154,7 @@ class InvocationProgressEvent(InvocationEventBase):
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
user_id=queue_item.user_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -179,6 +182,7 @@ class InvocationCompleteEvent(InvocationEventBase):
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
user_id=queue_item.user_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -211,6 +215,7 @@ class InvocationErrorEvent(InvocationEventBase):
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
user_id=queue_item.user_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -248,6 +253,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
user_id=queue_item.user_id,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,

View File

@@ -28,6 +28,10 @@ class FastAPIEventService(EventServiceBase):
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
def dispatch(self, event: EventBase) -> None:
if self._loop.is_closed():
# The event loop was closed during shutdown. Events can no longer be dispatched;
# silently drop this one so the generation thread can wind down cleanly.
return
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
async def _dispatch_from_queue(self, stop_event: threading.Event):

View File

@@ -50,8 +50,10 @@ class ImageRecordStorageBase(ABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
"""Gets a page of image records. When board_id is 'none', filters by user_id for per-user uncategorized images unless is_admin is True."""
pass
# TODO: The database has a nullable `deleted_at` column, currently unused.
@@ -90,6 +92,7 @@ class ImageRecordStorageBase(ABC):
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
user_id: Optional[str] = None,
) -> datetime:
"""Saves an image record."""
pass
@@ -109,6 +112,8 @@ class ImageRecordStorageBase(ABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -134,6 +134,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
@@ -186,6 +188,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
# For uncategorized images, filter by user_id to ensure per-user isolation
# Admin users can see all uncategorized images from all users
if user_id is not None and not is_admin:
query_conditions += """--sql
AND images.user_id = ?
"""
query_params.append(user_id)
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
@@ -305,6 +314,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
user_id: Optional[str] = None,
) -> datetime:
with self._db.transaction() as cursor:
try:
@@ -321,9 +331,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata,
is_intermediate,
starred,
has_workflow
has_workflow,
user_id
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
@@ -337,6 +348,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
is_intermediate,
starred,
has_workflow,
user_id or "system",
),
)
@@ -386,6 +398,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> ImageNamesResult:
with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
@@ -417,6 +431,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
# For uncategorized images, filter by user_id to ensure per-user isolation
# Admin users can see all uncategorized images from all users
if user_id is not None and not is_admin:
query_conditions += """--sql
AND images.user_id = ?
"""
query_params.append(user_id)
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?

View File

@@ -55,6 +55,7 @@ class ImageServiceABC(ABC):
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
user_id: Optional[str] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@@ -125,6 +126,8 @@ class ImageServiceABC(ABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs with starred images first when starred_first=True."""
pass
@@ -159,6 +162,8 @@ class ImageServiceABC(ABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -45,6 +45,7 @@ class ImageService(ImageServiceABC):
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
user_id: Optional[str] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@@ -72,6 +73,7 @@ class ImageService(ImageServiceABC):
node_id=node_id,
metadata=metadata,
session_id=session_id,
user_id=user_id,
)
if board_id is not None:
try:
@@ -215,6 +217,8 @@ class ImageService(ImageServiceABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self.__invoker.services.image_records.get_many(
@@ -227,6 +231,8 @@ class ImageService(ImageServiceABC):
is_intermediate,
board_id,
search_term,
user_id,
is_admin,
)
image_dtos = [
@@ -320,6 +326,8 @@ class ImageService(ImageServiceABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> ImageNamesResult:
try:
return self.__invoker.services.image_records.get_image_names(
@@ -330,6 +338,8 @@ class ImageService(ImageServiceABC):
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
user_id=user_id,
is_admin=is_admin,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting image names")

View File

@@ -36,6 +36,7 @@ if TYPE_CHECKING:
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.urls.urls_base import UrlServiceBase
from invokeai.app.services.users.users_base import UserServiceBase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_base import WorkflowThumbnailServiceBase
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@@ -75,6 +76,7 @@ class InvocationServices:
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
users: "UserServiceBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -105,3 +107,4 @@ class InvocationServices:
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence
self.users = users

View File

@@ -132,6 +132,9 @@ class ModelInstallService(ModelInstallServiceBase):
marker = {
"version": INSTALL_MARKER_VERSION,
"source": str(job.source),
"access_token": (
job.source.access_token if isinstance(job.source, (HFModelSource, URLModelSource)) else None
),
"config_in": job.config_in.model_dump(),
"status": (status or job.status).value,
"updated_at": get_iso_timestamp(),
@@ -186,6 +189,11 @@ class ModelInstallService(ModelInstallServiceBase):
def _restore_incomplete_installs(self) -> None:
path = self._app_config.models_path
seen_sources: set[str] = set()
# Collect sources already tracked by active jobs (including those being downloaded right now).
# We must not re-queue these or delete their tmpdirs.
with self._lock:
active_sources = {str(j.source) for j in self._install_jobs if not j.in_terminal_state}
active_sources.update(str(j.source) for j in self._download_cache.values() if not j.in_terminal_state)
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
marker = self._read_install_marker(tmpdir)
if not marker:
@@ -195,13 +203,22 @@ class ModelInstallService(ModelInstallServiceBase):
continue
try:
source_str = marker["source"]
source_str = marker.get("source")
if not isinstance(source_str, str):
raise ValueError("Missing source in install marker")
source = self._guess_source(source_str)
access_token = marker.get("access_token")
if isinstance(source, (HFModelSource, URLModelSource)) and isinstance(access_token, str):
source.access_token = access_token
if source_str in active_sources:
# This tmpdir belongs to an install already in progress; leave it alone.
self._logger.debug(f"Skipping restore for {source_str} - already being tracked")
continue
if source_str in seen_sources:
self._logger.info(f"Removing duplicate temporary directory {tmpdir}")
self._safe_rmtree(tmpdir, self._logger)
continue
seen_sources.add(source_str)
source = self._guess_source(source_str)
except Exception as e:
self._logger.warning(f"Skipping install marker in {tmpdir}: {e}")
continue
@@ -313,7 +330,7 @@ class ModelInstallService(ModelInstallServiceBase):
def stop(self, invoker: Optional[Invoker] = None) -> None:
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
if not self._running:
raise Exception("Attempt to stop the install service before it was started")
return
self._logger.debug("calling stop_event.set()")
self._stop_event.set()
self._clear_pending_jobs()

View File

@@ -355,6 +355,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._thread = Thread(
name="session_processor",
target=self._process,
daemon=True,
kwargs={
"stop_event": self._stop_event,
"poll_now_event": self._poll_now_event,
@@ -366,6 +367,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
def stop(self, *args, **kwargs) -> None:
self._stop_event.set()
# Cancel any in-progress generation so that long-running nodes (e.g. denoising) stop at
# the next step boundary instead of running to completion. Without this, the generation
# thread may still be executing CUDA operations when Python teardown begins, which can
# cause a C++ std::terminate() crash ("terminate called without an active exception").
self._cancel_event.set()
# Wake the thread if it is sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused).
self._poll_now_event.set()
self._resume_event.set()
def _poll_now(self) -> None:
self._poll_now_event.set()

View File

@@ -36,8 +36,10 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
"""Enqueues all permutations of a batch for execution."""
def enqueue_batch(
self, queue_id: str, batch: Batch, prepend: bool, user_id: str = "system"
) -> Coroutine[Any, Any, EnqueueBatchResult]:
"""Enqueues all permutations of a batch for execution for a specific user."""
pass
@abstractmethod
@@ -51,13 +53,13 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def clear(self, queue_id: str) -> ClearResult:
"""Deletes all session queue items"""
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
"""Deletes all session queue items. If user_id is provided, only clears items owned by that user."""
pass
@abstractmethod
def prune(self, queue_id: str) -> PruneResult:
"""Deletes all completed and errored session queue items"""
def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult:
"""Deletes all completed and errored session queue items. If user_id is provided, only prunes items owned by that user."""
pass
@abstractmethod
@@ -71,8 +73,8 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
"""Gets the status of the queue"""
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
"""Gets the status of the queue. If user_id is provided, also includes user-specific counts."""
pass
@abstractmethod
@@ -108,18 +110,24 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs"""
def cancel_by_batch_ids(
self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None
) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs. If user_id is provided, only cancels items owned by that user."""
pass
@abstractmethod
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
"""Cancels all queue items with the given batch destination"""
def cancel_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> CancelByDestinationResult:
"""Cancels all queue items with the given batch destination. If user_id is provided, only cancels items owned by that user."""
pass
@abstractmethod
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
"""Deletes all queue items with the given batch destination"""
def delete_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> DeleteByDestinationResult:
"""Deletes all queue items with the given batch destination. If user_id is provided, only deletes items owned by that user."""
pass
@abstractmethod
@@ -128,13 +136,13 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
"""Cancels all queue items except in-progress items"""
def cancel_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> CancelAllExceptCurrentResult:
"""Cancels all queue items except in-progress items. If user_id is provided, only cancels items owned by that user."""
pass
@abstractmethod
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
"""Deletes all queue items except in-progress items"""
def delete_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> DeleteAllExceptCurrentResult:
"""Deletes all queue items except in-progress items. If user_id is provided, only deletes items owned by that user."""
pass
@abstractmethod

View File

@@ -170,6 +170,7 @@ class Batch(BaseModel):
# region Queue Items
DEFAULT_QUEUE_ID = "default"
SYSTEM_USER_ID = "system" # Default user_id for system-generated queue items
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
@@ -243,6 +244,13 @@ class SessionQueueItem(BaseModel):
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
queue_id: str = Field(description="The id of the queue with which this item is associated")
user_id: str = Field(default="system", description="The id of the user who created this queue item")
user_display_name: Optional[str] = Field(
default=None, description="The display name of the user who created this queue item, if available"
)
user_email: Optional[str] = Field(
default=None, description="The email of the user who created this queue item, if available"
)
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
@@ -296,6 +304,12 @@ class SessionQueueStatus(BaseModel):
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
user_pending: Optional[int] = Field(
default=None, description="Number of queue items with status 'pending' for the current user"
)
user_in_progress: Optional[int] = Field(
default=None, description="Number of queue items with status 'in_progress' for the current user"
)
class SessionQueueCountsByDestination(BaseModel):
@@ -565,6 +579,7 @@ ValueToInsertTuple: TypeAlias = tuple[
str | None, # origin (optional)
str | None, # destination (optional)
int | None, # retried_from_item_id (optional, this is always None for new items)
str, # user_id
]
"""A type alias for the tuple of values to insert into the session queue table.
@@ -573,7 +588,7 @@ ValueToInsertTuple: TypeAlias = tuple[
def prepare_values_to_insert(
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int, user_id: str = "system"
) -> list[ValueToInsertTuple]:
"""
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
@@ -584,6 +599,7 @@ def prepare_values_to_insert(
batch: The batch to prepare the values for
priority: The priority of the queue items
max_new_queue_items: The maximum number of queue items to insert
user_id: The user ID who is creating these queue items
Returns:
A list of tuples to insert into the session queue table. Each tuple contains the following values:
@@ -597,6 +613,7 @@ def prepare_values_to_insert(
- origin (optional)
- destination (optional)
- retried_from_item_id (optional, this is always None for new items)
- user_id
"""
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
@@ -626,6 +643,7 @@ def prepare_values_to_insert(
batch.origin,
batch.destination,
None,
user_id,
)
)
return values_to_insert

View File

@@ -100,7 +100,9 @@ class SqliteSessionQueue(SessionQueueBase):
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
return priority
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
async def enqueue_batch(
self, queue_id: str, batch: Batch, prepend: bool, user_id: str = "system"
) -> EnqueueBatchResult:
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
@@ -119,14 +121,15 @@ class SqliteSessionQueue(SessionQueueBase):
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
user_id=user_id,
)
enqueued_count = len(values_to_insert)
with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
@@ -155,12 +158,16 @@ class SqliteSessionQueue(SessionQueueBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
SELECT
sq.*,
u.display_name as user_display_name,
u.email as user_email
FROM session_queue sq
LEFT JOIN users u ON sq.user_id = u.user_id
WHERE sq.status = 'pending'
ORDER BY
priority DESC,
item_id ASC
sq.priority DESC,
sq.item_id ASC
LIMIT 1
"""
)
@@ -175,14 +182,18 @@ class SqliteSessionQueue(SessionQueueBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
SELECT
sq.*,
u.display_name as user_display_name,
u.email as user_email
FROM session_queue sq
LEFT JOIN users u ON sq.user_id = u.user_id
WHERE
queue_id = ?
AND status = 'pending'
sq.queue_id = ?
AND sq.status = 'pending'
ORDER BY
priority DESC,
created_at ASC
sq.priority DESC,
sq.created_at ASC
LIMIT 1
""",
(queue_id,),
@@ -196,11 +207,15 @@ class SqliteSessionQueue(SessionQueueBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
SELECT
sq.*,
u.display_name as user_display_name,
u.email as user_email
FROM session_queue sq
LEFT JOIN users u ON sq.user_id = u.user_id
WHERE
queue_id = ?
AND status = 'in_progress'
sq.queue_id = ?
AND sq.status = 'in_progress'
LIMIT 1
""",
(queue_id,),
@@ -277,31 +292,41 @@ class SqliteSessionQueue(SessionQueueBase):
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
with self._db.transaction() as cursor:
user_filter = "AND user_id = ?" if user_id is not None else ""
where = f"""--sql
WHERE queue_id = ?
{user_filter}
"""
params: list[str] = [queue_id]
if user_id is not None:
params.append(user_id)
cursor.execute(
"""--sql
f"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE queue_id = ?
{where}
""",
(queue_id,),
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
f"""--sql
DELETE
FROM session_queue
WHERE queue_id = ?
{where}
""",
(queue_id,),
tuple(params),
)
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult:
with self._db.transaction() as cursor:
where = """--sql
# Build WHERE clause with optional user_id filter
user_filter = "AND user_id = ?" if user_id is not None else ""
where = f"""--sql
WHERE
queue_id = ?
AND (
@@ -309,14 +334,19 @@ class SqliteSessionQueue(SessionQueueBase):
OR status = 'failed'
OR status = 'canceled'
)
{user_filter}
"""
params = [queue_id]
if user_id is not None:
params.append(user_id)
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
@@ -325,7 +355,7 @@ class SqliteSessionQueue(SessionQueueBase):
FROM session_queue
{where};
""",
(queue_id,),
tuple(params),
)
return PruneResult(deleted=count)
@@ -369,10 +399,15 @@ class SqliteSessionQueue(SessionQueueBase):
)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
def cancel_by_batch_ids(
self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None
) -> CancelByBatchIDsResult:
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
# Build WHERE clause with optional user_id filter
user_filter = "AND user_id = ?" if user_id is not None else ""
where = f"""--sql
WHERE
queue_id == ?
@@ -382,8 +417,12 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
{user_filter}
"""
params = [queue_id] + batch_ids
if user_id is not None:
params.append(user_id)
cursor.execute(
f"""--sql
SELECT COUNT(*)
@@ -402,15 +441,22 @@ class SqliteSessionQueue(SessionQueueBase):
tuple(params),
)
# Handle current item separately - check ownership if user_id is provided
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
if user_id is None or current_queue_item.user_id == user_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
def cancel_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> CancelByDestinationResult:
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
# Build WHERE clause with optional user_id filter
user_filter = "AND user_id = ?" if user_id is not None else ""
where = f"""--sql
WHERE
queue_id == ?
AND destination == ?
@@ -419,15 +465,19 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
{user_filter}
"""
params = (queue_id, destination)
params = [queue_id, destination]
if user_id is not None:
params.append(user_id)
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
params,
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
@@ -436,55 +486,78 @@ class SqliteSessionQueue(SessionQueueBase):
SET status = 'canceled'
{where};
""",
params,
tuple(params),
)
# Handle current item separately - check ownership if user_id is provided
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
if user_id is None or current_queue_item.user_id == user_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
def delete_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> DeleteByDestinationResult:
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
# Handle current item separately - check ownership if user_id is provided
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
params = (queue_id, destination)
if user_id is None or current_queue_item.user_id == user_id:
self.cancel_queue_item(current_queue_item.item_id)
# Build WHERE clause with optional user_id filter
user_filter = "AND user_id = ?" if user_id is not None else ""
params = [queue_id, destination]
if user_id is not None:
params.append(user_id)
cursor.execute(
"""--sql
f"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
queue_id == ?
AND destination == ?
{user_filter}
""",
params,
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
DELETE
FROM session_queue
f"""--sql
DELETE FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
queue_id == ?
AND destination == ?
{user_filter}
""",
params,
tuple(params),
)
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
def delete_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> DeleteAllExceptCurrentResult:
with self._db.transaction() as cursor:
where = """--sql
# Build WHERE clause with optional user_id filter
user_filter = "AND user_id = ?" if user_id is not None else ""
where = f"""--sql
WHERE
queue_id == ?
AND status == 'pending'
{user_filter}
"""
params = [queue_id]
if user_id is not None:
params.append(user_id)
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
@@ -493,7 +566,7 @@ class SqliteSessionQueue(SessionQueueBase):
FROM session_queue
{where};
""",
(queue_id,),
tuple(params),
)
return DeleteAllExceptCurrentResult(deleted=count)
@@ -532,20 +605,27 @@ class SqliteSessionQueue(SessionQueueBase):
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
def cancel_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> CancelAllExceptCurrentResult:
with self._db.transaction() as cursor:
where = """--sql
# Build WHERE clause with optional user_id filter
user_filter = "AND user_id = ?" if user_id is not None else ""
where = f"""--sql
WHERE
queue_id == ?
AND status == 'pending'
{user_filter}
"""
params = [queue_id]
if user_id is not None:
params.append(user_id)
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
@@ -554,7 +634,7 @@ class SqliteSessionQueue(SessionQueueBase):
SET status = 'canceled'
{where};
""",
(queue_id,),
tuple(params),
)
return CancelAllExceptCurrentResult(canceled=count)
@@ -562,9 +642,13 @@ class SqliteSessionQueue(SessionQueueBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
SELECT
sq.*,
u.display_name as user_display_name,
u.email as user_email
FROM session_queue sq
LEFT JOIN users u ON sq.user_id = u.user_id
WHERE sq.item_id = ?
""",
(item_id,),
)
@@ -650,22 +734,26 @@ class SqliteSessionQueue(SessionQueueBase):
"""Gets all queue items that match the given parameters"""
with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
SELECT
sq.*,
u.display_name as user_display_name,
u.email as user_email
FROM session_queue sq
LEFT JOIN users u ON sq.user_id = u.user_id
WHERE sq.queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if destination is not None:
query += """---sql
AND destination = ?
AND sq.destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
sq.priority DESC,
sq.item_id ASC
;
"""
cursor.execute(query, params)
@@ -693,8 +781,9 @@ class SqliteSessionQueue(SessionQueueBase):
return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids))
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
with self._db.transaction() as cursor:
# Get total counts
cursor.execute(
"""--sql
SELECT status, count(*)
@@ -706,9 +795,32 @@ class SqliteSessionQueue(SessionQueueBase):
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
# Get user-specific counts if user_id is provided (using a single query with CASE)
user_counts_result = []
if user_id is not None:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ? AND user_id = ?
GROUP BY status
""",
(queue_id, user_id),
)
user_counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
# Process user-specific counts if available
user_pending = None
user_in_progress = None
if user_id is not None:
user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result}
user_pending = user_counts.get("pending", 0)
user_in_progress = user_counts.get("in_progress", 0)
return SessionQueueStatus(
queue_id=queue_id,
item_id=current_item.item_id if current_item else None,
@@ -720,6 +832,8 @@ class SqliteSessionQueue(SessionQueueBase):
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
user_pending=user_pending,
user_in_progress=user_in_progress,
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
@@ -822,6 +936,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item.origin,
queue_item.destination,
retried_from_item_id,
queue_item.user_id,
)
values_to_insert.append(value_to_insert)
@@ -829,8 +944,8 @@ class SqliteSessionQueue(SessionQueueBase):
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)

View File

@@ -72,7 +72,7 @@ class InvocationContextInterface:
class BoardsInterface(InvocationContextInterface):
def create(self, board_name: str) -> BoardDTO:
"""Creates a board.
"""Creates a board for the current user.
Args:
board_name: The name of the board to create.
@@ -80,7 +80,8 @@ class BoardsInterface(InvocationContextInterface):
Returns:
The created board DTO.
"""
return self._services.boards.create(board_name)
user_id = self._data.queue_item.user_id
return self._services.boards.create(board_name, user_id)
def get_dto(self, board_id: str) -> BoardDTO:
"""Gets a board DTO.
@@ -94,13 +95,14 @@ class BoardsInterface(InvocationContextInterface):
return self._services.boards.get_dto(board_id)
def get_all(self) -> list[BoardDTO]:
"""Gets all boards.
"""Gets all boards accessible to the current user.
Returns:
A list of all boards.
A list of all boards accessible to the current user.
"""
user_id = self._data.queue_item.user_id
return self._services.boards.get_all(
order_by=BoardRecordOrderBy.CreatedAt, direction=SQLiteDirection.Descending
user_id, order_by=BoardRecordOrderBy.CreatedAt, direction=SQLiteDirection.Descending
)
def add_image_to_board(self, board_id: str, image_name: str) -> None:
@@ -228,6 +230,7 @@ class ImagesInterface(InvocationContextInterface):
graph=graph_,
session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id,
user_id=self._data.queue_item.user_id,
)
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:

View File

@@ -29,6 +29,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -75,6 +76,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_24(app_config=config, logger=logger))
migrator.register_migration(build_migration_25(app_config=config, logger=logger))
migrator.register_migration(build_migration_26(app_config=config, logger=logger))
migrator.register_migration(build_migration_27())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,366 @@
"""Migration 27: Add multi-user support, per-user client state, and app settings.
This migration adds the database schema for multi-user support, including:
- users table for user accounts
- user_sessions table for session management
- user_invitations table for invitation system
- shared_boards table for board sharing
- Adding user_id columns to existing tables for data ownership
- Restructuring client_state table to support per-user storage
- app_settings table for storing JWT secret and other app-level settings
"""
import json
import secrets
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration27Callback:
"""Migration to add multi-user support, per-user client state, and app settings."""
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_users_table(cursor)
self._create_user_sessions_table(cursor)
self._create_user_invitations_table(cursor)
self._create_shared_boards_table(cursor)
self._update_boards_table(cursor)
self._update_images_table(cursor)
self._update_workflows_table(cursor)
self._update_session_queue_table(cursor)
self._update_style_presets_table(cursor)
self._create_system_user(cursor)
self._update_client_state_table(cursor)
self._create_app_settings_table(cursor)
self._generate_jwt_secret(cursor)
def _create_users_table(self, cursor: sqlite3.Cursor) -> None:
"""Create users table."""
cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
user_id TEXT NOT NULL PRIMARY KEY,
email TEXT NOT NULL UNIQUE,
display_name TEXT,
password_hash TEXT NOT NULL,
is_admin BOOLEAN NOT NULL DEFAULT FALSE,
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
last_login_at DATETIME
);
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_is_admin ON users(is_admin);")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_is_active ON users(is_active);")
cursor.execute("""
CREATE TRIGGER IF NOT EXISTS tg_users_updated_at
AFTER UPDATE ON users FOR EACH ROW
BEGIN
UPDATE users SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE user_id = old.user_id;
END;
""")
def _create_user_sessions_table(self, cursor: sqlite3.Cursor) -> None:
"""Create user_sessions table for session management."""
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_sessions (
session_id TEXT NOT NULL PRIMARY KEY,
user_id TEXT NOT NULL,
token_hash TEXT NOT NULL,
expires_at DATETIME NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
last_activity_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id);")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_sessions_token_hash ON user_sessions(token_hash);")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_sessions_expires_at ON user_sessions(expires_at);")
def _create_user_invitations_table(self, cursor: sqlite3.Cursor) -> None:
"""Create user_invitations table for invitation system."""
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_invitations (
invitation_id TEXT NOT NULL PRIMARY KEY,
email TEXT NOT NULL,
invited_by TEXT NOT NULL,
invitation_code TEXT NOT NULL UNIQUE,
is_admin BOOLEAN NOT NULL DEFAULT FALSE,
expires_at DATETIME NOT NULL,
used_at DATETIME,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
FOREIGN KEY (invited_by) REFERENCES users(user_id) ON DELETE CASCADE
);
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_invitations_email ON user_invitations(email);")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_user_invitations_invitation_code ON user_invitations(invitation_code);"
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_invitations_expires_at ON user_invitations(expires_at);")
def _create_shared_boards_table(self, cursor: sqlite3.Cursor) -> None:
"""Create shared_boards table for board sharing."""
cursor.execute("""
CREATE TABLE IF NOT EXISTS shared_boards (
board_id TEXT NOT NULL,
user_id TEXT NOT NULL,
can_edit BOOLEAN NOT NULL DEFAULT FALSE,
shared_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
PRIMARY KEY (board_id, user_id),
FOREIGN KEY (board_id) REFERENCES boards(board_id) ON DELETE CASCADE,
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_shared_boards_user_id ON shared_boards(user_id);")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_shared_boards_board_id ON shared_boards(board_id);")
def _update_boards_table(self, cursor: sqlite3.Cursor) -> None:
"""Add user_id and is_public columns to boards table."""
# Check if boards table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='boards';")
if cursor.fetchone() is None:
return
# Check if user_id column exists
cursor.execute("PRAGMA table_info(boards);")
columns = [row[1] for row in cursor.fetchall()]
if "user_id" not in columns:
cursor.execute("ALTER TABLE boards ADD COLUMN user_id TEXT DEFAULT 'system';")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_user_id ON boards(user_id);")
if "is_public" not in columns:
cursor.execute("ALTER TABLE boards ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_is_public ON boards(is_public);")
def _update_images_table(self, cursor: sqlite3.Cursor) -> None:
"""Add user_id column to images table."""
# Check if images table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='images';")
if cursor.fetchone() is None:
return
cursor.execute("PRAGMA table_info(images);")
columns = [row[1] for row in cursor.fetchall()]
if "user_id" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN user_id TEXT DEFAULT 'system';")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_images_user_id ON images(user_id);")
def _update_workflows_table(self, cursor: sqlite3.Cursor) -> None:
"""Add user_id and is_public columns to workflows table."""
# Check if workflows table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='workflows';")
if cursor.fetchone() is None:
return
cursor.execute("PRAGMA table_info(workflows);")
columns = [row[1] for row in cursor.fetchall()]
if "user_id" not in columns:
cursor.execute("ALTER TABLE workflows ADD COLUMN user_id TEXT DEFAULT 'system';")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflows_user_id ON workflows(user_id);")
if "is_public" not in columns:
cursor.execute("ALTER TABLE workflows ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflows_is_public ON workflows(is_public);")
def _update_session_queue_table(self, cursor: sqlite3.Cursor) -> None:
"""Add user_id column to session_queue table."""
# Check if session_queue table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='session_queue';")
if cursor.fetchone() is None:
return
cursor.execute("PRAGMA table_info(session_queue);")
columns = [row[1] for row in cursor.fetchall()]
if "user_id" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN user_id TEXT DEFAULT 'system';")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_session_queue_user_id ON session_queue(user_id);")
def _update_style_presets_table(self, cursor: sqlite3.Cursor) -> None:
"""Add user_id and is_public columns to style_presets table."""
# Check if style_presets table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='style_presets';")
if cursor.fetchone() is None:
return
cursor.execute("PRAGMA table_info(style_presets);")
columns = [row[1] for row in cursor.fetchall()]
if "user_id" not in columns:
cursor.execute("ALTER TABLE style_presets ADD COLUMN user_id TEXT DEFAULT 'system';")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_style_presets_user_id ON style_presets(user_id);")
if "is_public" not in columns:
cursor.execute("ALTER TABLE style_presets ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_style_presets_is_public ON style_presets(is_public);")
def _create_system_user(self, cursor: sqlite3.Cursor) -> None:
"""Create system user for backward compatibility.
The system user is NOT an admin - it's just used to own existing data
from before multi-user support was added. Real admin users should be
created through the /auth/setup endpoint.
"""
cursor.execute("""
INSERT OR IGNORE INTO users (user_id, email, display_name, password_hash, is_admin, is_active)
VALUES ('system', 'system@system.invokeai', 'System', '', FALSE, TRUE);
""")
def _update_client_state_table(self, cursor: sqlite3.Cursor) -> None:
"""Restructure client_state table to support per-user storage."""
# Check if client_state table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='client_state';")
if cursor.fetchone() is None:
# Table doesn't exist, create it with the new schema
cursor.execute(
"""
CREATE TABLE client_state (
user_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP),
PRIMARY KEY (user_id, key),
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);")
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE user_id = OLD.user_id AND key = OLD.key;
END;
"""
)
return
# Table exists with old schema - migrate it
# Get existing data if the data column is present (it may be absent if an older
# version of migration 21 was deployed without the column)
cursor.execute("PRAGMA table_info(client_state);")
columns = [row[1] for row in cursor.fetchall()]
existing_data = {}
if "data" in columns:
cursor.execute("SELECT data FROM client_state WHERE id = 1;")
row = cursor.fetchone()
if row is not None:
try:
existing_data = json.loads(row[0])
except (json.JSONDecodeError, TypeError):
# If data is corrupt, just start fresh
pass
# Drop the old table
cursor.execute("DROP TABLE IF EXISTS client_state;")
# Create new table with per-user schema
cursor.execute(
"""
CREATE TABLE client_state (
user_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP),
PRIMARY KEY (user_id, key),
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_client_state_user_id ON client_state(user_id);")
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE user_id = OLD.user_id AND key = OLD.key;
END;
"""
)
# Migrate existing data to 'system' user
for key, value in existing_data.items():
cursor.execute(
"""
INSERT INTO client_state (user_id, key, value)
VALUES ('system', ?, ?);
""",
(key, value),
)
def _create_app_settings_table(self, cursor: sqlite3.Cursor) -> None:
"""Create app_settings table for storing application-level configuration."""
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS app_settings (
key TEXT NOT NULL PRIMARY KEY,
value TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute(
"""
CREATE TRIGGER IF NOT EXISTS tg_app_settings_updated_at
AFTER UPDATE ON app_settings
FOR EACH ROW
BEGIN
UPDATE app_settings SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE key = OLD.key;
END;
"""
)
def _generate_jwt_secret(self, cursor: sqlite3.Cursor) -> None:
"""Generate and store a cryptographically secure JWT secret key.
The secret is a 64-character hexadecimal string (256 bits of entropy),
which is suitable for HS256 JWT signing.
"""
# Check if JWT secret already exists
cursor.execute("SELECT value FROM app_settings WHERE key = 'jwt_secret';")
existing_secret = cursor.fetchone()
if existing_secret is None:
# Generate a new cryptographically secure secret (256 bits)
jwt_secret = secrets.token_hex(32) # 32 bytes = 256 bits = 64 hex characters
# Store in database
cursor.execute(
"INSERT INTO app_settings (key, value) VALUES ('jwt_secret', ?);",
(jwt_secret,),
)
def build_migration_27() -> Migration:
"""Builds the migration object for migrating from version 26 to version 27.
This migration adds multi-user support, per-user client state, and app settings
(including a JWT secret) to the database schema.
"""
return Migration(
from_version=26,
to_version=27,
callback=Migration27Callback(),
)

View File

@@ -0,0 +1 @@
"""User service module."""

View File

@@ -0,0 +1,135 @@
"""Abstract base class for user service."""
from abc import ABC, abstractmethod
from invokeai.app.services.users.users_common import UserCreateRequest, UserDTO, UserUpdateRequest
class UserServiceBase(ABC):
"""High-level service for user management."""
@abstractmethod
def create(self, user_data: UserCreateRequest) -> UserDTO:
"""Create a new user.
Args:
user_data: User creation data
Returns:
The created user
Raises:
ValueError: If email already exists or password is weak
"""
pass
@abstractmethod
def get(self, user_id: str) -> UserDTO | None:
"""Get user by ID.
Args:
user_id: The user ID
Returns:
UserDTO if found, None otherwise
"""
pass
@abstractmethod
def get_by_email(self, email: str) -> UserDTO | None:
"""Get user by email.
Args:
email: The email address
Returns:
UserDTO if found, None otherwise
"""
pass
@abstractmethod
def update(self, user_id: str, changes: UserUpdateRequest) -> UserDTO:
"""Update user.
Args:
user_id: The user ID
changes: Fields to update
Returns:
The updated user
Raises:
ValueError: If user not found or password is weak
"""
pass
@abstractmethod
def delete(self, user_id: str) -> None:
"""Delete user.
Args:
user_id: The user ID
Raises:
ValueError: If user not found
"""
pass
@abstractmethod
def authenticate(self, email: str, password: str) -> UserDTO | None:
"""Authenticate user credentials.
Args:
email: User email
password: User password
Returns:
UserDTO if authentication successful, None otherwise
"""
pass
@abstractmethod
def has_admin(self) -> bool:
"""Check if any admin user exists.
Returns:
True if at least one admin user exists, False otherwise
"""
pass
@abstractmethod
def create_admin(self, user_data: UserCreateRequest) -> UserDTO:
"""Create an admin user (for initial setup).
Args:
user_data: User creation data
Returns:
The created admin user
Raises:
ValueError: If admin already exists or password is weak
"""
pass
@abstractmethod
def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
"""List all users.
Args:
limit: Maximum number of users to return
offset: Number of users to skip
Returns:
List of users
"""
pass
@abstractmethod
def count_admins(self) -> int:
"""Count active admin users.
Returns:
The number of active admin users
"""
pass

View File

@@ -0,0 +1,114 @@
"""Common types and data models for user service."""
from datetime import datetime
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticCustomError
def validate_email_with_special_domains(email: str) -> str:
"""Validate email address, allowing special-use domains like .local for testing.
This validator first tries standard email validation using email-validator library.
If it fails due to special-use domains (like .local, .test, .localhost), it performs
a basic syntax check instead. This allows development/testing with non-routable domains
while still catching actual typos and malformed emails.
Args:
email: The email address to validate
Returns:
The validated email address (lowercased)
Raises:
PydanticCustomError: If the email format is invalid
"""
try:
# Try standard email validation using email-validator
from email_validator import EmailNotValidError, validate_email
result = validate_email(email, check_deliverability=False)
return result.normalized
except EmailNotValidError as e:
error_msg = str(e)
# Check if the error is specifically about special-use/reserved domains or localhost
if (
"special-use" in error_msg.lower()
or "reserved" in error_msg.lower()
or "should have a period" in error_msg.lower()
):
# Perform basic email syntax validation
email = email.strip().lower()
if "@" not in email:
raise PydanticCustomError(
"value_error",
"Email address must contain an @ symbol",
)
local_part, domain = email.rsplit("@", 1)
if not local_part or not domain:
raise PydanticCustomError(
"value_error",
"Email address must have both local and domain parts",
)
# Allow localhost and domains with dots
if domain == "localhost" or "." in domain:
return email
raise PydanticCustomError(
"value_error",
"Email domain must contain a dot or be 'localhost'",
)
else:
# Re-raise other validation errors
raise PydanticCustomError(
"value_error",
f"Invalid email address: {error_msg}",
)
class UserDTO(BaseModel):
"""User data transfer object."""
user_id: str = Field(description="Unique user identifier")
email: str = Field(description="User email address")
display_name: str | None = Field(default=None, description="Display name")
is_admin: bool = Field(default=False, description="Whether user has admin privileges")
is_active: bool = Field(default=True, description="Whether user account is active")
created_at: datetime = Field(description="When the user was created")
updated_at: datetime = Field(description="When the user was last updated")
last_login_at: datetime | None = Field(default=None, description="When user last logged in")
@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email address, allowing special-use domains."""
return validate_email_with_special_domains(v)
class UserCreateRequest(BaseModel):
"""Request to create a new user."""
email: str = Field(description="User email address")
display_name: str | None = Field(default=None, description="Display name")
password: str = Field(description="User password")
is_admin: bool = Field(default=False, description="Whether user should have admin privileges")
@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email address, allowing special-use domains."""
return validate_email_with_special_domains(v)
class UserUpdateRequest(BaseModel):
"""Request to update a user."""
display_name: str | None = Field(default=None, description="Display name")
password: str | None = Field(default=None, description="New password")
is_admin: bool | None = Field(default=None, description="Whether user should have admin privileges")
is_active: bool | None = Field(default=None, description="Whether user account should be active")

View File

@@ -0,0 +1,258 @@
"""Default SQLite implementation of user service."""
import sqlite3
from datetime import datetime, timezone
from uuid import uuid4
from invokeai.app.services.auth.password_utils import hash_password, validate_password_strength, verify_password
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_base import UserServiceBase
from invokeai.app.services.users.users_common import UserCreateRequest, UserDTO, UserUpdateRequest
class UserService(UserServiceBase):
"""SQLite-based user service."""
def __init__(self, db: SqliteDatabase):
"""Initialize user service.
Args:
db: SQLite database instance
"""
self._db = db
def create(self, user_data: UserCreateRequest) -> UserDTO:
"""Create a new user."""
# Validate password strength
is_valid, error_msg = validate_password_strength(user_data.password)
if not is_valid:
raise ValueError(error_msg)
# Check if email already exists
if self.get_by_email(user_data.email) is not None:
raise ValueError(f"User with email {user_data.email} already exists")
user_id = str(uuid4())
password_hash = hash_password(user_data.password)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""
INSERT INTO users (user_id, email, display_name, password_hash, is_admin)
VALUES (?, ?, ?, ?, ?)
""",
(user_id, user_data.email, user_data.display_name, password_hash, user_data.is_admin),
)
except sqlite3.IntegrityError as e:
raise ValueError(f"Failed to create user: {e}") from e
user = self.get(user_id)
if user is None:
raise RuntimeError("Failed to retrieve created user")
return user
def get(self, user_id: str) -> UserDTO | None:
"""Get user by ID."""
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT user_id, email, display_name, is_admin, is_active, created_at, updated_at, last_login_at
FROM users
WHERE user_id = ?
""",
(user_id,),
)
row = cursor.fetchone()
if row is None:
return None
return UserDTO(
user_id=row[0],
email=row[1],
display_name=row[2],
is_admin=bool(row[3]),
is_active=bool(row[4]),
created_at=datetime.fromisoformat(row[5]),
updated_at=datetime.fromisoformat(row[6]),
last_login_at=datetime.fromisoformat(row[7]) if row[7] else None,
)
def get_by_email(self, email: str) -> UserDTO | None:
"""Get user by email."""
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT user_id, email, display_name, is_admin, is_active, created_at, updated_at, last_login_at
FROM users
WHERE email = ?
""",
(email,),
)
row = cursor.fetchone()
if row is None:
return None
return UserDTO(
user_id=row[0],
email=row[1],
display_name=row[2],
is_admin=bool(row[3]),
is_active=bool(row[4]),
created_at=datetime.fromisoformat(row[5]),
updated_at=datetime.fromisoformat(row[6]),
last_login_at=datetime.fromisoformat(row[7]) if row[7] else None,
)
def update(self, user_id: str, changes: UserUpdateRequest) -> UserDTO:
"""Update user."""
# Check if user exists
user = self.get(user_id)
if user is None:
raise ValueError(f"User {user_id} not found")
# Validate password if provided
if changes.password is not None:
is_valid, error_msg = validate_password_strength(changes.password)
if not is_valid:
raise ValueError(error_msg)
# Build update query dynamically based on provided fields
updates: list[str] = []
params: list[str | bool | int] = []
if changes.display_name is not None:
updates.append("display_name = ?")
params.append(changes.display_name)
if changes.password is not None:
updates.append("password_hash = ?")
params.append(hash_password(changes.password))
if changes.is_admin is not None:
updates.append("is_admin = ?")
params.append(changes.is_admin)
if changes.is_active is not None:
updates.append("is_active = ?")
params.append(changes.is_active)
if not updates:
return user
params.append(user_id)
query = f"UPDATE users SET {', '.join(updates)} WHERE user_id = ?"
with self._db.transaction() as cursor:
cursor.execute(query, params)
updated_user = self.get(user_id)
if updated_user is None:
raise RuntimeError("Failed to retrieve updated user")
return updated_user
def delete(self, user_id: str) -> None:
"""Delete user."""
user = self.get(user_id)
if user is None:
raise ValueError(f"User {user_id} not found")
with self._db.transaction() as cursor:
cursor.execute("DELETE FROM users WHERE user_id = ?", (user_id,))
def authenticate(self, email: str, password: str) -> UserDTO | None:
"""Authenticate user credentials."""
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT user_id, email, display_name, password_hash, is_admin, is_active, created_at, updated_at, last_login_at
FROM users
WHERE email = ?
""",
(email,),
)
row = cursor.fetchone()
if row is None:
return None
password_hash = row[3]
if not verify_password(password, password_hash):
return None
# Update last login time
with self._db.transaction() as cursor:
cursor.execute(
"UPDATE users SET last_login_at = ? WHERE user_id = ?",
(datetime.now(timezone.utc).isoformat(), row[0]),
)
return UserDTO(
user_id=row[0],
email=row[1],
display_name=row[2],
is_admin=bool(row[4]),
is_active=bool(row[5]),
created_at=datetime.fromisoformat(row[6]),
updated_at=datetime.fromisoformat(row[7]),
last_login_at=datetime.now(timezone.utc),
)
def has_admin(self) -> bool:
"""Check if any admin user exists."""
with self._db.transaction() as cursor:
cursor.execute("SELECT COUNT(*) FROM users WHERE is_admin = TRUE AND is_active = TRUE")
row = cursor.fetchone()
count = row[0] if row else 0
return bool(count > 0)
def create_admin(self, user_data: UserCreateRequest) -> UserDTO:
"""Create an admin user (for initial setup)."""
if self.has_admin():
raise ValueError("Admin user already exists")
# Force is_admin to True
admin_data = UserCreateRequest(
email=user_data.email,
display_name=user_data.display_name,
password=user_data.password,
is_admin=True,
)
return self.create(admin_data)
def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
"""List all users."""
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT user_id, email, display_name, is_admin, is_active, created_at, updated_at, last_login_at
FROM users
ORDER BY created_at DESC
LIMIT ? OFFSET ?
""",
(limit, offset),
)
rows = cursor.fetchall()
return [
UserDTO(
user_id=row[0],
email=row[1],
display_name=row[2],
is_admin=bool(row[3]),
is_active=bool(row[4]),
created_at=datetime.fromisoformat(row[5]),
updated_at=datetime.fromisoformat(row[6]),
last_login_at=datetime.fromisoformat(row[7]) if row[7] else None,
)
for row in rows
]
def count_admins(self) -> int:
"""Count active admin users."""
with self._db.transaction() as cursor:
cursor.execute("SELECT COUNT(*) FROM users WHERE is_admin = TRUE AND is_active = TRUE")
row = cursor.fetchone()
return int(row[0]) if row else 0

View File

@@ -0,0 +1,579 @@
"""User management command entry points for InvokeAI.
These functions are registered as console scripts in pyproject.toml and can be
called from the command line after installing the package:
invoke-useradd -- add a user
invoke-userdel -- delete a user
invoke-userlist -- list users
invoke-usermod -- modify a user
"""
import argparse
import getpass
import json
import os
import sys
_root_help = (
"Path to the InvokeAI root directory. If omitted, the root is resolved in this order: "
"the $INVOKEAI_ROOT environment variable, the active virtual environment's parent directory, "
"or $HOME/invokeai."
)
# ---------------------------------------------------------------------------
# useradd
# ---------------------------------------------------------------------------
def _add_user_interactive() -> bool:
"""Add a user interactively by prompting for details."""
from invokeai.app.services.auth.password_utils import validate_password_strength
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_common import UserCreateRequest
from invokeai.app.services.users.users_default import UserService
from invokeai.backend.util.logging import InvokeAILogger
print("=== Add InvokeAI User ===\n")
email = input("Email address: ").strip()
if not email:
print("Error: Email is required")
return False
display_name = input("Display name (optional): ").strip() or None
while True:
password = getpass.getpass("Password: ")
password_confirm = getpass.getpass("Confirm password: ")
if password != password_confirm:
print("Error: Passwords do not match. Please try again.\n")
continue
is_valid, error_msg = validate_password_strength(password)
if not is_valid:
print(f"Error: {error_msg}\n")
continue
break
is_admin_input = input("Make this user an administrator? (y/N): ").strip().lower()
is_admin = is_admin_input in ("y", "yes")
try:
config = get_config()
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
user_service = UserService(db)
user_data = UserCreateRequest(email=email, display_name=display_name, password=password, is_admin=is_admin)
user = user_service.create(user_data)
print("\n✅ User created successfully!")
print(f" User ID: {user.user_id}")
print(f" Email: {user.email}")
print(f" Display Name: {user.display_name or '(not set)'}")
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
print(f" Active: {'Yes' if user.is_active else 'No'}")
return True
except ValueError as e:
print(f"\n❌ Error: {e}")
return False
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return False
def _add_user_cli(email: str, password: str, display_name: str | None = None, is_admin: bool = False) -> bool:
"""Add a user via CLI arguments."""
from invokeai.app.services.auth.password_utils import validate_password_strength
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_common import UserCreateRequest
from invokeai.app.services.users.users_default import UserService
from invokeai.backend.util.logging import InvokeAILogger
is_valid, error_msg = validate_password_strength(password)
if not is_valid:
print(f"❌ Password validation failed: {error_msg}")
return False
try:
config = get_config()
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
user_service = UserService(db)
user_data = UserCreateRequest(email=email, display_name=display_name, password=password, is_admin=is_admin)
user = user_service.create(user_data)
print("✅ User created successfully!")
print(f" User ID: {user.user_id}")
print(f" Email: {user.email}")
print(f" Display Name: {user.display_name or '(not set)'}")
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
print(f" Active: {'Yes' if user.is_active else 'No'}")
return True
except ValueError as e:
print(f"❌ Error: {e}")
return False
except Exception as e:
print(f"❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return False
def useradd() -> None:
"""Entry point for ``invoke-useradd``."""
parser = argparse.ArgumentParser(
description="Add a user to the InvokeAI database",
epilog="If no arguments are provided, the script will run in interactive mode.",
)
parser.add_argument("--root", "-r", help=_root_help)
parser.add_argument("--email", "-e", help="User email address")
parser.add_argument("--password", "-p", help="User password")
parser.add_argument("--name", "-n", help="User display name (optional)")
parser.add_argument("--admin", "-a", action="store_true", help="Make user an administrator")
args = parser.parse_args()
if args.root:
os.environ["INVOKEAI_ROOT"] = args.root
if args.email or args.password:
if not args.email or not args.password:
print("❌ Error: Both --email and --password are required when using CLI mode")
print(" Run without arguments for interactive mode")
sys.exit(1)
success = _add_user_cli(args.email, args.password, args.name, args.admin)
else:
success = _add_user_interactive()
sys.exit(0 if success else 1)
# ---------------------------------------------------------------------------
# userdel
# ---------------------------------------------------------------------------
def _delete_user_interactive() -> bool:
"""Delete a user interactively by prompting for email."""
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_default import UserService
from invokeai.backend.util.logging import InvokeAILogger
print("=== Delete InvokeAI User ===\n")
email = input("Email address of user to delete: ").strip()
if not email:
print("Error: Email is required")
return False
try:
config = get_config()
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
user_service = UserService(db)
user = user_service.get_by_email(email)
if not user:
print(f"\n❌ Error: No user found with email '{email}'")
return False
print("\nUser to delete:")
print(f" User ID: {user.user_id}")
print(f" Email: {user.email}")
print(f" Display Name: {user.display_name or '(not set)'}")
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
print(f" Active: {'Yes' if user.is_active else 'No'}")
confirm = input("\n⚠️ Are you sure you want to delete this user? (yes/no): ").strip().lower()
if confirm not in ("yes", "y"):
print("Deletion cancelled.")
return False
user_service.delete(user.user_id)
print("\n✅ User deleted successfully!")
return True
except ValueError as e:
print(f"\n❌ Error: {e}")
return False
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return False
def _delete_user_cli(email: str, force: bool = False) -> bool:
"""Delete a user via CLI arguments."""
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_default import UserService
from invokeai.backend.util.logging import InvokeAILogger
try:
config = get_config()
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
user_service = UserService(db)
user = user_service.get_by_email(email)
if not user:
print(f"❌ Error: No user found with email '{email}'")
return False
if not force:
print("User to delete:")
print(f" User ID: {user.user_id}")
print(f" Email: {user.email}")
print(f" Display Name: {user.display_name or '(not set)'}")
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
print(f" Active: {'Yes' if user.is_active else 'No'}")
confirm = input("\n⚠️ Are you sure you want to delete this user? (yes/no): ").strip().lower()
if confirm not in ("yes", "y"):
print("Deletion cancelled.")
return False
user_service.delete(user.user_id)
print("✅ User deleted successfully!")
return True
except ValueError as e:
print(f"❌ Error: {e}")
return False
except Exception as e:
print(f"❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return False
def userdel() -> None:
"""Entry point for ``invoke-userdel``."""
parser = argparse.ArgumentParser(
description="Delete a user from the InvokeAI database",
epilog="If no arguments are provided, the script will run in interactive mode.",
)
parser.add_argument("--root", "-r", help=_root_help)
parser.add_argument("--email", "-e", help="User email address")
parser.add_argument("--force", "-f", action="store_true", help="Delete without confirmation prompt")
args = parser.parse_args()
if args.root:
os.environ["INVOKEAI_ROOT"] = args.root
if args.email:
success = _delete_user_cli(args.email, args.force)
else:
success = _delete_user_interactive()
sys.exit(0 if success else 1)
# ---------------------------------------------------------------------------
# userlist
# ---------------------------------------------------------------------------
def _list_users_table() -> bool:
"""List all users in a formatted table."""
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_default import UserService
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger)
user_service = UserService(db)
try:
users = user_service.list_users()
if not users:
print("No users found in database.")
return True
print("\n=== InvokeAI Users ===\n")
print(f"{'User ID':<36} {'Email':<30} {'Display Name':<20} {'Admin':<8} {'Active':<8}")
print("-" * 108)
for user in users:
user_id = user.user_id
email = user.email[:29] if len(user.email) > 29 else user.email
raw_name = user.display_name or ""
name = raw_name[:19] if len(raw_name) > 19 else raw_name
is_admin = "Yes" if user.is_admin else "No"
is_active = "Yes" if user.is_active else "No"
print(f"{user_id:<36} {email:<30} {name:<20} {is_admin:<8} {is_active:<8}")
print(f"\nTotal users: {len(users)}")
return True
except Exception as e:
print(f"Error listing users: {e}")
return False
def _list_users_json() -> bool:
"""List all users in JSON format."""
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_default import UserService
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger)
user_service = UserService(db)
try:
users = user_service.list_users()
users_data = [
{
"id": user.user_id,
"email": user.email,
"name": user.display_name,
"is_admin": user.is_admin,
"is_active": user.is_active,
}
for user in users
]
print(json.dumps(users_data, indent=2))
return True
except Exception as e:
print(f'{{"error": "{e}"}}', file=sys.stderr)
return False
def userlist() -> None:
"""Entry point for ``invoke-userlist``."""
parser = argparse.ArgumentParser(
description="List users from the InvokeAI database",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
invoke-userlist
invoke-userlist --json
""",
)
parser.add_argument("--root", "-r", help=_root_help)
parser.add_argument(
"--json",
action="store_true",
help="Output users in JSON format instead of table",
)
args = parser.parse_args()
if args.root:
os.environ["INVOKEAI_ROOT"] = args.root
success = _list_users_json() if args.json else _list_users_table()
sys.exit(0 if success else 1)
# ---------------------------------------------------------------------------
# usermod
# ---------------------------------------------------------------------------
def _modify_user_interactive() -> bool:
"""Modify a user interactively by prompting for details."""
from invokeai.app.services.auth.password_utils import validate_password_strength
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_common import UserUpdateRequest
from invokeai.app.services.users.users_default import UserService
from invokeai.backend.util.logging import InvokeAILogger
print("=== Modify InvokeAI User ===\n")
email = input("Email address of user to modify: ").strip()
if not email:
print("Error: Email is required")
return False
try:
config = get_config()
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
user_service = UserService(db)
user = user_service.get_by_email(email)
if not user:
print(f"\n❌ Error: No user found with email '{email}'")
return False
print("\nCurrent user details:")
print(f" User ID: {user.user_id}")
print(f" Email: {user.email}")
print(f" Display Name: {user.display_name or '(not set)'}")
print(f" Admin: {'Yes' if user.is_admin else 'No'}")
print(f" Active: {'Yes' if user.is_active else 'No'}")
print("\n--- What would you like to change? (leave blank to keep current value) ---\n")
new_name = input(f"New display name [{user.display_name or '(not set)'}]: ").strip()
display_name = new_name if new_name else None
change_password = input("Change password? (y/N): ").strip().lower()
password = None
if change_password in ("y", "yes"):
while True:
password = getpass.getpass("New password: ")
if not password:
print("Keeping existing password.")
password = None
break
password_confirm = getpass.getpass("Confirm new password: ")
if password != password_confirm:
print("Error: Passwords do not match. Please try again.\n")
continue
is_valid, error_msg = validate_password_strength(password)
if not is_valid:
print(f"Error: {error_msg}\n")
continue
break
change_admin = input("Change admin status? (y/N): ").strip().lower()
is_admin = None
if change_admin in ("y", "yes"):
is_admin_input = (
input(f"Make administrator? [current: {'Yes' if user.is_admin else 'No'}] (y/N): ").strip().lower()
)
is_admin = is_admin_input in ("y", "yes")
if display_name is None and password is None and is_admin is None:
print("\nNo changes requested. User not modified.")
return True
changes = UserUpdateRequest(display_name=display_name, password=password, is_admin=is_admin)
updated_user = user_service.update(user.user_id, changes)
print("\n✅ User updated successfully!")
print(f" User ID: {updated_user.user_id}")
print(f" Email: {updated_user.email}")
print(f" Display Name: {updated_user.display_name or '(not set)'}")
print(f" Admin: {'Yes' if updated_user.is_admin else 'No'}")
print(f" Active: {'Yes' if updated_user.is_active else 'No'}")
return True
except ValueError as e:
print(f"\n❌ Error: {e}")
return False
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return False
def _modify_user_cli(
email: str,
display_name: str | None = None,
password: str | None = None,
is_admin: bool | None = None,
) -> bool:
"""Modify a user via CLI arguments."""
from invokeai.app.services.auth.password_utils import validate_password_strength
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_common import UserUpdateRequest
from invokeai.app.services.users.users_default import UserService
from invokeai.backend.util.logging import InvokeAILogger
if password is not None:
is_valid, error_msg = validate_password_strength(password)
if not is_valid:
print(f"❌ Password validation failed: {error_msg}")
return False
try:
config = get_config()
db = SqliteDatabase(config.db_path, InvokeAILogger.get_logger())
user_service = UserService(db)
user = user_service.get_by_email(email)
if not user:
print(f"❌ Error: No user found with email '{email}'")
return False
if display_name is None and password is None and is_admin is None:
print("❌ Error: No changes specified. Use --name, --password, --admin, or --no-admin")
return False
changes = UserUpdateRequest(display_name=display_name, password=password, is_admin=is_admin)
updated_user = user_service.update(user.user_id, changes)
print("✅ User updated successfully!")
print(f" User ID: {updated_user.user_id}")
print(f" Email: {updated_user.email}")
print(f" Display Name: {updated_user.display_name or '(not set)'}")
print(f" Admin: {'Yes' if updated_user.is_admin else 'No'}")
print(f" Active: {'Yes' if updated_user.is_active else 'No'}")
return True
except ValueError as e:
print(f"❌ Error: {e}")
return False
except Exception as e:
print(f"❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return False
def usermod() -> None:
"""Entry point for ``invoke-usermod``."""
parser = argparse.ArgumentParser(
description="Modify a user in the InvokeAI database",
epilog="If no arguments are provided, the script will run in interactive mode.",
)
parser.add_argument("--root", "-r", help=_root_help)
parser.add_argument("--email", "-e", help="User email address")
parser.add_argument("--name", "-n", help="New display name")
parser.add_argument("--password", "-p", help="New password")
admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument("--admin", "-a", action="store_true", help="Grant administrator privileges")
admin_group.add_argument("--no-admin", dest="no_admin", action="store_true", help="Remove administrator privileges")
args = parser.parse_args()
if args.root:
os.environ["INVOKEAI_ROOT"] = args.root
is_admin = None
if args.admin:
is_admin = True
elif args.no_admin:
is_admin = False
if args.email:
success = _modify_user_cli(args.email, args.name, args.password, is_admin)
else:
success = _modify_user_interactive()
sys.exit(0 if success else 1)

View File

@@ -239,6 +239,52 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
if in_dim is not None:
return in_dim in _FLUX2_VEC_IN_DIMS
# Kohya format: check transformer block dimensions (hidden_size and MLP ratio).
# This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder).
# Klein 9B has hidden_size=4096 (vs 3072 for FLUX.1 and Klein 4B).
# Klein 4B has same hidden_size as FLUX.1 (3072) but different mlp_ratio (6 vs 4).
kohya_hidden_size: int | None = None
for key in state_dict:
if not isinstance(key, str):
continue
if not key.startswith("lora_unet_"):
continue
# Check img_attn_proj hidden_size
if "_img_attn_proj." in key and key.endswith("lora_down.weight"):
kohya_hidden_size = state_dict[key].shape[1]
if kohya_hidden_size != _FLUX1_HIDDEN_SIZE:
return True
break
# LoKR variant
elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim != _FLUX1_HIDDEN_SIZE:
return True
kohya_hidden_size = in_dim
break
# Kohya format: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
# Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288).
if kohya_hidden_size == _FLUX1_HIDDEN_SIZE:
for key in state_dict:
if not isinstance(key, str):
continue
if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith("lora_up.weight"):
ffn_dim = state_dict[key].shape[0]
if ffn_dim != kohya_hidden_size * _FLUX1_MLP_RATIO:
return True
break
# LoKR variant
if key.startswith("lora_unet_") and "_img_mlp_0." in key and key.endswith((".lokr_w1", ".lokr_w1_a")):
layer_prefix = key.rsplit(".", 1)[0]
out_dim = _lokr_out_dim(state_dict, layer_prefix)
if out_dim is not None and out_dim != kohya_hidden_size * _FLUX1_MLP_RATIO:
return True
break
return False
@@ -421,6 +467,33 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
return Flux2VariantType.Klein9B
return None
# Kohya format: check transformer block dimensions (hidden_size from img_attn_proj).
# This handles LoRAs that only target transformer blocks (no txt_in/vector_in/context_embedder).
for key in state_dict:
if not isinstance(key, str):
continue
if not key.startswith("lora_unet_"):
continue
# Check img_attn_proj hidden_size
if "_img_attn_proj." in key and key.endswith("lora_down.weight"):
dim = state_dict[key].shape[1]
if dim == KLEIN_4B_HIDDEN_SIZE:
return Flux2VariantType.Klein4B
if dim == KLEIN_9B_HIDDEN_SIZE:
return Flux2VariantType.Klein9B
return None
# LoKR variant
elif "_img_attn_proj." in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim == KLEIN_4B_HIDDEN_SIZE:
return Flux2VariantType.Klein4B
if in_dim == KLEIN_9B_HIDDEN_SIZE:
return Flux2VariantType.Klein9B
return None
return None

View File

@@ -253,16 +253,35 @@ class ZImageCheckpointModel(ModelLoader):
target_device = TorchDevice.choose_torch_device()
model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
# Filter out keys that don't belong to the ZImageTransformer2DModel.
# Merged checkpoints (e.g. LoRA-baked models) may bundle text encoder weights
# (text_encoders.*) or other non-transformer keys alongside the transformer weights.
# Also filter FP8 quantization metadata (scale_weight, scaled_fp8).
valid_prefixes = (
"all_x_embedder.",
"all_final_layer.",
"layers.",
"noise_refiner.",
"context_refiner.",
"t_embedder.",
"cap_embedder.",
"rope_embedder.",
)
valid_exact = {"x_pad_token", "cap_pad_token"}
keys_to_remove = [
k
for k in sd.keys()
if not (k.startswith(valid_prefixes) or k in valid_exact)
or k.endswith(".scale_weight")
or k == "scaled_fp8"
]
for k in keys_to_remove:
del sd[k]
# Handle memory management and dtype conversion
new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
# Filter out FP8 scale_weight and scaled_fp8 metadata keys
# These are quantization metadata that shouldn't be loaded into the model
keys_to_remove = [k for k in sd.keys() if k.endswith(".scale_weight") or k == "scaled_fp8"]
for k in keys_to_remove:
del sd[k]
# Convert to target dtype
for k in sd.keys():
sd[k] = sd[k].to(model_dtype)

View File

@@ -821,7 +821,7 @@ z_image_turbo = StarterModel(
name="Z-Image Turbo",
base=BaseModelType.ZImage,
source="Tongyi-MAI/Z-Image-Turbo",
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~13GB",
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~30.6GB",
type=ModelType.Main,
)

View File

@@ -15,6 +15,9 @@ const config: KnipConfig = {
// Will be using this
'src/common/hooks/useAsyncState.ts',
'src/app/store/use-debounced-app-selector.ts',
// Auth features - exports will be used in follow-up phases
'src/features/auth/**',
'src/services/api/endpoints/auth.ts',
],
ignoreBinaries: ['only-allow'],
ignoreDependencies: ['magic-string'],

View File

@@ -53481,6 +53481,36 @@
}
],
"description": "The workflow associated with this queue item"
},
"user_id": {
"type": "string",
"title": "User Id",
"description": "The id of the user who created this queue item",
"default": "system"
},
"user_display_name": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "User Display Name",
"description": "The display name of the user who created this queue item, if available"
},
"user_email": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "User Email",
"description": "The email of the user who created this queue item, if available"
}
},
"type": "object",
@@ -53571,6 +53601,30 @@
"type": "integer",
"title": "Total",
"description": "Total number of queue items"
},
"user_pending": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "User Pending",
"description": "Number of queue items with status 'pending' for the current user"
},
"user_in_progress": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "User In Progress",
"description": "Number of queue items with status 'in_progress' for the current user"
}
},
"type": "object",
@@ -60983,6 +61037,45 @@
"output": {
"$ref": "#/components/schemas/ZImageConditioningOutput"
}
},
"UserDTO": {
"type": "object",
"required": ["user_id", "email", "is_admin", "is_active"],
"properties": {
"user_id": {
"type": "string",
"title": "User Id",
"description": "The user ID"
},
"email": {
"type": "string",
"title": "Email",
"description": "The user email"
},
"display_name": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Display Name",
"description": "The user display name"
},
"is_admin": {
"type": "boolean",
"title": "Is Admin",
"description": "Whether the user is an admin"
},
"is_active": {
"type": "boolean",
"title": "Is Active",
"description": "Whether the user is active"
}
},
"title": "UserDTO"
}
}
}

View File

@@ -89,6 +89,7 @@
"react-icons": "^5.5.0",
"react-redux": "9.2.0",
"react-resizable-panels": "^3.0.3",
"react-router-dom": "^7.12.0",
"react-textarea-autosize": "^8.5.9",
"react-use": "^17.6.0",
"react-virtuoso": "^4.13.0",

View File

@@ -158,6 +158,9 @@ importers:
react-resizable-panels:
specifier: ^3.0.3
version: 3.0.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react-router-dom:
specifier: ^7.12.0
version: 7.12.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react-textarea-autosize:
specifier: ^8.5.9
version: 8.5.9(@types/react@18.3.23)(react@18.3.1)
@@ -1993,6 +1996,10 @@ packages:
convert-source-map@2.0.0:
resolution: {integrity: sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==}
cookie@1.1.1:
resolution: {integrity: sha512-ei8Aos7ja0weRpFzJnEA9UHJ/7XQmqglbRwnf2ATjcB9Wq874VKH9kfjjirM6UhU2/E5fFYadylyhFldcqSidQ==}
engines: {node: '>=18'}
copy-to-clipboard@3.3.3:
resolution: {integrity: sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==}
@@ -3459,6 +3466,23 @@ packages:
react: ^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc
react-dom: ^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc
react-router-dom@7.12.0:
resolution: {integrity: sha512-pfO9fiBcpEfX4Tx+iTYKDtPbrSLLCbwJ5EqP+SPYQu1VYCXdy79GSj0wttR0U4cikVdlImZuEZ/9ZNCgoaxwBA==}
engines: {node: '>=20.0.0'}
peerDependencies:
react: '>=18'
react-dom: '>=18'
react-router@7.12.0:
resolution: {integrity: sha512-kTPDYPFzDVGIIGNLS5VJykK0HfHLY5MF3b+xj0/tTyNYL1gF1qs7u67Z9jEhQk2sQ98SUaHxlG31g1JtF7IfVw==}
engines: {node: '>=20.0.0'}
peerDependencies:
react: '>=18'
react-dom: '>=18'
peerDependenciesMeta:
react-dom:
optional: true
react-select@5.10.2:
resolution: {integrity: sha512-Z33nHdEFWq9tfnfVXaiM12rbJmk+QjFEztWLtmXqQhz6Al4UZZ9xc0wiatmGtUOCCnHN0WizL3tCMYRENX4rVQ==}
peerDependencies:
@@ -3675,6 +3699,9 @@ packages:
resolution: {integrity: sha512-ZYkZLAvKTKQXWuh5XpBw7CdbSzagarX39WyZ2H07CDLC5/KfsRGlIXV8d4+tfqX1M7916mRqR1QfNHSij+c9Pw==}
engines: {node: '>=18'}
set-cookie-parser@2.7.2:
resolution: {integrity: sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==}
set-function-length@1.2.2:
resolution: {integrity: sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==}
engines: {node: '>= 0.4'}
@@ -6120,6 +6147,8 @@ snapshots:
convert-source-map@2.0.0: {}
cookie@1.1.1: {}
copy-to-clipboard@3.3.3:
dependencies:
toggle-selection: 1.0.6
@@ -7707,6 +7736,20 @@ snapshots:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
react-router-dom@7.12.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
dependencies:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
react-router: 7.12.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react-router@7.12.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
dependencies:
cookie: 1.1.1
react: 18.3.1
set-cookie-parser: 2.7.2
optionalDependencies:
react-dom: 18.3.1(react@18.3.1)
react-select@5.10.2(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
dependencies:
'@babel/runtime': 7.28.3
@@ -7982,6 +8025,8 @@ snapshots:
dependencies:
type-fest: 4.41.0
set-cookie-parser@2.7.2: {}
set-function-length@1.2.2:
dependencies:
define-data-property: 1.1.4

View File

@@ -15,6 +15,95 @@
"uploadImage": "Upload Image",
"uploadImages": "Upload Image(s)"
},
"auth": {
"login": {
"title": "Sign In to InvokeAI",
"email": "Email",
"emailPlaceholder": "Email",
"password": "Password",
"passwordPlaceholder": "Password",
"rememberMe": "Remember me for 7 days",
"signIn": "Sign In",
"signingIn": "Signing in...",
"loginFailed": "Login failed. Please check your credentials."
},
"setup": {
"title": "Welcome to InvokeAI",
"subtitle": "Set up your administrator account to get started",
"email": "Email",
"emailPlaceholder": "admin@example.com",
"emailHelper": "This will be your username for signing in",
"displayName": "Display Name",
"displayNamePlaceholder": "Administrator",
"displayNameHelper": "Your name as it will appear in the application",
"password": "Password",
"passwordPlaceholder": "Password",
"passwordHelper": "Must be at least 8 characters with uppercase, lowercase, and numbers",
"passwordTooShort": "Password must be at least 8 characters long",
"passwordMissingRequirements": "Password must contain uppercase, lowercase, and numbers",
"confirmPassword": "Confirm Password",
"confirmPasswordPlaceholder": "Confirm Password",
"passwordsDoNotMatch": "Passwords do not match",
"createAccount": "Create Administrator Account",
"creatingAccount": "Setting up...",
"setupFailed": "Setup failed. Please try again."
},
"userMenu": "User Menu",
"admin": "Admin",
"logout": "Logout",
"adminOnlyFeature": "This feature is only available to administrators.",
"profile": {
"menuItem": "My Profile",
"title": "My Profile",
"email": "Email",
"emailReadOnly": "Email address cannot be changed",
"displayName": "Display Name",
"displayNamePlaceholder": "Your name",
"changePassword": "Change Password",
"currentPassword": "Current Password",
"currentPasswordPlaceholder": "Current password",
"newPassword": "New Password",
"newPasswordPlaceholder": "New password",
"confirmPassword": "Confirm New Password",
"confirmPasswordPlaceholder": "Confirm new password",
"passwordsDoNotMatch": "Passwords do not match",
"saveSuccess": "Profile updated successfully",
"saveFailed": "Failed to save profile. Please try again."
},
"userManagement": {
"menuItem": "User Management",
"title": "User Management",
"email": "Email",
"emailPlaceholder": "user@example.com",
"displayName": "Display Name",
"displayNamePlaceholder": "Display name",
"password": "Password",
"passwordPlaceholder": "Password",
"newPassword": "New Password",
"newPasswordPlaceholder": "Leave blank to keep current password",
"role": "Role",
"status": "Status",
"actions": "Actions",
"isAdmin": "Administrator",
"user": "User",
"you": "You",
"createUser": "Create User",
"editUser": "Edit User",
"deleteUser": "Delete User",
"deleteConfirm": "Are you sure you want to delete \"{{name}}\"? This action cannot be undone.",
"generatePassword": "Generate Strong Password",
"showPassword": "Show password",
"hidePassword": "Hide password",
"activate": "Activate",
"deactivate": "Deactivate",
"saveFailed": "Failed to save user. Please try again.",
"deleteFailed": "Failed to delete user. Please try again.",
"loadFailed": "Failed to load users.",
"back": "Back",
"cannotDeleteSelf": "You cannot delete your own account",
"cannotDeactivateSelf": "You cannot deactivate your own account"
}
},
"boards": {
"addBoard": "Add Board",
"addPrivateBoard": "Add Private Board",
@@ -272,6 +361,7 @@
"cancelTooltip": "Cancel Current Item",
"cancelSucceeded": "Item Canceled",
"cancelFailed": "Problem Canceling Item",
"cancelFailedAccessDenied": "Problem Canceling Item: Access Denied",
"retrySucceeded": "Item Retried",
"retryFailed": "Problem Retrying Item",
"confirm": "Confirm",
@@ -283,6 +373,7 @@
"clearTooltip": "Cancel and Clear All Items",
"clearSucceeded": "Queue Cleared",
"clearFailed": "Problem Clearing Queue",
"clearFailedAccessDenied": "Problem Clearing Queue: Access Denied",
"cancelBatch": "Cancel Batch",
"cancelItem": "Cancel Item",
"retryItem": "Retry Item",
@@ -304,6 +395,7 @@
"canceled": "Canceled",
"completedIn": "Completed in",
"batch": "Batch",
"user": "User",
"origin": "Origin",
"destination": "Dest",
"upscaling": "Upscaling",
@@ -313,6 +405,8 @@
"other": "Other",
"gallery": "Gallery",
"batchFieldValues": "Batch Field Values",
"fieldValuesHidden": "<Hidden>",
"cannotViewDetails": "You do not have permission to view the details of this queue item",
"item": "Item",
"session": "Session",
"notReady": "Unable to Queue",
@@ -1094,6 +1188,7 @@
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
"queueEmpty": "The install queue is empty.",
"selectAll": "Select All",
"selectModelToView": "Select a model to view its details",
"typePhraseHere": "Type phrase here",
"t5Encoder": "T5 Encoder",
"qwen3Encoder": "Qwen3 Encoder",
@@ -1121,7 +1216,15 @@
"installingXModels_other": "Installing {{count}} models",
"skippingXDuplicates_one": ", skipping {{count}} duplicate",
"skippingXDuplicates_other": ", skipping {{count}} duplicates",
"manageModels": "Manage Models"
"manageModels": "Manage Models",
"exportSettings": "Export Settings",
"importSettings": "Import Settings",
"settingsExported": "Model settings exported",
"settingsImported": "Model settings imported",
"settingsImportedPartial": "Model settings partially imported. Incompatible settings were skipped: {{fields}}",
"settingsImportFailed": "Failed to import model settings",
"settingsImportIncompatible": "The settings file contains no compatible settings for this model type",
"settingsImportInvalidFile": "Invalid settings file"
},
"models": {
"addLora": "Add LoRA",
@@ -1515,6 +1618,8 @@
"general": "General",
"generation": "Generation",
"models": "Models",
"preferAttentionStyleNumeric": "Prefer Numeric Attention Style",
"prompt": "Prompt",
"resetComplete": "Web UI has been reset.",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
@@ -2370,10 +2475,14 @@
"text": {
"font": "Font",
"size": "Size",
"lineHeight": "Spacing",
"lineHeightDense": "Dense",
"lineHeightNormal": "Normal",
"lineHeightSpacious": "Spacious"
"bold": "Bold",
"italic": "Italic",
"underline": "Underline",
"strikethrough": "Strikethrough",
"alignLeft": "Align Left",
"alignCenter": "Align Center",
"alignRight": "Align Right",
"px": "px"
},
"newCanvasFromImage": "New Canvas from Image",
"newImg2ImgCanvasFromImage": "New Img2Img from Image",
@@ -2538,18 +2647,6 @@
"colorPicker": "Color Picker",
"text": "Text"
},
"text": {
"font": "Font",
"size": "Size",
"bold": "Bold",
"italic": "Italic",
"underline": "Underline",
"strikethrough": "Strikethrough",
"alignLeft": "Align Left",
"alignCenter": "Align Center",
"alignRight": "Align Right",
"px": "px"
},
"filter": {
"filter": "Filter",
"filters": "Filters",

View File

@@ -822,7 +822,29 @@
"orphanedModelsDeleted": "Eliminato con successo {{count}} modello orfano",
"orphanedModelsDeleteErrors": "Alcuni modelli non possono essere eliminati",
"orphanedModelsDeleteFailed": "Impossibile eliminare i modelli orfani",
"errorLoadingOrphanedModels": "Errore durante il caricamento dei modelli orfani. Riprova."
"errorLoadingOrphanedModels": "Errore durante il caricamento dei modelli orfani. Riprova.",
"pause": "Pausa",
"pauseAll": "Metti in pausa tutto",
"pauseAllTooltip": "Metti in pausa tutti i download attivi",
"resume": "Riprendi",
"resumeAll": "Riprendi tutto",
"resumeAllTooltip": "Riprendi tutti i download in pausa",
"restartFailed": "Riavvio non riuscito",
"restartFile": "Riavvia il file",
"restartRequired": "Riavvio richiesto",
"resumeRefused": "Ripristino rifiutato dal server. Riavvio richiesto.",
"backendDisconnected": "Backend disconnesso",
"cancelAll": "Annulla tutto",
"cancelAllTooltip": "Annulla tutti i download attivi",
"selectModelToView": "Seleziona un modello per visualizzarne i dettagli",
"exportSettings": "Impostazioni di esportazione",
"importSettings": "Impostazioni di importazione",
"settingsExported": "Impostazioni del modello esportate",
"settingsImported": "Impostazioni del modello importate",
"settingsImportedPartial": "Impostazioni del modello parzialmente importate. Le impostazioni incompatibili sono state ignorate: {{fields}}",
"settingsImportFailed": "Impossibile importare le impostazioni del modello",
"settingsImportIncompatible": "Il file delle impostazioni non contiene impostazioni compatibili per questo tipo di modello",
"settingsImportInvalidFile": "File di impostazioni non valido"
},
"parameters": {
"images": "Immagini",
@@ -993,7 +1015,8 @@
"showDetailedInvocationProgress": "Mostra dettagli avanzamento",
"enableHighlightFocusedRegions": "Evidenzia le regioni interessate",
"modelDescriptionsDisabled": "Descrizioni dei modelli nei menu a discesa disabilitate",
"modelDescriptionsDisabledDesc": "Le descrizioni dei modelli nei menu a discesa sono state disattivate. Abilitale nelle Impostazioni."
"modelDescriptionsDisabledDesc": "Le descrizioni dei modelli nei menu a discesa sono state disattivate. Abilitale nelle Impostazioni.",
"preferAttentionStyleNumeric": "Preferisci lo stile di attenzione numerico"
},
"toast": {
"uploadFailed": "Caricamento fallito",
@@ -1098,7 +1121,12 @@
"kleinEncoderClearedDescription": "Selezionare un encoder Qwen3 compatibile per la nuova variante del modello Klein",
"kleinEncoderCleared": "Encoder Qwen3 cancellato",
"schedulerReset": "Ripristino campionatore",
"schedulerResetZImageBase": "Il campionatore LCM non è compatibile con i modelli Z-Image Base. Reimpostare su Euler."
"schedulerResetZImageBase": "Il campionatore LCM non è compatibile con i modelli Z-Image Base. Reimpostare su Euler.",
"modelDownloadPaused": "Download del modello in pausa",
"modelDownloadResumed": "Ripresa del download",
"modelDownloadRestartFailed": "Riavvia i download non riusciti",
"modelDownloadRestartFile": "Riavvio del download del file",
"modelDownloadRestartedFromScratch": "Manca una parte del file. Riavviato il download dall'inizio."
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@@ -1357,7 +1385,13 @@
"locateInGalery": "Trova nella Galleria",
"deletedImagesCannotBeRestored": "Le immagini eliminate non possono essere ripristinate.",
"hideBoards": "Nascondi bacheche",
"viewBoards": "Visualizza le bacheche"
"viewBoards": "Visualizza le bacheche",
"pause": "Pausa",
"resume": "Riprendi",
"restartFailed": "Riavvio non riuscito",
"restartFile": "Riavvia il file",
"restartRequired": "Riavvio richiesto",
"resumeRefused": "Ripristino rifiutato dal server. Riavvio richiesto."
},
"queue": {
"queueFront": "Aggiungi all'inizio della coda",
@@ -1449,7 +1483,13 @@
"sortOrderDescending": "Discendente",
"createdAt": "Creato",
"completedAt": "Completato",
"batchFieldValues": "Valori del campo Lotto"
"batchFieldValues": "Valori del campo Lotto",
"paused": "In pausa",
"cancelFailedAccessDenied": "Problema durante l'annullamento dell'articolo: accesso negato",
"clearFailedAccessDenied": "Problema durante la cancellazione della coda: accesso negato",
"user": "Utente",
"cannotViewDetails": "Non hai l'autorizzazione per visualizzare i dettagli di questo elemento della coda",
"fieldValuesHidden": "<Nascosto>"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -2557,7 +2597,8 @@
"isEmpty": "{{title}} è vuoto",
"isDisabled": "{{title}} è disabilitato"
},
"scaledBbox": "Riquadro scalato"
"scaledBbox": "Riquadro scalato",
"textSessionActive": "L'inserimento del testo è attivo"
},
"canvasContextMenu": {
"newControlLayer": "Nuovo Livello di Controllo",
@@ -3024,5 +3065,80 @@
},
"lora": {
"weight": "Peso"
},
"auth": {
"login": {
"title": "Accedi a InvokeAI",
"rememberMe": "Ricordami per 7 giorni",
"signIn": "Accedi",
"signingIn": "Accesso in corso...",
"loginFailed": "Accesso non riuscito. Controlla le tue credenziali."
},
"setup": {
"title": "Benvenuti a InvokeAI",
"subtitle": "Configura il tuo account amministratore per iniziare",
"emailHelper": "Questo sarà il tuo nome utente per accedere",
"displayName": "Nome da visualizzare",
"displayNamePlaceholder": "Amministratore",
"displayNameHelper": "Il tuo nome come apparirà nell'applicazione",
"passwordHelper": "Deve contenere almeno 8 caratteri, tra maiuscole, minuscole e numeri",
"passwordTooShort": "La password deve essere lunga almeno 8 caratteri",
"passwordMissingRequirements": "La password deve contenere maiuscole, minuscole e numeri",
"confirmPassword": "Conferma password",
"confirmPasswordPlaceholder": "Conferma password",
"passwordsDoNotMatch": "Le password non corrispondono",
"createAccount": "Crea un account amministratore",
"creatingAccount": "Impostazione in corso...",
"setupFailed": "Installazione non riuscita. Riprova."
},
"userMenu": "Menu utente",
"logout": "Esci",
"adminOnlyFeature": "Questa funzionalità è disponibile solo per gli amministratori.",
"profile": {
"menuItem": "Il mio profilo",
"title": "Il mio profilo",
"emailReadOnly": "L'indirizzo email non può essere modificato",
"displayName": "Nome da visualizzare",
"displayNamePlaceholder": "Il tuo nome",
"changePassword": "Cambiare la password",
"currentPassword": "Password attuale",
"currentPasswordPlaceholder": "Password attuale",
"newPassword": "Nuova password",
"newPasswordPlaceholder": "Nuova password",
"confirmPassword": "Conferma nuova password",
"confirmPasswordPlaceholder": "Conferma nuova password",
"passwordsDoNotMatch": "Le password non corrispondono",
"saveSuccess": "Profilo aggiornato con successo",
"saveFailed": "Impossibile salvare il profilo. Riprova."
},
"userManagement": {
"menuItem": "Gestione utenti",
"title": "Gestione utenti",
"displayName": "Nome da visualizzare",
"displayNamePlaceholder": "Nome da visualizzare",
"newPassword": "Nuova password",
"newPasswordPlaceholder": "Lasciare vuoto per mantenere la password corrente",
"role": "Ruolo",
"status": "Stato",
"actions": "Azioni",
"isAdmin": "Amministratore",
"user": "Utente",
"you": "Tu",
"createUser": "Crea utente",
"editUser": "Modifica utente",
"deleteUser": "Elimina utente",
"deleteConfirm": "Vuoi davvero eliminare \"{{name}}\"? Questa azione non può essere annullata.",
"generatePassword": "Genera una password complessa",
"showPassword": "Mostra password",
"hidePassword": "Nascondi password",
"activate": "Attiva",
"deactivate": "Disattiva",
"saveFailed": "Impossibile salvare l'utente. Riprova.",
"deleteFailed": "Impossibile eliminare l'utente. Riprova.",
"loadFailed": "Impossibile caricare gli utenti.",
"back": "Indietro",
"cannotDeleteSelf": "Non puoi eliminare il tuo account",
"cannotDeactivateSelf": "Non puoi disattivare il tuo account"
}
}
}

View File

@@ -89,7 +89,16 @@
"none": "Ничего",
"new": "Новый",
"ok": "Ok",
"close": "Закрыть"
"close": "Закрыть",
"error_withCount_one": "{{count}} Ошибка",
"error_withCount_few": "{{count}} Ошибки",
"error_withCount_many": "{{count}} Ошибок",
"model_withCount_one": "{{count}} Модель",
"model_withCount_few": "{{count}} Модели",
"model_withCount_many": "{{count}} Моделей",
"options_withCount_one": "{{count}} Опция",
"options_withCount_few": "{{count}} Опции",
"options_withCount_many": "{{count}} Опций"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -108,7 +117,7 @@
"downloadSelection": "Скачать выделенное",
"currentlyInUse": "В настоящее время это изображение используется в следующих функциях:",
"unstarImage": "Удалить из избранного",
"dropOrUpload": "$t(gallery.drop) или загрузить",
"dropOrUpload": "Перетащите или загрузите",
"copy": "Копировать",
"download": "Скачать",
"noImageSelected": "Изображение не выбрано",
@@ -239,7 +248,7 @@
},
"filterSelected": {
"title": "Filter",
"desc": "Filter the selected layer. Only applies to Raster and Control layers."
"desc": "Применяет фильтр к выбранному слою. Применимо только к растровым слоям и слоям управления."
},
"undo": {
"desc": "Отменяет последнее действие на холсте.",
@@ -483,7 +492,7 @@
"deleteMsg1": "Вы точно хотите удалить модель из InvokeAI?",
"deleteMsg2": "Это приведет К УДАЛЕНИЮ модели С ДИСКА, если она находится в корневой папке Invoke. Если вы используете пользовательское расположение, то модель НЕ будет удалена с диска.",
"convertToDiffusersHelpText5": "Пожалуйста, убедитесь, что у вас достаточно места на диске. Модели обычно занимают 27 Гб.",
"convertToDiffusersHelpText3": "Ваш файл контрольной точки НА ДИСКЕ будет УДАЛЕН, если он находится в корневой папке InvokeAI. Если он находится в пользовательском расположении, то он НЕ будет удален.",
"convertToDiffusersHelpText3": "Файл чекпоинта будет удалён с диска, если он находится в корневой папке InvokeAI. Если файл расположен в пользовательской папке, он удалён не будет.",
"allModels": "Все модели",
"repo_id": "ID репозитория",
"convert": "Преобразовать",
@@ -541,7 +550,7 @@
"pathToConfig": "Путь к конфигурации",
"loraTriggerPhrases": "Триггерные фразы LoRA",
"mainModelTriggerPhrases": "Триггерные фразы основной модели",
"inplaceInstallDesc": "Устанавливайте модели без копирования файлов. При использовании модели она будет загружаться из этого места. Если этот параметр отключен, файлы модели будут скопированы в каталог моделей, управляемых Invoke, во время установки.",
"inplaceInstallDesc": "Устанавливать модели без перемещения файлов. В этом случае модель будет загружаться из исходной папки. Если опция отключена, файлы модели при установке будут перемещены в каталог моделей Invoke.",
"huggingFaceRepoID": "ID репозитория HuggingFace",
"installQueue": "Очередь установки",
"installAll": "Установить все",
@@ -575,8 +584,8 @@
"skippingXDuplicates_one": ", пропуская {{count}} дубликат",
"skippingXDuplicates_few": ", пропуская {{count}} дубликата",
"skippingXDuplicates_many": ", пропуская {{count}} дубликатов",
"includesNModels": "Включает в себя {{n}} моделей и их зависимостей",
"starterBundleHelpText": "Легко установите все модели, необходимые для начала работы с базовой моделью, включая основную модель, сети управления, IP-адаптеры и многое другое. При выборе комплекта все уже установленные модели будут пропущены."
"includesNModels": "Включает в себя {{n}} моделей и их зависимостей.",
"starterBundleHelpText": "Легко установите все модели, необходимые для начала работы с базовой моделью, включая основную модель, ControlNet, IP-адаптеры и другие. При выборе набора уже установленные модели будут пропущены."
},
"parameters": {
"images": "Изображения",
@@ -632,8 +641,8 @@
"canvasIsFiltering": "Холст фильтруется",
"canvasIsTransforming": "Холст трансформируется",
"noCLIPEmbedModelSelected": "Для генерации FLUX не выбрана модель CLIP Embed",
"canvasIsRasterizing": "Холст растрируется",
"canvasIsCompositing": "Холст составляется"
"canvasIsRasterizing": "Холст занят (идёт растеризация)",
"canvasIsCompositing": "Холст занят (идёт компоновка)"
},
"cfgRescaleMultiplier": "Множитель масштабирования CFG",
"patchmatchDownScaleSize": "уменьшить",
@@ -660,7 +669,10 @@
"optimizedImageToImage": "Оптимизированное img2img",
"sendToCanvas": "Отправить на холст",
"guidance": "Точность",
"boxBlur": "Box Blur"
"boxBlur": "Box Blur",
"images_withCount_one": "Изображение",
"images_withCount_few": "Изображения",
"images_withCount_many": "Изображений"
},
"settings": {
"models": "Модели",
@@ -690,7 +702,7 @@
"intermediatesCleared_one": "Очищено {{count}} промежуточное",
"intermediatesCleared_few": "Очищено {{count}} промежуточных",
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
"clearIntermediatesDesc1": "Очистка промежуточных данных приведёт к сбросу состояния холста и ControlNet.",
"intermediatesClearedFailed": "Проблема очистки промежуточных",
"reloadingIn": "Перезагрузка через",
"informationalPopoversDisabled": "Информационные всплывающие окна отключены",
@@ -704,7 +716,7 @@
"serverError": "Ошибка сервера",
"connected": "Подключено к серверу",
"canceled": "Обработка отменена",
"uploadFailedInvalidUploadDesc": "Это должны быть изображения PNG или JPEG.",
"uploadFailedInvalidUploadDesc": "Допускаются только изображения в формате PNG, JPEG или WEBP.",
"parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр задан",
"problemCopyingImage": "Не удается скопировать изображение",
@@ -747,7 +759,12 @@
"sentToUpscale": "Отправить на увеличение",
"linkCopied": "Ссылка скопирована",
"addedToUncategorized": "Добавлено в активы доски $t(boards.uncategorized)",
"imagesWillBeAddedTo": "Загруженные изображения будут добавлены в активы доски {{boardName}}."
"imagesWillBeAddedTo": "Загруженные изображения будут добавлены в активы доски {{boardName}}.",
"schedulerResetZImageBase": "Планировщик LCM несовместим с моделями Z-Image Base. Переключено на Euler.",
"schedulerReset": "Планировщик сброшен",
"uploadFailedInvalidUploadDesc_withCount_one": "Допускается не более 1 изображения в формате PNG, JPEG или WEBP.",
"uploadFailedInvalidUploadDesc_withCount_few": "Допускается не более {{count}} изображения в формате PNG, JPEG или WEBP.",
"uploadFailedInvalidUploadDesc_withCount_many": "Допускается не более {{count}} изображений в формате PNG, JPEG или WEBP."
},
"accessibility": {
"uploadImage": "Загрузить изображение",
@@ -892,7 +909,13 @@
"saveToGallery": "Сохранить в галерею",
"noWorkflows": "Нет рабочих процессов",
"noMatchingWorkflows": "Нет совпадающих рабочих процессов",
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>."
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>.",
"generatorImages_one": "{{count}} изображение",
"generatorImages_few": "{{count}} изображения",
"generatorImages_many": "{{count}} изображений",
"generatorNRandomValues_one": "{{count}} случайное значение",
"generatorNRandomValues_few": "{{count}} случайных значения",
"generatorNRandomValues_many": "{{count}} случайных значений"
},
"boards": {
"autoAddBoard": "Коллекция для автодобавления",
@@ -935,7 +958,19 @@
"shared": "Общие коллекции",
"noBoards": "Нет коллекций {{boardType}}",
"deletedPrivateBoardsCannotbeRestored": "Удалённые коллекции и изображения нельзя восстановить. При выборе «Удалить только коллекцию» изображения будут перемещены в личный раздел «Без категории» автора изображения.",
"updateBoardError": "Ошибка обновления коллекции"
"updateBoardError": "Ошибка обновления коллекции",
"pause": "Пауза",
"resume": "Возобновить",
"restartFailed": "Ошибка перезапуска",
"restartFile": "Перезапустить загрузку",
"restartRequired": "Требуется перезапуск",
"resumeRefused": "Сервер отклонил попытку возобновления. Требуется перезапуск.",
"uncategorizedImages": "Без категории",
"deleteAllUncategorizedImages": "Удалить все изображения без категории",
"deletedImagesCannotBeRestored": "Удалённые изображения нельзя восстановить.",
"hideBoards": "Скрыть коллекции",
"locateInGalery": "Показать в галерее",
"viewBoards": "Просмотреть коллекции"
},
"dynamicPrompts": {
"seedBehaviour": {
@@ -1031,19 +1066,19 @@
"controlNetResizeMode": {
"heading": "Режим изменения размера",
"paragraphs": [
"Метод подгонки размера входного изображения Control Adaptor к размеру выходного изображения."
"Метод подгонки размера входного изображения Control Adapter под размер выходного изображения."
]
},
"controlNetBeginEnd": {
"paragraphs": [
"Часть процесса шумоподавления, к которой будет применен адаптер контроля.",
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
"Эта настройка определяет, на каком этапе денойзинга (генерации) используется влияние данного слоя.",
"• Начальный шаг (%): Определяет, с какого момента в процессе генерации начинает учитываться влияние данного слоя."
],
"heading": "Процент начала/конца шага"
},
"dynamicPromptsSeedBehaviour": {
"paragraphs": [
"Управляет использованием сида при создании запросов.",
"Определяет, как используется сид при генерации промптов.",
"Для каждой итерации будет использоваться уникальный сид. Используйте это, чтобы изучить варианты запросов для одного сида.",
"Например, если у вас 5 запросов, каждое изображение будет использовать один и то же сид.",
"для каждого изображения будет использоваться уникальный сид. Это обеспечивает большую вариативность."
@@ -1071,8 +1106,8 @@
},
"paramDenoisingStrength": {
"paragraphs": [
"Количество шума, добавляемого к входному изображению.",
"0 приведет к идентичному изображению, а 1 - к совершенно новому."
"Определяет, насколько сгенерированное изображение отличается от растрового слоя (слоёв).",
"Меньшее значение сохраняет больше сходства с объединёнными видимыми растровыми слоями. Большее значение усиливает влияние глобального промпта."
],
"heading": "Шумоподавление"
},
@@ -1111,7 +1146,7 @@
"controlNetWeight": {
"heading": "Вес",
"paragraphs": [
"Вес адаптера управления. Более высокий вес приведет к большему воздействию на окончательное изображение."
"Определяет, насколько сильно слой влияет на процесс генерации."
]
},
"controlNet": {
@@ -1123,13 +1158,13 @@
"paramCFGScale": {
"heading": "Шкала точности (CFG)",
"paragraphs": [
"Контролирует, насколько запрос влияет на процесс генерации.",
"Определяет, насколько сильно промпт влияет на процесс генерации.",
"Высокие значения шкалы CFG могут привести к перенасыщению и искажению результатов генерации. "
]
},
"controlNetControlMode": {
"paragraphs": [
"Придает больший вес либо запросу, либо ControlNet."
"Смещает приоритет в сторону промпта или ControlNet."
],
"heading": "Режим управления"
},
@@ -1181,7 +1216,7 @@
"refinerCfgScale": {
"heading": "Шкала CFG",
"paragraphs": [
"Контролирует, насколько сильно запрос влияет на процесс генерации.",
"Определяет, насколько сильно промпт влияет на процесс генерации.",
"Аналогично CFG шкале генерации."
]
},
@@ -1290,24 +1325,24 @@
"ipAdapterMethod": {
"heading": "Метод",
"paragraphs": [
"Метод, с помощью которого применяется текущий IP-адаптер."
"Метод определяет, как референсное изображение будет влиять на процесс генерации."
]
},
"structure": {
"paragraphs": [
"Структура контролирует, насколько точно выходное изображение будет соответствовать макету оригинала. Низкая структура допускает значительные изменения, в то время как высокая структура строго сохраняет исходную композицию и макет."
"Структура определяет, насколько точно выходное изображение сохраняет компоновку исходного. Низкое значение допускает значительные изменения, а высокое строго сохраняет исходную композицию и расположение элементов."
],
"heading": "Структура"
},
"scale": {
"paragraphs": [
"Масштаб управляет размером выходного изображения и основывается на кратном разрешении входного изображения. Например, при увеличении в 2 раза изображения 1024x1024 на выходе получится 2048 x 2048."
"Масштаб определяет размер выходного изображения и рассчитывается как кратное разрешению исходного изображения. Например, увеличение в 2 раза для изображения 1024×1024 даст результат 2048×2048."
],
"heading": "Масштаб"
},
"creativity": {
"paragraphs": [
"Креативность контролирует степень свободы, предоставляемой модели при добавлении деталей. При низкой креативности модель остается близкой к оригинальному изображению, в то время как высокая креативность позволяет вносить больше изменений. При использовании подсказки высокая креативность увеличивает влияние подсказки."
"Креативность определяет степень свободы модели при добавлении деталей. Низкое значение сохраняет больше сходства с исходным изображением, а высокое допускает более значительные изменения. При использовании промпта высокое значение усиливает его влияние."
],
"heading": "Креативность"
},
@@ -1320,18 +1355,18 @@
"fluxDevLicense": {
"heading": "Некоммерческая лицензия",
"paragraphs": [
"Модели FLUX.1 [dev] лицензируются по некоммерческой лицензии FLUX [dev]. Чтобы использовать этот тип модели в коммерческих целях в Invoke, посетите наш веб-сайт, чтобы узнать больше."
"Модели FLUX.1 [dev] распространяются по некоммерческой лицензии FLUX [dev]. Для их коммерческого использования требуется отдельная лицензия."
]
},
"optimizedDenoising": {
"heading": "Оптимизированный img2img",
"paragraphs": [
"Включите опцию «Оптимизированный img2img», чтобы получить более плавную шкалу Denoise Strength для img2img и перерисовки с моделями Flux. Эта настройка улучшает возможность контролировать степень изменения изображения, но может быть отключена, если вы предпочитаете использовать стандартную шкалу Denoise Strength. Эта настройка все еще находится в стадии настройки и в настоящее время имеет статус бета-версии."
"Включите «Optimized Image-to-Image», чтобы использовать более плавную шкалу Denoise Strength для преобразований image-to-image и инпейнтинга с моделями Flux. Эта настройка улучшает контроль над степенью изменений изображения, однако её можно отключить, если вы предпочитаете стандартную шкалу Denoise Strength. Функция находится в стадии настройки и имеет статус бета-версии."
]
},
"paramGuidance": {
"paragraphs": [
"Контролирует, насколько сильно запрос влияет на процесс генерации.",
"Определяет, насколько сильно промпт влияет на процесс генерации.",
"Высокие значения точности могут привести к перенасыщению, а высокие или низкие значения точности могут привести к искажению результатов генерации. Точность применима только к моделям FLUX DEV."
],
"heading": "Точность"
@@ -1363,7 +1398,7 @@
"parameterSet": "Параметр {{parameter}} установлен",
"allPrompts": "Все запросы",
"imageDimensions": "Размеры изображения",
"canvasV2Metadata": "Холст",
"canvasV2Metadata": "Слои холста",
"guidance": "Точность"
},
"queue": {
@@ -1393,7 +1428,7 @@
"graphQueued": "График поставлен в очередь",
"queue": "Очередь",
"batch": "Пакет",
"clearQueueAlertDialog": "Очистка очереди немедленно отменяет все элементы обработки и полностью очищает очередь. Ожидающие фильтры будут отменены.",
"clearQueueAlertDialog": "Очистка очереди немедленно отменит все текущие задачи и очистит очередь. Ожидающие фильтры будут отменены, а область предпросмотра на холсте сброшена.",
"pending": "В ожидании",
"completedIn": "Завершено за",
"resumeFailed": "Проблема с возобновлением рендеринга",
@@ -1477,7 +1512,7 @@
"workflowEditorMenu": "Меню редактора рабочего процесса",
"workflowName": "Имя рабочего процесса",
"saveWorkflow": "Сохранить рабочий процесс",
"workflowLibrary": "Библиотека",
"workflowLibrary": "Библиотека схем генерации",
"downloadWorkflow": "Сохранить в файл",
"workflowSaved": "Рабочий процесс сохранен",
"unnamedWorkflow": "Безымянный рабочий процесс",
@@ -1560,7 +1595,7 @@
"autoNegative": "Авто негатив",
"rectangle": "Прямоугольник",
"addNegativePrompt": "Добавить $t(controlLayers.negativePrompt)",
"regionalGuidance": "Региональная точность",
"regionalGuidance": "Региональное влияние",
"opacity": "Непрозрачность",
"addLayer": "Добавить слой",
"moveToFront": "На передний план",
@@ -1568,33 +1603,33 @@
"regional": "Региональный",
"bookmark": "Закладка для быстрого переключения",
"fitBboxToLayers": "Подогнать рамку к слоям",
"mergeVisibleOk": "Объединенные видимые слои",
"mergeVisibleError": "Ошибка объединения видимых слоев",
"mergeVisibleOk": "Объединенные слои",
"mergeVisibleError": "Ошибка объединения слоев",
"clearHistory": "Очистить историю",
"mergeVisible": "Объединить видимые",
"removeBookmark": "Удалить закладку",
"saveLayerToAssets": "Сохранить слой в активы",
"saveLayerToAssets": "Сохранить слой в ресурсы",
"clearCaches": "Очистить кэши",
"recalculateRects": "Пересчитать прямоугольники",
"saveBboxToGallery": "Сохранить рамку в галерею",
"saveBboxToGallery": "Сохранить область в галерею",
"canvas": "Холст",
"global": "Глобальный",
"newGlobalReferenceImageError": "Проблема с созданием глобального эталонного изображения",
"newRegionalReferenceImageOk": "Создано региональное эталонное изображение",
"newRegionalReferenceImageError": "Проблема создания регионального эталонного изображения",
"newGlobalReferenceImageError": "Проблема с созданием глобального референсного изображения",
"newRegionalReferenceImageOk": "Создано региональное референсное изображение",
"newRegionalReferenceImageError": "Проблема создания регионального референсного изображения",
"newControlLayerOk": "Создан слой управления",
"newControlLayerError": "Ошибка создания слоя управления",
"newRasterLayerOk": "Создан растровый слой",
"newRasterLayerError": "Ошибка создания растрового слоя",
"newGlobalReferenceImageOk": "Создано глобальное эталонное изображение",
"bboxOverlay": "Показать наложение ограничительной рамки",
"newGlobalReferenceImageOk": "Создано глобальное референсное изображение",
"bboxOverlay": "Показать наложение рамки",
"saveCanvasToGallery": "Сохранить холст в галерею",
"pullBboxIntoReferenceImageOk": "рамка перенесена в эталонное изображение",
"pullBboxIntoReferenceImageError": "Ошибка переноса рамки в эталонное изображение",
"pullBboxIntoReferenceImageOk": "Рамка перенесена в референсное изображение",
"pullBboxIntoReferenceImageError": "Ошибка переноса рамки в референсное изображение",
"regionIsEmpty": "Выбранный регион пуст",
"savedToGalleryOk": "Сохранено в галерею",
"savedToGalleryError": "Ошибка сохранения в галерею",
"pullBboxIntoLayerOk": "Рамка перенесена в слой",
"pullBboxIntoLayerOk": "Содержимое рамки перенесено в слой",
"pullBboxIntoLayerError": "Проблема с переносом рамки в слой",
"newLayerFromImage": "Новый слой из изображения",
"filter": {
@@ -1693,11 +1728,12 @@
"isTransforming": "{{title}} трансформируется"
},
"scaledBbox": "Масштабированная рамка",
"bbox": "Ограничительная рамка"
"bbox": "Ограничительная рамка",
"textSessionActive": "Активен режим ввода"
},
"canvasContextMenu": {
"saveBboxToGallery": "Сохранить рамку в галерею",
"newGlobalReferenceImage": "Новое глобальное эталонное изображение",
"newGlobalReferenceImage": "Новое глобальное референсное изображение",
"bboxGroup": "Сохдать из рамки",
"canvasGroup": "Холст",
"newControlLayer": "Новый контрольный слой",
@@ -1709,8 +1745,8 @@
},
"fill": {
"solid": "Сплошной",
"fillStyle": "Стиль заполнения",
"fillColor": "Цвет заполнения",
"fillStyle": "Стиль заливки",
"fillColor": "Цвет заливкии",
"grid": "Сетка",
"horizontal": "Горизонтальная",
"diagonal": "Диагональная",
@@ -1729,8 +1765,8 @@
"inpaintMask": "Маска перерисовки",
"sendToCanvas": "Отправить на холст",
"regionalGuidance_withCount_one": "$t(controlLayers.regionalGuidance)",
"regionalGuidance_withCount_few": "Региональных точности",
"regionalGuidance_withCount_many": "Региональных точностей",
"regionalGuidance_withCount_few": "Региональных влияния",
"regionalGuidance_withCount_many": "Региональных влияний",
"controlLayer_withCount_one": "$t(controlLayers.controlLayer)",
"controlLayer_withCount_few": "Контрольных слоя",
"controlLayer_withCount_many": "Контрольных слоев",
@@ -1739,9 +1775,9 @@
"inpaintMask_withCount_few": "Маски перерисовки",
"inpaintMask_withCount_many": "Масок перерисовки",
"controlMode": {
"prompt": "Запрос",
"prompt": "Промпт",
"controlMode": "Режим контроля",
"megaControl": "Мега контроль",
"megaControl": "Максимальный контроль",
"balanced": "Сбалансированный",
"control": "Контроль"
},
@@ -1770,24 +1806,25 @@
"showResultsOn": "Показать результаты",
"showResultsOff": "Скрыть результаты"
},
"pullBboxIntoReferenceImage": оместить рамку в эталонное изображение",
"pullBboxIntoReferenceImage": реобразовать рамку в референсное изображение",
"enableAutoNegative": "Включить авто негатив",
"maskFill": "Заполнение маски",
"maskFill": "Заливка маски",
"tool": {
"move": "Двигать",
"move": "Перемещение",
"bbox": "Ограничительная рамка",
"view": "Смотреть",
"view": "Перемещение холста",
"brush": "Кисть",
"eraser": "Ластик",
"rectangle": "Прямоугольник",
"colorPicker": одборщик цветов"
"colorPicker": ипетка",
"text": "Текст"
},
"rasterLayer": "Растровый слой",
"enableTransparencyEffect": "Включить эффект прозрачности",
"hidingType": "Скрыть {{type}}",
"addRegionalGuidance": "Добавить $t(controlLayers.regionalGuidance)",
"deleteSelected": "Удалить выбранное",
"pullBboxIntoLayer": оместить рамку в слой",
"pullBboxIntoLayer": реобразовать рамку в слой",
"locked": "Заблокировано",
"replaceLayer": "Заменить слой",
"width": "Ширина",
@@ -1795,15 +1832,15 @@
"addRasterLayer": "Добавить $t(controlLayers.rasterLayer)",
"addControlLayer": "Добавить $t(controlLayers.controlLayer)",
"addInpaintMask": "Добавить $t(controlLayers.inpaintMask)",
"cropLayerToBbox": "Обрезать слой по ограничительной рамке",
"clipToBbox": "Обрезка штрихов в рамке",
"outputOnlyMaskedRegions": "Вывод только маскированных областей",
"cropLayerToBbox": "Обрезать слой по рамке",
"clipToBbox": "Ограничить мазки рамкой",
"outputOnlyMaskedRegions": "Выводить только сгенерированные области",
"duplicate": "Дублировать",
"layer_one": "Слой",
"layer_few": "Слоя",
"layer_many": "Слоев",
"prompt": "Запрос",
"negativePrompt": "Исключающий запрос",
"prompt": "Промпт",
"negativePrompt": "Негативный промпт",
"beginEndStepPercentShort": "Начало/конец %",
"transform": {
"transform": "Трансформировать",
@@ -1816,7 +1853,7 @@
"fitModeFill": "Заполнить"
},
"disableAutoNegative": "Отключить авто негатив",
"deleteReferenceImage": "Удалить эталонное изображение",
"deleteReferenceImage": "Удалить референсное изображение",
"rasterLayer_withCount_one": "$t(controlLayers.rasterLayer)",
"rasterLayer_withCount_few": "Растровых слоя",
"rasterLayer_withCount_many": "Растровых слоев",
@@ -1828,9 +1865,42 @@
"logDebugInfo": "Писать отладочную информацию",
"unlocked": "Разблокировано",
"showProgressOnCanvas": "Показать прогресс на холсте",
"regionalReferenceImage": "Региональное эталонное изображение",
"globalReferenceImage": "Глобальное эталонное изображение",
"referenceImage": "Эталонное изображение"
"regionalReferenceImage": "Региональное референсное изображение",
"globalReferenceImage": "Глобальное референсное изображение",
"referenceImage": "Референсное изображение",
"text": {
"px": "px",
"alignRight": "По правому краю",
"alignCenter": "По центру",
"alignLeft": "По левому краю",
"strikethrough": "Зачёркнутый",
"italic": "Курсив",
"bold": "Полужирный",
"size": "Размер",
"font": "Шрифт"
},
"newImg2ImgCanvasFromImage": "Новое изображение из Img2Img",
"sendToCanvasDesc": "При нажатии Invoke результат появляется на холсте в режиме предпросмотра.",
"compositeOperation": {
"blendModes": {
"darken": "Затемнение",
"multiply": "Умножение",
"color-dodge": "Осветление основы",
"color-burn": "Затемнение основы",
"screen": "Экран",
"hard-light": "Жёсткий свет",
"soft-light": "Мягкий свет",
"overlay": "Перекрытие",
"hue": "Тон",
"color": "Цвет",
"source-over": "Обычный"
}
},
"globalReferenceImage_withCount_one": "$t(controlLayers.globalReferenceImage)",
"globalReferenceImage_withCount_few": "Глобальных референсных изображения",
"globalReferenceImage_withCount_many": "Глобальных референсных изображений",
"regionalGuidance_withCount_hidden": "Региональное влияние (скрыто: {{count}})",
"controlLayers_withCount_hidden": "Слои управления (скрыто: {{count}})"
},
"ui": {
"tabs": {

View File

@@ -1,13 +1,21 @@
import { Box } from '@invoke-ai/ui-library';
import { Box, Center, Spinner } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { clearStorage } from 'app/store/enhancers/reduxRemember/driver';
import Loading from 'common/components/Loading/Loading';
import { AdministratorSetup } from 'features/auth/components/AdministratorSetup';
import { LoginPage } from 'features/auth/components/LoginPage';
import { ProtectedRoute } from 'features/auth/components/ProtectedRoute';
import { UserManagement } from 'features/auth/components/UserManagement';
import { UserProfile } from 'features/auth/components/UserProfile';
import { AppContent } from 'features/ui/components/AppContent';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { memo } from 'react';
import type { ReactNode } from 'react';
import { memo, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import { Route, Routes, useNavigate } from 'react-router-dom';
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import ThemeLocaleProvider from './ThemeLocaleProvider';
@@ -18,14 +26,94 @@ const errorBoundaryOnReset = () => {
return false;
};
const App = () => {
const MainApp = () => {
const isNavigationAPIConnected = useStore(navigationApi.$isConnected);
return (
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
{isNavigationAPIConnected ? <AppContent /> : <Loading />}
</Box>
);
};
const SetupChecker = () => {
const { data, isLoading } = useGetSetupStatusQuery();
const navigate = useNavigate();
// Check if user is already authenticated
const token = localStorage.getItem('auth_token');
const isAuthenticated = !!token;
useEffect(() => {
if (!isLoading && data) {
// If multiuser mode is disabled, go directly to the app
if (!data.multiuser_enabled) {
navigate('/app', { replace: true });
} else if (isAuthenticated) {
// In multiuser mode, check authentication
navigate('/app', { replace: true });
} else if (data.setup_required) {
navigate('/setup', { replace: true });
} else {
navigate('/login', { replace: true });
}
}
}, [data, isLoading, navigate, isAuthenticated]);
if (isLoading) {
return (
<Center w="100dvw" h="100dvh">
<Spinner size="xl" />
</Center>
);
}
return null;
};
/** Full-page wrapper for user management / profile pages rendered inside the protected area */
const FullPageWrapper = ({ children }: { children: ReactNode }) => (
<Box w="100dvw" h="100dvh" overflowY="auto" bg="base.900">
{children}
</Box>
);
const App = () => {
return (
<ThemeLocaleProvider>
<ErrorBoundary onReset={errorBoundaryOnReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
{isNavigationAPIConnected ? <AppContent /> : <Loading />}
</Box>
<Routes>
<Route path="/" element={<SetupChecker />} />
<Route path="/login" element={<LoginPage />} />
<Route path="/setup" element={<AdministratorSetup />} />
<Route
path="/profile"
element={
<ProtectedRoute>
<FullPageWrapper>
<UserProfile />
</FullPageWrapper>
</ProtectedRoute>
}
/>
<Route
path="/admin/users"
element={
<ProtectedRoute requireAdmin>
<FullPageWrapper>
<UserManagement />
</FullPageWrapper>
</ProtectedRoute>
}
/>
<Route
path="/*"
element={
<ProtectedRoute>
<MainApp />
</ProtectedRoute>
}
/>
</Routes>
<GlobalHookIsolator />
<GlobalModalIsolator />
</ErrorBoundary>

View File

@@ -7,6 +7,7 @@ import { createStore } from 'app/store/store';
import Loading from 'common/components/Loading/Loading';
import React, { lazy, memo, useEffect, useState } from 'react';
import { Provider } from 'react-redux';
import { BrowserRouter } from 'react-router-dom';
/*
* We need to configure logging before anything else happens - useLayoutEffect ensures we set this at the first
@@ -51,9 +52,11 @@ const InvokeAIUI = () => {
return (
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App />
</React.Suspense>
<BrowserRouter>
<React.Suspense fallback={<Loading />}>
<App />
</React.Suspense>
</BrowserRouter>
</Provider>
</React.StrictMode>
);

View File

@@ -68,10 +68,26 @@ const getIdbKey = (key: string) => {
return `${IDB_STORAGE_PREFIX}${key}`;
};
// Helper to get auth headers for client_state requests
const getAuthHeaders = (): Record<string, string> => {
const headers: Record<string, string> = {};
// Safe access to localStorage (not available in Node.js test environment)
if (typeof window !== 'undefined' && window.localStorage) {
const token = localStorage.getItem('auth_token');
if (token) {
headers['Authorization'] = `Bearer ${token}`;
}
}
return headers;
};
const getItem = async (key: string) => {
try {
const url = getUrl('get_by_key', key);
const res = await fetch(url, { method: 'GET' });
const res = await fetch(url, {
method: 'GET',
headers: getAuthHeaders(),
});
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
@@ -130,7 +146,11 @@ const setItem = async (key: string, value: string) => {
}
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
const url = getUrl('set_by_key', key);
const res = await fetch(url, { method: 'POST', body: value });
const res = await fetch(url, {
method: 'POST',
body: value,
headers: getAuthHeaders(),
});
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
@@ -158,7 +178,10 @@ export const clearStorage = async () => {
try {
persistRefCount++;
const url = getUrl('delete');
const res = await fetch(url, { method: 'POST' });
const res = await fetch(url, {
method: 'POST',
headers: getAuthHeaders(),
});
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}

View File

@@ -12,27 +12,12 @@ export const appStarted = createAction('app/appStarted');
export const addAppStartedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: appStarted,
effect: (action, { unsubscribe, cancelActiveListeners, take, getState, dispatch }) => {
effect: async (action, { unsubscribe, cancelActiveListeners, take, getState, dispatch }) => {
// this should only run once
cancelActiveListeners();
unsubscribe();
// ensure an image is selected when we load the first board
take(imagesApi.endpoints.getImageNames.matchFulfilled).then((firstImageLoad) => {
if (firstImageLoad === null) {
// timeout or cancelled
return;
}
const [{ payload }] = firstImageLoad;
const selectedImage = selectLastSelectedItem(getState());
if (selectedImage) {
return;
}
if (payload.image_names[0]) {
dispatch(imageSelected(payload.image_names[0]));
}
});
// Fire patchmatch check without blocking the image-selection logic below
dispatch(appInfoApi.endpoints.getPatchmatchStatus.initiate())
.unwrap()
.then((isPatchmatchAvailable) => {
@@ -43,6 +28,24 @@ export const addAppStartedListener = (startAppListening: AppStartListening) => {
}
})
.catch(noop);
// ensure an image is selected when we load the first board.
// The effect must be async and await take() so that RTK keeps the listener's AbortController
// alive until the query resolves; a synchronous effect causes the controller to be aborted
// immediately after the effect returns, before any network response arrives.
const firstImageLoad = await take(imagesApi.endpoints.getImageNames.matchFulfilled, 5000);
if (firstImageLoad === null) {
// timeout or cancelled
return;
}
const [{ payload }] = firstImageLoad;
const selectedImage = selectLastSelectedItem(getState());
if (selectedImage) {
return;
}
if (payload.image_names[0]) {
dispatch(imageSelected(payload.image_names[0]));
}
},
});
};

View File

@@ -19,6 +19,7 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi
import { deepClone } from 'common/util/deepClone';
import { merge } from 'es-toolkit';
import { omit, pick } from 'es-toolkit/compat';
import { authSliceConfig } from 'features/auth/store/authSlice';
import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/slice';
import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
@@ -61,6 +62,7 @@ const log = logger('system');
// When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
const SLICE_CONFIGS = {
[authSliceConfig.slice.reducerPath]: authSliceConfig,
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
[canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig,
@@ -87,6 +89,7 @@ const SLICE_CONFIGS = {
// Remember to wrap undoable reducers in `undoable()`!
const ALL_REDUCERS = {
[api.reducerPath]: api.reducer,
[authSliceConfig.slice.reducerPath]: authSliceConfig.slice.reducer,
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig.slice.reducer,
[canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig.slice.reducer,

View File

@@ -78,6 +78,44 @@ describe('promptAST', () => {
{ type: 'rembed', start: 15, end: 16 },
]);
});
it('should tokenize prompt function syntax', () => {
const tokens = tokenize("('a', 'b').and()");
expect(tokens).toEqual([
{ type: 'lparen', start: 0, end: 1 },
{ type: 'punct', value: "'", start: 1, end: 2 },
{ type: 'word', value: 'a', start: 2, end: 3 },
{ type: 'punct', value: "'", start: 3, end: 4 },
{ type: 'punct', value: ',', start: 4, end: 5 },
{ type: 'whitespace', value: ' ', start: 5, end: 6 },
{ type: 'punct', value: "'", start: 6, end: 7 },
{ type: 'word', value: 'b', start: 7, end: 8 },
{ type: 'punct', value: "'", start: 8, end: 9 },
{ type: 'rparen', start: 9, end: 10 },
{ type: 'punct', value: '.', start: 10, end: 11 },
{ type: 'word', value: 'and', start: 11, end: 14 },
{ type: 'lparen', start: 14, end: 15 },
{ type: 'rparen', start: 15, end: 16 },
]);
});
it('should tokenize curly/smart quotes as punctuation', () => {
const tokens = tokenize('\u201chello\u201d');
expect(tokens).toEqual([
{ type: 'punct', value: '\u201c', start: 0, end: 1 },
{ type: 'word', value: 'hello', start: 1, end: 6 },
{ type: 'punct', value: '\u201d', start: 6, end: 7 },
]);
});
it('should tokenize curly single quotes as punctuation', () => {
const tokens = tokenize('\u2018hello\u2019');
expect(tokens).toEqual([
{ type: 'punct', value: '\u2018', start: 0, end: 1 },
{ type: 'word', value: 'hello', start: 1, end: 6 },
{ type: 'punct', value: '\u2019', start: 6, end: 7 },
]);
});
});
describe('parseTokens', () => {
@@ -167,6 +205,312 @@ describe('promptAST', () => {
const ast = parseTokens(tokens);
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name', range: { start: 0, end: 16 } }]);
});
describe('prompt functions', () => {
it('should parse .and() prompt function with single-quoted args', () => {
const tokens = tokenize("('one two', 'three four').and()");
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('and');
expect(pf.functionParams).toBe('');
expect(pf.promptArgs).toHaveLength(2);
// First arg: 'one two'
expect(pf.promptArgs[0]!.quote).toBe("'");
expect(pf.promptArgs[0]!.nodes).toHaveLength(3); // word, ws, word
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
expect(pf.promptArgs[0]!.nodes[2]).toMatchObject({ type: 'word', text: 'two' });
// Second arg: 'three four'
expect(pf.promptArgs[1]!.quote).toBe("'");
expect(pf.promptArgs[1]!.nodes).toHaveLength(3);
expect(pf.promptArgs[1]!.nodes[0]).toMatchObject({ type: 'word', text: 'three' });
expect(pf.promptArgs[1]!.nodes[2]).toMatchObject({ type: 'word', text: 'four' });
});
it('should parse .or() prompt function', () => {
const tokens = tokenize("('one', 'two three. four.').or()");
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('or');
expect(pf.promptArgs).toHaveLength(2);
// First arg: 'one'
expect(pf.promptArgs[0]!.nodes).toHaveLength(1);
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
// Second arg: 'two three. four.'
expect(pf.promptArgs[1]!.nodes.length).toBeGreaterThanOrEqual(5);
});
it('should parse .blend() prompt function with params', () => {
const tokens = tokenize("('one', 'two').blend(0.7, 0.3)");
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('blend');
expect(pf.functionParams).toBe('0.7, 0.3');
expect(pf.promptArgs).toHaveLength(2);
});
it('should parse prompt function with double-quoted args', () => {
const tokens = tokenize('("one", "two").and()');
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('and');
expect(pf.promptArgs[0]!.quote).toBe('"');
});
it('should parse prompt function with curly double quotes', () => {
const tokens = tokenize('(\u201cone\u201d, \u201ctwo\u201d).and()');
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('and');
expect(pf.promptArgs).toHaveLength(2);
expect(pf.promptArgs[0]!.quote).toBe('\u201c');
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
expect(pf.promptArgs[1]!.nodes[0]).toMatchObject({ type: 'word', text: 'two' });
});
it('should parse prompt function with curly single quotes', () => {
const tokens = tokenize('(\u2018one\u2019, \u2018two\u2019).and()');
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('and');
expect(pf.promptArgs[0]!.quote).toBe('\u2018');
});
it('should parse prompt function with curly quotes containing commas in args', () => {
const prompt = '(\u201chigh detail, cinematic\u201d, \u201csoft light, portrait\u201d).and()';
const ast = parseTokens(tokenize(prompt));
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.promptArgs).toHaveLength(2);
});
it('should parse prompt function with newline before .method()', () => {
const prompt = '(\u201cone\u201d, \u201ctwo\u201d)\n.and()';
const ast = parseTokens(tokenize(prompt));
expect(ast).toHaveLength(1);
expect(ast[0]!.type).toBe('prompt_function');
});
it('should parse quoted prompt function with newline before .method()', () => {
const prompt = "('one', 'two')\n.and()";
const ast = parseTokens(tokenize(prompt));
expect(ast).toHaveLength(1);
expect(ast[0]!.type).toBe('prompt_function');
});
it('should parse prompt function with attention inside args', () => {
const tokens = tokenize("('hello+', '(world)-').and()");
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
// First arg: hello+
const arg0Word = pf.promptArgs[0]!.nodes[0]!;
expect(arg0Word).toMatchObject({ type: 'word', text: 'hello', attention: '+' });
// Second arg: (world)-
const arg1Group = pf.promptArgs[1]!.nodes[0]!;
expect(arg1Group.type).toBe('group');
if (arg1Group.type === 'group') {
expect(arg1Group.attention).toBe('-');
}
});
it('should preserve content range for each arg', () => {
const tokens = tokenize("('one two', 'three four').and()");
const ast = parseTokens(tokens);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
// 'one two' content is between quotes at positions 1 and 9
expect(pf.promptArgs[0]!.contentRange.start).toBe(2);
expect(pf.promptArgs[0]!.contentRange.end).toBe(9);
// 'three four' content is between quotes at positions 12 and 23
expect(pf.promptArgs[1]!.contentRange.start).toBe(13);
expect(pf.promptArgs[1]!.contentRange.end).toBe(23);
});
it('should parse prompt function embedded in larger prompt', () => {
const tokens = tokenize("some text, ('a', 'b').and(), more text");
const ast = parseTokens(tokens);
// Should have: word, ws, word, punct, ws, prompt_function, punct, ws, word, ws, word
const pfNodes = ast.filter((n) => n.type === 'prompt_function');
expect(pfNodes).toHaveLength(1);
expect(pfNodes[0]!.type).toBe('prompt_function');
});
it('should fall back to regular group when no method call follows', () => {
const tokens = tokenize("('a', 'b')");
const ast = parseTokens(tokens);
// Without .method(), this should be parsed as a regular group
expect(ast[0]!.type).toBe('group');
});
it('should parse three-arg prompt function', () => {
const tokens = tokenize("('a', 'b', 'c').blend(0.5, 0.3, 0.2)");
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.promptArgs).toHaveLength(3);
expect(pf.functionParams).toBe('0.5, 0.3, 0.2');
});
});
describe('unquoted prompt functions', () => {
it('should parse unquoted .and() prompt function', () => {
const tokens = tokenize('(one,two).and()');
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('and');
expect(pf.functionParams).toBe('');
expect(pf.promptArgs).toHaveLength(2);
expect(pf.promptArgs[0]!.quote).toBe('');
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
expect(pf.promptArgs[1]!.quote).toBe('');
expect(pf.promptArgs[1]!.nodes[0]).toMatchObject({ type: 'word', text: 'two' });
});
it('should parse unquoted .and() with spaces', () => {
const tokens = tokenize('(one two, three four).and()');
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('and');
expect(pf.promptArgs).toHaveLength(2);
expect(pf.promptArgs[0]!.nodes[0]).toMatchObject({ type: 'word', text: 'one' });
expect(pf.promptArgs[0]!.nodes[2]).toMatchObject({ type: 'word', text: 'two' });
expect(pf.promptArgs[1]!.nodes[0]).toMatchObject({ type: 'word', text: 'three' });
expect(pf.promptArgs[1]!.nodes[2]).toMatchObject({ type: 'word', text: 'four' });
});
it('should parse unquoted .blend() with params', () => {
const tokens = tokenize('(one two, three four).blend(0.7, 0.3)');
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.name).toBe('blend');
expect(pf.functionParams).toBe('0.7, 0.3');
expect(pf.promptArgs).toHaveLength(2);
});
it('should parse unquoted three-arg prompt function', () => {
const tokens = tokenize('(a, b, c).blend(0.5, 0.3, 0.2)');
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
expect(pf.promptArgs).toHaveLength(3);
expect(pf.functionParams).toBe('0.5, 0.3, 0.2');
});
it('should parse unquoted prompt function with attention inside args', () => {
const tokens = tokenize('(hello+, world).and()');
const ast = parseTokens(tokens);
expect(ast).toHaveLength(1);
const pf = ast[0]!;
expect(pf.type).toBe('prompt_function');
if (pf.type !== 'prompt_function') {
return;
}
const arg0Word = pf.promptArgs[0]!.nodes[0]!;
expect(arg0Word).toMatchObject({ type: 'word', text: 'hello', attention: '+' });
});
it('should fall back to regular group for single-arg unquoted function', () => {
const tokens = tokenize('(hello world).and()');
const ast = parseTokens(tokens);
// Without a comma, this is not detected as a prompt function
expect(ast[0]!.type).toBe('group');
});
it('should parse unquoted prompt function embedded in larger prompt', () => {
const tokens = tokenize('some text, (a, b).and(), more text');
const ast = parseTokens(tokens);
const pfNodes = ast.filter((n) => n.type === 'prompt_function');
expect(pfNodes).toHaveLength(1);
});
});
});
describe('serialize', () => {
@@ -218,6 +562,163 @@ describe('promptAST', () => {
const result = serialize(ast);
expect(result).toBe('<embedding_name>');
});
describe('prompt functions', () => {
it('should serialize .and() prompt function', () => {
const tokens = tokenize("('one two', 'three four').and()");
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe("('one two', 'three four').and()");
});
it('should serialize .or() prompt function', () => {
const tokens = tokenize("('one', 'two three. four.').or()");
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe("('one', 'two three. four.').or()");
});
it('should serialize .blend() with params', () => {
const tokens = tokenize("('one', 'two').blend(0.7, 0.3)");
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe("('one', 'two').blend(0.7, 0.3)");
});
it('should serialize prompt function with attention inside args', () => {
const tokens = tokenize("('hello+', '(world)-').and()");
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe("('hello+', '(world)-').and()");
});
it('should serialize prompt function embedded in larger prompt', () => {
const prompt = "some text, ('a', 'b').and(), more text";
const tokens = tokenize(prompt);
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe(prompt);
});
it('should serialize three-arg blend', () => {
const tokens = tokenize("('a', 'b', 'c').blend(0.5, 0.3, 0.2)");
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe("('a', 'b', 'c').blend(0.5, 0.3, 0.2)");
});
it('should serialize double-quoted prompt function', () => {
const tokens = tokenize('("one", "two").and()');
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe('("one", "two").and()');
});
it('should serialize curly double-quoted prompt function', () => {
const tokens = tokenize('(\u201cone\u201d, \u201ctwo\u201d).and()');
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe('(\u201cone\u201d, \u201ctwo\u201d).and()');
});
it('should serialize curly single-quoted prompt function', () => {
const tokens = tokenize('(\u2018one\u2019, \u2018two\u2019).and()');
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe('(\u2018one\u2019, \u2018two\u2019).and()');
});
});
describe('unquoted prompt functions', () => {
it('should serialize unquoted .and()', () => {
const tokens = tokenize('(one two, three four).and()');
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe('(one two, three four).and()');
});
it('should serialize unquoted .blend() with params', () => {
const tokens = tokenize('(one two, three four).blend(0.7, 0.3)');
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe('(one two, three four).blend(0.7, 0.3)');
});
it('should serialize unquoted prompt function embedded in larger prompt', () => {
const prompt = 'some text, (a, b).and(), more text';
const tokens = tokenize(prompt);
const ast = parseTokens(tokens);
const result = serialize(ast);
expect(result).toBe(prompt);
});
});
});
describe('round-trip (tokenize → parse → serialize)', () => {
const roundTrip = (prompt: string) => {
const tokens = tokenize(prompt);
const ast = parseTokens(tokens);
return serialize(ast);
};
it.each([
'a cat',
'(a cat)',
'(a cat)1.2',
'cat+',
'cat++',
'cat-',
'(hello world)+',
'(hello world)++',
'(hello world)-',
'\\(medium\\)',
'colored pencil \\(medium\\) (enhanced)',
'<embedding_name>',
'portrait \\(realistic\\) (high quality)1.2',
'(masterpiece)1.3, best quality, (high detail)1.2',
"('one two', 'three four').and()",
"('one', 'two three. four.').or()",
"('one', 'two').blend(0.7, 0.3)",
"('hello+', '(world)-').and()",
"some text, ('a', 'b').and(), more text",
"('a', 'b', 'c').blend(0.5, 0.3, 0.2)",
'("one", "two").and()',
// Curly double-quoted prompt functions
'(\u201cone\u201d, \u201ctwo\u201d).and()',
'(\u201chigh detail, cinematic\u201d, \u201csoft light, portrait\u201d).and()',
'(\u201cone\u201d, \u201ctwo\u201d).blend(0.7, 0.3)',
// Curly single-quoted prompt functions
'(\u2018one\u2019, \u2018two\u2019).and()',
'(\u2018one\u2019, \u2018two\u2019).or()',
// Unquoted prompt functions
'(one two, three four).and()',
'(one two, three four).blend(0.7, 0.3)',
'(a, b, c).blend(0.5, 0.3, 0.2)',
'some text, (a, b).and(), more text',
"('one',\n 'two',\n 'three').and()",
])('should round-trip: %s', (prompt) => {
expect(roundTrip(prompt)).toBe(prompt);
});
});
describe('newline normalization', () => {
const roundTrip = (prompt: string) => {
const tokens = tokenize(prompt);
const ast = parseTokens(tokens);
return serialize(ast);
};
it('should normalize newline before .method() in quoted prompt function', () => {
expect(roundTrip("('one', 'two')\n.and()")).toBe("('one', 'two').and()");
});
it('should normalize newline before .method() in curly-quoted prompt function', () => {
expect(roundTrip('(\u201cone\u201d, \u201ctwo\u201d)\n.and()')).toBe('(\u201cone\u201d, \u201ctwo\u201d).and()');
});
it('should normalize newline before .method() in unquoted prompt function', () => {
expect(roundTrip('(one, two)\n.and()')).toBe('(one, two).and()');
});
});
describe('compel compatibility examples', () => {

View File

@@ -3,18 +3,10 @@
*/
export type Attention = string | number;
type Word = string;
type Punct = string;
type Whitespace = string;
type Embedding = string;
type Token =
| { type: 'word'; value: Word; start: number; end: number }
| { type: 'whitespace'; value: Whitespace; start: number; end: number }
| { type: 'punct'; value: Punct; start: number; end: number }
| { type: 'word'; value: string; start: number; end: number }
| { type: 'whitespace'; value: string; start: number; end: number }
| { type: 'punct'; value: string; start: number; end: number }
| { type: 'lparen'; start: number; end: number }
| { type: 'rparen'; start: number; end: number }
| { type: 'weight'; value: Attention; start: number; end: number }
@@ -22,8 +14,21 @@ type Token =
| { type: 'rembed'; start: number; end: number }
| { type: 'escaped_paren'; value: '(' | ')'; start: number; end: number };
/**
* A single argument in a prompt function like .and(), .or(), or .blend().
* Contains the parsed AST nodes of the argument content and metadata about quoting/range.
*/
export type PromptFunctionArg = {
nodes: ASTNode[];
quote: string;
/** Range of the content between the quotes (exclusive of quotes themselves) in original prompt coordinates. */
contentRange: { start: number; end: number };
/** Raw separator whitespace after the comma before this arg (args[1+] only). */
separator?: string;
};
export type ASTNode =
| { type: 'word'; text: Word; attention?: Attention; range: { start: number; end: number }; isSelection?: boolean }
| { type: 'word'; text: string; attention?: Attention; range: { start: number; end: number }; isSelection?: boolean }
| {
type: 'group';
children: ASTNode[];
@@ -31,20 +36,60 @@ export type ASTNode =
range: { start: number; end: number };
isSelection?: boolean;
}
| { type: 'embedding'; value: Embedding; range: { start: number; end: number }; isSelection?: boolean }
| { type: 'whitespace'; value: Whitespace; range: { start: number; end: number }; isSelection?: boolean }
| { type: 'punct'; value: Punct; range: { start: number; end: number }; isSelection?: boolean }
| { type: 'escaped_paren'; value: '(' | ')'; range: { start: number; end: number }; isSelection?: boolean };
| { type: 'embedding'; value: string; range: { start: number; end: number }; isSelection?: boolean }
| { type: 'whitespace'; value: string; range: { start: number; end: number }; isSelection?: boolean }
| { type: 'punct'; value: string; range: { start: number; end: number }; isSelection?: boolean }
| { type: 'escaped_paren'; value: '(' | ')'; range: { start: number; end: number }; isSelection?: boolean }
| {
type: 'prompt_function';
name: string;
promptArgs: PromptFunctionArg[];
functionParams: string;
range: { start: number; end: number };
isSelection?: boolean;
};
const WEIGHT_PATTERN = /^[+-]?(\d+(\.\d+)?|[+-]+)/;
const WHITESPACE_PATTERN = /^\s+/;
const PUNCTUATION_PATTERN = /^[.,]/;
const OTHER_PATTERN = /\s/;
const WORD_CHAR_PATTERN = /[a-zA-Z0-9_]/;
// prettier-ignore
const PUNCTUATION_PATTERN = /^[.,/!?;:'"""''\u2018\u2019\u201c\u201d`~@#$%^&*=_|]/;
/** All characters that can serve as an opening quote in a prompt function argument. */
const OPEN_QUOTE_CHARS = new Set(["'", '"', '\u2018', '\u201c']);
/** Map from opening curly quote to the matching closing curly quote. Straight quotes match themselves. */
const CLOSE_QUOTE_MAP: Record<string, string> = {
"'": "'",
'"': '"',
'\u2018': '\u2019', // ' → '
'\u201c': '\u201d', // " → "
};
// #region Token Helpers
/** Get the string value of a token, if it has one. */
function tokenValue(t: Token | undefined): string | undefined {
if (!t) {
return undefined;
}
if ('value' in t) {
return String(t.value);
}
return undefined;
}
/** Check if a token is a punct token with a specific value. */
function isPunctValue(t: Token | undefined, value: string): boolean {
return t?.type === 'punct' && tokenValue(t) === value;
}
// #region Tokenizer
/**
* Convert a prompt string into an AST.
* Convert a prompt string into a token stream.
* @param prompt string
* @returns ASTNode[]
* @returns Token[]
*/
export function tokenize(prompt: string): Token[] {
if (!prompt) {
@@ -52,7 +97,7 @@ export function tokenize(prompt: string): Token[] {
}
const len = prompt.length;
let tokens: Token[] = [];
const tokens: Token[] = [];
let i = 0;
while (i < len) {
@@ -69,7 +114,7 @@ export function tokenize(prompt: string): Token[] {
tokenizeEmbedding(char, i) ||
tokenizeWord(prompt, i) ||
tokenizePunctuation(char, i) ||
tokenizeOther(char, i);
tokenizeFallback(char, i);
if (result) {
if (result.token) {
@@ -168,15 +213,15 @@ function tokenizeWord(prompt: string, i: number): TokenizeResult {
return null;
}
if (/[a-zA-Z0-9_]/.test(char)) {
if (WORD_CHAR_PATTERN.test(char)) {
let j = i;
while (j < prompt.length && /[a-zA-Z0-9_]/.test(prompt[j]!)) {
while (j < prompt.length && WORD_CHAR_PATTERN.test(prompt[j]!)) {
j++;
}
const word = prompt.slice(i, j);
// Check for weight immediately after word (e.g., "Lorem+", "consectetur-")
const weightMatch = prompt.slice(j).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/);
const weightMatch = prompt.slice(j).match(WEIGHT_PATTERN);
if (weightMatch && weightMatch[0]) {
const weightEnd = j + weightMatch[0].length;
return {
@@ -210,17 +255,20 @@ function tokenizeEmbedding(char: string, i: number): TokenizeResult {
return null;
}
function tokenizeOther(char: string, i: number): TokenizeResult {
// Any other single character punctuation
if (OTHER_PATTERN.test(char)) {
return {
token: { type: 'punct', value: char, start: i, end: i + 1 },
nextIndex: i + 1,
};
}
return null;
/**
* Fallback tokenizer for characters not matched by any other tokenizer.
* Emits them as word tokens so they are preserved in the AST rather than silently dropped.
* This handles non-Latin Unicode text (CJK, emoji, etc.) and any other unrecognized characters.
*/
function tokenizeFallback(char: string, i: number): TokenizeResult {
return {
token: { type: 'word', value: char, start: i, end: i + 1 },
nextIndex: i + 1,
};
}
// #region Parser
/**
* Convert tokens into an AST.
* @param tokens Token[]
@@ -233,10 +281,373 @@ export function parseTokens(tokens: Token[]): ASTNode[] {
return tokens[pos];
}
function peekAt(offset: number): Token | undefined {
return tokens[pos + offset];
}
function consume(): Token | undefined {
return tokens[pos++];
}
/**
* Quick lookahead check: does the current lparen (already consumed) start a quoted prompt function?
* A quoted prompt function looks like ('...', '...').method(...)
* We check if the first non-whitespace token after lparen is a quote character.
*/
function isQuotedPromptFunctionAhead(): boolean {
let p = 0;
while (peekAt(p)?.type === 'whitespace') {
p++;
}
const t = peekAt(p);
return t?.type === 'punct' && OPEN_QUOTE_CHARS.has(tokenValue(t)!);
}
/**
* Lookahead check: does the current lparen (already consumed) start an unquoted prompt function?
* An unquoted prompt function looks like (arg1, arg2).method(...) where args are not quoted.
* We scan forward looking for a comma at the same nesting depth, then rparen followed by .word(
*/
function isUnquotedPromptFunctionAhead(): boolean {
let p = 0;
let depth = 0;
let hasComma = false;
// Scan forward through tokens to find the matching rparen
while (peekAt(p)) {
const t = peekAt(p)!;
if (t.type === 'lparen') {
depth++;
} else if (t.type === 'rparen') {
if (depth === 0) {
// Found matching rparen — now check for .methodName( pattern
// (possibly with whitespace between ) and .)
if (!hasComma) {
return false; // No comma means it's just a regular group
}
let next = p + 1;
while (peekAt(next)?.type === 'whitespace') {
next++;
}
return (
isPunctValue(peekAt(next), '.') && peekAt(next + 1)?.type === 'word' && peekAt(next + 2)?.type === 'lparen'
);
}
depth--;
} else if (isPunctValue(t, ',') && depth === 0) {
hasComma = true;
}
p++;
}
return false;
}
/**
* Parse the `.methodName(params)` suffix that follows the closing rparen of a prompt function.
* Assumes whitespace has already been skipped. Returns null and restores pos if the pattern
* doesn't match.
*/
function tryParseMethodTail(savedPos: number): { name: string; functionParams: string; endPos: number } | null {
// Skip whitespace between ) and .methodName (allows newlines)
while (peek()?.type === 'whitespace') {
consume();
}
// Expect .methodName(params)
if (!isPunctValue(peek(), '.')) {
pos = savedPos;
return null;
}
consume(); // consume dot
if (peek()?.type !== 'word') {
pos = savedPos;
return null;
}
const methodName = tokenValue(consume())!;
// Expect opening paren for method call
if (peek()?.type !== 'lparen') {
pos = savedPos;
return null;
}
consume(); // consume method open paren
// Collect method params until closing rparen
let functionParams = '';
while (pos < tokens.length) {
const t = peek()!;
if (t.type === 'rparen') {
break;
}
const tok = consume()!;
const v = tokenValue(tok);
if (v !== undefined) {
functionParams += v;
}
}
// Expect closing rparen for method call
if (peek()?.type !== 'rparen') {
pos = savedPos;
return null;
}
const methodCloseParen = consume()!; // consume method close paren
return { name: methodName, functionParams, endPos: methodCloseParen.end };
}
/**
* Try to parse a prompt function starting after the opening lparen.
* Returns the PromptFunctionNode if successful, or null if the pattern doesn't match
* (in which case `pos` is restored to `savedPos`).
*/
function tryParsePromptFunction(lparenToken: Token & { type: 'lparen' }, savedPos: number): ASTNode | null {
const args: PromptFunctionArg[] = [];
let openQuoteChar: string | null = null;
let closeQuoteChar: string | null = null;
let pendingSeparator: string | undefined;
while (pos < tokens.length) {
// Skip whitespace before arg or closing paren
while (peek()?.type === 'whitespace') {
consume();
}
// Check for rparen (end of prompt function args)
if (peek()?.type === 'rparen') {
break;
}
// Expect comma separator between args
if (args.length > 0) {
if (isPunctValue(peek(), ',')) {
consume();
let sep = '';
while (peek()?.type === 'whitespace') {
const sepToken = consume()!;
const sepValue = tokenValue(sepToken);
if (sepValue !== undefined) {
sep += sepValue;
}
}
pendingSeparator = sep;
} else {
pos = savedPos;
return null;
}
}
// Expect opening quote
const openQuoteTok = peek();
if (!openQuoteTok || openQuoteTok.type !== 'punct') {
pos = savedPos;
return null;
}
const thisOpenQuote = tokenValue(openQuoteTok)!;
if (!OPEN_QUOTE_CHARS.has(thisOpenQuote)) {
pos = savedPos;
return null;
}
const thisCloseQuote = CLOSE_QUOTE_MAP[thisOpenQuote]!;
if (openQuoteChar === null) {
openQuoteChar = thisOpenQuote;
closeQuoteChar = thisCloseQuote;
} else if (thisOpenQuote !== openQuoteChar) {
// Mismatched quote style between args
pos = savedPos;
return null;
}
consume(); // consume opening quote
const contentStart = openQuoteTok.end;
// Collect tokens until closing quote
const argTokens: Token[] = [];
let contentEnd = contentStart;
while (pos < tokens.length) {
const t = peek();
if (isPunctValue(t, closeQuoteChar!)) {
contentEnd = t!.start;
break;
}
const consumed = consume()!;
argTokens.push(consumed);
contentEnd = consumed.end;
}
// Expect closing quote
if (!isPunctValue(peek(), closeQuoteChar!)) {
pos = savedPos;
return null;
}
consume(); // consume closing quote
// Parse sub-tokens as AST
const argNodes = parseTokens(argTokens);
args.push({
nodes: argNodes,
quote: openQuoteChar,
contentRange: { start: contentStart, end: contentEnd },
separator: pendingSeparator,
});
pendingSeparator = undefined;
}
if (args.length === 0) {
pos = savedPos;
return null;
}
// Expect rparen
if (peek()?.type !== 'rparen') {
pos = savedPos;
return null;
}
consume(); // consume rparen
// Parse .methodName(params) suffix
const methodTail = tryParseMethodTail(savedPos);
if (!methodTail) {
return null; // pos already restored by tryParseMethodTail
}
return {
type: 'prompt_function',
name: methodTail.name,
promptArgs: args,
functionParams: methodTail.functionParams,
range: { start: lparenToken.start, end: methodTail.endPos },
};
}
/**
* Try to parse an unquoted prompt function starting after the opening lparen.
* Unquoted prompt functions look like (arg1 words, arg2 words).method(params)
* where arguments are separated by commas without quotes.
* Returns the PromptFunctionNode if successful, or null if the pattern doesn't match
* (in which case `pos` is restored to `savedPos`).
*/
function tryParseUnquotedPromptFunction(lparenToken: Token & { type: 'lparen' }, savedPos: number): ASTNode | null {
const args: PromptFunctionArg[] = [];
let pendingSeparator: string | undefined;
while (pos < tokens.length) {
// Check for rparen (end of prompt function args)
if (peek()?.type === 'rparen') {
break;
}
// Expect comma separator between args (consume the comma)
if (args.length > 0) {
if (isPunctValue(peek(), ',')) {
consume(); // consume comma
let sep = '';
while (peek()?.type === 'whitespace') {
const sepToken = consume()!;
const sepValue = tokenValue(sepToken);
if (sepValue !== undefined) {
sep += sepValue;
}
}
pendingSeparator = sep;
} else {
pos = savedPos;
return null;
}
}
// Collect tokens until comma or rparen (at nesting depth 0)
const argTokens: Token[] = [];
let contentStart: number | null = null;
let contentEnd: number | null = null;
let depth = 0;
while (pos < tokens.length) {
const t = peek()!;
if (t.type === 'lparen') {
depth++;
} else if (t.type === 'rparen') {
if (depth === 0) {
break; // End of all args
}
depth--;
} else if (isPunctValue(t, ',') && depth === 0) {
break; // End of this arg
}
if (contentStart === null) {
contentStart = t.start;
}
const consumed = consume()!;
argTokens.push(consumed);
contentEnd = consumed.end;
}
if (argTokens.length === 0) {
pos = savedPos;
return null;
}
// Trim leading/trailing whitespace tokens from the arg content
let firstNonWs = 0;
while (firstNonWs < argTokens.length && argTokens[firstNonWs]!.type === 'whitespace') {
firstNonWs++;
}
let lastNonWs = argTokens.length - 1;
while (lastNonWs >= 0 && argTokens[lastNonWs]!.type === 'whitespace') {
lastNonWs--;
}
const trimmedArgTokens = argTokens.slice(firstNonWs, lastNonWs + 1);
const trimmedStart = trimmedArgTokens.length > 0 ? trimmedArgTokens[0]!.start : contentStart!;
const trimmedEnd = trimmedArgTokens.length > 0 ? trimmedArgTokens[trimmedArgTokens.length - 1]!.end : contentEnd!;
// Parse sub-tokens as AST
const argNodes = parseTokens(trimmedArgTokens);
args.push({
nodes: argNodes,
quote: '', // Unquoted
contentRange: { start: trimmedStart, end: trimmedEnd },
separator: pendingSeparator,
});
pendingSeparator = undefined;
}
if (args.length < 2) {
// An unquoted prompt function must have at least 2 args (otherwise it's a regular group)
pos = savedPos;
return null;
}
// Expect rparen
if (peek()?.type !== 'rparen') {
pos = savedPos;
return null;
}
consume(); // consume rparen
// Parse .methodName(params) suffix
const methodTail = tryParseMethodTail(savedPos);
if (!methodTail) {
return null; // pos already restored by tryParseMethodTail
}
return {
type: 'prompt_function',
name: methodTail.name,
promptArgs: args,
functionParams: methodTail.functionParams,
range: { start: lparenToken.start, end: methodTail.endPos },
};
}
function parseGroup(): ASTNode[] {
const nodes: ASTNode[] = [];
@@ -254,6 +665,30 @@ export function parseTokens(tokens: Token[]): ASTNode[] {
}
case 'lparen': {
const lparen = consume() as Token & { type: 'lparen' };
// Try to parse as a quoted prompt function first
if (isQuotedPromptFunctionAhead()) {
const savedPos = pos;
const pfResult = tryParsePromptFunction(lparen, savedPos);
if (pfResult) {
nodes.push(pfResult);
break;
}
// pos was restored by tryParsePromptFunction on failure
}
// Try to parse as an unquoted prompt function
if (isUnquotedPromptFunctionAhead()) {
const savedPos = pos;
const pfResult = tryParseUnquotedPromptFunction(lparen, savedPos);
if (pfResult) {
nodes.push(pfResult);
break;
}
// pos was restored by tryParseUnquotedPromptFunction on failure
}
// Regular group parsing
const groupChildren = parseGroup();
let attention: Attention | undefined;
@@ -283,10 +718,10 @@ export function parseTokens(tokens: Token[]): ASTNode[] {
let end = lembed.end;
while (peek() && peek()!.type !== 'rembed') {
const embedToken = consume()!;
embedValue +=
embedToken.type === 'word' || embedToken.type === 'punct' || embedToken.type === 'whitespace'
? embedToken.value
: '';
const v = tokenValue(embedToken);
if (v !== undefined) {
embedValue += v;
}
end = embedToken.end;
}
if (peek()?.type === 'rembed') {
@@ -341,47 +776,131 @@ export function parseTokens(tokens: Token[]): ASTNode[] {
return parseGroup();
}
// #region Serialization
/**
* Visitor callbacks for AST serialization. All callbacks are optional.
* Called during traversal to allow tracking node positions in the output string.
*/
type SerializeVisitor = {
/** Called after a node has been fully serialized, with its start and end positions in the output. */
onNode?: (node: ASTNode, start: number, end: number) => void;
};
/** Mutable buffer used by serializeCore so all recursive calls share the same position tracking. */
type SerializeBuffer = { prompt: string };
/**
* Shared serialization core. Converts an AST back into a prompt string,
* optionally calling visitor hooks for position tracking.
*
* Uses a shared mutable buffer so that node positions reported via
* `visitor.onNode` are always absolute offsets in the final output string,
* even for nodes nested inside groups or prompt function args.
*/
function serializeCore(ast: ASTNode[], visitor: SerializeVisitor | undefined, buf: SerializeBuffer): void {
for (const node of ast) {
const nodeStart = buf.prompt.length;
switch (node.type) {
case 'punct':
case 'whitespace': {
buf.prompt += node.value;
break;
}
case 'escaped_paren': {
buf.prompt += `\\${node.value}`;
break;
}
case 'word': {
buf.prompt += node.text;
if (node.attention) {
buf.prompt += String(node.attention);
}
break;
}
case 'group': {
buf.prompt += '(';
serializeCore(node.children, visitor, buf);
buf.prompt += ')';
if (node.attention) {
buf.prompt += String(node.attention);
}
break;
}
case 'embedding': {
buf.prompt += `<${node.value}>`;
break;
}
case 'prompt_function': {
buf.prompt += '(';
for (let i = 0; i < node.promptArgs.length; i++) {
if (i > 0) {
const sep = node.promptArgs[i]!.separator ?? ' ';
buf.prompt += `,${sep}`;
}
const arg = node.promptArgs[i]!;
buf.prompt += arg.quote;
serializeCore(arg.nodes, visitor, buf);
buf.prompt += CLOSE_QUOTE_MAP[arg.quote] ?? arg.quote;
}
buf.prompt += ').';
buf.prompt += node.name;
buf.prompt += '(';
buf.prompt += node.functionParams;
buf.prompt += ')';
break;
}
}
visitor?.onNode?.(node, nodeStart, buf.prompt.length);
}
}
/**
* Convert an AST back into a prompt string.
* @param ast ASTNode[]
* @returns string
*/
export function serialize(ast: ASTNode[]): string {
let prompt = '';
const buf: SerializeBuffer = { prompt: '' };
serializeCore(ast, undefined, buf);
return buf.prompt;
}
for (const node of ast) {
switch (node.type) {
case 'punct':
case 'whitespace': {
prompt += node.value;
break;
}
case 'escaped_paren': {
prompt += `\\${node.value}`;
break;
}
case 'word': {
prompt += node.text;
if (node.attention) {
prompt += String(node.attention);
/**
* Serialize an AST to a prompt string while simultaneously computing the
* selection range from `isSelection` flags on nodes.
*
* This is more reliable than separate serialize + selection computation because
* the position tracking is guaranteed to match the serialized output.
*/
export function serializeWithSelection(ast: ASTNode[]): {
prompt: string;
selectionStart: number;
selectionEnd: number;
} {
let selStart = Infinity;
let selEnd = -1;
const buf: SerializeBuffer = { prompt: '' };
serializeCore(
ast,
{
onNode(node, start, end) {
if (node.isSelection) {
selStart = Math.min(selStart, start);
selEnd = Math.max(selEnd, end);
}
break;
}
case 'group': {
prompt += '(';
prompt += serialize(node.children);
prompt += ')';
if (node.attention) {
prompt += String(node.attention);
}
break;
}
case 'embedding': {
prompt += `<${node.value}>`;
break;
}
}
},
},
buf
);
if (selStart === Infinity) {
selStart = 0;
selEnd = buf.prompt.length;
}
return prompt;
return { prompt: buf.prompt, selectionStart: selStart, selectionEnd: selEnd };
}

View File

@@ -2,170 +2,706 @@ import { describe, expect, it } from 'vitest';
import { adjustPromptAttention } from './promptAttention';
/**
* Helper: select by substring match within the prompt.
* If `selected` is a string, finds it in the prompt and uses its position.
* If `selected` is a [start, end] tuple, uses those positions directly.
*/
function adj(
prompt: string,
selected: string | [number, number],
direction: 'increment' | 'decrement',
prefersNumericWeights = false
) {
const [start, end] =
typeof selected === 'string' ? [prompt.indexOf(selected), prompt.indexOf(selected) + selected.length] : selected;
return adjustPromptAttention(prompt, start, end, direction, prefersNumericWeights);
}
/** Helper that calls adj with prefersNumericWeights=true */
function adjNumeric(prompt: string, selected: string | [number, number], direction: 'increment' | 'decrement') {
return adj(prompt, selected, direction, true);
}
describe('adjustPromptAttention', () => {
describe('cross-boundary selection', () => {
it('should split group and apply attention when selection spans from inside group to outside (increment)', () => {
const prompt = '(a b)+ c';
const result = adjustPromptAttention(prompt, 3, 8, 'increment');
expect(result.prompt).toBe('(a b+ c)+');
});
it('should split group and apply attention when selection spans from inside group to outside (decrement)', () => {
const prompt = '(a b)+ c';
const result = adjustPromptAttention(prompt, 3, 8, 'decrement');
expect(result.prompt).toBe('a+ b c-');
});
it('should split group when selection starts before group and ends inside (increment)', () => {
const prompt = 'a (b c)+';
const result = adjustPromptAttention(prompt, 0, 4, 'increment');
expect(result.prompt).toBe('(a b+ c)+');
});
it('should split group when selection starts before group and ends inside (decrement)', () => {
const prompt = 'a (b c)+';
const result = adjustPromptAttention(prompt, 0, 4, 'decrement');
expect(result.prompt).toBe('a- b c+');
});
it('should handle nested groups with cross-boundary selection (increment)', () => {
const prompt = '((a b)+)+ c';
const result = adjustPromptAttention(prompt, 2, 11, 'increment');
expect(result.prompt).toBe('((a b)++ c)+');
});
it('should handle nested groups with cross-boundary selection (decrement)', () => {
const prompt = '((a b)+)+ c';
const result = adjustPromptAttention(prompt, 2, 11, 'decrement');
expect(result.prompt).toBe('(a b)+ c-');
});
it('should handle selection spanning multiple groups (increment)', () => {
const prompt = '(a)+ (b)+';
const result = adjustPromptAttention(prompt, 0, 9, 'increment');
expect(result.prompt).toBe('(a b)++');
});
it('should handle selection spanning multiple groups (decrement)', () => {
const prompt = '(a)+ (b)+';
const result = adjustPromptAttention(prompt, 0, 9, 'decrement');
expect(result.prompt).toBe('a b');
});
it('should split negative group correctly (decrement on negative group)', () => {
const prompt = '(a b)- c';
const result = adjustPromptAttention(prompt, 3, 8, 'decrement');
expect(result.prompt).toBe('(a b- c)-');
});
it('should split negative group correctly (increment on negative group)', () => {
const prompt = '(a b)- c';
const result = adjustPromptAttention(prompt, 3, 8, 'increment');
expect(result.prompt).toBe('a- b c+');
});
it('should handle multiple non-selected items in group', () => {
const prompt = '(a b c)+ d';
const result = adjustPromptAttention(prompt, 5, 10, 'decrement');
expect(result.prompt).toBe('(a b)+ c d-');
});
it('should handle word with existing attention in group when crossing boundary', () => {
const prompt = 'c (d- e)+';
const result = adjustPromptAttention(prompt, 0, 5, 'increment');
expect(result.prompt).toBe('c+ d e+');
});
it('should handle complex multi-group case', () => {
const prompt = '(a+ b)+ c (d- e)+';
const result = adjustPromptAttention(prompt, 8, 14, 'increment');
expect(result.prompt).toBe('(a+ b c)+ d e+');
});
});
// Basic Attention
describe('single word', () => {
it('should add + when incrementing word without attention', () => {
const prompt = 'hello world';
const result = adjustPromptAttention(prompt, 0, 5, 'increment');
expect(result.prompt).toBe('hello+ world');
});
it('should add - when decrementing word without attention', () => {
const prompt = 'hello world';
const result = adjustPromptAttention(prompt, 0, 5, 'decrement');
expect(result.prompt).toBe('hello- world');
it.each([
['hello world', 'hello', 'increment', 'hello+ world'],
['hello world', 'hello', 'decrement', 'hello- world'],
['hello+ world', 'hello+', 'increment', 'hello++ world'],
['hello+ world', 'hello+', 'decrement', 'hello world'],
['hello- world', 'hello-', 'decrement', 'hello-- world'],
['hello- world', 'hello-', 'increment', 'hello world'],
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
expect(adj(prompt, selected, direction).prompt).toBe(expected);
});
});
describe('existing group', () => {
it('should adjust group attention when cursor is at group boundary', () => {
const prompt = '(hello world)+';
const result = adjustPromptAttention(prompt, 13, 14, 'increment');
describe('multiple words', () => {
it.each([
['hello world', [0, 11] as [number, number], 'increment', '(hello world)+'],
['hello world', [0, 11] as [number, number], 'decrement', '(hello world)-'],
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
expect(adj(prompt, selected, direction).prompt).toBe(expected);
});
});
expect(result.prompt).toBe('(hello world)++');
describe('cursor at word-punctuation boundary', () => {
it('should select word, not punctuation, when cursor is between word and comma', () => {
// "one|, two" — cursor at position 3, between "one" (0-3) and "," (3-4)
expect(adj('one, two', [3, 3], 'increment').prompt).toBe('one+, two');
});
it('should select word, not punctuation, when cursor is between word and period', () => {
expect(adj('one. two', [3, 3], 'increment').prompt).toBe('one+. two');
});
it('should select word when cursor is at start of word after punctuation', () => {
// "one, |two" — cursor at position 5, between " " (4-5) and "two" (5-8)
expect(adj('one, two', [5, 5], 'increment').prompt).toBe('one, two+');
});
it('should still select punctuation when cursor is only touching punctuation', () => {
// Cursor in the middle of a run of punctuation with no adjacent word
// e.g. "one ,, two" cursor at position 5 — between "," (4-5) and "," (5-6)
// Both neighbors are punct, so no word to prefer — should still work
const result = adj('one ,, two', [5, 5], 'increment');
expect(result).toBeDefined();
});
});
// Existing Groups
describe('existing groups', () => {
it('should increment group when cursor is at group boundary', () => {
expect(adj('(hello world)+', [13, 14], 'increment').prompt).toBe('(hello world)++');
});
it('should remove group when attention becomes neutral', () => {
const prompt = '(hello world)+';
const result = adjustPromptAttention(prompt, 0, 14, 'decrement');
expect(adj('(hello world)+', [0, 14], 'decrement').prompt).toBe('hello world');
});
expect(result.prompt).toBe('hello world');
it('should increment inner word within group', () => {
const result = adj('(a b)+', [1, 2], 'increment');
expect(result.prompt).toBe('(a+ b)+');
});
});
describe('multiple words without group', () => {
it('should create new group with + when incrementing multiple words', () => {
const prompt = 'hello world';
const result = adjustPromptAttention(prompt, 0, 11, 'increment');
// Cross-Boundary Selection
expect(result.prompt).toBe('(hello world)+');
});
it('should create new group with - when decrementing multiple words', () => {
const prompt = 'hello world';
const result = adjustPromptAttention(prompt, 0, 11, 'decrement');
expect(result.prompt).toBe('(hello world)-');
describe('cross-boundary selection', () => {
it.each([
// Selection from inside group to outside
['(a b)+ c', [3, 8], 'increment', '(a b+ c)+'],
['(a b)+ c', [3, 8], 'decrement', 'a+ b c-'],
// Selection from outside to inside group
['a (b c)+', [0, 4], 'increment', '(a b+ c)+'],
['a (b c)+', [0, 4], 'decrement', 'a- b c+'],
// Nested groups
['((a b)+)+ c', [2, 11], 'increment', '((a b)++ c)+'],
['((a b)+)+ c', [2, 11], 'decrement', '(a b)+ c-'],
// Spanning multiple groups
['(a)+ (b)+', [0, 9], 'increment', '(a b)++'],
['(a)+ (b)+', [0, 9], 'decrement', 'a b'],
// Negative groups
['(a b)- c', [3, 8], 'decrement', '(a b- c)-'],
['(a b)- c', [3, 8], 'increment', 'a- b c+'],
// Multiple non-selected items in group
['(a b c)+ d', [5, 10], 'decrement', '(a b)+ c d-'],
// Word with existing attention crossing boundary
['c (d- e)+', [0, 5], 'increment', 'c+ d e+'],
// Complex multi-group
['(a+ b)+ c (d- e)+', [8, 14], 'increment', '(a+ b c)+ d e+'],
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
expect(adj(prompt, selected as string | [number, number], direction).prompt).toBe(expected);
});
});
// Selection Preservation
describe('selection preservation', () => {
it('should preserve selection when incrementing single word', () => {
const prompt = 'hello world';
const result = adjustPromptAttention(prompt, 0, 5, 'increment');
it('should track selection when incrementing single word', () => {
const result = adj('hello world', 'hello', 'increment');
expect(result.prompt).toBe('hello+ world');
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('hello+');
});
it('should preserve selection when incrementing group', () => {
const prompt = '(hello world)+';
const result = adjustPromptAttention(prompt, 0, 14, 'increment');
it('should track selection when incrementing full group', () => {
const result = adj('(hello world)+', [0, 14], 'increment');
expect(result.prompt).toBe('(hello world)++');
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('(hello world)++');
});
it('should preserve selection when splitting group', () => {
const prompt = '(a b)+';
const result = adjustPromptAttention(prompt, 1, 2, 'increment'); // Select 'a' (index 1 to 2)
// 'a' becomes 1.21, 'b' stays 1.1
// Result: (a+ b)+ which is equivalent to a++ b+
it('should track selection when splitting group', () => {
const result = adj('(a b)+', [1, 2], 'increment');
expect(result.prompt).toBe('(a+ b)+');
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('a+');
});
});
// Numeric Attention Weights
describe('numeric attention weights', () => {
it.each([
// Increment / decrement numeric weights with additive step
['(masterpiece)1.3', [0, 16], 'increment', '(masterpiece)1.4'],
['(masterpiece)1.3', [0, 16], 'decrement', '(masterpiece)1.2'],
['(high detail)1.2', [0, 16], 'increment', '(high detail)1.3'],
['(sunny midday light)1.15', [0, 24], 'increment', '(sunny midday light)1.25'],
['(sunny midday light)1.15', [0, 24], 'decrement', '(sunny midday light)1.05'],
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
expect(adj(prompt, selected as [number, number], direction).prompt).toBe(expected);
});
it('should preserve non-selected numeric weights when adjusting elsewhere', () => {
const prompt = '(masterpiece)1.3, best quality';
const result = adj(prompt, 'best quality', 'increment');
expect(result.prompt).toContain('(masterpiece)1.3');
expect(result.prompt).not.toContain('masterpiece1.3');
});
it('should not produce floating point garbage', () => {
const prompt = '(high detail)1.2, oil painting';
const result = adj(prompt, 'oil painting', 'increment');
expect(result.prompt).toContain('(high detail)1.2');
expect(result.prompt).not.toMatch(/1\.19999/);
expect(result.prompt).not.toMatch(/1\.20000/);
});
it('should preserve numeric weight 1.15 without corruption', () => {
const prompt = '(sunny midday light)1.15, landscape';
const result = adj(prompt, 'landscape', 'increment');
expect(result.prompt).toContain('(sunny midday light)1.15');
expect(result.prompt).not.toMatch(/1\.15005/);
});
it('should normalize numeric 1.1 weight to + syntax', () => {
const prompt = '(lush rolling hills)1.1, landscape';
const result = adj(prompt, 'landscape', 'increment');
expect(result.prompt).toMatch(/\(lush rolling hills\)(\+|1\.1)/);
});
it('should handle the full complex prompt without corrupting non-selected weights', () => {
const prompt =
'(masterpiece)1.3, best quality, (high detail)1.2, oil painting, (sunny midday light)1.15, an old stone castle standing on a hill, medieval architecture, weathered stone walls, (lush rolling hills)1.1, expansive landscape, clear blue sky';
const result = adj(prompt, 'clear blue sky', 'increment');
expect(result.prompt).toContain('(masterpiece)1.3');
expect(result.prompt).toContain('(high detail)1.2');
expect(result.prompt).toContain('(sunny midday light)1.15');
expect(result.prompt).toContain('(clear blue sky)+');
expect(result.prompt).not.toMatch(/\d\.\d{5,}/);
});
});
// Prompt Functions
describe('prompt functions', () => {
describe('within a single argument', () => {
it.each([
// Single word inside an arg
["('hello world', 'other').and()", 'hello', 'increment', "('hello+ world', 'other').and()"],
["('hello world', 'other').and()", 'hello', 'decrement', "('hello- world', 'other').and()"],
// Multiple words in second arg
["('a', 'hello world').or()", 'hello world', 'increment', "('a', '(hello world)+').or()"],
["('a', 'hello world').or()", 'hello world', 'decrement', "('a', '(hello world)-').or()"],
// Single word in .blend()
["('one two', 'three four').blend(0.7, 0.3)", 'two', 'increment', "('one two+', 'three four').blend(0.7, 0.3)"],
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
expect(adj(prompt, selected, direction).prompt).toBe(expected);
});
});
describe('across argument separator', () => {
it('should adjust both args simultaneously when selection spans separator (increment)', () => {
const prompt = "('one two', 'three four').and()";
// Select across the separator: "two', 'three"
const start = prompt.indexOf('two');
const end = prompt.indexOf('three') + 'three'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe("('one two+', 'three+ four').and()");
});
it('should adjust both args simultaneously when selection spans separator (decrement)', () => {
const prompt = "('one two', 'three four').and()";
const start = prompt.indexOf('two');
const end = prompt.indexOf('three') + 'three'.length;
const result = adjustPromptAttention(prompt, start, end, 'decrement');
expect(result.prompt).toBe("('one two-', 'three- four').and()");
});
it('should adjust across separator for .or()', () => {
const prompt = "('alpha beta', 'gamma delta').or()";
const start = prompt.indexOf('beta');
const end = prompt.indexOf('gamma') + 'gamma'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe("('alpha beta+', 'gamma+ delta').or()");
});
it('should adjust across separator for .blend() preserving params', () => {
const prompt = "('one two', 'three four').blend(0.7, 0.3)";
const start = prompt.indexOf('two');
const end = prompt.indexOf('three') + 'three'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe("('one two+', 'three+ four').blend(0.7, 0.3)");
});
it('should handle repeated increment across separator', () => {
const prompt = "('one two+', 'three+ four').and()";
const start = prompt.indexOf('two');
const end = prompt.indexOf('three') + 'three'.length;
// "two+" is at the boundary, "three+" is at the boundary
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe("('one two++', 'three++ four').and()");
});
});
describe('whole function selected', () => {
it('should increment all content in all args when whole function is selected', () => {
const prompt = "('one', 'two').and()";
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
expect(result.prompt).toBe("('one+', 'two+').and()");
});
it('should decrement all content in all args', () => {
const prompt = "('one', 'two').and()";
const result = adjustPromptAttention(prompt, 0, prompt.length, 'decrement');
expect(result.prompt).toBe("('one-', 'two-').and()");
});
it('should increment all args of .blend() preserving params', () => {
const prompt = "('one', 'two').blend(0.7, 0.3)";
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
expect(result.prompt).toBe("('one+', 'two+').blend(0.7, 0.3)");
});
});
describe('prompt function embedded in larger prompt', () => {
it('should adjust only the targeted region outside the function', () => {
const prompt = "some text, ('a', 'b').and(), more text";
const result = adj(prompt, 'some', 'increment');
expect(result.prompt).toContain('some+');
expect(result.prompt).toContain("('a', 'b').and()");
});
it('should adjust only the targeted region inside the function', () => {
const prompt = "prefix ('alpha beta', 'gamma').and() suffix";
const result = adj(prompt, 'alpha', 'increment');
expect(result.prompt).toContain("'alpha+ beta'");
expect(result.prompt).toContain('prefix');
expect(result.prompt).toContain('suffix');
});
it('should adjust text outside and inside function when selection spans boundary', () => {
const prompt = "text ('one two', 'three').and()";
// Select from 'text' through 'one'
const start = prompt.indexOf('text');
const end = prompt.indexOf('one') + 'one'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toContain('text+');
expect(result.prompt).toContain("'one+ two'");
});
});
describe('prompt function with existing attention inside args', () => {
it('should further increment already-weighted word inside arg', () => {
const prompt = "('hello+', 'world').and()";
// Select hello+ (the word with its weight marker)
const result = adj(prompt, 'hello+', 'increment');
expect(result.prompt).toBe("('hello++', 'world').and()");
});
it('should cancel attention to neutral inside arg', () => {
const prompt = "('hello+', 'world').and()";
const result = adj(prompt, 'hello+', 'decrement');
expect(result.prompt).toBe("('hello', 'world').and()");
});
it('should handle group attention inside arg', () => {
const prompt = "('(a b)+', 'c').and()";
// Select everything in first arg
const start = prompt.indexOf('(a b)+');
const end = start + '(a b)+'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe("('(a b)++', 'c').and()");
});
});
describe('three-arg prompt functions', () => {
it('should adjust a word in one arg of a three-arg blend', () => {
const prompt = "('a', 'b', 'c').blend(0.5, 0.3, 0.2)";
const result = adj(prompt, 'b', 'increment');
expect(result.prompt).toBe("('a', 'b+', 'c').blend(0.5, 0.3, 0.2)");
});
it('should adjust across two separators in a three-arg blend', () => {
const prompt = "('aa bb', 'cc dd', 'ee ff').blend(0.5, 0.3, 0.2)";
// Select from bb through ee
const start = prompt.indexOf('bb');
const end = prompt.indexOf('ee') + 'ee'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe("('aa bb+', '(cc dd)+', 'ee+ ff').blend(0.5, 0.3, 0.2)");
});
});
describe('unquoted prompt functions', () => {
it('should increment a word in unquoted .and()', () => {
const prompt = '(one, two).and()';
const result = adj(prompt, 'one', 'increment');
expect(result.prompt).toBe('(one+, two).and()');
});
it('should decrement a word in unquoted .and()', () => {
const prompt = '(one, two).and()';
const result = adj(prompt, 'one', 'decrement');
expect(result.prompt).toBe('(one-, two).and()');
});
it('should increment a word in unquoted multi-word arg', () => {
const prompt = '(hello world, foo bar).and()';
const result = adj(prompt, 'hello', 'increment');
expect(result.prompt).toBe('(hello+ world, foo bar).and()');
});
it('should increment all args when whole unquoted function is selected', () => {
const prompt = '(one, two).and()';
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
expect(result.prompt).toBe('(one+, two+).and()');
});
it('should preserve unquoted prompt function when adjusting text outside', () => {
const prompt = 'prefix (a, b).and() suffix';
const result = adj(prompt, 'prefix', 'increment');
expect(result.prompt).toContain('(a, b).and()');
expect(result.prompt).toContain('prefix+');
});
it('should handle unquoted .blend() with params', () => {
const prompt = '(one two, three four).blend(0.7, 0.3)';
const result = adj(prompt, 'one', 'increment');
expect(result.prompt).toBe('(one+ two, three four).blend(0.7, 0.3)');
});
it('should adjust across separator in unquoted prompt function', () => {
const prompt = '(one two, three four).and()';
const start = prompt.indexOf('two');
const end = prompt.indexOf('three') + 'three'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe('(one two+, three+ four).and()');
});
});
describe('curly-quoted prompt functions', () => {
it('should increment a word inside curly double-quoted arg', () => {
const prompt = '(\u201chello world\u201d, \u201cother\u201d).and()';
const result = adj(prompt, 'hello', 'increment');
expect(result.prompt).toBe('(\u201chello+ world\u201d, \u201cother\u201d).and()');
});
it('should decrement a word inside curly double-quoted arg', () => {
const prompt = '(\u201chello world\u201d, \u201cother\u201d).and()';
const result = adj(prompt, 'hello', 'decrement');
expect(result.prompt).toBe('(\u201chello- world\u201d, \u201cother\u201d).and()');
});
it('should increment a word inside curly single-quoted arg', () => {
const prompt = '(\u2018hello world\u2019, \u2018other\u2019).and()';
const result = adj(prompt, 'hello', 'increment');
expect(result.prompt).toBe('(\u2018hello+ world\u2019, \u2018other\u2019).and()');
});
it('should increment all args when whole curly-quoted function is selected', () => {
const prompt = '(\u201cone\u201d, \u201ctwo\u201d).and()';
const result = adjustPromptAttention(prompt, 0, prompt.length, 'increment');
expect(result.prompt).toBe('(\u201cone+\u201d, \u201ctwo+\u201d).and()');
});
it('should adjust across separator in curly double-quoted prompt function', () => {
const prompt = '(\u201cone two\u201d, \u201cthree four\u201d).and()';
const start = prompt.indexOf('two');
const end = prompt.indexOf('three') + 'three'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe('(\u201cone two+\u201d, \u201cthree+ four\u201d).and()');
});
it('should preserve curly-quoted function when adjusting text outside', () => {
const prompt = 'prefix (\u201ca\u201d, \u201cb\u201d).and() suffix';
const result = adj(prompt, 'prefix', 'increment');
expect(result.prompt).toContain('(\u201ca\u201d, \u201cb\u201d).and()');
expect(result.prompt).toContain('prefix+');
});
it('should handle curly-quoted .blend() with params', () => {
const prompt = '(\u201cone two\u201d, \u201cthree four\u201d).blend(0.7, 0.3)';
const result = adj(prompt, 'one', 'increment');
expect(result.prompt).toBe('(\u201cone+ two\u201d, \u201cthree four\u201d).blend(0.7, 0.3)');
});
});
describe('newline before .method()', () => {
it('should increment a word in quoted prompt function with newline before .method()', () => {
const prompt = "('hello world', 'other')\n.and()";
const result = adj(prompt, 'hello', 'increment');
// Newline is normalized away in output
expect(result.prompt).toBe("('hello+ world', 'other').and()");
});
it('should increment a word in curly-quoted prompt function with newline before .method()', () => {
const prompt = '(\u201chello world\u201d, \u201cother\u201d)\n.and()';
const result = adj(prompt, 'hello', 'increment');
expect(result.prompt).toBe('(\u201chello+ world\u201d, \u201cother\u201d).and()');
});
it('should increment a word in unquoted prompt function with newline before .method()', () => {
const prompt = '(hello, other)\n.and()';
const result = adj(prompt, 'hello', 'increment');
expect(result.prompt).toBe('(hello+, other).and()');
});
});
describe('paragraph separators between args', () => {
it('should preserve newlines between quoted args when adjusting', () => {
const prompt = "('chunk 1\n\nline',\n 'chunk 2').and()";
const result = adj(prompt, 'chunk', 'increment');
expect(result.prompt).toBe("('chunk+ 1\n\nline',\n 'chunk 2').and()");
});
});
});
// Selection Preservation with Prompt Functions
describe('selection preservation with prompt functions', () => {
it('should track selection for single word inside prompt function arg', () => {
const prompt = "('hello world', 'other').and()";
const result = adj(prompt, 'hello', 'increment');
expect(result.prompt).toBe("('hello+ world', 'other').and()");
expect(result.prompt.slice(result.selectionStart, result.selectionEnd)).toBe('hello+');
});
it('should track selection spanning across prompt function separator', () => {
const prompt = "('one two', 'three four').and()";
const start = prompt.indexOf('two');
const end = prompt.indexOf('three') + 'three'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment');
expect(result.prompt).toBe("('one two+', 'three+ four').and()");
// Selection should span from 'two+' through 'three+' (including structural chars between)
const sel = result.prompt.slice(result.selectionStart, result.selectionEnd);
expect(sel).toContain('two+');
expect(sel).toContain('three+');
});
});
// Edge Cases
describe('edge cases', () => {
it('should return prompt unchanged when no selection overlap', () => {
const prompt = 'hello world';
const result = adjustPromptAttention(prompt, 5, 5, 'increment');
// Cursor at the boundary between hello and space — should still find a terminal
expect(result.prompt).toBeDefined();
});
it('should handle empty prompt', () => {
const result = adjustPromptAttention('', 0, 0, 'increment');
expect(result.prompt).toBe('');
});
it('should not modify prompt function structure when cursor is on structural char', () => {
const prompt = "('a', 'b').and()";
// Cursor on the dot between ) and and
const dotPos = prompt.indexOf('.and');
const result = adjustPromptAttention(prompt, dotPos, dotPos, 'increment');
// Should either not change or only affect content, not break the structure
expect(result.prompt).toContain('.and()');
});
});
// Numeric Weight Preference
describe('prefersNumericWeights', () => {
describe('single word (no existing attention)', () => {
it.each([
['hello world', 'hello', 'increment', '(hello)1.1 world'],
['hello world', 'hello', 'decrement', '(hello)0.9 world'],
['hello world', 'world', 'increment', 'hello (world)1.1'],
['hello world', 'world', 'decrement', 'hello (world)0.9'],
] as const)('%s [%s] %s → %s', (prompt, selected, direction, expected) => {
expect(adjNumeric(prompt, selected, direction).prompt).toBe(expected);
});
});
describe('successive numeric adjustments', () => {
it('should use additive step on second increment', () => {
const result = adjNumeric('(hello)1.1 world', '(hello)1.1', 'increment');
expect(result.prompt).toBe('(hello)1.2 world');
});
it('should use additive step on second decrement', () => {
const result = adjNumeric('(hello)0.9 world', '(hello)0.9', 'decrement');
expect(result.prompt).toBe('(hello)0.8 world');
});
it('should return to neutral from 1.1 on decrement', () => {
const result = adjNumeric('(hello)1.1 world', '(hello)1.1', 'decrement');
expect(result.prompt).toBe('hello world');
});
});
describe('does not convert existing +/- attention on unselected terminals', () => {
it('should preserve +/- on unselected word when adjusting another', () => {
const result = adjNumeric('hello+ world', 'world', 'increment');
expect(result.prompt).toContain('hello+');
expect(result.prompt).toContain('(world)1.1');
});
it('should preserve - on unselected word', () => {
const result = adjNumeric('hello- world', 'world', 'decrement');
expect(result.prompt).toContain('hello-');
expect(result.prompt).toContain('(world)0.9');
});
});
describe('existing +/- attention on selected terminals', () => {
it('should increment existing + word with multiplicative step (respects existing style)', () => {
const result = adjNumeric('hello+ world', 'hello+', 'increment');
// The terminal already has explicit +/- attention, so it keeps that style
expect(result.prompt).toBe('hello++ world');
});
it('should decrement existing + word to neutral', () => {
const result = adjNumeric('hello+ world', 'hello+', 'decrement');
expect(result.prompt).toBe('hello world');
});
});
describe('existing numeric attention on selected terminals', () => {
it('should increment existing numeric weight additively', () => {
const result = adjNumeric('(detail)1.3 world', '(detail)1.3', 'increment');
expect(result.prompt).toBe('(detail)1.4 world');
});
it('should decrement existing numeric weight additively', () => {
const result = adjNumeric('(detail)1.3 world', '(detail)1.3', 'decrement');
expect(result.prompt).toBe('(detail)1.2 world');
});
});
describe('multiple words selected', () => {
it('should wrap multiple words in numeric group on increment', () => {
const result = adjNumeric('hello world', [0, 11], 'increment');
expect(result.prompt).toBe('(hello world)1.1');
});
it('should wrap multiple words in numeric group on decrement', () => {
const result = adjNumeric('hello world', [0, 11], 'decrement');
expect(result.prompt).toBe('(hello world)0.9');
});
});
describe('inside prompt functions', () => {
it('should use numeric format inside prompt function arg', () => {
const prompt = "('hello world', 'other').and()";
const result = adjNumeric(prompt, 'hello', 'increment');
expect(result.prompt).toBe("('(hello)1.1 world', 'other').and()");
});
it('should use numeric format across prompt function separator', () => {
const prompt = "('one two', 'three four').and()";
const start = prompt.indexOf('two');
const end = prompt.indexOf('three') + 'three'.length;
const result = adjustPromptAttention(prompt, start, end, 'increment', true);
expect(result.prompt).toBe("('one (two)1.1', '(three)1.1 four').and()");
});
});
describe('group splitting inside prompt function args', () => {
it('should correctly split weighted group when decrementing a single word inside it', () => {
const prompt =
'("high detail, (cinematic lighting)1.25, soft volumetric light, (sharp focus)+, professional photography", "a young woman with balanced natural proportions, medium length brown hair, neutral expression, casual modern clothing", "subtle rim light, shallow depth of field, natural skin texture, clean background").and()';
const result = adj(prompt, 'lighting', 'decrement');
// "lighting" gets decremented from 1.25 → 1.25/1.1 ≈ 1.1364
// "cinematic" stays at 1.25
// The key thing: no space should be lost/misplaced
expect(result.prompt).toContain('(cinematic)1.25');
expect(result.prompt).toContain('lighting)');
// Verify there's a space between the cinematic group and lighting group
const cinIdx = result.prompt.indexOf('(cinematic)1.25');
const afterCinematic = result.prompt.substring(
cinIdx + '(cinematic)1.25'.length,
cinIdx + '(cinematic)1.25'.length + 2
);
expect(afterCinematic).toMatch(/^ /); // Should start with a space
});
it('should rejoin groups when incrementing back to the same weight', () => {
const prompt =
'("high detail, (cinematic lighting)1.25, soft volumetric light, (sharp focus)+, professional photography", "a young woman with balanced natural proportions, medium length brown hair, neutral expression, casual modern clothing", "subtle rim light, shallow depth of field, natural skin texture, clean background").and()';
// Decrement "lighting" to split the group
const step1 = adj(prompt, 'lighting', 'decrement');
expect(step1.prompt).toContain('(cinematic)1.25');
// Now increment "lighting" back — should rejoin into (cinematic lighting)1.25
const step2 = adj(step1.prompt, 'lighting', 'increment');
expect(step2.prompt).toContain('(cinematic lighting)1.25');
});
});
describe('numeric group whitespace trimming', () => {
it('should not capture trailing whitespace inside numeric weighted groups', () => {
// (foo bar)1.3 → decrement "bar" → (foo)1.3 (bar)X, with space between
const result = adj('(foo bar)1.3', 'bar', 'decrement');
expect(result.prompt).toContain('(foo)1.3');
// Space should be outside the group, not inside
expect(result.prompt).not.toContain('(foo )');
expect(result.prompt).toMatch(/\(foo\)1\.3 /);
});
it('should not capture leading whitespace inside numeric weighted groups', () => {
// (foo bar)1.3 → decrement "foo" → (foo)X (bar)1.3, with space between
const result = adj('(foo bar)1.3', 'foo', 'decrement');
expect(result.prompt).toContain('(bar)1.3');
// Space should be outside the group, not inside
expect(result.prompt).not.toContain('( bar)');
expect(result.prompt).toMatch(/ \(bar\)1\.3/);
});
});
describe('numeric group conjoining', () => {
it('should merge adjacent same-weight numeric groups back together', () => {
// Two separate groups with same weight should conjoin into one
const result = adj('(foo)1.25 (bar)1.25', [0, 19], 'increment');
// Both words get the same increment, so they should stay in one group
expect(result.prompt).not.toContain(') (');
});
it('should merge adjacent same-weight groups when incrementing to match', () => {
// Start with (foo bar)1.3, decrement "bar", then increment it back
const step1 = adj('(foo bar)1.3', 'bar', 'decrement');
// Now increment "bar" back — it should rejoin into a single group
const step2 = adj(step1.prompt, 'bar', 'increment');
expect(step2.prompt).toBe('(foo bar)1.3');
});
it('should merge inside prompt function args', () => {
const prompt = '("(cinematic)1.25 (lighting)1.25", "other").and()';
const start = prompt.indexOf('cinematic');
const end = prompt.indexOf('lighting') + 'lighting'.length;
const result = adj(prompt, [start, end], 'increment');
// Both get incremented to same weight, should be one group
expect(result.prompt).not.toMatch(/\)\d[.\d]* \(/);
});
});
describe('without prefersNumericWeights (default behavior unchanged)', () => {
it('should still use +/- syntax by default', () => {
expect(adj('hello world', 'hello', 'increment').prompt).toBe('hello+ world');
expect(adj('hello world', 'hello', 'decrement').prompt).toBe('hello- world');
});
it('should still use +/- for multiple words by default', () => {
expect(adj('hello world', [0, 11], 'increment').prompt).toBe('(hello world)+');
});
});
});
});

View File

@@ -1,152 +1,56 @@
import { logger } from 'app/logging/logger';
import { serializeError } from 'serialize-error';
import { type ASTNode, type Attention, parseTokens, serialize, tokenize } from './promptAST';
import {
type ASTNode,
type Attention,
parseTokens,
type PromptFunctionArg,
serializeWithSelection,
tokenize,
} from './promptAST';
const log = logger('events');
const log = logger('generation');
type AttentionDirection = 'increment' | 'decrement';
type AdjustmentResult = { prompt: string; selectionStart: number; selectionEnd: number };
const ATTENTION_STEP = 1.1;
const NUMERIC_ATTENTION_STEP = 0.1;
/** Tolerance for floating-point weight comparisons. */
const WEIGHT_TOLERANCE = 0.001;
/** Tolerance for checking if a weight is a power of ATTENTION_STEP. */
const STEP_COUNT_TOLERANCE = 0.005;
// #region Weight Helpers
/**
* Adjusts the attention of the prompt at the current cursor/selection position.
* Check if a weight is approximately ATTENTION_STEP^n for some integer n.
* Returns n if so, or null if the weight is not a power of ATTENTION_STEP.
*/
export function adjustPromptAttention(
prompt: string,
selectionStart: number,
selectionEnd: number,
direction: AttentionDirection
): AdjustmentResult {
try {
const tokens = tokenize(prompt);
const ast = parseTokens(tokens);
const terminals = flattenAST(ast);
let selectedTerminals = terminals.filter((t) => {
const isSelected =
(t.range.start < selectionEnd && t.range.end > selectionStart) ||
(selectionStart === selectionEnd && t.range.start <= selectionStart && t.range.end >= selectionStart);
if (!isSelected) {
return false;
}
if (t.parentRange) {
const parentContainsSelection = t.parentRange.start <= selectionStart && t.parentRange.end >= selectionEnd;
const selectionCoversParent = selectionStart <= t.parentRange.start && selectionEnd >= t.parentRange.end;
if (!parentContainsSelection && !selectionCoversParent) {
// Partial overlap.
if (t.hasExplicitAttention) {
return false; // Don't modify explicit weight in partial group
}
}
}
return true;
});
for (const t of selectedTerminals) {
t.isSelected = true;
}
if (selectedTerminals.length === 0) {
const selectedGroup = findSelectedGroup(ast, selectionStart, selectionEnd);
if (selectedGroup) {
selectedTerminals = terminals.filter(
(t) => t.range.start >= selectedGroup.range.start && t.range.end <= selectedGroup.range.end
);
for (const t of selectedTerminals) {
t.isSelected = true;
}
}
}
if (selectedTerminals.length === 0) {
return { prompt, selectionStart, selectionEnd };
}
for (const terminal of selectedTerminals) {
if (direction === 'increment') {
terminal.weight *= ATTENTION_STEP;
} else {
terminal.weight /= ATTENTION_STEP;
}
}
const newAST = groupTerminals(terminals);
const newPrompt = serialize(newAST);
const newSelection = calculateSelectionRange(newAST);
return {
prompt: newPrompt,
selectionStart: newSelection.start,
selectionEnd: newSelection.end,
};
} catch (e) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
log.error({ error: serializeError(e) as any }, 'Failed to adjust prompt attention');
return { prompt, selectionStart, selectionEnd };
function getAttentionStepCount(weight: number): number | null {
if (weight <= 0) {
return null;
}
}
type Terminal = {
text: string;
type: ASTNode['type'];
weight: number;
range: { start: number; end: number };
hasExplicitAttention: boolean;
parentRange?: { start: number; end: number };
isSelected: boolean;
};
function flattenAST(ast: ASTNode[], currentWeight = 1.0, parentRange?: { start: number; end: number }): Terminal[] {
let terminals: Terminal[] = [];
for (const node of ast) {
let nodeWeight = currentWeight;
if ('attention' in node && node.attention) {
nodeWeight *= parseAttention(node.attention);
}
if (node.type === 'group') {
terminals.push(...flattenAST(node.children, nodeWeight, node.range));
} else {
terminals.push({
text: node.type === 'word' ? node.text : node.value,
type: node.type,
weight: nodeWeight,
range: node.range,
hasExplicitAttention: 'attention' in node && !!node.attention,
parentRange: parentRange,
isSelected: false,
});
}
if (Math.abs(weight - 1.0) < WEIGHT_TOLERANCE) {
return 0;
}
return terminals;
}
function findSelectedGroup(nodes: ASTNode[], start: number, end: number): ASTNode | null {
for (const node of nodes) {
if (node.type === 'group') {
const foundInChildren = findSelectedGroup(node.children, start, end);
if (foundInChildren) {
return foundInChildren;
}
if (rangesOverlap(node.range, { start, end })) {
return node;
}
}
const n = Math.round(Math.log(weight) / Math.log(ATTENTION_STEP));
if (n === 0) {
return null;
}
const expected = Math.pow(ATTENTION_STEP, n);
if (Math.abs(expected - weight) < STEP_COUNT_TOLERANCE) {
return n;
}
return null;
}
function rangesOverlap(a: { start: number; end: number }, b: { start: number; end: number }) {
return a.start < b.end && a.end > b.start;
}
/**
* Convert an Attention value ('+', '--', 1.2, etc.) into a numeric multiplier.
*/
function parseAttention(attention: Attention): number {
if (typeof attention === 'number') {
return attention;
@@ -161,83 +65,435 @@ function parseAttention(attention: Attention): number {
return isNaN(num) ? 1.0 : num;
}
function calculateSelectionRange(nodes: ASTNode[]): { start: number; end: number } {
let selectionStart = Infinity;
let selectionEnd = -1;
let currentPos = 0;
/**
* Combine an existing attention value with an additional '+' or '-' level.
* Handles cancellation: e.g. '++' + '-' → '+', '+' + '-' → undefined (neutral).
*/
function addAttention(current: Attention | undefined, added: '+' | '-'): Attention | undefined {
if (!current) {
return added;
}
if (typeof current === 'number') {
if (added === '+') {
return Number((current * ATTENTION_STEP).toFixed(4));
}
return Number((current / ATTENTION_STEP).toFixed(4));
}
// Check if the added direction cancels the current one
const isCancel = (current.startsWith('+') && added === '-') || (current.startsWith('-') && added === '+');
if (isCancel) {
const res = current.substring(1);
return res === '' ? undefined : res;
}
return `${current}${added}`;
}
function traverse(nodes: ASTNode[]) {
for (const node of nodes) {
if (node.isSelection) {
const len = serialize([node]).length;
selectionStart = Math.min(selectionStart, currentPos);
selectionEnd = Math.max(selectionEnd, currentPos + len);
currentPos += len;
} else {
if (node.type === 'group') {
// Group is not fully selected, but children might be.
// Group structure: "(" + children + ")" + attention
currentPos += 1; // '('
traverse(node.children);
currentPos += 1; // ')'
if (node.attention) {
currentPos += String(node.attention).length;
// #region Terminal Type
type Terminal = {
text: string;
type: ASTNode['type'];
weight: number;
range: { start: number; end: number };
hasExplicitAttention: boolean;
hasNumericAttention: boolean;
parentRange?: { start: number; end: number };
isSelected: boolean;
};
// #region Main Entry Point
/**
* Adjusts the attention of the prompt at the current cursor/selection position.
* Supports regular prompts and prompt functions (.and(), .or(), .blend()).
*
* When a selection spans across a prompt function's argument separator, each
* affected argument is adjusted independently and simultaneously.
*/
export function adjustPromptAttention(
prompt: string,
selectionStart: number,
selectionEnd: number,
direction: AttentionDirection,
prefersNumericWeights = false
): AdjustmentResult {
try {
const tokens = tokenize(prompt);
const ast = parseTokens(tokens);
const regions = extractRegions(ast);
const processedNodes: ASTNode[] = [];
let anyModified = false;
for (const region of regions) {
if (region.type === 'normal') {
const clipped = clipSelection(selectionStart, selectionEnd, region.range);
if (clipped) {
const result = adjustRegionNodes(region.nodes, clipped.start, clipped.end, direction, prefersNumericWeights);
if (result.modified) {
anyModified = true;
}
processedNodes.push(...result.nodes);
} else {
// Leaf node not selected.
const len = serialize([node]).length;
currentPos += len;
processedNodes.push(...region.nodes);
}
} else {
// prompt_function region
const pfNode = region.node;
const clipped = clipSelection(selectionStart, selectionEnd, pfNode.range);
if (clipped) {
const result = adjustPromptFunctionNode(pfNode, clipped.start, clipped.end, direction, prefersNumericWeights);
if (result.modified) {
anyModified = true;
}
processedNodes.push(result.node);
} else {
processedNodes.push(pfNode);
}
}
}
}
traverse(nodes);
if (!anyModified) {
return { prompt, selectionStart, selectionEnd };
}
if (selectionStart === Infinity) {
return { start: 0, end: serialize(nodes).length };
return serializeWithSelection(processedNodes);
} catch (e) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
log.error({ error: serializeError(e) as any }, 'Failed to adjust prompt attention');
return { prompt, selectionStart, selectionEnd };
}
return { start: selectionStart, end: selectionEnd };
}
// #region Region Extraction
type Region =
| { type: 'normal'; nodes: ASTNode[]; range: { start: number; end: number } }
| { type: 'prompt_function'; node: ASTNode & { type: 'prompt_function' } };
/**
* Split the top-level AST into contiguous "normal" regions and prompt function regions.
* This allows us to process prompt function arguments independently.
*/
function extractRegions(ast: ASTNode[]): Region[] {
const regions: Region[] = [];
let currentNormal: ASTNode[] = [];
const flushNormal = () => {
if (currentNormal.length > 0) {
const first = currentNormal[0]!;
const last = currentNormal[currentNormal.length - 1]!;
regions.push({
type: 'normal',
nodes: currentNormal,
range: { start: first.range.start, end: last.range.end },
});
currentNormal = [];
}
};
for (const node of ast) {
if (node.type === 'prompt_function') {
flushNormal();
regions.push({ type: 'prompt_function', node });
} else {
currentNormal.push(node);
}
}
flushNormal();
return regions;
}
/**
* Clip a selection range to a target range. Returns null if there is no overlap.
* For cursor positions (start === end), checks containment including boundaries.
*/
function clipSelection(
selStart: number,
selEnd: number,
range: { start: number; end: number }
): { start: number; end: number } | null {
if (selStart === selEnd) {
// Cursor position: check if within range (inclusive of boundaries)
if (selStart >= range.start && selStart <= range.end) {
return { start: selStart, end: selEnd };
}
return null;
}
const clippedStart = Math.max(selStart, range.start);
const clippedEnd = Math.min(selEnd, range.end);
if (clippedStart >= clippedEnd) {
return null;
}
return { start: clippedStart, end: clippedEnd };
}
// #region Prompt Function Handling
/**
* Adjust attention within a prompt function node by processing each argument
* whose content range overlaps the selection independently.
* Returns the (possibly updated) node and whether any modification was made.
*/
function adjustPromptFunctionNode(
pf: ASTNode & { type: 'prompt_function' },
selStart: number,
selEnd: number,
direction: AttentionDirection,
prefersNumericWeights = false
): { node: ASTNode & { type: 'prompt_function' }; modified: boolean } {
let modified = false;
const newArgs: PromptFunctionArg[] = pf.promptArgs.map((arg) => {
const clipped = clipSelection(selStart, selEnd, arg.contentRange);
if (clipped) {
const result = adjustRegionNodes(arg.nodes, clipped.start, clipped.end, direction, prefersNumericWeights);
if (result.modified) {
modified = true;
return { ...arg, nodes: result.nodes };
}
}
return arg;
});
if (!modified) {
return { node: pf, modified: false };
}
return { node: { ...pf, promptArgs: newArgs }, modified: true };
}
// #region Core Attention Adjustment
/**
* Adjust attention for a set of AST nodes (a "region") given a selection range.
* This is the core flatten → select → adjust → regroup pipeline.
* Returns the adjusted nodes and whether any modification was made.
*/
function adjustRegionNodes(
nodes: ASTNode[],
selStart: number,
selEnd: number,
direction: AttentionDirection,
prefersNumericWeights = false
): { nodes: ASTNode[]; modified: boolean } {
const terminals = flattenAST(nodes);
let selectedTerminals = selectTerminals(terminals, selStart, selEnd);
// Fallback: if no terminals were selected, try to find an overlapping group
if (selectedTerminals.length === 0) {
const group = findSelectedGroup(nodes, selStart, selEnd);
if (group) {
selectedTerminals = terminals.filter((t) => t.range.start >= group.range.start && t.range.end <= group.range.end);
}
}
if (selectedTerminals.length === 0) {
return { nodes, modified: false };
}
for (const t of selectedTerminals) {
t.isSelected = true;
// When the user prefers numeric weights and the terminal doesn't already
// have explicit attention, mark it as numeric so adjustWeights uses
// additive steps and groupTerminals emits numeric syntax.
if (prefersNumericWeights && !t.hasExplicitAttention) {
t.hasNumericAttention = true;
}
}
adjustWeights(selectedTerminals, direction);
return { nodes: groupTerminals(terminals), modified: true };
}
// #region Flatten AST to Terminals
/**
* Flatten an AST into a flat list of terminals, computing the effective weight
* of each terminal by accumulating attention from ancestor groups.
*/
function flattenAST(
ast: ASTNode[],
currentWeight = 1.0,
parentRange?: { start: number; end: number },
numericAttention = false
): Terminal[] {
const terminals: Terminal[] = [];
for (const node of ast) {
let nodeWeight = currentWeight;
let nodeNumericAttention = numericAttention;
if ((node.type === 'word' || node.type === 'group') && node.attention) {
nodeWeight *= parseAttention(node.attention);
nodeNumericAttention = typeof node.attention === 'number';
}
if (node.type === 'group') {
terminals.push(...flattenAST(node.children, nodeWeight, node.range, nodeNumericAttention));
} else if (node.type === 'prompt_function') {
// Prompt functions should not appear inside regions being flattened;
// they are handled at the region level. If one somehow appears, skip it.
continue;
} else {
terminals.push({
text: node.type === 'word' ? node.text : node.value,
type: node.type,
weight: nodeWeight,
range: node.range,
hasExplicitAttention: node.type === 'word' && !!node.attention,
hasNumericAttention: nodeNumericAttention,
parentRange,
isSelected: false,
});
}
}
return terminals;
}
// #region Terminal Selection
/**
* Find terminals that overlap the selection range and should be affected
* by the attention adjustment. Handles partial group overlap carefully:
* terminals with explicit attention inside partially-overlapping groups
* are excluded to avoid corrupting explicit weights.
*
* When the cursor is at a boundary between two tokens (e.g. "word|,"),
* both tokens technically overlap the cursor position. In this case we
* prefer word/embedding terminals over punctuation/whitespace so that
* adjusting attention at a word boundary doesn't accidentally include
* adjacent punctuation.
*/
function selectTerminals(terminals: Terminal[], selStart: number, selEnd: number): Terminal[] {
const result = terminals.filter((t) => {
const isOverlapping =
(t.range.start < selEnd && t.range.end > selStart) ||
(selStart === selEnd && t.range.start <= selStart && t.range.end >= selStart);
if (!isOverlapping) {
return false;
}
if (t.parentRange) {
const parentContainsSelection = t.parentRange.start <= selStart && t.parentRange.end >= selEnd;
const selectionCoversParent = selStart <= t.parentRange.start && selEnd >= t.parentRange.end;
if (!parentContainsSelection && !selectionCoversParent) {
// Partial overlap between selection and parent group
if (t.hasExplicitAttention) {
return false; // Don't modify explicit weight in partially-overlapping group
}
}
}
return true;
});
// When the cursor is at a token boundary (no selection range), multiple tokens
// can match. Prefer word/embedding terminals over punctuation/whitespace.
if (selStart === selEnd && result.length > 1) {
const contentTerminals = result.filter((t) => t.type === 'word' || t.type === 'embedding');
if (contentTerminals.length > 0) {
return contentTerminals;
}
}
return result;
}
// #region Weight Adjustment
/**
* Apply weight changes to the selected terminals based on direction.
* Numeric weights use additive steps; +/- syntax uses multiplicative steps.
* All results are rounded to 4 decimal places to prevent floating-point drift.
*/
function adjustWeights(terminals: Terminal[], direction: AttentionDirection): void {
for (const terminal of terminals) {
if (terminal.hasNumericAttention) {
// Additive step for explicit numeric weights (e.g. 1.1 → 1.2)
if (direction === 'increment') {
terminal.weight = Number((terminal.weight + NUMERIC_ATTENTION_STEP).toFixed(4));
} else {
terminal.weight = Number((terminal.weight - NUMERIC_ATTENTION_STEP).toFixed(4));
}
} else {
// Multiplicative step for +/- syntax weights, rounded to prevent drift
if (direction === 'increment') {
terminal.weight = Number((terminal.weight * ATTENTION_STEP).toFixed(4));
} else {
terminal.weight = Number((terminal.weight / ATTENTION_STEP).toFixed(4));
}
}
}
}
// #region Find Selected Group (fallback)
/**
* When no terminals directly overlap the selection (e.g. cursor is on a group
* boundary character), find the innermost group that overlaps the selection.
*/
function findSelectedGroup(nodes: ASTNode[], start: number, end: number): ASTNode | null {
for (const node of nodes) {
if (node.type === 'group') {
const foundInChildren = findSelectedGroup(node.children, start, end);
if (foundInChildren) {
return foundInChildren;
}
if (node.range.start < end && node.range.end > start) {
return node;
}
}
}
return null;
}
// #region Regroup Terminals into AST
/**
* Reconstruct an AST from a flat list of terminals with adjusted weights.
* Groups consecutive terminals with compatible weights using +/- or numeric syntax.
*
* Note: Reconstructed group nodes use `range: { start: 0, end: 0 }` as a sentinel
* value since the original source positions are no longer meaningful after regrouping.
* These nodes are only used for serialization output, never for source-position lookups.
*/
function groupTerminals(terminals: Terminal[]): ASTNode[] {
if (terminals.length === 0) {
return [];
}
/** Sentinel range for reconstructed nodes whose original positions are not applicable. */
const NO_RANGE = { start: 0, end: 0 };
const nodes: ASTNode[] = [];
let i = 0;
while (i < terminals.length) {
const t = terminals[i]!;
const weight = t.weight;
const stepCount = getAttentionStepCount(weight);
const findRunEnd = (predicate: (w: number) => boolean) => {
let j = i;
while (j < terminals.length) {
const next = terminals[j]!;
if (predicate(next.weight)) {
j++;
} else if (next.type === 'whitespace') {
let k = j + 1;
while (k < terminals.length && terminals[k]!.type === 'whitespace') {
k++;
}
if (k < terminals.length && predicate(terminals[k]!.weight)) {
j = k;
} else {
break;
}
} else {
break;
// ── +/- attention (weight is a non-zero power of ATTENTION_STEP) ──
// Skip this branch if the terminal prefers numeric format to avoid an
// infinite loop (predicate would reject it, findRunEnd returns i, i never advances).
if (stepCount !== null && stepCount !== 0 && !t.hasNumericAttention) {
const isPositive = stepCount > 0;
const sign: '+' | '-' = isPositive ? '+' : '-';
const predicate = (t: Terminal): boolean => {
if (t.hasNumericAttention) {
return false; // Numeric-preference terminals should not join +/- runs
}
}
return j;
};
const sc = getAttentionStepCount(t.weight);
return sc !== null && (isPositive ? sc > 0 : sc < 0);
};
const factor = isPositive ? ATTENTION_STEP : 1 / ATTENTION_STEP;
// Check for + (>= 1.1)
if (weight >= ATTENTION_STEP - 0.001) {
const j = findRunEnd((w) => w >= ATTENTION_STEP - 0.001);
const j = findRunEnd(terminals, i, predicate);
// Trim whitespace from the content run boundaries
let runStart = i;
let runEnd = j;
while (runStart < runEnd && terminals[runStart]!.type === 'whitespace') {
@@ -247,28 +503,31 @@ function groupTerminals(terminals: Terminal[]): ASTNode[] {
runEnd--;
}
// Emit leading whitespace as standalone nodes
for (let k = i; k < runStart; k++) {
nodes.push(createNodeFromTerminal(terminals[k]!));
}
if (runStart < runEnd) {
const slice = terminals.slice(runStart, runEnd).map((t) => ({ ...t, weight: t.weight / ATTENTION_STEP }));
// Factor out one level of attention and recurse
const slice = terminals.slice(runStart, runEnd).map((t) => ({ ...t, weight: t.weight / factor }));
const children = groupTerminals(slice);
const isSelection = slice.every((t) => t.isSelected);
if (children.length === 1) {
const child = children[0]!;
if (child.type === 'word' || child.type === 'group') {
const newAttention = addAttention(child.attention, '+');
nodes.push({ ...child, attention: newAttention });
const newAttention = addAttention(child.attention, sign);
nodes.push({ ...child, attention: newAttention, isSelection: isSelection || undefined });
} else {
nodes.push({ type: 'group', children, attention: '+', range: { start: 0, end: 0 }, isSelection });
nodes.push({ type: 'group', children, attention: sign, range: NO_RANGE, isSelection });
}
} else {
nodes.push({ type: 'group', children, attention: '+', range: { start: 0, end: 0 }, isSelection });
nodes.push({ type: 'group', children, attention: sign, range: NO_RANGE, isSelection });
}
}
// Emit trailing whitespace as standalone nodes
for (let k = runEnd; k < j; k++) {
nodes.push(createNodeFromTerminal(terminals[k]!));
}
@@ -277,126 +536,103 @@ function groupTerminals(terminals: Terminal[]): ASTNode[] {
continue;
}
// Check for - (<= 0.909)
if (weight <= 1 / ATTENTION_STEP + 0.001) {
const j = findRunEnd((w) => w <= 1 / ATTENTION_STEP + 0.001);
let runStart = i;
let runEnd = j;
while (runStart < runEnd && terminals[runStart]!.type === 'whitespace') {
runStart++;
}
while (runEnd > runStart && terminals[runEnd - 1]!.type === 'whitespace') {
runEnd--;
}
for (let k = i; k < runStart; k++) {
nodes.push(createNodeFromTerminal(terminals[k]!));
}
if (runStart < runEnd) {
const slice = terminals.slice(runStart, runEnd).map((t) => ({ ...t, weight: t.weight * ATTENTION_STEP }));
const children = groupTerminals(slice);
const isSelection = slice.every((t) => t.isSelected);
if (children.length === 1) {
const child = children[0]!;
if (child.type === 'word' || child.type === 'group') {
const newAttention = addAttention(child.attention, '-');
nodes.push({ ...child, attention: newAttention });
} else {
nodes.push({ type: 'group', children, attention: '-', range: { start: 0, end: 0 }, isSelection });
}
} else {
nodes.push({ type: 'group', children, attention: '-', range: { start: 0, end: 0 }, isSelection });
}
}
for (let k = runEnd; k < j; k++) {
nodes.push(createNodeFromTerminal(terminals[k]!));
}
i = j;
continue;
}
// Residual or 1.0
if (Math.abs(weight - 1.0) < 0.001) {
// ── Neutral weight (≈ 1.0) ──
if (Math.abs(weight - 1.0) < WEIGHT_TOLERANCE) {
nodes.push(createNodeFromTerminal(t));
i++;
} else {
let j = i;
while (j < terminals.length && Math.abs(terminals[j]!.weight - weight) < 0.001) {
j++;
continue;
}
// ── Numeric weight (not a power of ATTENTION_STEP) ──
{
const j = findRunEnd(terminals, i, (t) => Math.abs(t.weight - weight) < WEIGHT_TOLERANCE);
// Trim whitespace from the content run boundaries (same as +/- branch)
let runStart = i;
let runEnd = j;
while (runStart < runEnd && terminals[runStart]!.type === 'whitespace') {
runStart++;
}
while (runEnd > runStart && terminals[runEnd - 1]!.type === 'whitespace') {
runEnd--;
}
const groupTerminalsSlice = terminals.slice(i, j).map((t) => ({ ...t, weight: 1.0 }));
const children = groupTerminals(groupTerminalsSlice);
const isSelection = groupTerminalsSlice.every((t) => t.isSelected);
const weightStr = Number(weight.toFixed(4));
if (children.length === 1) {
const child = children[0]!;
if (child.type === 'word' || child.type === 'group') {
nodes.push({ ...child, attention: weightStr });
} else {
nodes.push({ type: 'group', children, attention: weightStr, range: { start: 0, end: 0 }, isSelection });
}
} else {
nodes.push({ type: 'group', children, attention: weightStr, range: { start: 0, end: 0 }, isSelection });
// Emit leading whitespace as standalone nodes
for (let k = i; k < runStart; k++) {
nodes.push(createNodeFromTerminal(terminals[k]!));
}
if (runStart < runEnd) {
const groupSlice = terminals.slice(runStart, runEnd).map((t) => ({ ...t, weight: 1.0 }));
const children = groupTerminals(groupSlice);
const isSelection = groupSlice.every((t) => t.isSelected);
const weightNum = Number(weight.toFixed(4));
nodes.push({ type: 'group', children, attention: weightNum, range: NO_RANGE, isSelection });
}
// Emit trailing whitespace as standalone nodes
for (let k = runEnd; k < j; k++) {
nodes.push(createNodeFromTerminal(terminals[k]!));
}
i = j;
}
}
return nodes;
}
function createNodeFromTerminal(t: Terminal): ASTNode {
if (t.type === 'word') {
return { type: 'word', text: t.text, range: t.range, isSelection: t.isSelected };
/**
* Find the end of a "run" of terminals whose weights satisfy a predicate.
* Whitespace terminals are included if the next non-whitespace terminal also satisfies the predicate.
* Note: The returned index may point to a whitespace token that is NOT included in the run;
* the caller is responsible for trimming trailing whitespace from the run boundaries.
*/
function findRunEnd(terminals: Terminal[], start: number, predicate: (t: Terminal) => boolean): number {
let j = start;
while (j < terminals.length) {
const next = terminals[j]!;
if (predicate(next)) {
j++;
} else if (next.type === 'whitespace') {
// Look ahead past consecutive whitespace
let k = j + 1;
while (k < terminals.length && terminals[k]!.type === 'whitespace') {
k++;
}
if (k < terminals.length && predicate(terminals[k]!)) {
j = k;
} else {
break;
}
} else {
break;
}
}
if (t.type === 'whitespace') {
return { type: 'whitespace', value: t.text, range: t.range, isSelection: t.isSelected };
}
if (t.type === 'punct') {
return { type: 'punct', value: t.text, range: t.range, isSelection: t.isSelected };
}
if (t.type === 'embedding') {
return { type: 'embedding', value: t.text, range: t.range, isSelection: t.isSelected };
}
if (t.type === 'escaped_paren') {
return { type: 'escaped_paren', value: t.text as '(' | ')', range: t.range, isSelection: t.isSelected };
}
return { type: 'word', text: t.text, range: t.range, isSelection: t.isSelected };
return j;
}
function addAttention(current: Attention | undefined, added: string): Attention | undefined {
if (!current) {
return added;
/**
* Convert a Terminal back into a leaf ASTNode.
*/
function createNodeFromTerminal(t: Terminal): ASTNode {
switch (t.type) {
case 'word':
return { type: 'word', text: t.text, range: t.range, isSelection: t.isSelected || undefined };
case 'whitespace':
return { type: 'whitespace', value: t.text, range: t.range, isSelection: t.isSelected || undefined };
case 'punct':
return { type: 'punct', value: t.text, range: t.range, isSelection: t.isSelected || undefined };
case 'embedding':
return { type: 'embedding', value: t.text, range: t.range, isSelection: t.isSelected || undefined };
case 'escaped_paren':
return {
type: 'escaped_paren',
value: t.text as '(' | ')',
range: t.range,
isSelection: t.isSelected || undefined,
};
default:
return { type: 'word', text: t.text, range: t.range, isSelection: t.isSelected || undefined };
}
if (typeof current === 'number') {
if (added === '+') {
return current * ATTENTION_STEP;
}
if (added === '-') {
return current / ATTENTION_STEP;
}
return current;
}
if (added === '+') {
if (current.startsWith('-')) {
const res = current.substring(1);
return res === '' ? undefined : res;
}
return `${current}+`;
}
if (added === '-') {
if (current.startsWith('+')) {
const res = current.substring(1);
return res === '' ? undefined : res;
}
return `${current}-`;
}
return current;
}

View File

@@ -0,0 +1,246 @@
import {
Box,
Button,
Center,
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Grid,
GridItem,
Heading,
Input,
Spinner,
Text,
VStack,
} from '@invoke-ai/ui-library';
import type { ChangeEvent, FormEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router-dom';
import { useGetSetupStatusQuery, useSetupMutation } from 'services/api/endpoints/auth';
const validatePasswordStrength = (
password: string,
t: (key: string) => string
): { isValid: boolean; message: string } => {
if (password.length < 8) {
return { isValid: false, message: t('auth.setup.passwordTooShort') };
}
const hasUpper = /[A-Z]/.test(password);
const hasLower = /[a-z]/.test(password);
const hasDigit = /\d/.test(password);
if (!hasUpper || !hasLower || !hasDigit) {
return {
isValid: false,
message: t('auth.setup.passwordMissingRequirements'),
};
}
return { isValid: true, message: '' };
};
export const AdministratorSetup = memo(() => {
const { t } = useTranslation();
const navigate = useNavigate();
const [email, setEmail] = useState('');
const [displayName, setDisplayName] = useState('');
const [password, setPassword] = useState('');
const [confirmPassword, setConfirmPassword] = useState('');
const [setup, { isLoading, error }] = useSetupMutation();
const { data: setupStatus, isLoading: isLoadingSetup } = useGetSetupStatusQuery();
// Redirect to app if multiuser mode is disabled
useEffect(() => {
if (!isLoadingSetup && setupStatus && !setupStatus.multiuser_enabled) {
navigate('/app', { replace: true });
}
}, [setupStatus, isLoadingSetup, navigate]);
const passwordValidation = validatePasswordStrength(password, t);
const passwordsMatch = password === confirmPassword;
const handleSubmit = useCallback(
async (e: FormEvent) => {
e.preventDefault();
if (!passwordValidation.isValid) {
return;
}
if (!passwordsMatch) {
return;
}
try {
const result = await setup({ email, display_name: displayName, password }).unwrap();
if (result.success) {
// Auto-login after setup - need to call login API
// For now, just redirect to login page
window.location.href = '/login';
}
} catch {
// Error is handled by RTK Query and displayed via error state
}
},
[email, displayName, password, passwordValidation.isValid, passwordsMatch, setup]
);
const handleEmailChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setEmail(e.target.value);
}, []);
const handleDisplayNameChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setDisplayName(e.target.value);
}, []);
const handlePasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setPassword(e.target.value);
}, []);
const handleConfirmPasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setConfirmPassword(e.target.value);
}, []);
const errorMessage = error
? 'data' in error && typeof error.data === 'object' && error.data && 'detail' in error.data
? String(error.data.detail)
: t('auth.setup.setupFailed')
: null;
// Show loading spinner while checking setup status or redirecting
if (isLoadingSetup || (setupStatus && !setupStatus.multiuser_enabled)) {
return (
<Center w="100dvw" h="100dvh">
<Spinner size="xl" />
</Center>
);
}
return (
<Center w="100dvw" h="100dvh" bg="base.900">
<Box w="full" maxW="600px" p={8} borderRadius="lg" bg="base.800" boxShadow="dark-lg">
<form onSubmit={handleSubmit}>
<VStack spacing={6} align="stretch">
<VStack spacing={2}>
<Heading size="lg" textAlign="center">
{t('auth.setup.title')}
</Heading>
<Text fontSize="sm" color="base.400" textAlign="center">
{t('auth.setup.subtitle')}
</Text>
</VStack>
<FormControl isRequired>
<Grid templateColumns="140px 1fr" gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" pt={2} mb={0}>
{t('auth.setup.email')}
</FormLabel>
</GridItem>
<GridItem>
<Input
type="email"
value={email}
onChange={handleEmailChange}
placeholder={t('auth.setup.emailPlaceholder')}
autoComplete="email"
autoFocus
/>
<FormHelperText mt={1}>{t('auth.setup.emailHelper')}</FormHelperText>
</GridItem>
</Grid>
</FormControl>
<FormControl isRequired>
<Grid templateColumns="140px 1fr" gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" pt={2} mb={0}>
{t('auth.setup.displayName')}
</FormLabel>
</GridItem>
<GridItem>
<Input
type="text"
value={displayName}
onChange={handleDisplayNameChange}
placeholder={t('auth.setup.displayNamePlaceholder')}
/>
<FormHelperText mt={1}>{t('auth.setup.displayNameHelper')}</FormHelperText>
</GridItem>
</Grid>
</FormControl>
<FormControl isRequired isInvalid={password.length > 0 && !passwordValidation.isValid}>
<Grid templateColumns="140px 1fr" gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" pt={2} mb={0}>
{t('auth.setup.password')}
</FormLabel>
</GridItem>
<GridItem>
<Input
type="password"
value={password}
onChange={handlePasswordChange}
placeholder={t('auth.setup.passwordPlaceholder')}
autoComplete="new-password"
/>
{password.length > 0 && !passwordValidation.isValid && (
<FormErrorMessage>{passwordValidation.message}</FormErrorMessage>
)}
{password.length === 0 && <FormHelperText mt={1}>{t('auth.setup.passwordHelper')}</FormHelperText>}
</GridItem>
</Grid>
</FormControl>
<FormControl isRequired isInvalid={confirmPassword.length > 0 && !passwordsMatch}>
<Grid templateColumns="140px 1fr" gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" pt={2} mb={0}>
{t('auth.setup.confirmPassword')}
</FormLabel>
</GridItem>
<GridItem>
<Input
type="password"
value={confirmPassword}
onChange={handleConfirmPasswordChange}
placeholder={t('auth.setup.confirmPasswordPlaceholder')}
autoComplete="new-password"
/>
{confirmPassword.length > 0 && !passwordsMatch && (
<FormErrorMessage>{t('auth.setup.passwordsDoNotMatch')}</FormErrorMessage>
)}
</GridItem>
</Grid>
</FormControl>
<Button
type="submit"
isLoading={isLoading}
loadingText={t('auth.setup.creatingAccount')}
colorScheme="invokeBlue"
size="lg"
w="full"
isDisabled={!passwordValidation.isValid || !passwordsMatch}
>
{t('auth.setup.createAccount')}
</Button>
{errorMessage && (
<Flex p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm" justifyContent="center">
<Text fontWeight="semibold">{errorMessage}</Text>
</Flex>
)}
</VStack>
</form>
</Box>
</Center>
);
});
AdministratorSetup.displayName = 'AdministratorSetup';

View File

@@ -0,0 +1,168 @@
import {
Box,
Button,
Center,
Checkbox,
Flex,
FormControl,
FormErrorMessage,
FormLabel,
Heading,
Input,
Spinner,
Text,
VStack,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { setCredentials } from 'features/auth/store/authSlice';
import type { ChangeEvent, FormEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router-dom';
import { useGetSetupStatusQuery, useLoginMutation } from 'services/api/endpoints/auth';
export const LoginPage = memo(() => {
const { t } = useTranslation();
const navigate = useNavigate();
const [email, setEmail] = useState('');
const [password, setPassword] = useState('');
const [rememberMe, setRememberMe] = useState(true);
const [login, { isLoading, error }] = useLoginMutation();
const dispatch = useAppDispatch();
const { data: setupStatus, isLoading: isLoadingSetup } = useGetSetupStatusQuery();
// Redirect to app if multiuser mode is disabled
useEffect(() => {
if (!isLoadingSetup && setupStatus && !setupStatus.multiuser_enabled) {
navigate('/app', { replace: true });
}
}, [setupStatus, isLoadingSetup, navigate]);
// Redirect to setup page if setup is required
useEffect(() => {
if (!isLoadingSetup && setupStatus?.setup_required) {
navigate('/setup', { replace: true });
}
}, [setupStatus, isLoadingSetup, navigate]);
const handleSubmit = useCallback(
async (e: FormEvent) => {
e.preventDefault();
try {
const result = await login({ email, password, remember_me: rememberMe }).unwrap();
// Map the UserDTO from API to our User type
const user = {
user_id: result.user.user_id,
email: result.user.email,
display_name: result.user.display_name || null,
is_admin: result.user.is_admin || false,
is_active: result.user.is_active || true,
};
dispatch(setCredentials({ token: result.token, user }));
// Force a page reload to ensure all user-specific state is loaded from server
// This is important for multiuser isolation to prevent state leakage
window.location.href = '/app';
} catch {
// Error is handled by RTK Query and displayed via error state
}
},
[email, password, rememberMe, login, dispatch]
);
const handleEmailChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setEmail(e.target.value);
}, []);
const handlePasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setPassword(e.target.value);
}, []);
const handleRememberMeChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setRememberMe(e.target.checked);
}, []);
const errorMessage = error
? 'data' in error && typeof error.data === 'object' && error.data && 'detail' in error.data
? String(error.data.detail)
: t('auth.login.loginFailed')
: null;
// Show loading spinner while checking setup status or redirecting
if (isLoadingSetup || (setupStatus && !setupStatus.multiuser_enabled)) {
return (
<Center w="100dvw" h="100dvh">
<Spinner size="xl" />
</Center>
);
}
// Show loading spinner if setup is required (redirecting to setup)
if (setupStatus?.setup_required) {
return (
<Center w="100dvw" h="100dvh">
<Spinner size="xl" />
</Center>
);
}
return (
<Center w="100dvw" h="100dvh" bg="base.900">
<Box w="full" maxW="400px" p={8} borderRadius="lg" bg="base.800" boxShadow="dark-lg">
<form onSubmit={handleSubmit}>
<VStack spacing={6} align="stretch">
<Heading size="lg" textAlign="center">
{t('auth.login.title')}
</Heading>
<FormControl isRequired isInvalid={!!errorMessage}>
<FormLabel>{t('auth.login.email')}</FormLabel>
<Input
type="email"
value={email}
onChange={handleEmailChange}
placeholder={t('auth.login.emailPlaceholder')}
autoComplete="email"
autoFocus
/>
</FormControl>
<FormControl isRequired isInvalid={!!errorMessage}>
<FormLabel>{t('auth.login.password')}</FormLabel>
<Input
type="password"
value={password}
onChange={handlePasswordChange}
placeholder={t('auth.login.passwordPlaceholder')}
autoComplete="current-password"
/>
{errorMessage && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</FormControl>
<Checkbox isChecked={rememberMe} onChange={handleRememberMeChange}>
{t('auth.login.rememberMe')}
</Checkbox>
<Button
type="submit"
isLoading={isLoading}
loadingText={t('auth.login.signingIn')}
colorScheme="invokeBlue"
size="lg"
w="full"
>
{t('auth.login.signIn')}
</Button>
{errorMessage && (
<Flex p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm" justifyContent="center">
<Text fontWeight="semibold">{errorMessage}</Text>
</Flex>
)}
</VStack>
</form>
</Box>
</Center>
);
});
LoginPage.displayName = 'LoginPage';

View File

@@ -0,0 +1,100 @@
import { Center, Spinner } from '@invoke-ai/ui-library';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { logout, setCredentials } from 'features/auth/store/authSlice';
import type { PropsWithChildren } from 'react';
import { memo, useEffect } from 'react';
import { useNavigate } from 'react-router-dom';
import { useGetCurrentUserQuery, useGetSetupStatusQuery } from 'services/api/endpoints/auth';
interface ProtectedRouteProps {
requireAdmin?: boolean;
}
export const ProtectedRoute = memo(({ children, requireAdmin = false }: PropsWithChildren<ProtectedRouteProps>) => {
const isAuthenticated = useAppSelector((state: RootState) => state.auth?.isAuthenticated || false);
const token = useAppSelector((state: RootState) => state.auth?.token);
const user = useAppSelector((state: RootState) => state.auth?.user);
const navigate = useNavigate();
const dispatch = useAppDispatch();
// Check if multiuser mode is enabled
const { data: setupStatus } = useGetSetupStatusQuery();
const multiuserEnabled = setupStatus?.multiuser_enabled ?? true; // Default to true for safety
// Only fetch user if we have a token but no user data, and multiuser mode is enabled
const shouldFetchUser = multiuserEnabled && isAuthenticated && token && !user;
const {
data: currentUser,
isLoading: isLoadingUser,
error: userError,
} = useGetCurrentUserQuery(undefined, {
skip: !shouldFetchUser,
});
useEffect(() => {
// If we have a token but fetching user failed, token is invalid - logout
if (userError && isAuthenticated) {
dispatch(logout());
navigate('/login', { replace: true });
}
}, [userError, isAuthenticated, dispatch, navigate]);
useEffect(() => {
// If we successfully fetched user data, update auth state
if (currentUser && token && !user) {
const userObj = {
user_id: currentUser.user_id,
email: currentUser.email,
display_name: currentUser.display_name || null,
is_admin: currentUser.is_admin || false,
is_active: currentUser.is_active || true,
};
dispatch(setCredentials({ token, user: userObj }));
}
}, [currentUser, token, user, dispatch]);
useEffect(() => {
// If multiuser is disabled, allow access without authentication
if (!multiuserEnabled) {
// Clear any persisted auth state when switching to single-user mode
if (isAuthenticated) {
dispatch(logout());
}
return;
}
// In multiuser mode, check authentication
if (!isLoadingUser && !isAuthenticated) {
navigate('/login', { replace: true });
} else if (!isLoadingUser && isAuthenticated && user && requireAdmin && !user.is_admin) {
navigate('/', { replace: true });
}
}, [isAuthenticated, isLoadingUser, requireAdmin, user, navigate, multiuserEnabled, dispatch]);
// In single-user mode, always allow access
if (!multiuserEnabled) {
return <>{children}</>;
}
// Show loading while fetching user data
if (isLoadingUser || (isAuthenticated && !user)) {
return (
<Center w="100dvw" h="100dvh">
<Spinner size="xl" />
</Center>
);
}
if (!isAuthenticated) {
return null;
}
if (requireAdmin && !user?.is_admin) {
return null;
}
return <>{children}</>;
});
ProtectedRoute.displayName = 'ProtectedRoute';

View File

@@ -0,0 +1,640 @@
import {
Badge,
Box,
Button,
Center,
Checkbox,
Flex,
FormControl,
FormErrorMessage,
FormLabel,
Grid,
GridItem,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Spinner,
Switch,
Table,
Tbody,
Td,
Text,
Th,
Thead,
Tooltip,
Tr,
useDisclosure,
VStack,
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import type { ChangeEvent, FormEvent } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
PiArrowLeftBold,
PiEyeBold,
PiEyeSlashBold,
PiLightningFill,
PiPencilBold,
PiPlusBold,
PiTrashBold,
} from 'react-icons/pi';
import { useNavigate } from 'react-router-dom';
import type { UserDTO } from 'services/api/endpoints/auth';
import {
useCreateUserMutation,
useDeleteUserMutation,
useLazyGeneratePasswordQuery,
useListUsersQuery,
useUpdateUserMutation,
} from 'services/api/endpoints/auth';
const validatePasswordStrength = (
password: string,
t: (key: string) => string
): { isValid: boolean; message: string } => {
if (password.length === 0) {
return { isValid: true, message: '' };
}
if (password.length < 8) {
return { isValid: false, message: t('auth.setup.passwordTooShort') };
}
const hasUpper = /[A-Z]/.test(password);
const hasLower = /[a-z]/.test(password);
const hasDigit = /\d/.test(password);
if (!hasUpper || !hasLower || !hasDigit) {
return { isValid: false, message: t('auth.setup.passwordMissingRequirements') };
}
return { isValid: true, message: '' };
};
const FORM_GRID_COLUMNS = '120px 1fr';
// ---------------------------------------------------------------------------
// Create / Edit user modal
// ---------------------------------------------------------------------------
type UserFormModalProps = {
isOpen: boolean;
onClose: () => void;
/** When provided, the modal operates in "edit" mode for the given user */
editUser?: UserDTO | null;
};
const UserFormModal = memo(({ isOpen, onClose, editUser }: UserFormModalProps) => {
const { t } = useTranslation();
const isEdit = !!editUser;
const [email, setEmail] = useState(editUser?.email ?? '');
const [displayName, setDisplayName] = useState(editUser?.display_name ?? '');
const [password, setPassword] = useState('');
const [isAdmin, setIsAdmin] = useState(editUser?.is_admin ?? false);
const [showPassword, setShowPassword] = useState(false);
const [error, setError] = useState<string | null>(null);
const [createUser, { isLoading: isCreating }] = useCreateUserMutation();
const [updateUser, { isLoading: isUpdating }] = useUpdateUserMutation();
const [triggerGeneratePassword] = useLazyGeneratePasswordQuery();
const isLoading = isCreating || isUpdating;
const passwordValidation = validatePasswordStrength(password, t);
const handleGeneratePassword = useCallback(async () => {
try {
const result = await triggerGeneratePassword().unwrap();
setPassword(result.password);
setShowPassword(true);
} catch {
// ignore
}
}, [triggerGeneratePassword]);
const toggleShowPassword = useCallback(() => {
setShowPassword((v) => !v);
}, []);
const handleEmailChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setEmail(e.target.value);
}, []);
const handleDisplayNameChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setDisplayName(e.target.value);
}, []);
const handlePasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setPassword(e.target.value);
}, []);
const handleIsAdminChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setIsAdmin(e.target.checked);
}, []);
const handleSubmit = useCallback(
async (e: FormEvent) => {
e.preventDefault();
setError(null);
if (!isEdit && (!password || !passwordValidation.isValid)) {
return;
}
if (isEdit && password && !passwordValidation.isValid) {
return;
}
try {
if (isEdit && editUser) {
const updateData: Parameters<typeof updateUser>[0]['data'] = {
display_name: displayName || null,
is_admin: isAdmin,
};
if (password) {
updateData.password = password;
}
await updateUser({
userId: editUser.user_id,
data: updateData,
}).unwrap();
} else {
await createUser({
email,
display_name: displayName || null,
password,
is_admin: isAdmin,
}).unwrap();
}
onClose();
} catch (err) {
const detail =
err && typeof err === 'object' && 'data' in err && typeof (err as { data: unknown }).data === 'object'
? ((err as { data: { detail?: string } }).data?.detail ?? t('auth.userManagement.saveFailed'))
: t('auth.userManagement.saveFailed');
setError(detail);
}
},
[
isEdit,
editUser,
email,
displayName,
password,
isAdmin,
passwordValidation.isValid,
createUser,
updateUser,
onClose,
t,
]
);
// Reset local state when modal closes
const handleClose = useCallback(() => {
setEmail(editUser?.email ?? '');
setDisplayName(editUser?.display_name ?? '');
setPassword('');
setIsAdmin(editUser?.is_admin ?? false);
setShowPassword(false);
setError(null);
onClose();
}, [editUser, onClose]);
return (
<Modal isOpen={isOpen} onClose={handleClose} isCentered size="md">
<ModalOverlay />
<ModalContent bg="base.800">
<form onSubmit={handleSubmit}>
<ModalHeader>{isEdit ? t('auth.userManagement.editUser') : t('auth.userManagement.createUser')}</ModalHeader>
<ModalCloseButton />
<ModalBody>
<VStack spacing={4}>
{!isEdit && (
<FormControl isRequired>
<Grid templateColumns={FORM_GRID_COLUMNS} gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" mb={0} pt={2}>
{t('auth.userManagement.email')}
</FormLabel>
</GridItem>
<GridItem>
<Input
type="email"
value={email}
onChange={handleEmailChange}
placeholder={t('auth.userManagement.emailPlaceholder')}
autoComplete="off"
/>
</GridItem>
</Grid>
</FormControl>
)}
<FormControl>
<Grid templateColumns={FORM_GRID_COLUMNS} gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" mb={0} pt={2}>
{t('auth.userManagement.displayName')}
</FormLabel>
</GridItem>
<GridItem>
<Input
type="text"
value={displayName}
onChange={handleDisplayNameChange}
placeholder={t('auth.userManagement.displayNamePlaceholder')}
/>
</GridItem>
</Grid>
</FormControl>
<FormControl isInvalid={password.length > 0 && !passwordValidation.isValid} isRequired={!isEdit}>
<Grid templateColumns={FORM_GRID_COLUMNS} gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" mb={0} pt={2}>
{isEdit ? t('auth.userManagement.newPassword') : t('auth.userManagement.password')}
</FormLabel>
</GridItem>
<GridItem>
<InputGroup>
<Input
type={showPassword ? 'text' : 'password'}
value={password}
onChange={handlePasswordChange}
placeholder={
isEdit
? t('auth.userManagement.newPasswordPlaceholder')
: t('auth.userManagement.passwordPlaceholder')
}
autoComplete="new-password"
pr="4.5rem"
/>
<InputRightElement w="4.5rem">
<Tooltip
label={
showPassword ? t('auth.userManagement.hidePassword') : t('auth.userManagement.showPassword')
}
>
<IconButton
aria-label={
showPassword
? t('auth.userManagement.hidePassword')
: t('auth.userManagement.showPassword')
}
icon={showPassword ? <PiEyeSlashBold /> : <PiEyeBold />}
variant="ghost"
size="sm"
onClick={toggleShowPassword}
tabIndex={-1}
/>
</Tooltip>
</InputRightElement>
</InputGroup>
{password.length > 0 && !passwordValidation.isValid && (
<FormErrorMessage>{passwordValidation.message}</FormErrorMessage>
)}
</GridItem>
</Grid>
</FormControl>
<Grid templateColumns={FORM_GRID_COLUMNS} gap={4} w="full">
<GridItem />
<GridItem>
<Button size="sm" variant="ghost" onClick={handleGeneratePassword} leftIcon={<PiLightningFill />}>
{t('auth.userManagement.generatePassword')}
</Button>
</GridItem>
</Grid>
<FormControl display="flex" alignItems="center">
<FormLabel mb={0}>{t('auth.userManagement.isAdmin')}</FormLabel>
<Checkbox isChecked={isAdmin} onChange={handleIsAdminChange} />
</FormControl>
{error && (
<Flex p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm" w="full">
<Text fontWeight="semibold">{error}</Text>
</Flex>
)}
</VStack>
</ModalBody>
<ModalFooter gap={2}>
<Button variant="ghost" onClick={handleClose}>
{t('common.cancel')}
</Button>
<Button
type="submit"
colorScheme="invokeBlue"
isLoading={isLoading}
isDisabled={!isEdit && (!password || !passwordValidation.isValid)}
>
{isEdit ? t('common.save') : t('auth.userManagement.createUser')}
</Button>
</ModalFooter>
</form>
</ModalContent>
</Modal>
);
});
UserFormModal.displayName = 'UserFormModal';
// ---------------------------------------------------------------------------
// Delete confirmation modal
// ---------------------------------------------------------------------------
type DeleteUserModalProps = {
isOpen: boolean;
onClose: () => void;
user: UserDTO | null;
};
const DeleteUserModal = memo(({ isOpen, onClose, user }: DeleteUserModalProps) => {
const { t } = useTranslation();
const [deleteUser, { isLoading }] = useDeleteUserMutation();
const [error, setError] = useState<string | null>(null);
const handleDelete = useCallback(async () => {
if (!user) {
return;
}
setError(null);
try {
await deleteUser(user.user_id).unwrap();
onClose();
} catch (err) {
const detail =
err && typeof err === 'object' && 'data' in err && typeof (err as { data: unknown }).data === 'object'
? ((err as { data: { detail?: string } }).data?.detail ?? t('auth.userManagement.deleteFailed'))
: t('auth.userManagement.deleteFailed');
setError(detail);
}
}, [user, deleteUser, onClose, t]);
const handleClose = useCallback(() => {
setError(null);
onClose();
}, [onClose]);
return (
<Modal isOpen={isOpen} onClose={handleClose} isCentered size="sm">
<ModalOverlay />
<ModalContent bg="base.800">
<ModalHeader>{t('auth.userManagement.deleteUser')}</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Text>
{t('auth.userManagement.deleteConfirm', {
name: user?.display_name ?? user?.email ?? '',
})}
</Text>
{error && (
<Flex mt={3} p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm">
<Text fontWeight="semibold">{error}</Text>
</Flex>
)}
</ModalBody>
<ModalFooter gap={2}>
<Button variant="ghost" onClick={handleClose}>
{t('common.cancel')}
</Button>
<Button colorScheme="error" isLoading={isLoading} onClick={handleDelete}>
{t('common.delete')}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
});
DeleteUserModal.displayName = 'DeleteUserModal';
// ---------------------------------------------------------------------------
// Inline active/inactive toggle
// Wrapping the Switch in a Box lets the Tooltip track mouse-enter/leave
// correctly; without it the tooltip may not dismiss on mouse-out.
// ---------------------------------------------------------------------------
const UserStatusToggle = memo(({ user, isCurrentUser }: { user: UserDTO; isCurrentUser: boolean }) => {
const { t } = useTranslation();
const [updateUser, { isLoading }] = useUpdateUserMutation();
const handleChange = useCallback(
async (e: ChangeEvent<HTMLInputElement>) => {
await updateUser({ userId: user.user_id, data: { is_active: e.target.checked } })
.unwrap()
.catch(() => null);
},
[user.user_id, updateUser]
);
const tooltipLabel = isCurrentUser
? t('auth.userManagement.cannotDeactivateSelf')
: user.is_active
? t('auth.userManagement.deactivate')
: t('auth.userManagement.activate');
return (
<Tooltip label={tooltipLabel}>
<Box as="span" display="inline-flex">
<Switch isChecked={user.is_active} onChange={handleChange} isDisabled={isLoading || isCurrentUser} size="sm" />
</Box>
</Tooltip>
);
});
UserStatusToggle.displayName = 'UserStatusToggle';
// ---------------------------------------------------------------------------
// Main component
// ---------------------------------------------------------------------------
export const UserManagement = memo(() => {
const { t } = useTranslation();
const currentUser = useAppSelector(selectCurrentUser);
const navigate = useNavigate();
const { data: users, isLoading, error } = useListUsersQuery();
const createModal = useDisclosure();
const editModal = useDisclosure();
const deleteModal = useDisclosure();
const [selectedUser, setSelectedUser] = useState<UserDTO | null>(null);
const handleBack = useCallback(() => {
navigate(-1);
}, [navigate]);
const handleEdit = useCallback(
(user: UserDTO) => {
setSelectedUser(user);
editModal.onOpen();
},
[editModal]
);
const handleDelete = useCallback(
(user: UserDTO) => {
setSelectedUser(user);
deleteModal.onOpen();
},
[deleteModal]
);
const handleEditClose = useCallback(() => {
editModal.onClose();
setSelectedUser(null);
}, [editModal]);
const handleDeleteClose = useCallback(() => {
deleteModal.onClose();
setSelectedUser(null);
}, [deleteModal]);
if (isLoading) {
return (
<Center py={12}>
<Spinner size="xl" />
</Center>
);
}
if (error) {
return (
<Center py={12}>
<Text color="error.400">{t('auth.userManagement.loadFailed')}</Text>
</Center>
);
}
return (
<Box p={6}>
<Flex justify="space-between" align="center" mb={6}>
<Flex align="center" gap={4}>
<Button leftIcon={<PiArrowLeftBold />} variant="outline" size="sm" onClick={handleBack}>
{t('auth.userManagement.back')}
</Button>
<Heading size="md">{t('auth.userManagement.title')}</Heading>
</Flex>
<Button leftIcon={<PiPlusBold />} colorScheme="invokeBlue" size="sm" onClick={createModal.onOpen}>
{t('auth.userManagement.createUser')}
</Button>
</Flex>
<Box overflowX="auto">
<Table variant="simple" size="sm">
<Thead>
<Tr>
<Th>{t('auth.userManagement.email')}</Th>
<Th>{t('auth.userManagement.displayName')}</Th>
<Th>{t('auth.userManagement.role')}</Th>
<Th>{t('auth.userManagement.status')}</Th>
<Th>{t('auth.userManagement.actions')}</Th>
</Tr>
</Thead>
<Tbody>
{(users ?? []).map((user) => (
<UserRow
key={user.user_id}
user={user}
isCurrentUser={user.user_id === currentUser?.user_id}
onEdit={handleEdit}
onDelete={handleDelete}
/>
))}
</Tbody>
</Table>
</Box>
{/* Create user modal */}
<UserFormModal isOpen={createModal.isOpen} onClose={createModal.onClose} />
{/* Edit user modal */}
<UserFormModal isOpen={editModal.isOpen} onClose={handleEditClose} editUser={selectedUser} />
{/* Delete confirmation modal */}
<DeleteUserModal isOpen={deleteModal.isOpen} onClose={handleDeleteClose} user={selectedUser} />
</Box>
);
});
UserManagement.displayName = 'UserManagement';
// ---------------------------------------------------------------------------
// User table row
// ---------------------------------------------------------------------------
type UserRowProps = {
user: UserDTO;
isCurrentUser: boolean;
onEdit: (user: UserDTO) => void;
onDelete: (user: UserDTO) => void;
};
const UserRow = memo(({ user, isCurrentUser, onEdit, onDelete }: UserRowProps) => {
const { t } = useTranslation();
const handleEdit = useCallback(() => {
onEdit(user);
}, [user, onEdit]);
const handleDelete = useCallback(() => {
onDelete(user);
}, [user, onDelete]);
return (
<Tr>
<Td>
<Text fontSize="sm">{user.email}</Text>
{isCurrentUser && (
<Badge colorScheme="invokeBlue" size="xs" ml={1}>
{t('auth.userManagement.you')}
</Badge>
)}
</Td>
<Td>
<Text fontSize="sm">{user.display_name ?? '—'}</Text>
</Td>
<Td>
{user.is_admin ? (
<Badge colorScheme="invokeYellow">{t('auth.admin')}</Badge>
) : (
<Badge colorScheme="base">{t('auth.userManagement.user')}</Badge>
)}
</Td>
<Td>
<UserStatusToggle user={user} isCurrentUser={isCurrentUser} />
</Td>
<Td>
<Flex gap={1}>
<Tooltip label={t('auth.userManagement.editUser')}>
<IconButton
aria-label={t('auth.userManagement.editUser')}
icon={<PiPencilBold />}
variant="ghost"
size="sm"
onClick={handleEdit}
/>
</Tooltip>
<Tooltip
label={isCurrentUser ? t('auth.userManagement.cannotDeleteSelf') : t('auth.userManagement.deleteUser')}
>
<IconButton
aria-label={t('auth.userManagement.deleteUser')}
icon={<PiTrashBold />}
variant="ghost"
size="sm"
colorScheme="error"
isDisabled={isCurrentUser}
onClick={handleDelete}
/>
</Tooltip>
</Flex>
</Td>
</Tr>
);
});
UserRow.displayName = 'UserRow';

View File

@@ -0,0 +1,87 @@
import { Badge, Flex, IconButton, Menu, MenuButton, MenuItem, MenuList, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { logout, selectCurrentUser } from 'features/auth/store/authSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearBold, PiSignOutBold, PiUserBold, PiUsersBold } from 'react-icons/pi';
import { useNavigate } from 'react-router-dom';
import { useLogoutMutation } from 'services/api/endpoints/auth';
export const UserMenu = memo(() => {
const { t } = useTranslation();
const user = useAppSelector(selectCurrentUser);
const dispatch = useAppDispatch();
const navigate = useNavigate();
const [logoutMutation] = useLogoutMutation();
const handleLogout = useCallback(() => {
// Call backend logout endpoint
logoutMutation()
.unwrap()
.catch(() => {
// Ignore errors - we'll log out locally anyway
})
.finally(() => {
// Clear local state regardless of backend response
dispatch(logout());
navigate('/login');
});
}, [dispatch, navigate, logoutMutation]);
const handleProfile = useCallback(() => {
navigate('/profile');
}, [navigate]);
const handleUserManagement = useCallback(() => {
navigate('/admin/users');
}, [navigate]);
if (!user) {
return null;
}
return (
<Menu>
<Tooltip label={t('auth.userMenu')}>
<MenuButton
as={IconButton}
aria-label={t('auth.userMenu')}
icon={<PiUserBold />}
variant="link"
minW={8}
w={8}
h={8}
borderRadius="base"
/>
</Tooltip>
<MenuList>
<Flex px={3} py={2} flexDir="column" gap={1}>
<Text fontSize="sm" fontWeight="semibold" noOfLines={1}>
{user.display_name || user.email}
</Text>
<Text fontSize="xs" color="base.500" noOfLines={1}>
{user.email}
</Text>
{user.is_admin && (
<Badge colorScheme="invokeYellow" size="sm" alignSelf="flex-start" mt={1}>
{t('auth.admin')}
</Badge>
)}
</Flex>
<MenuItem icon={<PiGearBold />} onClick={handleProfile}>
{t('auth.profile.menuItem')}
</MenuItem>
{user.is_admin && (
<MenuItem icon={<PiUsersBold />} onClick={handleUserManagement}>
{t('auth.userManagement.menuItem')}
</MenuItem>
)}
<MenuItem icon={<PiSignOutBold />} onClick={handleLogout}>
{t('auth.logout')}
</MenuItem>
</MenuList>
</Menu>
);
});
UserMenu.displayName = 'UserMenu';

View File

@@ -0,0 +1,390 @@
import {
Box,
Button,
Center,
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Grid,
GridItem,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
Spinner,
Text,
Tooltip,
VStack,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectAuthToken, selectCurrentUser, setCredentials } from 'features/auth/store/authSlice';
import type { ChangeEvent, FormEvent } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiEyeBold, PiEyeSlashBold, PiLightningFill } from 'react-icons/pi';
import { useNavigate } from 'react-router-dom';
import { useLazyGeneratePasswordQuery, useUpdateCurrentUserMutation } from 'services/api/endpoints/auth';
const validatePasswordStrength = (
password: string,
t: (key: string) => string
): { isValid: boolean; message: string } => {
if (password.length === 0) {
return { isValid: true, message: '' };
}
if (password.length < 8) {
return { isValid: false, message: t('auth.setup.passwordTooShort') };
}
const hasUpper = /[A-Z]/.test(password);
const hasLower = /[a-z]/.test(password);
const hasDigit = /\d/.test(password);
if (!hasUpper || !hasLower || !hasDigit) {
return { isValid: false, message: t('auth.setup.passwordMissingRequirements') };
}
return { isValid: true, message: '' };
};
const PASSWORD_GRID_COLUMNS = '180px 1fr';
export const UserProfile = memo(() => {
const { t } = useTranslation();
const currentUser = useAppSelector(selectCurrentUser);
const currentToken = useAppSelector(selectAuthToken);
const dispatch = useAppDispatch();
const navigate = useNavigate();
const [displayName, setDisplayName] = useState(currentUser?.display_name ?? '');
const [currentPassword, setCurrentPassword] = useState('');
const [newPassword, setNewPassword] = useState('');
const [confirmPassword, setConfirmPassword] = useState('');
const [showCurrentPassword, setShowCurrentPassword] = useState(false);
const [showNewPassword, setShowNewPassword] = useState(false);
const [showConfirmPassword, setShowConfirmPassword] = useState(false);
const [errorMessage, setErrorMessage] = useState<string | null>(null);
const [updateCurrentUser, { isLoading }] = useUpdateCurrentUserMutation();
const [triggerGeneratePassword] = useLazyGeneratePasswordQuery();
const newPasswordValidation = validatePasswordStrength(newPassword, t);
const isPasswordChangeAttempted = newPassword.length > 0 || currentPassword.length > 0;
const passwordsMatch = newPassword.length > 0 && newPassword === confirmPassword;
const isPasswordChangeValid =
!isPasswordChangeAttempted || (currentPassword.length > 0 && newPasswordValidation.isValid && passwordsMatch);
const handleCancel = useCallback(() => {
navigate(-1);
}, [navigate]);
const handleGeneratePassword = useCallback(async () => {
try {
const result = await triggerGeneratePassword().unwrap();
setNewPassword(result.password);
setConfirmPassword(result.password);
setShowNewPassword(true);
setShowConfirmPassword(true);
} catch {
// ignore
}
}, [triggerGeneratePassword]);
const toggleShowCurrentPassword = useCallback(() => {
setShowCurrentPassword((v) => !v);
}, []);
const toggleShowNewPassword = useCallback(() => {
setShowNewPassword((v) => !v);
}, []);
const toggleShowConfirmPassword = useCallback(() => {
setShowConfirmPassword((v) => !v);
}, []);
const handleDisplayNameChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setDisplayName(e.target.value);
}, []);
const handleCurrentPasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setCurrentPassword(e.target.value);
}, []);
const handleNewPasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNewPassword(e.target.value);
}, []);
const handleConfirmPasswordChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setConfirmPassword(e.target.value);
}, []);
const handleSubmit = useCallback(
async (e: FormEvent) => {
e.preventDefault();
setErrorMessage(null);
if (!isPasswordChangeValid) {
return;
}
try {
const updatePayload: Parameters<typeof updateCurrentUser>[0] = {
display_name: displayName || null,
};
if (newPassword) {
updatePayload.current_password = currentPassword;
updatePayload.new_password = newPassword;
}
const updatedUser = await updateCurrentUser(updatePayload).unwrap();
// Refresh the stored user info so the header reflects the new display name
if (currentToken) {
dispatch(
setCredentials({
token: currentToken,
user: {
user_id: updatedUser.user_id,
email: updatedUser.email,
display_name: updatedUser.display_name ?? null,
is_admin: updatedUser.is_admin ?? false,
is_active: updatedUser.is_active ?? true,
},
})
);
}
// Navigate back after successful save
navigate(-1);
} catch (err) {
const detail =
err && typeof err === 'object' && 'data' in err && typeof (err as { data: unknown }).data === 'object'
? ((err as { data: { detail?: string } }).data?.detail ?? t('auth.profile.saveFailed'))
: t('auth.profile.saveFailed');
setErrorMessage(detail);
}
},
[
displayName,
currentPassword,
newPassword,
isPasswordChangeValid,
updateCurrentUser,
currentToken,
dispatch,
navigate,
t,
]
);
if (!currentUser) {
return (
<Center py={12}>
<Spinner size="xl" />
</Center>
);
}
return (
<Box p={6} maxW="480px">
<Heading size="md" mb={6}>
{t('auth.profile.title')}
</Heading>
<form onSubmit={handleSubmit}>
<VStack spacing={5} align="stretch">
{/* Email (read-only) */}
<FormControl>
<FormLabel>{t('auth.profile.email')}</FormLabel>
<Input type="email" value={currentUser.email} isReadOnly opacity={0.6} />
<FormHelperText>{t('auth.profile.emailReadOnly')}</FormHelperText>
</FormControl>
{/* Display name */}
<FormControl>
<FormLabel>{t('auth.profile.displayName')}</FormLabel>
<Input
type="text"
value={displayName}
onChange={handleDisplayNameChange}
placeholder={t('auth.profile.displayNamePlaceholder')}
/>
</FormControl>
<Box borderTop="1px solid" borderColor="base.600" pt={4}>
<Text fontSize="sm" fontWeight="semibold" mb={4} color="base.300">
{t('auth.profile.changePassword')}
</Text>
{/* Current password */}
<FormControl mb={4} isRequired={newPassword.length > 0}>
<Grid templateColumns={PASSWORD_GRID_COLUMNS} gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" mb={0} pt={2}>
{t('auth.profile.currentPassword')}
</FormLabel>
</GridItem>
<GridItem>
<InputGroup>
<Input
type={showCurrentPassword ? 'text' : 'password'}
value={currentPassword}
onChange={handleCurrentPasswordChange}
placeholder={t('auth.profile.currentPasswordPlaceholder')}
autoComplete="current-password"
pr="3rem"
/>
<InputRightElement>
<Tooltip
label={
showCurrentPassword
? t('auth.userManagement.hidePassword')
: t('auth.userManagement.showPassword')
}
>
<IconButton
aria-label={
showCurrentPassword
? t('auth.userManagement.hidePassword')
: t('auth.userManagement.showPassword')
}
icon={showCurrentPassword ? <PiEyeSlashBold /> : <PiEyeBold />}
variant="ghost"
size="sm"
onClick={toggleShowCurrentPassword}
tabIndex={-1}
/>
</Tooltip>
</InputRightElement>
</InputGroup>
</GridItem>
</Grid>
</FormControl>
{/* New password */}
<FormControl isInvalid={newPassword.length > 0 && !newPasswordValidation.isValid} mb={4}>
<Grid templateColumns={PASSWORD_GRID_COLUMNS} gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" mb={0} pt={2}>
{t('auth.profile.newPassword')}
</FormLabel>
</GridItem>
<GridItem>
<InputGroup>
<Input
type={showNewPassword ? 'text' : 'password'}
value={newPassword}
onChange={handleNewPasswordChange}
placeholder={t('auth.profile.newPasswordPlaceholder')}
autoComplete="new-password"
pr="3rem"
/>
<InputRightElement>
<Tooltip
label={
showNewPassword
? t('auth.userManagement.hidePassword')
: t('auth.userManagement.showPassword')
}
>
<IconButton
aria-label={
showNewPassword
? t('auth.userManagement.hidePassword')
: t('auth.userManagement.showPassword')
}
icon={showNewPassword ? <PiEyeSlashBold /> : <PiEyeBold />}
variant="ghost"
size="sm"
onClick={toggleShowNewPassword}
tabIndex={-1}
/>
</Tooltip>
</InputRightElement>
</InputGroup>
{newPassword.length > 0 && !newPasswordValidation.isValid && (
<FormErrorMessage>{newPasswordValidation.message}</FormErrorMessage>
)}
</GridItem>
</Grid>
</FormControl>
{/* Confirm new password */}
<FormControl isInvalid={confirmPassword.length > 0 && !passwordsMatch} mb={4}>
<Grid templateColumns={PASSWORD_GRID_COLUMNS} gap={4} alignItems="start">
<GridItem>
<FormLabel textAlign="right" mb={0} pt={2}>
{t('auth.profile.confirmPassword')}
</FormLabel>
</GridItem>
<GridItem>
<InputGroup>
<Input
type={showConfirmPassword ? 'text' : 'password'}
value={confirmPassword}
onChange={handleConfirmPasswordChange}
placeholder={t('auth.profile.confirmPasswordPlaceholder')}
autoComplete="new-password"
pr="3rem"
/>
<InputRightElement>
<Tooltip
label={
showConfirmPassword
? t('auth.userManagement.hidePassword')
: t('auth.userManagement.showPassword')
}
>
<IconButton
aria-label={
showConfirmPassword
? t('auth.userManagement.hidePassword')
: t('auth.userManagement.showPassword')
}
icon={showConfirmPassword ? <PiEyeSlashBold /> : <PiEyeBold />}
variant="ghost"
size="sm"
onClick={toggleShowConfirmPassword}
tabIndex={-1}
/>
</Tooltip>
</InputRightElement>
</InputGroup>
{confirmPassword.length > 0 && !passwordsMatch && (
<FormErrorMessage>{t('auth.profile.passwordsDoNotMatch')}</FormErrorMessage>
)}
</GridItem>
</Grid>
</FormControl>
{/* Generate password button aligned with the input column */}
<Grid templateColumns={PASSWORD_GRID_COLUMNS} gap={4}>
<GridItem />
<GridItem>
<Button size="sm" variant="ghost" onClick={handleGeneratePassword} leftIcon={<PiLightningFill />}>
{t('auth.userManagement.generatePassword')}
</Button>
</GridItem>
</Grid>
</Box>
{errorMessage && (
<Flex p={3} borderRadius="md" bg="error.500" color="white" fontSize="sm">
<Text fontWeight="semibold">{errorMessage}</Text>
</Flex>
)}
<Flex gap={3}>
<Button variant="ghost" onClick={handleCancel}>
{t('common.cancel')}
</Button>
<Button type="submit" colorScheme="invokeBlue" isLoading={isLoading} isDisabled={!isPasswordChangeValid}>
{t('common.save')}
</Button>
</Flex>
</VStack>
</form>
</Box>
);
});
UserProfile.displayName = 'UserProfile';

View File

@@ -0,0 +1,83 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { SliceConfig } from 'app/store/types';
import { z } from 'zod';
const zUser = z.object({
user_id: z.string(),
email: z.string(),
display_name: z.string().nullable(),
is_admin: z.boolean(),
is_active: z.boolean(),
});
const zAuthState = z.object({
isAuthenticated: z.boolean(),
token: z.string().nullable(),
user: zUser.nullable(),
isLoading: z.boolean(),
});
type User = z.infer<typeof zUser>;
type AuthState = z.infer<typeof zAuthState>;
// Helper to safely access localStorage (not available in test environment)
const getStoredAuthToken = (): string | null => {
if (typeof window !== 'undefined' && window.localStorage) {
return localStorage.getItem('auth_token');
}
return null;
};
const initialState: AuthState = {
isAuthenticated: !!getStoredAuthToken(),
token: getStoredAuthToken(),
user: null,
isLoading: false,
};
const getInitialAuthState = (): AuthState => initialState;
const authSlice = createSlice({
name: 'auth',
initialState,
reducers: {
setCredentials: (state, action: PayloadAction<{ token: string; user: User }>) => {
state.token = action.payload.token;
state.user = action.payload.user;
state.isAuthenticated = true;
if (typeof window !== 'undefined' && window.localStorage) {
localStorage.setItem('auth_token', action.payload.token);
}
},
logout: (state) => {
state.token = null;
state.user = null;
state.isAuthenticated = false;
if (typeof window !== 'undefined' && window.localStorage) {
localStorage.removeItem('auth_token');
}
},
setLoading: (state, action: PayloadAction<boolean>) => {
state.isLoading = action.payload;
},
},
});
export const { setCredentials, logout, setLoading } = authSlice.actions;
export const authSliceConfig: SliceConfig<typeof authSlice> = {
slice: authSlice,
schema: zAuthState,
getInitialState: getInitialAuthState,
persistConfig: {
migrate: () => getInitialAuthState(),
// Don't persist auth state - token is stored in localStorage
persistDenylist: ['isAuthenticated', 'token', 'user', 'isLoading'],
},
};
export const selectIsAuthenticated = (state: { auth: AuthState }) => state.auth.isAuthenticated;
export const selectCurrentUser = (state: { auth: AuthState }) => state.auth.user;
export const selectAuthToken = (state: { auth: AuthState }) => state.auth.token;
export const selectIsAuthLoading = (state: { auth: AuthState }) => state.auth.isLoading;

View File

@@ -96,16 +96,14 @@ const FontSelect = () => {
<Text fontSize="sm" lineHeight="1" whiteSpace="nowrap">
{t('controlLayers.text.font', { defaultValue: 'Font' })}
</Text>
<Tooltip label={t('controlLayers.text.font', { defaultValue: 'Font' })}>
<Combobox
size="sm"
variant="outline"
isSearchable={false}
options={options}
value={selectedOption}
onChange={handleFontChange}
/>
</Tooltip>
<Combobox
size="sm"
variant="outline"
isSearchable={false}
options={options}
value={selectedOption}
onChange={handleFontChange}
/>
</Flex>
);
};

View File

@@ -6,6 +6,7 @@ import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
import { isPlainObject } from 'es-toolkit';
import { clamp } from 'es-toolkit/compat';
import { logout } from 'features/auth/store/authSlice';
import type { AspectRatioID, InfillMethod, ParamsState, RgbaColor } from 'features/controlLayers/store/types';
import {
ASPECT_RATIO_MAP,
@@ -428,6 +429,12 @@ const slice = createSlice({
},
paramsReset: (state) => resetState(state),
},
extraReducers(builder) {
// Reset params state on logout to prevent user data leakage when switching users
builder.addCase(logout, () => {
return getInitialParamsState();
});
},
});
const applyClipSkip = (state: { clipSkip: number }, model: ParameterModel | null, clipSkip: number) => {

View File

@@ -2,6 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import type { AddImageToBoardDndTargetData } from 'features/dnd/dnd';
import { addImageToBoardDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -36,6 +37,7 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const autoAssignBoardOnClick = useAppSelector(selectAutoAssignBoardOnClick);
const selectedBoardId = useAppSelector(selectSelectedBoardId);
const currentUser = useAppSelector(selectCurrentUser);
const onClick = useCallback(() => {
if (selectedBoardId !== board.board_id) {
dispatch(boardIdSelected({ boardId: board.board_id }));
@@ -58,6 +60,8 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
[board]
);
const showOwner = currentUser?.is_admin && board.owner_username;
return (
<Box position="relative" w="full" h={12}>
<BoardContextMenu board={board}>
@@ -85,8 +89,13 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
h="full"
>
<CoverImage board={board} />
<Flex flex={1}>
<Flex flex={1} direction="column" minW={0}>
<BoardEditableTitle board={board} isSelected={isSelected} />
{showOwner && (
<Text fontSize="xs" color="base.500" noOfLines={1}>
{board.owner_username}
</Text>
)}
</Flex>
{autoAddBoardId === board.board_id && <AutoAddBadge />}
{board.archived && <Icon as={PiArchiveBold} fill="base.300" />}

View File

@@ -13,6 +13,9 @@ import {
} from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { MutableRefObject } from 'react';
import React, { memo, useCallback, useEffect, useMemo, useRef } from 'react';
import type {
@@ -80,22 +83,41 @@ const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageNam
return `${JSON.stringify(queryArgs)}-${imageName ?? index}`;
};
const canHandleGridArrowNavigation = (
activeTab: ReturnType<typeof selectActiveTab>,
focusedRegion: ReturnType<typeof getFocusedRegion>
) => {
if (navigationApi.isViewerArrowNavigationMode(activeTab)) {
// When gallery is not effectively available, viewer hotkeys own left/right navigation.
return false;
}
if (focusedRegion === 'gallery' || focusedRegion === 'viewer') {
return true;
}
// Fallback for tab-switch edge case: allow nav when viewer dock tab is active before first click.
return navigationApi.isDockviewPanelActive(activeTab, VIEWER_PANEL_ID);
};
/**
* Handles keyboard navigation for the gallery.
*/
const useKeyboardNavigation = (
imageNames: string[],
navigationImageNames: string[],
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
rootRef: React.RefObject<HTMLDivElement>
) => {
const { dispatch, getState } = useAppStore();
const activeTab = useAppSelector(selectActiveTab);
const handleKeyDown = useCallback(
(event: KeyboardEvent) => {
if (getFocusedRegion() !== 'gallery') {
// Only handle keyboard navigation when the gallery is focused
const focusedRegion = getFocusedRegion();
if (!canHandleGridArrowNavigation(activeTab, focusedRegion)) {
return;
}
// Only handle arrow keys
if (!['ArrowUp', 'ArrowDown', 'ArrowLeft', 'ArrowRight'].includes(event.key)) {
return;
@@ -112,7 +134,7 @@ const useKeyboardNavigation = (
return;
}
if (imageNames.length === 0) {
if (navigationImageNames.length === 0) {
return;
}
@@ -132,7 +154,7 @@ const useKeyboardNavigation = (
(selectImageToCompare(state) ?? selectLastSelectedItem(state))
: selectLastSelectedItem(state);
const currentIndex = getItemIndex(imageName ?? null, imageNames);
const currentIndex = getItemIndex(imageName ?? null, navigationImageNames);
let newIndex = currentIndex;
@@ -146,7 +168,7 @@ const useKeyboardNavigation = (
}
break;
case 'ArrowRight':
if (currentIndex < imageNames.length - 1) {
if (currentIndex < navigationImageNames.length - 1) {
newIndex = currentIndex + 1;
// } else {
// // Wrap to first image
@@ -163,16 +185,16 @@ const useKeyboardNavigation = (
break;
case 'ArrowDown':
// If no images below, stay on current image
if (currentIndex >= imageNames.length - imagesPerRow) {
if (currentIndex >= navigationImageNames.length - imagesPerRow) {
newIndex = currentIndex;
} else {
newIndex = Math.min(imageNames.length - 1, currentIndex + imagesPerRow);
newIndex = Math.min(navigationImageNames.length - 1, currentIndex + imagesPerRow);
}
break;
}
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < imageNames.length) {
const newImageName = imageNames[newIndex];
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < navigationImageNames.length) {
const newImageName = navigationImageNames[newIndex];
if (newImageName) {
if (event.altKey) {
dispatch(imageToCompareChanged(newImageName));
@@ -182,7 +204,7 @@ const useKeyboardNavigation = (
}
}
},
[rootRef, virtuosoRef, imageNames, getState, dispatch]
[activeTab, rootRef, virtuosoRef, navigationImageNames, getState, dispatch]
);
useRegisteredHotkeys({
@@ -316,13 +338,14 @@ const useStarImageHotkey = () => {
type GalleryImageGridContentProps = {
imageNames: string[];
navigationImageNames?: string[];
isLoading: boolean;
queryArgs: ListImageNamesQueryArgs;
rootRef?: React.RefObject<HTMLDivElement>;
};
export const GalleryImageGridContent = memo(
({ imageNames, isLoading, queryArgs, rootRef: rootRefProp }: GalleryImageGridContentProps) => {
({ imageNames, navigationImageNames, isLoading, queryArgs, rootRef: rootRefProp }: GalleryImageGridContentProps) => {
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const internalRootRef = useRef<HTMLDivElement>(null);
@@ -336,7 +359,7 @@ export const GalleryImageGridContent = memo(
useStarImageHotkey();
useKeepSelectedImageInView(imageNames, virtuosoRef, rootRef, rangeRef);
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
useKeyboardNavigation(navigationImageNames ?? imageNames, virtuosoRef, rootRef);
const scrollerRef = useScrollableGallery(rootRef);
/*

View File

@@ -181,6 +181,7 @@ export const GalleryImageGridPaged = memo(() => {
<Flex w="full" h="full">
<GalleryImageGridContent
imageNames={pageImageNames}
navigationImageNames={imageNames}
isLoading={false}
queryArgs={queryArgs}
rootRef={gridRootRef}

View File

@@ -5,10 +5,18 @@ import { CanvasAlertsInvocationProgress } from 'features/controlLayers/component
import { DndImage } from 'features/dnd/DndImage';
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
import NextPrevItemButtons from 'features/gallery/components/NextPrevItemButtons';
import { selectShouldShowItemDetails, selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors';
import { useNextPrevItemNavigation } from 'features/gallery/components/useNextPrevItemNavigation';
import { selectLastSelectedItem } from 'features/gallery/store/gallerySelectors';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import {
selectActiveTab,
selectShouldShowItemDetails,
selectShouldShowProgressInViewer,
} from 'features/ui/store/uiSelectors';
import type { AnimationProps } from 'framer-motion';
import { AnimatePresence, motion } from 'framer-motion';
import { memo, useCallback, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import type { ImageDTO } from 'services/api/types';
import { useImageViewerContext } from './context';
@@ -17,11 +25,56 @@ import { ProgressImage } from './ProgressImage2';
import { ProgressIndicator } from './ProgressIndicator2';
export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => {
const activeTab = useAppSelector(selectActiveTab);
const selectedImageName = useAppSelector(selectLastSelectedItem);
const shouldShowItemDetails = useAppSelector(selectShouldShowItemDetails);
const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer);
const { goToPreviousImage, goToNextImage, isFetching } = useNextPrevItemNavigation();
const { onLoadImage, $progressEvent, $progressImage } = useImageViewerContext();
const progressEvent = useStore($progressEvent);
const progressImage = useStore($progressImage);
const [imageToRender, setImageToRender] = useState<ImageDTO | null>(null);
useEffect(() => {
if (!selectedImageName) {
setImageToRender(null);
return;
}
if (!imageDTO || imageToRender?.image_name === imageDTO.image_name) {
return;
}
let canceled = false;
const onReady = () => {
if (canceled) {
return;
}
setImageToRender(imageDTO);
};
if (typeof window === 'undefined') {
onReady();
return;
}
const preloader = new window.Image();
preloader.onload = onReady;
preloader.onerror = onReady;
preloader.src = imageDTO.image_url;
if (preloader.complete) {
onReady();
}
return () => {
canceled = true;
preloader.onload = null;
preloader.onerror = null;
};
}, [imageDTO, imageToRender?.image_name, selectedImageName]);
// Show and hide the next/prev buttons on mouse move
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
@@ -36,6 +89,50 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
}, 500);
}, []);
const handleViewerArrowNavigation = useCallback(
(event: KeyboardEvent, navigate: () => void) => {
if (!navigationApi.isViewerArrowNavigationMode(activeTab) || !imageToRender || isFetching) {
return;
}
if (event.target instanceof HTMLInputElement || event.target instanceof HTMLTextAreaElement) {
return;
}
event.preventDefault();
navigate();
},
[activeTab, imageToRender, isFetching]
);
const onHotkeyPrevImage = useCallback(
(event: KeyboardEvent) => {
handleViewerArrowNavigation(event, goToPreviousImage);
},
[goToPreviousImage, handleViewerArrowNavigation]
);
const onHotkeyNextImage = useCallback(
(event: KeyboardEvent) => {
handleViewerArrowNavigation(event, goToNextImage);
},
[goToNextImage, handleViewerArrowNavigation]
);
useRegisteredHotkeys({
id: 'galleryNavLeft',
category: 'gallery',
callback: onHotkeyPrevImage,
options: { preventDefault: true },
dependencies: [onHotkeyPrevImage],
});
useRegisteredHotkeys({
id: 'galleryNavRight',
category: 'gallery',
callback: onHotkeyNextImage,
options: { preventDefault: true },
dependencies: [onHotkeyNextImage],
});
const withProgress = shouldShowProgressInViewer && progressImage !== null;
return (
@@ -48,19 +145,12 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
justifyContent="center"
position="relative"
>
{imageDTO && (
<Flex
key={imageDTO.image_name}
w="full"
h="full"
position="absolute"
alignItems="center"
justifyContent="center"
>
<DndImage imageDTO={imageDTO} onLoad={onLoadImage} borderRadius="base" />
{imageToRender && (
<Flex w="full" h="full" position="absolute" alignItems="center" justifyContent="center">
<DndImage imageDTO={imageToRender} onLoad={onLoadImage} borderRadius="base" />
</Flex>
)}
{!imageDTO && <NoContentForViewer />}
{!imageToRender && <NoContentForViewer />}
{withProgress && (
<Flex w="full" h="full" position="absolute" alignItems="center" justifyContent="center" bg="base.900">
<ProgressImage progressImage={progressImage} />
@@ -72,13 +162,13 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
<Flex flexDir="column" gap={2} position="absolute" top={0} insetInlineStart={0} alignItems="flex-start">
<CanvasAlertsInvocationProgress />
</Flex>
{shouldShowItemDetails && imageDTO && !withProgress && (
{shouldShowItemDetails && imageToRender && !withProgress && (
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
<ImageMetadataViewer image={imageDTO} />
<ImageMetadataViewer image={imageToRender} />
</Box>
)}
<AnimatePresence>
{shouldShowNextPrevButtons && imageDTO && (
{shouldShowNextPrevButtons && imageToRender && (
<Box
as={motion.div}
key="nextPrevButtons"

View File

@@ -1,51 +1,28 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { clamp } from 'es-toolkit/compat';
import { selectLastSelectedItem } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { memo, type MouseEvent, type PointerEvent } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
import { useGalleryImageNames } from './use-gallery-image-names';
import { useNextPrevItemNavigation } from './useNextPrevItemNavigation';
const ARROW_SIZE = 48;
const preventButtonFocusOnPointerDown = (event: PointerEvent<HTMLButtonElement>) => {
event.preventDefault();
};
const preventButtonFocusOnMouseDown = (event: MouseEvent<HTMLButtonElement>) => {
event.preventDefault();
};
const blurButtonOnPointerUp = (event: PointerEvent<HTMLButtonElement>) => {
event.currentTarget.blur();
};
const NextPrevItemButtons = ({ inset = 8 }: { inset?: ChakraProps['insetInlineStart' | 'insetInlineEnd'] }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const lastSelectedItem = useAppSelector(selectLastSelectedItem);
const { imageNames, isFetching } = useGalleryImageNames();
const isOnFirstItem = useMemo(
() => (lastSelectedItem ? imageNames.at(0) === lastSelectedItem : false),
[imageNames, lastSelectedItem]
);
const isOnLastItem = useMemo(
() => (lastSelectedItem ? imageNames.at(-1) === lastSelectedItem : false),
[imageNames, lastSelectedItem]
);
const onClickLeftArrow = useCallback(() => {
const targetIndex = lastSelectedItem ? imageNames.findIndex((n) => n === lastSelectedItem) - 1 : 0;
const clampedIndex = clamp(targetIndex, 0, imageNames.length - 1);
const n = imageNames.at(clampedIndex);
if (!n) {
return;
}
dispatch(imageSelected(n));
}, [dispatch, imageNames, lastSelectedItem]);
const onClickRightArrow = useCallback(() => {
const targetIndex = lastSelectedItem ? imageNames.findIndex((n) => n === lastSelectedItem) + 1 : 0;
const clampedIndex = clamp(targetIndex, 0, imageNames.length - 1);
const n = imageNames.at(clampedIndex);
if (!n) {
return;
}
dispatch(imageSelected(n));
}, [dispatch, imageNames, lastSelectedItem]);
const { goToPreviousImage, goToNextImage, isOnFirstItem, isOnLastItem, isFetching } = useNextPrevItemNavigation();
return (
<Box pos="relative" h="full" w="full">
@@ -62,7 +39,10 @@ const NextPrevItemButtons = ({ inset = 8 }: { inset?: ChakraProps['insetInlineSt
minH={0}
w={`${ARROW_SIZE}px`}
h={`${ARROW_SIZE}px`}
onClick={onClickLeftArrow}
onClick={goToPreviousImage}
onPointerDown={preventButtonFocusOnPointerDown}
onMouseDown={preventButtonFocusOnMouseDown}
onPointerUp={blurButtonOnPointerUp}
isDisabled={isFetching}
color="base.100"
pointerEvents="auto"
@@ -82,7 +62,10 @@ const NextPrevItemButtons = ({ inset = 8 }: { inset?: ChakraProps['insetInlineSt
minH={0}
w={`${ARROW_SIZE}px`}
h={`${ARROW_SIZE}px`}
onClick={onClickRightArrow}
onClick={goToNextImage}
onPointerDown={preventButtonFocusOnPointerDown}
onMouseDown={preventButtonFocusOnMouseDown}
onPointerUp={blurButtonOnPointerUp}
isDisabled={isFetching}
color="base.100"
pointerEvents="auto"

View File

@@ -0,0 +1,47 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { clamp } from 'es-toolkit/compat';
import { selectLastSelectedItem } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { useCallback, useMemo } from 'react';
import { useGalleryImageNames } from './use-gallery-image-names';
export const useNextPrevItemNavigation = () => {
const dispatch = useAppDispatch();
const lastSelectedItem = useAppSelector(selectLastSelectedItem);
const { imageNames, isFetching } = useGalleryImageNames();
const currentIndex = useMemo(
() => (lastSelectedItem ? imageNames.findIndex((n) => n === lastSelectedItem) : -1),
[imageNames, lastSelectedItem]
);
const isOnFirstItem = currentIndex === 0;
const isOnLastItem = currentIndex >= 0 && currentIndex === imageNames.length - 1;
const navigateBy = useCallback(
(delta: number) => {
const maxIndex = imageNames.length - 1;
if (maxIndex < 0) {
return;
}
const targetIndex = currentIndex >= 0 ? clamp(currentIndex + delta, 0, maxIndex) : 0;
const imageName = imageNames[targetIndex];
if (!imageName) {
return;
}
dispatch(imageSelected(imageName));
},
[currentIndex, dispatch, imageNames]
);
const goToPreviousImage = useCallback(() => {
navigateBy(-1);
}, [navigateBy]);
const goToNextImage = useCallback(() => {
navigateBy(1);
}, [navigateBy]);
return { goToPreviousImage, goToNextImage, isOnFirstItem, isOnLastItem, isFetching };
};

View File

@@ -3,6 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject, uniq } from 'es-toolkit';
import { logout } from 'features/auth/store/authSlice';
import type { BoardRecordOrderBy } from 'services/api/types';
import { assert } from 'tsafe';
@@ -142,6 +143,14 @@ const slice = createSlice({
state.boardsListOrderDir = action.payload;
},
},
extraReducers(builder) {
// Clear board-related state on logout to prevent stale data when switching users
builder.addCase(logout, (state) => {
state.selectedBoardId = 'none';
state.autoAddBoardId = 'none';
state.boardSearchText = '';
});
},
});
export const {
@@ -182,6 +191,6 @@ export const gallerySliceConfig: SliceConfig<typeof slice> = {
}
return zGalleryState.parse(state);
},
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
persistDenylist: ['selection', 'galleryView', 'imageToCompare'],
},
};

View File

@@ -0,0 +1,29 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import { useMemo } from 'react';
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
/**
* Hook to determine if model manager features should be enabled for the current user.
*
* Returns true if:
* - Multiuser mode is disabled (single-user mode = always admin)
* - Multiuser mode is enabled AND user is an admin
*
* Returns false if:
* - Multiuser mode is enabled AND user is not an admin
*/
export const useIsModelManagerEnabled = (): boolean => {
const user = useAppSelector(selectCurrentUser);
const { data: setupStatus } = useGetSetupStatusQuery();
return useMemo(() => {
// If multiuser is disabled, treat as admin (single-user mode)
if (setupStatus && !setupStatus.multiuser_enabled) {
return true;
}
// If multiuser is enabled, check if user is admin
return user?.is_admin ?? false;
}, [setupStatus, user]);
};

View File

@@ -1,4 +1,6 @@
import { Button, Text, useToast } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsAuthenticated } from 'features/auth/store/authSlice';
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { useCallback, useEffect, useState } from 'react';
@@ -12,8 +14,14 @@ export const useStarterModelsToast = () => {
const [didToast, setDidToast] = useState(false);
const [mainModels, { data }] = useMainModels();
const toast = useToast();
const isAuthenticated = useAppSelector(selectIsAuthenticated);
useEffect(() => {
// Only show the toast if the user is authenticated
if (!isAuthenticated) {
return;
}
if (toast.isActive(TOAST_ID)) {
if (mainModels.length === 0) {
return;
@@ -32,7 +40,7 @@ export const useStarterModelsToast = () => {
onCloseComplete: () => setDidToast(true),
});
}
}, [data, didToast, mainModels.length, t, toast]);
}, [data, didToast, isAuthenticated, mainModels.length, t, toast]);
};
const ToastDescription = () => {

View File

@@ -1,6 +1,7 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useIsModelManagerEnabled } from 'features/modelManagerV2/hooks/useIsModelManagerEnabled';
import { selectSelectedModelKey, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -23,6 +24,7 @@ const modelManagerSx: SystemStyleObject = {
export const ModelManager = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const canManageModels = useIsModelManagerEnabled();
const handleClickAddModel = useCallback(() => {
dispatch(setSelectedModelKey(null));
}, [dispatch]);
@@ -36,7 +38,7 @@ export const ModelManager = memo(() => {
</Heading>
<Flex gap={2}>
<SyncModelsButton />
{!!selectedModelKey && (
{!!selectedModelKey && canManageModels && (
<Button size="sm" colorScheme="invokeYellow" leftIcon={<PiPlusBold />} onClick={handleClickAddModel}>
{t('modelManager.addModels')}
</Button>

Some files were not shown because too many files have changed in this diff Show More