diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index f351be11ad..40d4f48b63 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -26,9 +26,11 @@ from invokeai.app.services.model_install.model_install_common import ModelInstal from invokeai.app.services.model_records import ( InvalidModelException, ModelRecordChanges, + ModelRecordOrderBy, UnknownModelException, ) from invokeai.app.services.orphaned_models import OrphanedModelInfo +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.suppress_output import SuppressOutput from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory @@ -159,6 +161,8 @@ async def list_model_records( model_format: Optional[ModelFormat] = Query( default=None, description="Exact match on the format of the model (e.g. 'diffusers')" ), + order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"), + direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_manager.store @@ -167,12 +171,23 @@ async def list_model_records( for base_model in base_models: found_models.extend( record_store.search_by_attr( - base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format + base_model=base_model, + model_type=model_type, + model_name=model_name, + model_format=model_format, + order_by=order_by, + direction=direction, ) ) else: found_models.extend( - record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) + record_store.search_by_attr( + model_type=model_type, + model_name=model_name, + model_format=model_format, + order_by=order_by, + direction=direction, + ) ) for index, model in enumerate(found_models): found_models[index] = prepare_model_config_for_response(model, ApiDependencies) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 6420949c29..31fbadb3cb 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,6 +11,7 @@ from typing import List, Optional, Set, Union from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.external_api import ( @@ -60,6 +61,10 @@ class ModelRecordOrderBy(str, Enum): Base = "base" Name = "name" Format = "format" + Size = "size" + DateAdded = "created_at" + DateModified = "updated_at" + Path = "path" class ModelSummary(BaseModel): @@ -200,7 +205,11 @@ class ModelRecordServiceBase(ABC): @abstractmethod def list_models( - self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default + self, + page: int = 0, + per_page: int = 10, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" pass @@ -237,6 +246,8 @@ class ModelRecordServiceBase(ABC): base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, model_format: Optional[ModelFormat] = None, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index edcbba2acd..f104c3855e 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -57,6 +57,7 @@ from invokeai.app.services.model_records.model_records_base import ( UnknownModelException, ) from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType @@ -257,6 +258,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): model_type: Optional[ModelType] = None, model_format: Optional[ModelFormat] = None, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -266,18 +268,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): :param model_type: Filter by type of model (optional) :param model_format: Filter by model format (e.g. "diffusers") (optional) :param order_by: Result order + :param direction: Result direction If none of the optional filters are passed, will return all models in the database. """ with self._db.transaction() as cursor: assert isinstance(order_by, ModelRecordOrderBy) + order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC" ordering = { - ModelRecordOrderBy.Default: "type, base, name, format", + ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format", ModelRecordOrderBy.Type: "type", - ModelRecordOrderBy.Base: "base", - ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Base: "base COLLATE NOCASE", + ModelRecordOrderBy.Name: "name COLLATE NOCASE", ModelRecordOrderBy.Format: "format", + ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)", + ModelRecordOrderBy.DateAdded: "created_at", + ModelRecordOrderBy.DateModified: "updated_at", + ModelRecordOrderBy.Path: "path", } where_clause: list[str] = [] @@ -301,7 +309,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): SELECT config FROM models {where} - ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason; + ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason; """, tuple(bindings), ) @@ -357,17 +365,26 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): return results def list_models( - self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default + self, + page: int = 0, + per_page: int = 10, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" with self._db.transaction() as cursor: assert isinstance(order_by, ModelRecordOrderBy) + order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC" ordering = { - ModelRecordOrderBy.Default: "type, base, name, format", + ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format", ModelRecordOrderBy.Type: "type", - ModelRecordOrderBy.Base: "base", - ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Base: "base COLLATE NOCASE", + ModelRecordOrderBy.Name: "name COLLATE NOCASE", ModelRecordOrderBy.Format: "format", + ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)", + ModelRecordOrderBy.DateAdded: "created_at", + ModelRecordOrderBy.DateModified: "updated_at", + ModelRecordOrderBy.Path: "path", } # Lock so that the database isn't updated while we're doing the two queries. @@ -385,7 +402,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): f"""--sql SELECT config FROM models - ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason + ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason LIMIT ? OFFSET ?; """, diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 88f917d0d3..46606a3c0d 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -714,14 +714,25 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): - diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format) - diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format) - diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale) + - lora_unet__layers_X_attention_to_k.lora_down.weight (Kohya format) """ + from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import ( + is_state_dict_likely_z_image_kohya_lora, + ) + state_dict = mod.load_state_dict() - # Check for Z-Image specific LoRA patterns + # Check for Kohya format first + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return + + # Check for Z-Image specific LoRA patterns (dot-notation formats) has_z_image_lora_keys = state_dict_has_any_keys_starting_with( state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "diffusion_model.context_refiner.", + "diffusion_model.noise_refiner.", "transformer.layers.", # OneTrainer/diffusers prefix variant "base_model.model.transformer.layers.", # PEFT-wrapped variant }, @@ -751,15 +762,26 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): Z-Image uses S3-DiT architecture with layer names like: - diffusion_model.layers.0.attention.to_k.lora_A.weight - diffusion_model.layers.0.feed_forward.w1.lora_A.weight + - lora_unet__layers_0_attention_to_k.lora_down.weight (Kohya format) """ + from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import ( + is_state_dict_likely_z_image_kohya_lora, + ) + state_dict = mod.load_state_dict() - # Check for Z-Image transformer layer patterns + # Check for Kohya format + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return BaseModelType.ZImage + + # Check for Z-Image transformer layer patterns (dot-notation formats) # Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks) has_z_image_keys = state_dict_has_any_keys_starting_with( state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "diffusion_model.context_refiner.", + "diffusion_model.noise_refiner.", "transformer.layers.", # OneTrainer/diffusers prefix variant "base_model.model.transformer.layers.", # PEFT-wrapped variant }, diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index 1be349f394..a2f008f41e 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -160,17 +160,20 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool: ".lora_A.weight", ".lora_B.weight", ".dora_scale", + ".alpha", ) + # First pass: check if any key has LoRA suffixes - if so, this is a LoRA not a main model for key in state_dict.keys(): if isinstance(key, int): continue - - # If we find any LoRA-specific keys, this is not a main model if key.endswith(lora_suffixes): return False - # Check for Z-Image specific key prefixes + # Second pass: check for Z-Image specific key parts + for key in state_dict.keys(): + if isinstance(key, int): + continue # Handle both direct keys (cap_embedder.0.weight) and # ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight) key_parts = key.split(".") diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 2de51a8aca..c802154797 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -178,12 +178,43 @@ class Flux2VAELoader(ModelLoader): if is_bfl_format: sd = self._convert_flux2_vae_bfl_to_diffusers(sd) - # FLUX.2 VAE configuration (32 latent channels) - # Based on the official FLUX.2 VAE architecture - # Use default config - AutoencoderKLFlux2 has built-in defaults + # FLUX.2 VAE configuration (32 latent channels). + # The standard FLUX.2 VAE uses block_out_channels=(128,256,512,512) for both + # encoder and decoder. The "small decoder" variant from + # black-forest-labs/FLUX.2-small-decoder keeps the full encoder but uses a + # narrower decoder with channels (96,192,384,384). AutoencoderKLFlux2 only + # exposes a single block_out_channels, so we build the model with the + # encoder's channels and, if the decoder differs, replace just the decoder + # submodule with a matching one before loading the state dict. + encoder_block_out_channels = (128, 256, 512, 512) + decoder_block_out_channels = encoder_block_out_channels + if "encoder.conv_in.weight" in sd and "encoder.conv_norm_out.weight" in sd: + enc_last = int(sd["encoder.conv_norm_out.weight"].shape[0]) + enc_first = int(sd["encoder.conv_in.weight"].shape[0]) + encoder_block_out_channels = (enc_first, enc_first * 2, enc_last, enc_last) + if "decoder.conv_in.weight" in sd and "decoder.conv_norm_out.weight" in sd: + dec_last = int(sd["decoder.conv_in.weight"].shape[0]) + dec_first = int(sd["decoder.conv_norm_out.weight"].shape[0]) + decoder_block_out_channels = (dec_first, dec_first * 2, dec_last, dec_last) + with SilenceWarnings(): with accelerate.init_empty_weights(): - model = AutoencoderKLFlux2() + model = AutoencoderKLFlux2(block_out_channels=encoder_block_out_channels) + if decoder_block_out_channels != encoder_block_out_channels: + # Rebuild the decoder with the smaller channel widths. + from diffusers.models.autoencoders.vae import Decoder + + cfg = model.config + model.decoder = Decoder( + in_channels=cfg.latent_channels, + out_channels=cfg.out_channels, + up_block_types=cfg.up_block_types, + block_out_channels=decoder_block_out_channels, + layers_per_block=cfg.layers_per_block, + norm_num_groups=cfg.norm_num_groups, + act_fn=cfg.act_fn, + mid_block_add_attention=cfg.mid_block_add_attention, + ) # Convert to bfloat16 and load for k in sd.keys(): diff --git a/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py index e248f9cfc4..70b10de50d 100644 --- a/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py @@ -1,10 +1,11 @@ """Z-Image LoRA conversion utilities. Z-Image uses S3-DiT transformer architecture with Qwen3 text encoder. -LoRAs for Z-Image typically follow the diffusers PEFT format. +LoRAs for Z-Image typically follow the diffusers PEFT format or Kohya format. """ -from typing import Dict +import re +from typing import Any, Dict import torch @@ -16,6 +17,29 @@ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import ( ) from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +# Regex for Kohya-format Z-Image transformer keys. +# Example keys: +# lora_unet__layers_0_attention_to_k.alpha +# lora_unet__layers_0_attention_to_k.lora_down.weight +# lora_unet__context_refiner_0_feed_forward_w1.lora_up.weight +# lora_unet__noise_refiner_1_attention_to_v.lora_down.weight +Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX = ( + r"lora_unet__(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)" +) + + +def is_state_dict_likely_z_image_kohya_lora(state_dict: dict[str | int, Any]) -> bool: + """Checks if the provided state dict is likely a Z-Image LoRA in Kohya format. + + Kohya Z-Image LoRAs have keys like: + - lora_unet__layers_0_attention_to_k.lora_down.weight + - lora_unet__context_refiner_0_feed_forward_w1.alpha + - lora_unet__noise_refiner_1_attention_to_v.lora_up.weight + """ + return any( + isinstance(k, str) and re.match(Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX, k.split(".")[0]) for k in state_dict.keys() + ) + def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) -> bool: """Checks if the provided state dict is likely a Z-Image LoRA. @@ -23,6 +47,9 @@ def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) Z-Image LoRAs can have keys for transformer and/or Qwen3 text encoder. They may use various prefixes depending on the training framework. """ + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return True + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] # Check for Z-Image transformer keys (S3-DiT architecture) @@ -57,6 +84,7 @@ def lora_model_from_z_image_state_dict( - "transformer." or "base_model.model.transformer." for diffusers PEFT format - "diffusion_model." for some training frameworks - "text_encoder." or "base_model.model.text_encoder." for Qwen3 encoder + - "lora_unet__" for Kohya format (underscores instead of dots) Args: state_dict: The LoRA state dict @@ -65,6 +93,10 @@ def lora_model_from_z_image_state_dict( Returns: A ModelPatchRaw containing the LoRA layers """ + # If Kohya format, convert keys first then process normally + if is_state_dict_likely_z_image_kohya_lora(state_dict): + state_dict = _convert_z_image_kohya_state_dict(state_dict) + layers: dict[str, BaseLayerPatch] = {} # Group keys by layer @@ -120,6 +152,45 @@ def lora_model_from_z_image_state_dict( return ModelPatchRaw(layers=layers) +def _convert_z_image_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Converts a Kohya-format Z-Image LoRA state dict to diffusion_model dot-notation. + + Example key conversions: + - lora_unet__layers_0_attention_to_k.lora_down.weight -> diffusion_model.layers.0.attention.to_k.lora_down.weight + - lora_unet__context_refiner_0_feed_forward_w1.alpha -> diffusion_model.context_refiner.0.feed_forward.w1.alpha + - lora_unet__noise_refiner_1_attention_to_v.lora_up.weight -> diffusion_model.noise_refiner.1.attention.to_v.lora_up.weight + """ + converted: Dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + if not isinstance(key, str) or not key.startswith("lora_unet__"): + converted[key] = value + continue + + # Split into layer name and param suffix (e.g. "lora_down.weight", "alpha") + layer_name, _, param_suffix = key.partition(".") + + # Strip lora_unet__ prefix + remainder = layer_name[len("lora_unet__") :] + + # Convert Kohya underscore format to dot-notation using the known structure + match = re.match( + r"(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)$", + remainder, + ) + if match: + block, idx, submodule, param = match.groups() + new_layer = f"diffusion_model.{block}.{idx}.{submodule}.{param}" + else: + # Fallback: keep original key for unrecognized patterns + converted[key] = value + continue + + new_key = f"{new_layer}.{param_suffix}" if param_suffix else new_layer + converted[new_key] = value + + return converted + + def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]: """Convert layer dict keys from PEFT format to internal format.""" if "lora_A.weight" in layer_dict: diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index cc654e4d39..fb8671cec2 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -40,11 +40,9 @@ def directory_size(directory: Path) -> int: Return the aggregate size of all files in a directory (bytes). """ sum = 0 - for root, dirs, files in os.walk(directory): + for root, _, files in os.walk(directory): for f in files: sum += Path(root, f).stat().st_size - for d in dirs: - sum += Path(root, d).stat().st_size return sum diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1d99f2cae4..75c5ad6671 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1203,6 +1203,15 @@ "modelType": "Model Type", "modelUpdated": "Model Updated", "modelUpdateFailed": "Model Update Failed", + "sortByName": "Name", + "sortByBase": "Base", + "sortBySize": "Size", + "sortByDateAdded": "Date Added", + "sortByDateModified": "Date Modified", + "sortByPath": "Path", + "sortByType": "Type", + "sortByFormat": "Format", + "sortDefault": "Default", "name": "Name", "externalProvider": "External Provider", "externalCapabilities": "External Capabilities", @@ -1316,9 +1325,11 @@ "zImageQwen3Source": "Qwen3 & VAE Source Model", "zImageQwen3SourcePlaceholder": "Required if VAE/Encoder empty", "flux2KleinVae": "VAE (optional)", - "flux2KleinVaePlaceholder": "From main model", + "flux2KleinVaePlaceholder": "From diffusers model", + "flux2KleinVaeNoModelPlaceholder": "No diffusers model available", "flux2KleinQwen3Encoder": "Qwen3 Encoder (optional)", - "flux2KleinQwen3EncoderPlaceholder": "From main model", + "flux2KleinQwen3EncoderPlaceholder": "From diffusers model", + "flux2KleinQwen3EncoderNoModelPlaceholder": "No diffusers model available", "qwenImageComponentSource": "VAE/Encoder Source (Diffusers)", "qwenImageComponentSourcePlaceholder": "Required for GGUF models", "qwenImageQuantization": "Encoder Quantization", @@ -1623,6 +1634,8 @@ "noFLUXVAEModelSelected": "No VAE model selected for FLUX generation", "noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation", "noQwen3EncoderModelSelected": "No Qwen3 Encoder model selected for FLUX2 Klein generation", + "noFlux2KleinVaeModelSelected": "No VAE selected. Non-diffusers FLUX.2 Klein models require a standalone VAE", + "noFlux2KleinQwen3EncoderModelSelected": "No Qwen3 Encoder selected. Non-diffusers FLUX.2 Klein models require a standalone Qwen3 Encoder", "noQwenImageComponentSourceSelected": "GGUF Qwen Image models require a Diffusers Component Source for VAE/encoder", "noZImageVaeSourceSelected": "No VAE source: Select VAE (FLUX) or Qwen3 Source model", "noZImageQwen3EncoderSourceSelected": "No Qwen3 Encoder source: Select Qwen3 Encoder or Qwen3 Source model", diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index db0a9a11a6..d823258dbf 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -2660,7 +2660,9 @@ "fitModeCover": "Copri", "smoothingMode": "Modalità di ricampionamento", "smoothingDesc": "Applica un ricampionamento di alta qualità lato backend alla conferma delle trasformazioni.", - "smoothing": "Smussamento" + "smoothing": "Smussamento", + "smoothingModeBilinear": "Bilineare", + "smoothingModeBicubic": "Bicubico" }, "stagingArea": { "next": "Successiva", diff --git a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx index f8ea01909a..4c1bc56e49 100644 --- a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx +++ b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx @@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => { }; }; -export const SubMenuButtonContent = ({ label }: { label: string }) => { +export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => { return ( {label} - + + {value !== undefined && ( + + {value} + + )} + + ); }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts index 2b7c0f7d17..1ea7626290 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts @@ -35,6 +35,15 @@ type PayloadActionWithId = T extends void } & T >; +/** Fingerprint used to match the same reference image entry after recall when ids are regenerated. */ +/** Empty configs of the same type may collide; the worst case is selecting an equivalent empty entity. */ +const getRefImageRecallMatchKey = (entity: RefImageState): string => { + const { config } = entity; + const imageName = config.image?.original.image.image_name ?? ''; + const modelKey = 'model' in config && config.model ? config.model.key : ''; + return `${config.type}\0${modelKey}\0${imageName}`; +}; + const slice = createSlice({ name: 'refImages', initialState: getInitialRefImagesState(), @@ -54,13 +63,41 @@ const slice = createSlice({ }, refImagesRecalled: (state, action: PayloadAction<{ entities: RefImageState[]; replace: boolean }>) => { const { entities, replace } = action.payload; - if (replace) { - state.entities = entities; - state.isPanelOpen = false; - state.selectedEntityId = null; - } else { + if (!replace) { state.entities.push(...entities); + return; } + const wasPanelOpen = state.isPanelOpen; + const previousSelectedId = state.selectedEntityId; + let previousEntity: RefImageState | null = null; + if (previousSelectedId !== null) { + previousEntity = state.entities.find((e) => e.id === previousSelectedId) ?? null; + } + state.entities = entities; + if (entities.length === 0) { + state.selectedEntityId = null; + state.isPanelOpen = false; + return; + } + if (!wasPanelOpen) { + state.selectedEntityId = null; + return; + } + const firstEntity = entities[0]; + assert(firstEntity); + if (previousSelectedId === null) { + // Open panel must have a selection; otherwise, fall back to the first entity. + state.selectedEntityId = firstEntity.id; + return; + } + if (previousSelectedId !== null && entities.some((e) => e.id === previousSelectedId)) { + state.selectedEntityId = previousSelectedId; + return; + } + const previousKey = previousEntity ? getRefImageRecallMatchKey(previousEntity) : null; + const matched = + previousKey !== null ? entities.find((e) => getRefImageRecallMatchKey(e) === previousKey) : undefined; + state.selectedEntityId = matched?.id ?? firstEntity.id; }, refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => { const { id, croppableImage } = action.payload; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 9cc4ed24d9..7cdba474bb 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -31,7 +31,7 @@ export type ModelCategoryData = { filter: (config: AnyModelConfig) => boolean; }; -export const MODEL_CATEGORIES: Record = { +const MODEL_CATEGORIES: Record = { unknown: { category: 'unknown', i18nKey: 'common.unknown', diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 44df38d911..91fb1afd4d 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -25,6 +25,10 @@ const zModelManagerState = z.object({ scanPath: z.string().optional(), shouldInstallInPlace: z.boolean(), selectedModelKeys: z.array(z.string()), + orderBy: z + .enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format']) + .default('name'), + sortDirection: z.enum(['asc', 'desc']).default('asc'), }); type ModelManagerState = z.infer; @@ -38,6 +42,8 @@ const getInitialState = (): ModelManagerState => ({ scanPath: undefined, shouldInstallInPlace: true, selectedModelKeys: [], + orderBy: 'name', + sortDirection: 'asc', }); const slice = createSlice({ @@ -77,6 +83,12 @@ const slice = createSlice({ clearModelSelection: (state) => { state.selectedModelKeys = []; }, + setOrderBy: (state, action: PayloadAction) => { + state.orderBy = action.payload; + }, + setSortDirection: (state, action: PayloadAction) => { + state.sortDirection = action.payload; + }, }, }); @@ -90,6 +102,8 @@ export const { modelSelectionChanged, toggleModelSelection, clearModelSelection, + setOrderBy, + setSortDirection, } = slice.actions; export const modelManagerSliceConfig: SliceConfig = { @@ -119,3 +133,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType); export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace); export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys); +export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy); +export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx new file mode 100644 index 0000000000..57dad58f2c --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx @@ -0,0 +1,231 @@ +import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu'; +import type { ModelCategoryData } from 'features/modelManagerV2/models'; +import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; +import { + selectFilteredModelType, + selectOrderBy, + selectSortDirection, + setFilteredModelType, + setOrderBy, + setSortDirection, +} from 'features/modelManagerV2/store/modelManagerV2Slice'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + PiCheckBold, + PiFunnelBold, + PiListBold, + PiSortAscendingBold, + PiSortDescendingBold, + PiWarningBold, +} from 'react-icons/pi'; + +type OrderBy = 'default' | 'name' | 'type' | 'base' | 'size' | 'created_at' | 'updated_at' | 'path' | 'format'; + +const ORDER_BY_OPTIONS: { key: OrderBy; i18nKey: string }[] = [ + { key: 'default', i18nKey: 'modelManager.sortDefault' }, + { key: 'name', i18nKey: 'modelManager.sortByName' }, + { key: 'base', i18nKey: 'modelManager.sortByBase' }, + { key: 'size', i18nKey: 'modelManager.sortBySize' }, + { key: 'created_at', i18nKey: 'modelManager.sortByDateAdded' }, + { key: 'updated_at', i18nKey: 'modelManager.sortByDateModified' }, + { key: 'path', i18nKey: 'modelManager.sortByPath' }, + { key: 'type', i18nKey: 'modelManager.sortByType' }, + { key: 'format', i18nKey: 'modelManager.sortByFormat' }, +]; + +const SortByMenuItem = memo(({ option, label }: { option: OrderBy; label: string }) => { + const dispatch = useAppDispatch(); + const orderBy = useAppSelector(selectOrderBy); + const onClick = useCallback(() => { + dispatch(setOrderBy(option)); + }, [dispatch, option]); + + return ( + : } + > + {label} + + ); +}); +SortByMenuItem.displayName = 'SortByMenuItem'; + +const SortBySubMenu = memo(() => { + const { t } = useTranslation(); + const subMenu = useSubMenu(); + const orderBy = useAppSelector(selectOrderBy); + + const currentSortLabel = useMemo(() => { + const option = ORDER_BY_OPTIONS.find((o) => o.key === orderBy); + if (!option) { + return ''; + } + return t(option.i18nKey); + }, [orderBy, t]); + + return ( + }> + + + + + + {ORDER_BY_OPTIONS.map(({ key, i18nKey }) => ( + + ))} + + + + ); +}); +SortBySubMenu.displayName = 'SortBySubMenu'; + +const DirectionSubMenu = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const direction = useAppSelector(selectSortDirection); + const subMenu = useSubMenu(); + + const setDirectionAsc = useCallback(() => { + dispatch(setSortDirection('asc')); + }, [dispatch]); + + const setDirectionDesc = useCallback(() => { + dispatch(setSortDirection('desc')); + }, [dispatch]); + + const currentValue = direction === 'asc' ? t('common.ascending', 'Ascending') : t('common.descending', 'Descending'); + + return ( + : } + > + + + + + + : } + > + {t('common.ascending', 'Ascending')} + + : } + > + {t('common.descending', 'Descending')} + + + + + ); +}); +DirectionSubMenu.displayName = 'DirectionSubMenu'; + +const ModelTypeSubMenu = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const subMenu = useSubMenu(); + + const clearModelType = useCallback(() => { + dispatch(setFilteredModelType(null)); + }, [dispatch]); + + const setMissingFilter = useCallback(() => { + dispatch(setFilteredModelType('missing')); + }, [dispatch]); + + const currentValue = useMemo(() => { + if (filteredModelType === null) { + return t('modelManager.allModels'); + } + if (filteredModelType === 'missing') { + return t('modelManager.missingFiles'); + } + const categoryData = MODEL_CATEGORIES_AS_LIST.find((data) => data.category === filteredModelType); + return categoryData ? t(categoryData.i18nKey) : ''; + }, [filteredModelType, t]); + + return ( + }> + + + + + + : } + > + {t('modelManager.allModels')} + + : } + > + + {filteredModelType !== 'missing' && } + {t('modelManager.missingFiles')} + + + {MODEL_CATEGORIES_AS_LIST.map((data) => ( + + ))} + + + + ); +}); +ModelTypeSubMenu.displayName = 'ModelTypeSubMenu'; + +const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const onClick = useCallback(() => { + dispatch(setFilteredModelType(data.category)); + }, [data.category, dispatch]); + return ( + : } + > + {t(data.i18nKey)} + + ); +}); +ModelMenuItem.displayName = 'ModelMenuItem'; + +export const ModelFilterMenu = memo(() => { + const { t } = useTranslation(); + + return ( + + }> + {t('common.filtering', 'Filtering')} + + + + + + + + ); +}); + +ModelFilterMenu.displayName = 'ModelFilterMenu'; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index ed49fa2870..033a439bfc 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -8,8 +8,10 @@ import { clearModelSelection, type FilterableModelType, selectFilteredModelType, + selectOrderBy, selectSearchTerm, selectSelectedModelKeys, + selectSortDirection, setSelectedModelKey, } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { memo, useCallback, useMemo, useState } from 'react'; @@ -39,6 +41,8 @@ const ModelList = () => { const dispatch = useAppDispatch(); const filteredModelType = useAppSelector(selectFilteredModelType); const searchTerm = useAppSelector(selectSearchTerm); + const orderBy = useAppSelector(selectOrderBy); + const direction = useAppSelector(selectSortDirection); const selectedModelKeys = useAppSelector(selectSelectedModelKeys); const { t } = useTranslation(); const toast = useToast(); @@ -47,7 +51,8 @@ const ModelList = () => { const [isDeleting, setIsDeleting] = useState(false); const [isReidentifying, setIsReidentifying] = useState(false); - const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(); + const queryArgs = useMemo(() => ({ order_by: orderBy, direction: direction.toUpperCase() }), [orderBy, direction]); + const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(queryArgs); const { data: missingModelsData, isLoading: isLoadingMissing } = useGetMissingModelsQuery(); const [bulkDeleteModels] = useBulkDeleteModelsMutation(); const [bulkReidentifyModels] = useBulkReidentifyModelsMutation(); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx index 78bed8ab83..bbfb88df5c 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx @@ -6,8 +6,8 @@ import type { ChangeEventHandler } from 'react'; import { memo, useCallback } from 'react'; import { PiXBold } from 'react-icons/pi'; +import { ModelFilterMenu } from './ModelFilterMenu'; import { ModelListBulkActions } from './ModelListBulkActions'; -import { ModelTypeFilter } from './ModelTypeFilter'; export const ModelListNavigation = memo(() => { const dispatch = useAppDispatch(); @@ -50,7 +50,7 @@ export const ModelListNavigation = memo(() => { - + diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx deleted file mode 100644 index 5aa8e62886..0000000000 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ /dev/null @@ -1,78 +0,0 @@ -import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import type { ModelCategoryData } from 'features/modelManagerV2/models'; -import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; -import type { ModelCategoryType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { memo, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import { PiFunnelBold, PiWarningBold } from 'react-icons/pi'; - -const isModelCategoryType = (type: string): type is ModelCategoryType => { - return type in MODEL_CATEGORIES; -}; - -export const ModelTypeFilter = memo(() => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const filteredModelType = useAppSelector(selectFilteredModelType); - - const clearModelType = useCallback(() => { - dispatch(setFilteredModelType(null)); - }, [dispatch]); - - const setMissingFilter = useCallback(() => { - dispatch(setFilteredModelType('missing')); - }, [dispatch]); - - const getButtonLabel = () => { - if (filteredModelType === 'missing') { - return t('modelManager.missingFiles'); - } - if (filteredModelType && isModelCategoryType(filteredModelType)) { - return t(MODEL_CATEGORIES[filteredModelType].i18nKey); - } - return t('modelManager.allModels'); - }; - - return ( - - }> - {getButtonLabel()} - - - {t('modelManager.allModels')} - - - - {t('modelManager.missingFiles')} - - - {MODEL_CATEGORIES_AS_LIST.map((data) => ( - - ))} - - - ); -}); - -ModelTypeFilter.displayName = 'ModelTypeFilter'; - -const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const filteredModelType = useAppSelector(selectFilteredModelType); - const onClick = useCallback(() => { - dispatch(setFilteredModelType(data.category)); - }, [data.category, dispatch]); - return ( - - {t(data.i18nKey)} - - ); -}); -ModelMenuItem.displayName = 'ModelMenuItem'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts new file mode 100644 index 0000000000..7f01becc3d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts @@ -0,0 +1,370 @@ +import { afterEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('app/logging/logger', () => ({ + logger: () => ({ + debug: vi.fn(), + }), +})); + +let nextId = 0; +vi.mock('features/controlLayers/konva/util', () => ({ + getPrefixedId: (prefix: string) => `${prefix}:${nextId++}`, +})); + +// --- Flux2 Klein model fixtures --- + +const flux2DiffusersModel = { + key: 'flux2-klein-diffusers', + hash: 'flux2-diff-hash', + name: 'FLUX.2 Klein 4B', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +}; + +const flux2GGUFModel = { + key: 'flux2-klein-gguf', + hash: 'flux2-gguf-hash', + name: 'FLUX.2 Klein 4B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_4b', +}; + +const kleinVaeModelFixture = { key: 'klein-vae', name: 'Klein VAE', base: 'flux2', type: 'vae' }; +const kleinQwen3EncoderModelFixture = { + key: 'klein-qwen3', + name: 'Qwen3 4B', + base: 'flux2', + type: 'qwen3_encoder', +}; + +const flux2GGUF9BModel = { + key: 'flux2-klein-gguf-9b', + hash: 'flux2-gguf-9b-hash', + name: 'FLUX.2 Klein 9B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_9b', +}; + +const diffusersSourceModelFixture = { + key: 'flux2-source-diffusers', + hash: 'flux2-src-hash', + name: 'FLUX.2 Klein 4B Source', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +}; + +const diffusers9BSourceModelFixture = { + key: 'flux2-source-diffusers-9b', + hash: 'flux2-src-9b-hash', + name: 'FLUX.2 Klein 9B Source', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_9b', +}; + +// --- Mutable state --- + +let model: Record = { ...flux2DiffusersModel }; +let kleinVaeModel: Record | null = null; +let kleinQwen3EncoderModel: Record | null = null; +let diffusersModels: Record[] = []; + +vi.mock('features/controlLayers/store/paramsSlice', () => ({ + selectMainModelConfig: vi.fn(() => model), + selectParamsSlice: vi.fn(() => ({ + guidance: 4, + steps: 20, + fluxScheduler: 'euler', + fluxDypePreset: 'off', + fluxDypeScale: 2.0, + fluxDypeExponent: 2.0, + fluxVAE: null, + t5EncoderModel: null, + clipEmbedModel: null, + })), + selectKleinVaeModel: vi.fn(() => kleinVaeModel), + selectKleinQwen3EncoderModel: vi.fn(() => kleinQwen3EncoderModel), +})); + +vi.mock('features/controlLayers/store/refImagesSlice', () => ({ + selectRefImagesSlice: vi.fn(() => ({ + entities: [], + })), +})); + +vi.mock('features/controlLayers/store/selectors', () => ({ + selectCanvasMetadata: vi.fn(() => ({})), + selectCanvasSlice: vi.fn(() => ({})), +})); + +vi.mock('features/controlLayers/store/types', () => ({ + isFlux2ReferenceImageConfig: vi.fn(() => false), + isFluxKontextReferenceImageConfig: vi.fn(() => false), +})); + +vi.mock('features/controlLayers/store/validators', () => ({ + getGlobalReferenceImageWarnings: vi.fn(() => []), +})); + +vi.mock('features/nodes/util/graph/generation/addFlux2KleinLoRAs', () => ({ + addFlux2KleinLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXFill', () => ({ + addFLUXFill: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXLoRAs', () => ({ + addFLUXLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXRedux', () => ({ + addFLUXReduxes: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({ + addImageToImage: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({ + addInpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({ + addNSFWChecker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({ + addOutpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addRegions', () => ({ + addRegions: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({ + addTextToImage: vi.fn(({ l2i }) => l2i), +})); + +vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({ + addWatermarker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addControlAdapters', () => ({ + addControlLoRA: vi.fn(), + addControlNets: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addIPAdapters', () => ({ + addIPAdapters: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({ + selectCanvasOutputFields: vi.fn(() => ({})), + selectPresetModifiedPrompts: vi.fn(() => ({ + positive: 'a prompt', + negative: '', + })), +})); + +vi.mock('features/ui/store/uiSelectors', () => ({ + selectActiveTab: vi.fn(() => 'generation'), +})); + +vi.mock('services/api/hooks/modelsByType', () => ({ + selectFlux2DiffusersModels: vi.fn(() => diffusersModels), +})); + +vi.mock('services/api/types', async () => { + const actual = await vi.importActual('services/api/types'); + return { + ...actual, + isNonRefinerMainModelConfig: vi.fn(() => true), + }; +}); + +import { buildFLUXGraph } from './buildFLUXGraph'; + +const buildGraphArg = () => ({ + generationMode: 'txt2img' as const, + manager: null, + state: { + system: { + shouldUseNSFWChecker: false, + shouldUseWatermarker: false, + }, + } as never, +}); + +/** Find the flux2_klein_model_loader node in the graph. */ +const getLoaderNode = async () => { + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const loaderEntry = Object.entries(graph.nodes).find(([id]) => id.startsWith('flux2_klein_model_loader:')); + return loaderEntry?.[1] as Record | undefined; +}; + +describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { + afterEach(() => { + nextId = 0; + model = { ...flux2DiffusersModel }; + kleinVaeModel = null; + kleinQwen3EncoderModel = null; + diffusersModels = []; + }); + + it('does not set qwen3_source_model when main model is diffusers', async () => { + model = { ...flux2DiffusersModel }; + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('sets qwen3_source_model when main model is GGUF and a diffusers model is available', async () => { + model = { ...flux2GGUFModel }; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toEqual({ + key: diffusersSourceModelFixture.key, + hash: diffusersSourceModelFixture.hash, + name: diffusersSourceModelFixture.name, + base: diffusersSourceModelFixture.base, + type: diffusersSourceModelFixture.type, + }); + }); + + it('does not set qwen3_source_model when main model is GGUF but standalone VAE and Qwen3 are both selected', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('does not set qwen3_source_model when main model is GGUF and no diffusers model is available', async () => { + model = { ...flux2GGUFModel }; + diffusersModels = []; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('sets qwen3_source_model when only VAE is selected but Qwen3 is missing', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = null; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeDefined(); + }); + + it('sets qwen3_source_model when only Qwen3 is selected but VAE is missing', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = null; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeDefined(); + }); + + it('passes standalone vae_model and qwen3_encoder_model when selected', async () => { + model = { ...flux2DiffusersModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.vae_model).toEqual(kleinVaeModelFixture); + expect(loader!.qwen3_encoder_model).toEqual(kleinQwen3EncoderModelFixture); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + describe('variant matching', () => { + it('selects a variant-matching diffusers model when multiple are available', async () => { + model = { ...flux2GGUF9BModel }; + diffusersModels = [diffusersSourceModelFixture, diffusers9BSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should pick the 9B diffusers model, not the 4B + expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusers9BSourceModelFixture.key })); + }); + + it('falls back to any diffusers model for VAE when standalone Qwen3 is selected but no variant match', async () => { + model = { ...flux2GGUF9BModel }; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + // Only 4B diffusers available, no 9B — but Qwen3 is already provided standalone + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should use the 4B diffusers model just for VAE extraction + expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusersSourceModelFixture.key })); + }); + + it('does not set qwen3_source_model when GGUF 9B with only 4B diffusers available and no standalone Qwen3', async () => { + model = { ...flux2GGUF9BModel }; + kleinQwen3EncoderModel = null; + // Only 4B diffusers available — wrong variant for Qwen3, no standalone Qwen3 selected + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should NOT use the 4B diffusers since it has the wrong Qwen3 encoder + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + }); + + describe('graph structure', () => { + it('uses flux2_klein_model_loader for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_klein_model_loader:'))).toBe(true); + }); + + it('uses flux2_vae_decode for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_vae_decode:'))).toBe(true); + }); + + it('uses flux2_klein_text_encoder for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_klein_text_encoder:'))).toBe(true); + }); + + it('uses flux2_denoise for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeTypes = Object.values(graph.nodes).map((n) => n.type); + expect(nodeTypes).toContain('flux2_denoise'); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index ba27e5dbf6..407c921421 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -10,7 +10,8 @@ import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlic import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { isFlux2ReferenceImageConfig, isFluxKontextReferenceImageConfig } from 'features/controlLayers/store/types'; import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; -import { zImageField } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { zImageField, zModelIdentifierField } from 'features/nodes/types/common'; import { addFlux2KleinLoRAs } from 'features/nodes/util/graph/generation/addFlux2KleinLoRAs'; import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill'; import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs'; @@ -26,8 +27,10 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils'; import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; +import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { t } from 'i18next'; +import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; @@ -141,7 +144,23 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); + const sourceModel = variantMatch ?? (kleinQwen3EncoderModel ? diffusersModels[0] : undefined); + if (sourceModel) { + qwen3SourceModel = zModelIdentifierField.parse(sourceModel); + } + } + modelLoader = g.addNode({ type: 'flux2_klein_model_loader', id: getPrefixedId('flux2_klein_model_loader'), @@ -149,6 +168,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise { const dispatch = useAppDispatch(); const { t } = useTranslation(); const kleinVaeModel = useAppSelector(selectKleinVaeModel); + const mainModelConfig = useAppSelector(selectMainModelConfig); const [modelConfigs, { isLoading }] = useFlux2VAEModels(); + const [diffusersModels] = useFlux2DiffusersModels(); const _onChange = useCallback( (model: VAEModelConfig | null) => { @@ -42,6 +45,11 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { isLoading, }); + const hasDiffusersSource = mainModelConfig?.format === 'diffusers' || diffusersModels.length > 0; + const placeholder = hasDiffusersSource + ? t('modelManager.flux2KleinVaePlaceholder') + : t('modelManager.flux2KleinVaeNoModelPlaceholder'); + return ( {t('modelManager.flux2KleinVae')} @@ -51,7 +59,7 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { onChange={onChange} noOptionsMessage={noOptionsMessage} isClearable - placeholder={t('modelManager.flux2KleinVaePlaceholder')} + placeholder={placeholder} /> ); @@ -59,15 +67,6 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { ParamFlux2KleinVaeModelSelect.displayName = 'ParamFlux2KleinVaeModelSelect'; -/** - * Maps FLUX.2 Klein variants to compatible Qwen3 encoder variants - */ -const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { - klein_4b: 'qwen3_4b', - klein_9b: 'qwen3_8b', - klein_9b_base: 'qwen3_8b', -}; - /** * FLUX.2 Klein Qwen3 Encoder Model Select * Selects a Qwen3 text encoder model for FLUX.2 Klein @@ -79,6 +78,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { const kleinQwen3EncoderModel = useAppSelector(selectKleinQwen3EncoderModel); const mainModelConfig = useAppSelector(selectMainModelConfig); const [allModelConfigs, { isLoading }] = useQwen3EncoderModels(); + const [diffusersModels] = useFlux2DiffusersModels(); // Filter Qwen3 encoders based on the main model's variant const modelConfigs = useMemo(() => { @@ -112,6 +112,20 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { isLoading, }); + // Qwen3 encoder requires a Qwen3-compatible diffusers model (variants that share the same Qwen3 encoder). + const hasMatchingDiffusersSource = + mainModelConfig?.format === 'diffusers' || + diffusersModels.some( + (m) => + 'variant' in m && + mainModelConfig && + 'variant' in mainModelConfig && + isFlux2KleinQwen3Compatible(m.variant, mainModelConfig.variant) + ); + const placeholder = hasMatchingDiffusersSource + ? t('modelManager.flux2KleinQwen3EncoderPlaceholder') + : t('modelManager.flux2KleinQwen3EncoderNoModelPlaceholder'); + return ( {t('modelManager.flux2KleinQwen3Encoder')} @@ -121,7 +135,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { onChange={onChange} noOptionsMessage={noOptionsMessage} isClearable - placeholder={t('modelManager.flux2KleinQwen3EncoderPlaceholder')} + placeholder={placeholder} /> ); diff --git a/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts new file mode 100644 index 0000000000..b9508a4f82 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts @@ -0,0 +1,24 @@ +/** + * Maps a FLUX.2 Klein main-model variant to the Qwen3 encoder variant it uses. + * Multiple Klein variants can share the same Qwen3 variant (e.g. `klein_9b` and + * `klein_9b_base` both use `qwen3_8b`), so two different Klein variants can be + * Qwen3-compatible sources for each other. + */ +export const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { + klein_4b: 'qwen3_4b', + klein_9b: 'qwen3_8b', + klein_9b_base: 'qwen3_8b', +}; + +/** + * Returns true if two Klein variants share the same Qwen3 encoder and can therefore + * be used as a Qwen3 source for each other. + */ +export const isFlux2KleinQwen3Compatible = (variantA: unknown, variantB: unknown): boolean => { + if (typeof variantA !== 'string' || typeof variantB !== 'string') { + return false; + } + const qwen3A = KLEIN_TO_QWEN3_VARIANT_MAP[variantA]; + const qwen3B = KLEIN_TO_QWEN3_VARIANT_MAP[variantB]; + return qwen3A !== undefined && qwen3A === qwen3B; +}; diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts new file mode 100644 index 0000000000..632006050e --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -0,0 +1,272 @@ +import { describe, expect, it, vi } from 'vitest'; + +vi.mock('features/dynamicPrompts/util/getShouldProcessPrompt', () => ({ + getShouldProcessPrompt: vi.fn(() => false), +})); + +vi.mock('i18next', () => ({ + default: { + t: (key: string) => key, + }, +})); + +import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types'; +import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; +import type { MainModelConfig } from 'services/api/types'; + +import { getReasonsWhyCannotEnqueueCanvasTab, getReasonsWhyCannotEnqueueGenerateTab } from './readiness'; + +// --- Fixtures --- + +const flux2DiffusersModel = { + key: 'flux2-diff', + hash: 'h', + name: 'FLUX.2 Klein 4B', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +} as unknown as MainModelConfig; + +const flux2GGUF4BModel = { + key: 'flux2-gguf-4b', + hash: 'h', + name: 'FLUX.2 Klein 4B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_4b', +} as unknown as MainModelConfig; + +const flux2GGUF9BModel = { + key: 'flux2-gguf-9b', + hash: 'h', + name: 'FLUX.2 Klein 9B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_9b', +} as unknown as MainModelConfig; + +const kleinVaeModel = { key: 'vae', name: 'VAE', base: 'flux2', type: 'vae' }; +const kleinQwen3Model = { key: 'qwen3', name: 'Qwen3', base: 'flux2', type: 'qwen3_encoder' }; + +const baseDynamicPrompts: DynamicPromptsState = { + _version: 1, + maxPrompts: 100, + combinatorial: false, + prompts: ['test prompt'], + parsingError: undefined, + isError: false, + isLoading: false, + seedBehaviour: 'PER_PROMPT', +}; + +const baseRefImages: RefImagesState = { + entities: [], + ipAdapters: { entities: [], ids: [] }, +} as unknown as RefImagesState; + +const baseParams = { + positivePrompt: 'test', + kleinVaeModel: null, + kleinQwen3EncoderModel: null, +} as unknown as ParamsState; + +// --- Helpers --- + +const buildGenerateTabArg = (overrides: { + model?: MainModelConfig | null; + kleinVaeModel?: unknown; + kleinQwen3EncoderModel?: unknown; + hasFlux2DiffusersVaeSource?: boolean; + hasFlux2DiffusersQwen3Source?: boolean; +}) => ({ + isConnected: true, + model: overrides.model ?? flux2DiffusersModel, + params: { + ...baseParams, + kleinVaeModel: overrides.kleinVaeModel ?? null, + kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + } as unknown as ParamsState, + refImages: baseRefImages, + loras: [], + dynamicPrompts: baseDynamicPrompts, + hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false, + hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false, +}); + +const buildCanvasTabArg = (overrides: { + model?: MainModelConfig | null; + kleinVaeModel?: unknown; + kleinQwen3EncoderModel?: unknown; + hasFlux2DiffusersVaeSource?: boolean; + hasFlux2DiffusersQwen3Source?: boolean; +}) => ({ + isConnected: true, + model: overrides.model ?? flux2DiffusersModel, + canvas: { + bbox: { + scaleMethod: 'none', + rect: { width: 1024, height: 1024 }, + scaledSize: { width: 1024, height: 1024 }, + }, + controlLayers: { entities: [] }, + regionalGuidance: { entities: [] }, + rasterLayers: { entities: [] }, + inpaintMasks: { entities: [] }, + }, + params: { + ...baseParams, + kleinVaeModel: overrides.kleinVaeModel ?? null, + kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + } as unknown as ParamsState, + refImages: baseRefImages, + loras: [], + dynamicPrompts: baseDynamicPrompts, + canvasIsFiltering: false, + canvasIsTransforming: false, + canvasIsRasterizing: false, + canvasIsCompositing: false, + canvasIsSelectingObject: false, + hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false, + hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false, +}); + +const hasFlux2VaeReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('noFlux2KleinVaeModelSelected')); + +const hasFlux2Qwen3Reason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('noFlux2KleinQwen3EncoderModelSelected')); + +// --- Tests --- + +describe('FLUX.2 Klein readiness checks – generate tab', () => { + it('no errors when main model is diffusers (VAE/Qwen3 extracted from it)', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2DiffusersModel })); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF4BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('errors for both VAE and Qwen3 when GGUF model with no diffusers source and no standalone models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2GGUF4BModel })); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('errors only for Qwen3 when GGUF model with standalone VAE but no Qwen3 and no diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: flux2GGUF4BModel, kleinVaeModel: kleinVaeModel }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('errors only for VAE when GGUF model with standalone Qwen3 but no VAE and no diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: flux2GGUF4BModel, kleinQwen3EncoderModel: kleinQwen3Model }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both standalone VAE and Qwen3', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF4BModel, + kleinVaeModel: kleinVaeModel, + kleinQwen3EncoderModel: kleinQwen3Model, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('VAE ok but Qwen3 errors when GGUF 9B model with only a 4B diffusers source (variant mismatch)', () => { + // User has Klein 9B GGUF selected, only a 4B diffusers model installed. + // VAE is shared across variants so it's ok. Qwen3 encoder differs, so it's not ok. + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: false, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('no errors when GGUF 9B model with variant-matching diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); +}); + +describe('FLUX.2 Klein readiness checks – canvas tab', () => { + it('no errors when main model is diffusers', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2DiffusersModel }) as never); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF4BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('errors for both VAE and Qwen3 when GGUF model with no sources', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2GGUF4BModel }) as never); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('no errors when GGUF model with both standalone VAE and Qwen3', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF4BModel, + kleinVaeModel: kleinVaeModel, + kleinQwen3EncoderModel: kleinQwen3Model, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('VAE ok but Qwen3 errors when GGUF 9B with variant-mismatched diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: false, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); +}); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 67dfe3141c..5802a2aed5 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -33,12 +33,14 @@ import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/ import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue'; import type { UpscaleState } from 'features/parameters/store/upscaleSlice'; import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice'; +import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein'; import { getGridSize } from 'features/parameters/util/optimalDimension'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import type { TabName } from 'features/ui/store/uiTypes'; import i18n from 'i18next'; import { atom, computed } from 'nanostores'; import { useEffect } from 'react'; +import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType'; import type { MainOrExternalModelConfig } from 'services/api/types'; import { isExternalApiModelConfig } from 'services/api/types'; import { $isConnected } from 'services/events/stores'; @@ -109,6 +111,12 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { } = arg; if (tab === 'generate') { const model = selectMainModelConfig(store.getState()); + const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState()); + const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0; + const modelVariant = model && 'variant' in model ? model.variant : undefined; + const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some( + (m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); const reasons = await getReasonsWhyCannotEnqueueGenerateTab({ isConnected, model, @@ -116,10 +124,18 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { refImages, dynamicPrompts, loras, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'canvas') { const model = selectMainModelConfig(store.getState()); + const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState()); + const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0; + const modelVariant = model && 'variant' in model ? model.variant : undefined; + const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some( + (m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); const reasons = await getReasonsWhyCannotEnqueueCanvasTab({ isConnected, model, @@ -133,6 +149,8 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { canvasIsCompositing, canvasIsSelectingObject, loras, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'workflows') { @@ -220,15 +238,26 @@ export const useReadinessWatcher = () => { const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') }); -const getReasonsWhyCannotEnqueueGenerateTab = (arg: { +export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { isConnected: boolean; model: MainOrExternalModelConfig | null | undefined; params: ParamsState; refImages: RefImagesState; loras: LoRA[]; dynamicPrompts: DynamicPromptsState; + hasFlux2DiffusersVaeSource: boolean; + hasFlux2DiffusersQwen3Source: boolean; }) => { - const { isConnected, model, params, refImages, loras, dynamicPrompts } = arg; + const { + isConnected, + model, + params, + refImages, + loras, + dynamicPrompts, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, + } = arg; const { positivePrompt } = params; const reasons: Reason[] = []; @@ -260,7 +289,17 @@ const getReasonsWhyCannotEnqueueGenerateTab = (arg: { } } - // FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed + if (model?.base === 'flux2' && model.format !== 'diffusers') { + // Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder + // unless a diffusers flux2 model is available to extract them from. + // VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model. + if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') }); + } + if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') }); + } + } if (model?.base === 'qwen-image' && model.format === 'gguf_quantized') { if (!params.qwenImageComponentSource) { @@ -452,7 +491,7 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: { return reasons; }; -const getReasonsWhyCannotEnqueueCanvasTab = (arg: { +export const getReasonsWhyCannotEnqueueCanvasTab = (arg: { isConnected: boolean; model: MainOrExternalModelConfig | null | undefined; canvas: CanvasState; @@ -465,6 +504,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { canvasIsRasterizing: boolean; canvasIsCompositing: boolean; canvasIsSelectingObject: boolean; + hasFlux2DiffusersVaeSource: boolean; + hasFlux2DiffusersQwen3Source: boolean; }) => { const { isConnected, @@ -479,6 +520,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { canvasIsRasterizing, canvasIsCompositing, canvasIsSelectingObject, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, } = arg; const { positivePrompt } = params; const reasons: Reason[] = []; @@ -571,7 +614,17 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { } if (model?.base === 'flux2') { - // FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed + // Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder + // unless a diffusers flux2 model is available to extract them from. + // VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model. + if (model.format !== 'diffusers') { + if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') }); + } + if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') }); + } + } const { bbox } = canvas; const gridSize = getGridSize('flux'); // FLUX.2 uses same grid size as FLUX.1 diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index c3d0decd53..f279d46d82 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -111,9 +111,13 @@ type DeleteOrphanedModelsResponse = { errors: Record; }; +type GetModelConfigsArg = { + order_by?: string; + direction?: string; +} | void; + const modelConfigsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions); @@ -338,8 +342,11 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['ModelInstalls'], }), - getModelConfigs: build.query, void>({ - query: () => ({ url: buildModelsUrl() }), + getModelConfigs: build.query, GetModelConfigsArg>({ + query: (arg) => { + const queryStr = arg ? `?${queryString.stringify(arg)}` : ''; + return { url: buildModelsUrl(queryStr) }; + }, providesTags: (result) => { const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }]; if (result) { @@ -498,5 +505,5 @@ export const { useDeleteOrphanedModelsMutation, } = modelsApi; -export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(); +export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(undefined); export const selectMissingModelsQuery = modelsApi.endpoints.getMissingModels.select(); diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index 55746e5294..2496c06ed0 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -18,6 +18,7 @@ import { isControlNetModelConfig, isExternalApiModelConfig, isFlux1VAEModelConfig, + isFlux2DiffusersMainModelConfig, isFlux2VAEModelConfig, isFluxKontextModelConfig, isFluxReduxModelConfig, @@ -101,6 +102,7 @@ export const useFlux2VAEModels = () => buildModelsHook(isFlux2VAEModelConfig)(); export const useAnimaVAEModels = () => buildModelsHook(isAnimaVAEModelConfig)(); export const useAnimaQwen3EncoderModels = () => buildModelsHook(isAnimaQwen3EncoderModelConfig)(); export const useZImageDiffusersModels = () => buildModelsHook(isZImageDiffusersMainModelConfig)(); +export const useFlux2DiffusersModels = () => buildModelsHook(isFlux2DiffusersMainModelConfig)(); export const useQwenImageDiffusersModels = () => buildModelsHook(isQwenImageDiffusersMainModelConfig)(); export const useQwen3EncoderModels = () => buildModelsHook(isQwen3EncoderModelConfig)(); export const useGlobalReferenceImageModels = buildModelsHook( @@ -140,6 +142,7 @@ export const selectAnimaQwen3EncoderModels = buildModelsSelector(isAnimaQwen3Enc export const selectQwen3EncoderModels = buildModelsSelector(isQwen3EncoderModelConfig); export const selectQwenImageDiffusersModels = buildModelsSelector(isQwenImageDiffusersMainModelConfig); export const selectZImageDiffusersModels = buildModelsSelector(isZImageDiffusersMainModelConfig); +export const selectFlux2DiffusersModels = buildModelsSelector(isFlux2DiffusersMainModelConfig); export const selectFluxVAEModels = buildModelsSelector(isFluxVAEModelConfig); export const selectAnimaVAEModels = buildModelsSelector(isAnimaVAEModelConfig); export const selectT5EncoderModels = buildModelsSelector(isT5EncoderModelConfigOrSubmodel); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index f23be1707d..6de38aecf1 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -22999,6 +22999,12 @@ export type components = { */ config_path?: string | null; }; + /** + * ModelRecordOrderBy + * @description The order in which to return model summaries. + * @enum {string} + */ + ModelRecordOrderBy: "default" | "type" | "base" | "name" | "format" | "size" | "created_at" | "updated_at" | "path"; /** ModelRelationshipBatchRequest */ ModelRelationshipBatchRequest: { /** @@ -31550,6 +31556,10 @@ export interface operations { model_name?: string | null; /** @description Exact match on the format of the model (e.g. 'diffusers') */ model_format?: components["schemas"]["ModelFormat"] | null; + /** @description The field to order by */ + order_by?: components["schemas"]["ModelRecordOrderBy"]; + /** @description The direction to order by */ + direction?: components["schemas"]["SQLiteDirection"]; }; header?: never; path?: never; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 3624d7ef6a..9deefada23 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -457,6 +457,10 @@ export const isZImageDiffusersMainModelConfig = (config: AnyModelConfig): config return config.type === 'main' && config.base === 'z-image' && config.format === 'diffusers'; }; +export const isFlux2DiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && config.base === 'flux2' && config.format === 'diffusers'; +}; + export const isQwenImageDiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base === 'qwen-image' && config.format === 'diffusers'; }; diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 2b6c54d5b0..19a1b74e73 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -11,11 +11,13 @@ from pydantic import ValidationError from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.model_records import ( DuplicateModelException, + ModelRecordOrderBy, ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException, ) from invokeai.app.services.model_records.model_records_base import ModelRecordChanges +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config from invokeai.backend.model_manager.configs.main import ( @@ -364,6 +366,73 @@ def test_filter_2(store: ModelRecordServiceBase): assert len(matches) == 1 +def test_search_by_attr_sorting(store: ModelRecordServiceSQL): + config1 = Main_Diffusers_SD1_Config( + path="/tmp/config1", + name="alpha", + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, + hash="CONFIG1HASH", + file_size=1000, + source="test/source/", + source_type=ModelSourceType.Path, + variant=ModelVariantType.Normal, + prediction_type=SchedulerPredictionType.Epsilon, + repo_variant=ModelRepoVariant.Default, + ) + config2 = Main_Diffusers_SD2_Config( + path="/tmp/config2", + name="beta", + base=BaseModelType.StableDiffusion2, + type=ModelType.Main, + hash="CONFIG2HASH", + file_size=2000, + source="test/source/", + source_type=ModelSourceType.Path, + variant=ModelVariantType.Normal, + prediction_type=SchedulerPredictionType.Epsilon, + repo_variant=ModelRepoVariant.Default, + ) + config3 = VAE_Diffusers_SD1_Config( + path="/tmp/config3", + name="gamma", + base=BaseModelType.StableDiffusion1, + type=ModelType.VAE, + hash="CONFIG3HASH", + file_size=500, + source="test/source/", + source_type=ModelSourceType.Path, + repo_variant=ModelRepoVariant.Default, + ) + for c in config1, config2, config3: + store.add_model(c) + + # Test sorting by Name Ascending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Ascending) + assert len(matches) == 3 + assert matches[0].name == "alpha" + assert matches[1].name == "beta" + assert matches[2].name == "gamma" + + # Test sorting by Name Descending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Descending) + assert matches[0].name == "gamma" + assert matches[1].name == "beta" + assert matches[2].name == "alpha" + + # Test sorting by Size Ascending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Ascending) + assert matches[0].name == "gamma" # 500 + assert matches[1].name == "alpha" # 1000 + assert matches[2].name == "beta" # 2000 + + # Test sorting by Size Descending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Descending) + assert matches[0].name == "beta" # 2000 + assert matches[1].name == "alpha" # 1000 + assert matches[2].name == "gamma" # 500 + + def test_model_record_changes(): # This test guards against some unexpected behaviours from pydantic's union evaluation. See #6035 changes = ModelRecordChanges.model_validate({"default_settings": {"preprocessor": "value"}})