Merge branch 'main' into feat/model-manager-queue-redesign

This commit is contained in:
Josh Corbett
2026-03-23 13:06:29 -06:00
committed by GitHub
29 changed files with 688 additions and 115 deletions

View File

@@ -79,6 +79,7 @@ class SetupStatusResponse(BaseModel):
setup_required: bool = Field(description="Whether initial setup is required")
multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled")
strict_password_checking: bool = Field(description="Whether strict password requirements are enforced")
@auth_router.get("/status", response_model=SetupStatusResponse)
@@ -92,13 +93,17 @@ async def get_setup_status() -> SetupStatusResponse:
# If multiuser is disabled, setup is never required
if not config.multiuser:
return SetupStatusResponse(setup_required=False, multiuser_enabled=False)
return SetupStatusResponse(
setup_required=False, multiuser_enabled=False, strict_password_checking=config.strict_password_checking
)
# In multiuser mode, check if an admin exists
user_service = ApiDependencies.invoker.services.users
setup_required = not user_service.has_admin()
return SetupStatusResponse(setup_required=setup_required, multiuser_enabled=True)
return SetupStatusResponse(
setup_required=setup_required, multiuser_enabled=True, strict_password_checking=config.strict_password_checking
)
@auth_router.post("/login", response_model=LoginResponse)
@@ -248,7 +253,7 @@ async def setup_admin(
password=request.password,
is_admin=True,
)
user = user_service.create_admin(user_data)
user = user_service.create_admin(user_data, strict_password_checking=config.strict_password_checking)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
@@ -359,6 +364,7 @@ async def create_user(
HTTPException: 400 if email already exists or password is weak
"""
user_service = ApiDependencies.invoker.services.users
config = ApiDependencies.invoker.services.configuration
try:
user_data = UserCreateRequest(
email=request.email,
@@ -366,7 +372,7 @@ async def create_user(
password=request.password,
is_admin=request.is_admin,
)
return user_service.create(user_data)
return user_service.create(user_data, strict_password_checking=config.strict_password_checking)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
@@ -414,6 +420,7 @@ async def update_user(
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
config = ApiDependencies.invoker.services.configuration
try:
changes = UserUpdateRequest(
display_name=request.display_name,
@@ -421,7 +428,7 @@ async def update_user(
is_admin=request.is_admin,
is_active=request.is_active,
)
return user_service.update(user_id, changes)
return user_service.update(user_id, changes, strict_password_checking=config.strict_password_checking)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
@@ -483,6 +490,7 @@ async def update_current_user(
HTTPException: 404 if user not found
"""
user_service = ApiDependencies.invoker.services.users
config = ApiDependencies.invoker.services.configuration
# Verify current password when attempting a password change
if request.new_password is not None:
@@ -509,6 +517,8 @@ async def update_current_user(
display_name=request.display_name,
password=request.new_password,
)
return user_service.update(current_user.user_id, changes)
return user_service.update(
current_user.user_id, changes, strict_password_checking=config.strict_password_checking
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e

View File

@@ -193,6 +193,23 @@ async def get_model_records_by_attrs(
return configs[0]
@model_manager_router.get(
"/get_by_hash",
operation_id="get_model_records_by_hash",
response_model=AnyModelConfig,
)
async def get_model_records_by_hash(
hash: str = Query(description="The hash of the model"),
) -> AnyModelConfig:
"""Gets a model by its hash. This is useful for recalling models that were deleted and reinstalled,
as the hash remains stable across reinstallations while the key (UUID) changes."""
configs = ApiDependencies.invoker.services.model_manager.store.search_by_hash(hash)
if not configs:
raise HTTPException(status_code=404, detail="No model found with this hash")
return configs[0]
@model_manager_router.get(
"/i/{key}",
operation_id="get_model_record",

View File

@@ -6,6 +6,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import GlmEncoderField
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
CogView4ConditioningInfo,
ConditioningFieldData,
@@ -46,10 +47,18 @@ class CogView4TextEncoderInvocation(BaseInvocation):
prompt = [self.prompt]
# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
glm_text_encoder_info = context.models.load(self.glm_encoder.text_encoder)
with (
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
glm_text_encoder_info.model_on_device() as (_, glm_text_encoder),
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
):
repaired_tensors = glm_text_encoder_info.repair_required_tensors_on_device()
device = get_effective_device(glm_text_encoder)
if repaired_tensors > 0:
context.logger.warning(
f"Recovered {repaired_tensors} required GLM tensor(s) onto {device} after a partial device mismatch."
)
context.util.signal_progress("Running GLM text encoder")
assert isinstance(glm_text_encoder, GlmModel)
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)
@@ -85,9 +94,7 @@ class CogView4TextEncoderInvocation(BaseInvocation):
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = glm_text_encoder(
text_input_ids.to(glm_text_encoder.device), output_hidden_states=True
).hidden_states[-2]
prompt_embeds = glm_text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds

View File

@@ -25,6 +25,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import Qwen3EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -100,7 +101,12 @@ class Flux2KleinTextEncoderInvocation(BaseInvocation):
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
device = text_encoder.device
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
device = get_effective_device(text_encoder)
if repaired_tensors > 0:
context.logger.warning(
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
)
# Apply LoRA models
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)

View File

@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import Qwen3EncoderField
from invokeai.app.invocations.primitives import ZImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_QWEN3_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -76,11 +77,17 @@ class ZImageTextEncoderInvocation(BaseInvocation):
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
with ExitStack() as exit_stack:
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
# Use the device that the text_encoder is actually on
device = text_encoder.device
# Use the device that the text encoder is effectively executing on, and repair any required tensors left on
# the CPU by a previous interrupted run.
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
device = get_effective_device(text_encoder)
if repaired_tensors > 0:
context.logger.warning(
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
)
# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
@@ -90,6 +97,7 @@ class ZImageTextEncoderInvocation(BaseInvocation):
patches=self._lora_iterator(context),
prefix=Z_IMAGE_LORA_QWEN3_PREFIX,
dtype=lora_dtype,
cached_weights=cached_weights,
)
)

View File

@@ -1,6 +1,6 @@
"""Password hashing and validation utilities."""
from typing import cast
from typing import Literal, cast
from passlib.context import CryptContext
@@ -84,3 +84,30 @@ def validate_password_strength(password: str) -> tuple[bool, str]:
return False, "Password must contain uppercase, lowercase, and numbers"
return True, ""
def get_password_strength(password: str) -> Literal["weak", "moderate", "strong"]:
"""Determine the strength of a password.
Strength levels:
- weak: less than 8 characters
- moderate: 8+ characters but missing at least one of uppercase, lowercase, or digit
- strong: 8+ characters with uppercase, lowercase, and digit
Args:
password: The password to evaluate
Returns:
One of "weak", "moderate", or "strong"
"""
if len(password) < 8:
return "weak"
has_upper = any(c.isupper() for c in password)
has_lower = any(c.islower() for c in password)
has_digit = any(c.isdigit() for c in password)
if not (has_upper and has_lower and has_digit):
return "moderate"
return "strong"

View File

@@ -111,6 +111,7 @@ class InvokeAIAppConfig(BaseSettings):
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
multiuser: Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.
strict_password_checking: Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.
"""
_root: Optional[Path] = PrivateAttr(default=None)
@@ -206,6 +207,7 @@ class InvokeAIAppConfig(BaseSettings):
# MULTIUSER
multiuser: bool = Field(default=False, description="Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.")
strict_password_checking: bool = Field(default=False, description="Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.")
# fmt: on

View File

@@ -663,10 +663,12 @@ class ModelInstallService(ModelInstallServiceBase):
# directory. However, the path we store in the model record may be either a file within the key directory,
# or the directory itself. So we have to handle both cases.
if model_path.is_file() or model_path.is_symlink():
# Sanity check - file models should be in their own directory under the models dir. The parent of the
# file should be the model's directory, not the Invoke models dir!
assert model_path.parent != self.app_config.models_path
rmtree(model_path.parent)
# Delete the individual model file, not the entire parent directory.
# Other unrelated files may exist in the same directory.
model_path.unlink()
# Clean up the parent directory only if it is now empty
if model_path.parent != self.app_config.models_path and not any(model_path.parent.iterdir()):
model_path.parent.rmdir()
elif model_path.is_dir():
# Sanity check - folder models should be in their own directory under the models dir. The path should
# not be the Invoke models dir itself!

View File

@@ -9,17 +9,19 @@ class UserServiceBase(ABC):
"""High-level service for user management."""
@abstractmethod
def create(self, user_data: UserCreateRequest) -> UserDTO:
def create(self, user_data: UserCreateRequest, strict_password_checking: bool = True) -> UserDTO:
"""Create a new user.
Args:
user_data: User creation data
strict_password_checking: If True (default), passwords must meet strength requirements.
If False, any non-empty password is accepted.
Returns:
The created user
Raises:
ValueError: If email already exists or password is weak
ValueError: If email already exists or (when strict) password is weak
"""
pass
@@ -48,18 +50,20 @@ class UserServiceBase(ABC):
pass
@abstractmethod
def update(self, user_id: str, changes: UserUpdateRequest) -> UserDTO:
def update(self, user_id: str, changes: UserUpdateRequest, strict_password_checking: bool = True) -> UserDTO:
"""Update user.
Args:
user_id: The user ID
changes: Fields to update
strict_password_checking: If True (default), passwords must meet strength requirements.
If False, any non-empty password is accepted.
Returns:
The updated user
Raises:
ValueError: If user not found or password is weak
ValueError: If user not found or (when strict) password is weak
"""
pass
@@ -98,17 +102,19 @@ class UserServiceBase(ABC):
pass
@abstractmethod
def create_admin(self, user_data: UserCreateRequest) -> UserDTO:
def create_admin(self, user_data: UserCreateRequest, strict_password_checking: bool = True) -> UserDTO:
"""Create an admin user (for initial setup).
Args:
user_data: User creation data
strict_password_checking: If True (default), passwords must meet strength requirements.
If False, any non-empty password is accepted.
Returns:
The created admin user
Raises:
ValueError: If admin already exists or password is weak
ValueError: If admin already exists or (when strict) password is weak
"""
pass

View File

@@ -21,12 +21,15 @@ class UserService(UserServiceBase):
"""
self._db = db
def create(self, user_data: UserCreateRequest) -> UserDTO:
def create(self, user_data: UserCreateRequest, strict_password_checking: bool = True) -> UserDTO:
"""Create a new user."""
# Validate password strength
is_valid, error_msg = validate_password_strength(user_data.password)
if not is_valid:
raise ValueError(error_msg)
if strict_password_checking:
is_valid, error_msg = validate_password_strength(user_data.password)
if not is_valid:
raise ValueError(error_msg)
elif not user_data.password:
raise ValueError("Password cannot be empty")
# Check if email already exists
if self.get_by_email(user_data.email) is not None:
@@ -106,7 +109,7 @@ class UserService(UserServiceBase):
last_login_at=datetime.fromisoformat(row[7]) if row[7] else None,
)
def update(self, user_id: str, changes: UserUpdateRequest) -> UserDTO:
def update(self, user_id: str, changes: UserUpdateRequest, strict_password_checking: bool = True) -> UserDTO:
"""Update user."""
# Check if user exists
user = self.get(user_id)
@@ -115,9 +118,12 @@ class UserService(UserServiceBase):
# Validate password if provided
if changes.password is not None:
is_valid, error_msg = validate_password_strength(changes.password)
if not is_valid:
raise ValueError(error_msg)
if strict_password_checking:
is_valid, error_msg = validate_password_strength(changes.password)
if not is_valid:
raise ValueError(error_msg)
elif not changes.password:
raise ValueError("Password cannot be empty")
# Build update query dynamically based on provided fields
updates: list[str] = []
@@ -208,7 +214,7 @@ class UserService(UserServiceBase):
count = row[0] if row else 0
return bool(count > 0)
def create_admin(self, user_data: UserCreateRequest) -> UserDTO:
def create_admin(self, user_data: UserCreateRequest, strict_password_checking: bool = True) -> UserDTO:
"""Create an admin user (for initial setup)."""
if self.has_admin():
raise ValueError("Admin user already exists")
@@ -220,7 +226,7 @@ class UserService(UserServiceBase):
password=user_data.password,
is_admin=True,
)
return self.create(admin_data)
return self.create(admin_data, strict_password_checking=strict_password_checking)
def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
"""List all users."""

View File

@@ -14,6 +14,9 @@ import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
@@ -80,6 +83,13 @@ class LoadedModelWithoutConfig:
"""Return the model without locking it."""
return self._cache_record.cached_model.model
def repair_required_tensors_on_device(self) -> int:
"""Repair required tensors that should be resident on the cached model's execution device."""
cached_model = self._cache_record.cached_model
if not isinstance(cached_model, CachedModelWithPartialLoad):
return 0
return cached_model.repair_required_tensors_on_compute_device()
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""

View File

@@ -149,6 +149,27 @@ class CachedModelWithPartialLoad:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
@torch.no_grad()
def repair_required_tensors_on_compute_device(self) -> int:
"""Repair required non-autocast tensors that were left off the compute device.
This can happen if an interrupted run leaves the model in a partially inconsistent state. Any repaired device
movement invalidates the cached VRAM accounting.
"""
cur_state_dict = self._model.state_dict()
keys_to_repair = {
key
for key in self._keys_in_modules_that_do_not_support_autocast
if cur_state_dict[key].device.type != self._compute_device.type
}
if len(keys_to_repair) == 0:
return 0
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_repair, self._compute_device)
self._move_non_persistent_buffers_to_device(self._compute_device)
self._cur_vram_bytes = None
return len(keys_to_repair)
def _load_state_dict_with_device_conversion(
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
):

View File

@@ -46,7 +46,8 @@
"passwordsDoNotMatch": "Passwords do not match",
"createAccount": "Create Administrator Account",
"creatingAccount": "Setting up...",
"setupFailed": "Setup failed. Please try again."
"setupFailed": "Setup failed. Please try again.",
"passwordHelperRelaxed": "Enter any password (strength will be shown)"
},
"userMenu": "User Menu",
"admin": "Admin",
@@ -102,6 +103,11 @@
"back": "Back",
"cannotDeleteSelf": "You cannot delete your own account",
"cannotDeactivateSelf": "You cannot deactivate your own account"
},
"passwordStrength": {
"weak": "Weak password",
"moderate": "Moderate password",
"strong": "Strong password"
}
},
"boards": {

View File

@@ -4,7 +4,8 @@
"uploadImage": "Lataa kuva",
"invokeProgressBar": "Invoken edistymispalkki",
"nextImage": "Seuraava kuva",
"previousImage": "Edellinen kuva"
"previousImage": "Edellinen kuva",
"uploadImages": "Lähetä Kuva(t)"
},
"common": {
"languagePickerLabel": "Kielen valinta",
@@ -29,5 +30,28 @@
"galleryImageSize": "Kuvan koko",
"gallerySettings": "Gallerian asetukset",
"autoSwitchNewImages": "Vaihda uusiin kuviin automaattisesti"
},
"modelManager": {
"t5Encoder": "T5-kooderi",
"qwen3Encoder": "Qwen3-kooderi",
"zImageVae": "VAE (valinnainen)",
"zImageQwen3Encoder": "Qwen3-kooderi (valinnainen)",
"zImageQwen3SourcePlaceholder": "Pakollinen, jos VAE/Enkooderi on tyhjä",
"flux2KleinVae": "VAE (valinnainen)",
"flux2KleinQwen3Encoder": "Qwen3-kooderi (valinnainen)"
},
"auth": {
"login": {
"title": "Kirjaudu sisään InvokeAI:hin",
"password": "Salasana",
"passwordPlaceholder": "Salasana",
"signIn": "Kirjaudu sisään",
"signingIn": "Kirjaudutaan sisään...",
"loginFailed": "Kirjautuminen epäonnistui. Tarkista käyttäjätunnuksesi tiedot."
},
"setup": {
"title": "Tervetuloa InvokeAI:hin",
"subtitle": "Määritä ensimmäiseksi järjestelmänvalvojan tili"
}
}
}

View File

@@ -3139,6 +3139,11 @@
"back": "Indietro",
"cannotDeleteSelf": "Non puoi eliminare il tuo account",
"cannotDeactivateSelf": "Non puoi disattivare il tuo account"
},
"passwordStrength": {
"weak": "Password debole",
"moderate": "Password moderata",
"strong": "Password forte"
}
}
}

View File

@@ -15,34 +15,13 @@ import {
Text,
VStack,
} from '@invoke-ai/ui-library';
import { validatePasswordField } from 'features/auth/util/passwordUtils';
import type { ChangeEvent, FormEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router-dom';
import { useGetSetupStatusQuery, useSetupMutation } from 'services/api/endpoints/auth';
const validatePasswordStrength = (
password: string,
t: (key: string) => string
): { isValid: boolean; message: string } => {
if (password.length < 8) {
return { isValid: false, message: t('auth.setup.passwordTooShort') };
}
const hasUpper = /[A-Z]/.test(password);
const hasLower = /[a-z]/.test(password);
const hasDigit = /\d/.test(password);
if (!hasUpper || !hasLower || !hasDigit) {
return {
isValid: false,
message: t('auth.setup.passwordMissingRequirements'),
};
}
return { isValid: true, message: '' };
};
export const AdministratorSetup = memo(() => {
const { t } = useTranslation();
const navigate = useNavigate();
@@ -60,7 +39,8 @@ export const AdministratorSetup = memo(() => {
}
}, [setupStatus, isLoadingSetup, navigate]);
const passwordValidation = validatePasswordStrength(password, t);
const strictPasswordChecking = setupStatus?.strict_password_checking ?? true;
const passwordValidation = validatePasswordField(password, t, strictPasswordChecking, false);
const passwordsMatch = password === confirmPassword;
const handleSubmit = useCallback(
@@ -120,6 +100,13 @@ export const AdministratorSetup = memo(() => {
);
}
const passwordStrengthColor =
passwordValidation.strength === 'weak'
? 'error.300'
: passwordValidation.strength === 'moderate'
? 'warning.300'
: 'invokeBlue.300';
return (
<Center w="100dvw" h="100dvh" bg="base.900">
<Box w="full" maxW="600px" p={8} borderRadius="lg" bg="base.800" boxShadow="dark-lg">
@@ -192,7 +179,16 @@ export const AdministratorSetup = memo(() => {
{password.length > 0 && !passwordValidation.isValid && (
<FormErrorMessage>{passwordValidation.message}</FormErrorMessage>
)}
{password.length === 0 && <FormHelperText mt={1}>{t('auth.setup.passwordHelper')}</FormHelperText>}
{password.length > 0 && passwordValidation.isValid && passwordValidation.message && (
<Text mt={1} fontSize="sm" color={passwordStrengthColor}>
{passwordValidation.message}
</Text>
)}
{password.length === 0 && (
<FormHelperText mt={1}>
{strictPasswordChecking ? t('auth.setup.passwordHelper') : t('auth.setup.passwordHelperRelaxed')}
</FormHelperText>
)}
</GridItem>
</Grid>
</FormControl>

View File

@@ -37,6 +37,7 @@ import {
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import { validatePasswordField } from 'features/auth/util/passwordUtils';
import type { ChangeEvent, FormEvent } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
@@ -54,30 +55,12 @@ import type { UserDTO } from 'services/api/endpoints/auth';
import {
useCreateUserMutation,
useDeleteUserMutation,
useGetSetupStatusQuery,
useLazyGeneratePasswordQuery,
useListUsersQuery,
useUpdateUserMutation,
} from 'services/api/endpoints/auth';
const validatePasswordStrength = (
password: string,
t: (key: string) => string
): { isValid: boolean; message: string } => {
if (password.length === 0) {
return { isValid: true, message: '' };
}
if (password.length < 8) {
return { isValid: false, message: t('auth.setup.passwordTooShort') };
}
const hasUpper = /[A-Z]/.test(password);
const hasLower = /[a-z]/.test(password);
const hasDigit = /\d/.test(password);
if (!hasUpper || !hasLower || !hasDigit) {
return { isValid: false, message: t('auth.setup.passwordMissingRequirements') };
}
return { isValid: true, message: '' };
};
const FORM_GRID_COLUMNS = '120px 1fr';
// ---------------------------------------------------------------------------
@@ -105,9 +88,12 @@ const UserFormModal = memo(({ isOpen, onClose, editUser }: UserFormModalProps) =
const [createUser, { isLoading: isCreating }] = useCreateUserMutation();
const [updateUser, { isLoading: isUpdating }] = useUpdateUserMutation();
const [triggerGeneratePassword] = useLazyGeneratePasswordQuery();
const { data: setupStatus } = useGetSetupStatusQuery();
const isLoading = isCreating || isUpdating;
const passwordValidation = validatePasswordStrength(password, t);
const strictPasswordChecking = setupStatus?.strict_password_checking ?? true;
// In edit mode, empty password means "no change" (allowEmpty=true); in create mode password is required (allowEmpty=false)
const passwordValidation = validatePasswordField(password, t, strictPasswordChecking, isEdit);
const handleGeneratePassword = useCallback(async () => {
try {
@@ -300,6 +286,21 @@ const UserFormModal = memo(({ isOpen, onClose, editUser }: UserFormModalProps) =
{password.length > 0 && !passwordValidation.isValid && (
<FormErrorMessage>{passwordValidation.message}</FormErrorMessage>
)}
{password.length > 0 && passwordValidation.isValid && passwordValidation.message && (
<Text
mt={1}
fontSize="sm"
color={
passwordValidation.strength === 'weak'
? 'error.300'
: passwordValidation.strength === 'moderate'
? 'warning.300'
: 'invokeBlue.300'
}
>
{passwordValidation.message}
</Text>
)}
</GridItem>
</Grid>
</FormControl>

View File

@@ -21,31 +21,17 @@ import {
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectAuthToken, selectCurrentUser, setCredentials } from 'features/auth/store/authSlice';
import { validatePasswordField } from 'features/auth/util/passwordUtils';
import type { ChangeEvent, FormEvent } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiEyeBold, PiEyeSlashBold, PiLightningFill } from 'react-icons/pi';
import { useNavigate } from 'react-router-dom';
import { useLazyGeneratePasswordQuery, useUpdateCurrentUserMutation } from 'services/api/endpoints/auth';
const validatePasswordStrength = (
password: string,
t: (key: string) => string
): { isValid: boolean; message: string } => {
if (password.length === 0) {
return { isValid: true, message: '' };
}
if (password.length < 8) {
return { isValid: false, message: t('auth.setup.passwordTooShort') };
}
const hasUpper = /[A-Z]/.test(password);
const hasLower = /[a-z]/.test(password);
const hasDigit = /\d/.test(password);
if (!hasUpper || !hasLower || !hasDigit) {
return { isValid: false, message: t('auth.setup.passwordMissingRequirements') };
}
return { isValid: true, message: '' };
};
import {
useGetSetupStatusQuery,
useLazyGeneratePasswordQuery,
useUpdateCurrentUserMutation,
} from 'services/api/endpoints/auth';
const PASSWORD_GRID_COLUMNS = '180px 1fr';
@@ -67,8 +53,10 @@ export const UserProfile = memo(() => {
const [updateCurrentUser, { isLoading }] = useUpdateCurrentUserMutation();
const [triggerGeneratePassword] = useLazyGeneratePasswordQuery();
const { data: setupStatus } = useGetSetupStatusQuery();
const newPasswordValidation = validatePasswordStrength(newPassword, t);
const strictPasswordChecking = setupStatus?.strict_password_checking ?? true;
const newPasswordValidation = validatePasswordField(newPassword, t, strictPasswordChecking, true);
const isPasswordChangeAttempted = newPassword.length > 0 || currentPassword.length > 0;
const passwordsMatch = newPassword.length > 0 && newPassword === confirmPassword;
@@ -305,6 +293,21 @@ export const UserProfile = memo(() => {
{newPassword.length > 0 && !newPasswordValidation.isValid && (
<FormErrorMessage>{newPasswordValidation.message}</FormErrorMessage>
)}
{newPassword.length > 0 && newPasswordValidation.isValid && newPasswordValidation.message && (
<Text
mt={1}
fontSize="sm"
color={
newPasswordValidation.strength === 'weak'
? 'error.300'
: newPasswordValidation.strength === 'moderate'
? 'warning.300'
: 'invokeBlue.300'
}
>
{newPasswordValidation.message}
</Text>
)}
</GridItem>
</Grid>
</FormControl>

View File

@@ -0,0 +1,70 @@
export type PasswordStrength = 'weak' | 'moderate' | 'strong';
export type PasswordValidationResult = {
isValid: boolean;
message: string;
strength: PasswordStrength | null;
};
/**
* Returns the strength level of a password.
* - weak: less than 8 characters
* - moderate: 8+ characters but missing uppercase, lowercase, or digit
* - strong: 8+ characters with uppercase, lowercase, and digit
*/
export const getPasswordStrength = (password: string): PasswordStrength => {
if (password.length < 8) {
return 'weak';
}
const hasUpper = /[A-Z]/.test(password);
const hasLower = /[a-z]/.test(password);
const hasDigit = /\d/.test(password);
if (!hasUpper || !hasLower || !hasDigit) {
return 'moderate';
}
return 'strong';
};
/**
* Validates a password field.
*
* In strict mode, passwords must be 8+ characters with uppercase, lowercase, and digits.
* In non-strict mode, any non-empty password is accepted but strength is reported.
*
* @param password - The password to validate
* @param t - Translation function
* @param strictPasswordChecking - Whether to enforce strict requirements
* @param allowEmpty - When true, an empty string is treated as "no change" (valid with no message)
*/
export const validatePasswordField = (
password: string,
t: (key: string) => string,
strictPasswordChecking: boolean,
allowEmpty = false
): PasswordValidationResult => {
if (password.length === 0) {
return { isValid: allowEmpty, message: '', strength: null };
}
const strength = getPasswordStrength(password);
if (!strictPasswordChecking) {
return {
isValid: true,
message: t(`auth.passwordStrength.${strength}`),
strength,
};
}
// Strict mode
if (password.length < 8) {
return { isValid: false, message: t('auth.setup.passwordTooShort'), strength };
}
const hasUpper = /[A-Z]/.test(password);
const hasLower = /[a-z]/.test(password);
const hasDigit = /\d/.test(password);
if (!hasUpper || !hasLower || !hasDigit) {
return { isValid: false, message: t('auth.setup.passwordMissingRequirements'), strength };
}
return { isValid: true, message: '', strength };
};

View File

@@ -9,7 +9,8 @@ import { rasterLayerGlobalCompositeOperationChanged } from 'features/controlLaye
import type { CanvasEntityIdentifier, CompositeOperation } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { CgPathBack, CgPathCrop, CgPathExclude, CgPathFront, CgPathIntersect } from 'react-icons/cg';
import { CgPathBack, CgPathExclude, CgPathFront, CgPathIntersect } from 'react-icons/cg';
import { PiIntersectSquareBold } from 'react-icons/pi';
export const RasterLayerMenuItemsBooleanSubMenu = memo(() => {
const { t } = useTranslation();
@@ -48,7 +49,7 @@ export const RasterLayerMenuItemsBooleanSubMenu = memo(() => {
const disabled = isBusy || !entityIdentifierBelowThisOne;
return (
<MenuItem {...subMenu.parentMenuItemProps} isDisabled={disabled} icon={<CgPathCrop size={18} />}>
<MenuItem {...subMenu.parentMenuItemProps} isDisabled={disabled} icon={<PiIntersectSquareBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.booleanOps.label')} />

View File

@@ -1063,7 +1063,8 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
for (const entity of parsed.controlLayers) {
if (entity.controlAdapter.model) {
await throwIfModelDoesNotExist(entity.controlAdapter.model.key, store);
const resolvedConfig = await resolveModel(entity.controlAdapter.model, store);
entity.controlAdapter.model = zModelIdentifierField.parse(resolvedConfig);
}
for (const object of entity.objects) {
if (object.type === 'image' && 'image_name' in object.image) {
@@ -1099,7 +1100,8 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);
const resolvedConfig = await resolveModel(refImage.config.model, store);
refImage.config.model = zModelIdentifierField.parse(resolvedConfig);
}
}
}
@@ -1165,7 +1167,9 @@ const RefImages: CollectionMetadataHandler<RefImageState[]> = {
}
// FLUX.2 reference images don't have a model field (built-in support)
if ('model' in refImage.config && refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);
const resolvedConfig = await resolveModel(refImage.config.model, store);
// Update the model reference in case the key changed (e.g. model was reinstalled)
refImage.config.model = zModelIdentifierField.parse(resolvedConfig);
}
}
@@ -1534,7 +1538,19 @@ const parseModelIdentifier = async (raw: unknown, store: AppStore, type: ModelTy
const modelConfig = await req.unwrap();
return zModelIdentifierField.parse(modelConfig);
} catch {
// We'll try to parse the old format identifier next
// We'll try hash-based lookup next
}
// Try hash-based lookup (handles reinstalled models with new UUID keys)
try {
const { hash } = zModelIdentifierField.parse(raw);
if (hash) {
const req = store.dispatch(modelsApi.endpoints.getModelConfigByHash.initiate(hash, options));
const modelConfig = await req.unwrap();
return zModelIdentifierField.parse(modelConfig);
}
} catch {
// We'll try the old format identifier next
}
// Fall back to old format identifier: model_name, base_model
@@ -1562,10 +1578,44 @@ const throwIfImageDoesNotExist = async (name: string, store: AppStore): Promise<
}
};
const throwIfModelDoesNotExist = async (key: string, store: AppStore): Promise<void> => {
/**
* Resolve a model by key, falling back to hash or name+base+type lookup if the key is not found.
* This handles the case where a model was deleted and reinstalled (getting a new UUID key).
* Fallback order: key → hash → name+base+type
* Returns the resolved model config, or throws if the model cannot be found by any method.
*/
const resolveModel = async (
model: { key: string; hash?: string; name: string; base: string; type: string },
store: AppStore
): Promise<AnyModelConfig> => {
// First try by key (fast path)
try {
await store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
const req = store.dispatch(modelsApi.endpoints.getModelConfig.initiate(model.key, { subscribe: false }));
return await req.unwrap();
} catch {
throw new Error(`Model with key ${key} does not exist`);
// Key not found - try fallback
}
// Second try by hash (most reliable for reinstalled models - hash is content-based)
if (model.hash) {
try {
const req = store.dispatch(modelsApi.endpoints.getModelConfigByHash.initiate(model.hash, { subscribe: false }));
return await req.unwrap();
} catch {
// Hash not found - try next fallback
}
}
// Last resort: look up by name + base + type
try {
const req = store.dispatch(
modelsApi.endpoints.getModelConfigByAttrs.initiate(
{ name: model.name, base: model.base as any, type: model.type as any },
{ subscribe: false }
)
);
return await req.unwrap();
} catch {
throw new Error(`Model "${model.name}" (key: ${model.key}) does not exist`);
}
};

View File

@@ -33,6 +33,7 @@ type LogoutResponse = {
type SetupStatusResponse = {
setup_required: boolean;
multiuser_enabled: boolean;
strict_password_checking: boolean;
};
export type UserDTO = components['schemas']['UserDTO'];

View File

@@ -239,6 +239,18 @@ export const modelsApi = api.injectEndpoints({
},
serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`,
}),
getModelConfigByHash: build.query<AnyModelConfig, string>({
query: (hash) => buildModelsUrl(`get_by_hash?${queryString.stringify({ hash })}`),
providesTags: (result) => {
const tags: ApiTagDescription[] = [];
if (result) {
tags.push({ type: 'ModelConfig', id: result.key });
}
return tags;
},
}),
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => {
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';

View File

@@ -369,6 +369,27 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v2/models/get_by_hash": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Model Records By Hash
* @description Gets a model by its hash. This is useful for recalling models that were deleted and reinstalled,
* as the hash remains stable across reinstallations while the key (UUID) changes.
*/
get: operations["get_model_records_by_hash"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v2/models/i/{key}": {
parameters: {
query?: never;
@@ -14375,6 +14396,7 @@ export type components = {
* unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
* allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
* multiuser: Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.
* strict_password_checking: Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.
*/
InvokeAIAppConfig: {
/**
@@ -14748,6 +14770,12 @@ export type components = {
* @default false
*/
multiuser?: boolean;
/**
* Strict Password Checking
* @description Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.
* @default false
*/
strict_password_checking?: boolean;
};
/**
* InvokeAIAppConfigWithSetFields
@@ -24486,6 +24514,11 @@ export type components = {
* @description Whether multiuser mode is enabled
*/
multiuser_enabled: boolean;
/**
* Strict Password Checking
* @description Whether strict password requirements are enforced
*/
strict_password_checking: boolean;
};
/**
* Show Image
@@ -29105,6 +29138,38 @@ export interface operations {
};
};
};
get_model_records_by_hash: {
parameters: {
query: {
/** @description The hash of the model */
hash: string;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["Main_Diffusers_SD1_Config"] | components["schemas"]["Main_Diffusers_SD2_Config"] | components["schemas"]["Main_Diffusers_SDXL_Config"] | components["schemas"]["Main_Diffusers_SDXLRefiner_Config"] | components["schemas"]["Main_Diffusers_SD3_Config"] | components["schemas"]["Main_Diffusers_FLUX_Config"] | components["schemas"]["Main_Diffusers_Flux2_Config"] | components["schemas"]["Main_Diffusers_CogView4_Config"] | components["schemas"]["Main_Diffusers_ZImage_Config"] | components["schemas"]["Main_Checkpoint_SD1_Config"] | components["schemas"]["Main_Checkpoint_SD2_Config"] | components["schemas"]["Main_Checkpoint_SDXL_Config"] | components["schemas"]["Main_Checkpoint_SDXLRefiner_Config"] | components["schemas"]["Main_Checkpoint_Flux2_Config"] | components["schemas"]["Main_Checkpoint_FLUX_Config"] | components["schemas"]["Main_Checkpoint_ZImage_Config"] | components["schemas"]["Main_BnBNF4_FLUX_Config"] | components["schemas"]["Main_GGUF_Flux2_Config"] | components["schemas"]["Main_GGUF_FLUX_Config"] | components["schemas"]["Main_GGUF_ZImage_Config"] | components["schemas"]["VAE_Checkpoint_SD1_Config"] | components["schemas"]["VAE_Checkpoint_SD2_Config"] | components["schemas"]["VAE_Checkpoint_SDXL_Config"] | components["schemas"]["VAE_Checkpoint_FLUX_Config"] | components["schemas"]["VAE_Checkpoint_Flux2_Config"] | components["schemas"]["VAE_Diffusers_SD1_Config"] | components["schemas"]["VAE_Diffusers_SDXL_Config"] | components["schemas"]["VAE_Diffusers_Flux2_Config"] | components["schemas"]["ControlNet_Checkpoint_SD1_Config"] | components["schemas"]["ControlNet_Checkpoint_SD2_Config"] | components["schemas"]["ControlNet_Checkpoint_SDXL_Config"] | components["schemas"]["ControlNet_Checkpoint_FLUX_Config"] | components["schemas"]["ControlNet_Checkpoint_ZImage_Config"] | components["schemas"]["ControlNet_Diffusers_SD1_Config"] | components["schemas"]["ControlNet_Diffusers_SD2_Config"] | components["schemas"]["ControlNet_Diffusers_SDXL_Config"] | components["schemas"]["ControlNet_Diffusers_FLUX_Config"] | components["schemas"]["LoRA_LyCORIS_SD1_Config"] | components["schemas"]["LoRA_LyCORIS_SD2_Config"] | components["schemas"]["LoRA_LyCORIS_SDXL_Config"] | components["schemas"]["LoRA_LyCORIS_Flux2_Config"] | components["schemas"]["LoRA_LyCORIS_FLUX_Config"] | components["schemas"]["LoRA_LyCORIS_ZImage_Config"] | components["schemas"]["LoRA_OMI_SDXL_Config"] | components["schemas"]["LoRA_OMI_FLUX_Config"] | components["schemas"]["LoRA_Diffusers_SD1_Config"] | components["schemas"]["LoRA_Diffusers_SD2_Config"] | components["schemas"]["LoRA_Diffusers_SDXL_Config"] | components["schemas"]["LoRA_Diffusers_Flux2_Config"] | components["schemas"]["LoRA_Diffusers_FLUX_Config"] | components["schemas"]["LoRA_Diffusers_ZImage_Config"] | components["schemas"]["ControlLoRA_LyCORIS_FLUX_Config"] | components["schemas"]["T5Encoder_T5Encoder_Config"] | components["schemas"]["T5Encoder_BnBLLMint8_Config"] | components["schemas"]["Qwen3Encoder_Qwen3Encoder_Config"] | components["schemas"]["Qwen3Encoder_Checkpoint_Config"] | components["schemas"]["Qwen3Encoder_GGUF_Config"] | components["schemas"]["TI_File_SD1_Config"] | components["schemas"]["TI_File_SD2_Config"] | components["schemas"]["TI_File_SDXL_Config"] | components["schemas"]["TI_Folder_SD1_Config"] | components["schemas"]["TI_Folder_SD2_Config"] | components["schemas"]["TI_Folder_SDXL_Config"] | components["schemas"]["IPAdapter_InvokeAI_SD1_Config"] | components["schemas"]["IPAdapter_InvokeAI_SD2_Config"] | components["schemas"]["IPAdapter_InvokeAI_SDXL_Config"] | components["schemas"]["IPAdapter_Checkpoint_SD1_Config"] | components["schemas"]["IPAdapter_Checkpoint_SD2_Config"] | components["schemas"]["IPAdapter_Checkpoint_SDXL_Config"] | components["schemas"]["IPAdapter_Checkpoint_FLUX_Config"] | components["schemas"]["T2IAdapter_Diffusers_SD1_Config"] | components["schemas"]["T2IAdapter_Diffusers_SDXL_Config"] | components["schemas"]["Spandrel_Checkpoint_Config"] | components["schemas"]["CLIPEmbed_Diffusers_G_Config"] | components["schemas"]["CLIPEmbed_Diffusers_L_Config"] | components["schemas"]["CLIPVision_Diffusers_Config"] | components["schemas"]["SigLIP_Diffusers_Config"] | components["schemas"]["FLUXRedux_Checkpoint_Config"] | components["schemas"]["LlavaOnevision_Diffusers_Config"] | components["schemas"]["Unknown_Config"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_model_record: {
parameters: {
query?: never;

View File

@@ -0,0 +1,80 @@
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
import torch
from invokeai.app.invocations.cogview4_text_encoder import CogView4TextEncoderInvocation
class FakeGlmModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_parameter("weight", torch.nn.Parameter(torch.ones(1)))
self.repaired = False
self.forward_input_device: torch.device | None = None
def forward(self, input_ids: torch.Tensor, output_hidden_states: bool = False):
assert output_hidden_states
if not self.repaired:
raise RuntimeError("model must be repaired before forward")
self.forward_input_device = input_ids.device
hidden = input_ids.unsqueeze(-1).float()
return SimpleNamespace(hidden_states=[hidden, hidden + 1])
class FakeTokenizer:
pad_token_id = 0
def __call__(self, prompt, padding, max_length=None, truncation=None, add_special_tokens=None, return_tensors=None):
del prompt, padding, max_length, truncation, add_special_tokens, return_tensors
return SimpleNamespace(input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long))
def batch_decode(self, input_ids):
del input_ids
return ["decoded"]
class FakeLoadedModel:
def __init__(self, model):
self._model = model
self.repair_calls = 0
@contextmanager
def model_on_device(self):
yield (None, self._model)
def repair_required_tensors_on_device(self) -> int:
self.repair_calls += 1
self._model.repaired = True
return 1
def test_cogview4_text_encoder_repairs_model_before_forward(monkeypatch):
fake_model = FakeGlmModel()
fake_tokenizer = FakeTokenizer()
fake_model_info = FakeLoadedModel(fake_model)
fake_tokenizer_info = FakeLoadedModel(fake_tokenizer)
mock_context = MagicMock()
mock_context.models.load.side_effect = [fake_model_info, fake_tokenizer_info]
mock_context.util.signal_progress = MagicMock()
mock_context.logger.warning = MagicMock()
invocation = CogView4TextEncoderInvocation.model_construct(
prompt="test prompt",
glm_encoder=SimpleNamespace(text_encoder=SimpleNamespace(), tokenizer=SimpleNamespace()),
)
module_path = "invokeai.app.invocations.cogview4_text_encoder"
monkeypatch.setattr(f"{module_path}.GlmModel", FakeGlmModel)
monkeypatch.setattr(f"{module_path}.PreTrainedTokenizerFast", FakeTokenizer)
embeds = invocation._glm_encode(mock_context, max_seq_len=16)
assert fake_model_info.repair_calls == 1
mock_context.logger.warning.assert_called_once()
mock_context.util.signal_progress.assert_called_once_with("Running GLM text encoder")
assert fake_model.forward_input_device == torch.device("cpu")
assert embeds.shape == (1, 16, 1)

View File

@@ -300,8 +300,9 @@ def test_setup_admin_already_exists(monkeypatch: Any, mock_invoker: Invoker, cli
def test_setup_admin_weak_password(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
"""Test setup fails with weak password."""
"""Test setup fails with weak password when strict password checking is enabled."""
monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", MockApiDependencies(mock_invoker))
mock_invoker.services.configuration.strict_password_checking = True
response = client.post(
"/api/v1/auth/setup",
@@ -316,6 +317,25 @@ def test_setup_admin_weak_password(monkeypatch: Any, mock_invoker: Invoker, clie
assert "Password" in response.json()["detail"]
def test_setup_admin_weak_password_non_strict(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
"""Test setup succeeds with weak password when strict password checking is disabled (the default)."""
monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", MockApiDependencies(mock_invoker))
mock_invoker.services.configuration.strict_password_checking = False
response = client.post(
"/api/v1/auth/setup",
json={
"email": "admin3b@example.com",
"display_name": "Admin User",
"password": "weak",
},
)
assert response.status_code == 200
json_response = response.json()
assert json_response["success"] is True
def test_admin_user_token_has_admin_flag(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
"""Test that admin user login returns token with admin flag."""
monkeypatch.setattr("invokeai.app.api.routers.auth.ApiDependencies", MockApiDependencies(mock_invoker))

View File

@@ -1,6 +1,11 @@
"""Unit tests for password utilities."""
from invokeai.app.services.auth.password_utils import hash_password, validate_password_strength, verify_password
from invokeai.app.services.auth.password_utils import (
get_password_strength,
hash_password,
validate_password_strength,
verify_password,
)
class TestPasswordHashing:
@@ -223,6 +228,58 @@ class TestPasswordStrengthValidation:
assert message == ""
class TestGetPasswordStrength:
"""Tests for get_password_strength function."""
def test_weak_password_too_short(self):
"""Test that passwords shorter than 8 characters are 'weak'."""
assert get_password_strength("Ab1") == "weak"
assert get_password_strength("Ab1defg") == "weak" # 7 chars
assert get_password_strength("") == "weak"
def test_moderate_password_missing_uppercase(self):
"""Test that 8+ char passwords missing uppercase are 'moderate'."""
assert get_password_strength("lowercase1") == "moderate"
def test_moderate_password_missing_lowercase(self):
"""Test that 8+ char passwords missing lowercase are 'moderate'."""
assert get_password_strength("UPPERCASE1") == "moderate"
def test_moderate_password_missing_digit(self):
"""Test that 8+ char passwords missing digits are 'moderate'."""
assert get_password_strength("NoDigitsHere") == "moderate"
def test_moderate_password_only_lowercase_and_digit(self):
"""Test that 8+ char passwords with only lowercase and digit are 'moderate'."""
assert get_password_strength("lowercase1") == "moderate"
def test_strong_password(self):
"""Test that 8+ char passwords with upper, lower, and digit are 'strong'."""
assert get_password_strength("StrongPass1") == "strong"
assert get_password_strength("Pass123A") == "strong"
def test_strong_password_with_special_chars(self):
"""Test that passwords meeting all requirements plus special chars are 'strong'."""
assert get_password_strength("Pass!@#$123") == "strong"
def test_exactly_8_characters_meeting_requirements(self):
"""Test that exactly 8 characters meeting requirements is 'strong'."""
assert get_password_strength("Pass123A") == "strong"
def test_exactly_8_characters_missing_uppercase(self):
"""Test that exactly 8 characters missing uppercase is 'moderate'."""
assert get_password_strength("pass123a") == "moderate"
def test_strength_progression(self):
"""Test that strength improves as requirements are met."""
# Too short - weak
assert get_password_strength("Abc1") == "weak"
# Long enough but only lowercase - moderate
assert get_password_strength("abcdefgh") == "moderate"
# Meets all requirements - strong
assert get_password_strength("Abcdefg1") == "strong"
class TestPasswordSecurityProperties:
"""Tests for security properties of password handling."""

View File

@@ -62,7 +62,7 @@ def test_create_user(user_service: UserService):
def test_create_user_weak_password(user_service: UserService):
"""Test creating a user with weak password."""
"""Test creating a user with weak password fails when strict checking is enabled."""
user_data = UserCreateRequest(
email="test@example.com",
display_name="Test User",
@@ -71,7 +71,20 @@ def test_create_user_weak_password(user_service: UserService):
)
with pytest.raises(ValueError, match="at least 8 characters"):
user_service.create(user_data)
user_service.create(user_data, strict_password_checking=True)
def test_create_user_weak_password_non_strict(user_service: UserService):
"""Test creating a user with weak password succeeds when strict checking is disabled."""
user_data = UserCreateRequest(
email="weakpass@example.com",
display_name="Test User",
password="weak",
is_admin=False,
)
user = user_service.create(user_data, strict_password_checking=False)
assert user.email == "weakpass@example.com"
def test_create_duplicate_user(user_service: UserService):

View File

@@ -0,0 +1,47 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
)
class ModelWithRequiredScale(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
self.scale = torch.nn.Parameter(torch.ones(4))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x) * self.scale
@pytest.mark.parametrize(
"device",
[
pytest.param(
torch.device("cuda"), marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
),
pytest.param(
torch.device("mps"),
marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="requires MPS device"),
),
],
)
@pytest.mark.parametrize("keep_ram_copy", [True, False])
@torch.no_grad()
def test_repair_required_tensors_on_compute_device(device: torch.device, keep_ram_copy: bool):
model = ModelWithRequiredScale()
apply_custom_layers_to_model(model, device_autocasting_enabled=True)
cached_model = CachedModelWithPartialLoad(model=model, compute_device=device, keep_ram_copy=keep_ram_copy)
cached_model._cur_vram_bytes = 0
repaired_tensors = cached_model.repair_required_tensors_on_compute_device()
assert repaired_tensors == 1
assert cached_model._cur_vram_bytes is None
assert model.scale.device.type == device.type
assert all(param.device.type == "cpu" for param in model.linear.parameters())