mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'main' into lstein/recall-reference-images
This commit is contained in:
@@ -26,9 +26,11 @@ from invokeai.app.services.model_install.model_install_common import ModelInstal
|
||||
from invokeai.app.services.model_records import (
|
||||
InvalidModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.orphaned_models import OrphanedModelInfo
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
|
||||
@@ -159,6 +161,8 @@ async def list_model_records(
|
||||
model_format: Optional[ModelFormat] = Query(
|
||||
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
|
||||
),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
@@ -167,12 +171,23 @@ async def list_model_records(
|
||||
for base_model in base_models:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(
|
||||
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
model_format=model_format,
|
||||
order_by=order_by,
|
||||
direction=direction,
|
||||
)
|
||||
)
|
||||
else:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
record_store.search_by_attr(
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
model_format=model_format,
|
||||
order_by=order_by,
|
||||
direction=direction,
|
||||
)
|
||||
)
|
||||
for index, model in enumerate(found_models):
|
||||
found_models[index] = prepare_model_config_for_response(model, ApiDependencies)
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import List, Optional, Set, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
|
||||
from invokeai.backend.model_manager.configs.external_api import (
|
||||
@@ -60,6 +61,10 @@ class ModelRecordOrderBy(str, Enum):
|
||||
Base = "base"
|
||||
Name = "name"
|
||||
Format = "format"
|
||||
Size = "size"
|
||||
DateAdded = "created_at"
|
||||
DateModified = "updated_at"
|
||||
Path = "path"
|
||||
|
||||
|
||||
class ModelSummary(BaseModel):
|
||||
@@ -200,7 +205,11 @@ class ModelRecordServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
|
||||
direction: SQLiteDirection = SQLiteDirection.Ascending,
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
pass
|
||||
@@ -237,6 +246,8 @@ class ModelRecordServiceBase(ABC):
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
|
||||
direction: SQLiteDirection = SQLiteDirection.Ascending,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
|
||||
@@ -57,6 +57,7 @@ from invokeai.app.services.model_records.model_records_base import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
@@ -257,6 +258,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
|
||||
direction: SQLiteDirection = SQLiteDirection.Ascending,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
@@ -266,18 +268,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
:param model_type: Filter by type of model (optional)
|
||||
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||
:param order_by: Result order
|
||||
:param direction: Result direction
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Base: "base COLLATE NOCASE",
|
||||
ModelRecordOrderBy.Name: "name COLLATE NOCASE",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
|
||||
ModelRecordOrderBy.DateAdded: "created_at",
|
||||
ModelRecordOrderBy.DateModified: "updated_at",
|
||||
ModelRecordOrderBy.Path: "path",
|
||||
}
|
||||
|
||||
where_clause: list[str] = []
|
||||
@@ -301,7 +309,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
SELECT config
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
@@ -357,17 +365,26 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
|
||||
direction: SQLiteDirection = SQLiteDirection.Ascending,
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Base: "base COLLATE NOCASE",
|
||||
ModelRecordOrderBy.Name: "name COLLATE NOCASE",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
|
||||
ModelRecordOrderBy.DateAdded: "created_at",
|
||||
ModelRecordOrderBy.DateModified: "updated_at",
|
||||
ModelRecordOrderBy.Path: "path",
|
||||
}
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
@@ -385,7 +402,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
|
||||
@@ -714,14 +714,25 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
||||
- diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format)
|
||||
- diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format)
|
||||
- diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale)
|
||||
- lora_unet__layers_X_attention_to_k.lora_down.weight (Kohya format)
|
||||
"""
|
||||
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import (
|
||||
is_state_dict_likely_z_image_kohya_lora,
|
||||
)
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
# Check for Z-Image specific LoRA patterns
|
||||
# Check for Kohya format first
|
||||
if is_state_dict_likely_z_image_kohya_lora(state_dict):
|
||||
return
|
||||
|
||||
# Check for Z-Image specific LoRA patterns (dot-notation formats)
|
||||
has_z_image_lora_keys = state_dict_has_any_keys_starting_with(
|
||||
state_dict,
|
||||
{
|
||||
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
|
||||
"diffusion_model.context_refiner.",
|
||||
"diffusion_model.noise_refiner.",
|
||||
"transformer.layers.", # OneTrainer/diffusers prefix variant
|
||||
"base_model.model.transformer.layers.", # PEFT-wrapped variant
|
||||
},
|
||||
@@ -751,15 +762,26 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
||||
Z-Image uses S3-DiT architecture with layer names like:
|
||||
- diffusion_model.layers.0.attention.to_k.lora_A.weight
|
||||
- diffusion_model.layers.0.feed_forward.w1.lora_A.weight
|
||||
- lora_unet__layers_0_attention_to_k.lora_down.weight (Kohya format)
|
||||
"""
|
||||
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import (
|
||||
is_state_dict_likely_z_image_kohya_lora,
|
||||
)
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
# Check for Z-Image transformer layer patterns
|
||||
# Check for Kohya format
|
||||
if is_state_dict_likely_z_image_kohya_lora(state_dict):
|
||||
return BaseModelType.ZImage
|
||||
|
||||
# Check for Z-Image transformer layer patterns (dot-notation formats)
|
||||
# Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks)
|
||||
has_z_image_keys = state_dict_has_any_keys_starting_with(
|
||||
state_dict,
|
||||
{
|
||||
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
|
||||
"diffusion_model.context_refiner.",
|
||||
"diffusion_model.noise_refiner.",
|
||||
"transformer.layers.", # OneTrainer/diffusers prefix variant
|
||||
"base_model.model.transformer.layers.", # PEFT-wrapped variant
|
||||
},
|
||||
|
||||
@@ -160,17 +160,20 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
||||
".lora_A.weight",
|
||||
".lora_B.weight",
|
||||
".dora_scale",
|
||||
".alpha",
|
||||
)
|
||||
|
||||
# First pass: check if any key has LoRA suffixes - if so, this is a LoRA not a main model
|
||||
for key in state_dict.keys():
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
|
||||
# If we find any LoRA-specific keys, this is not a main model
|
||||
if key.endswith(lora_suffixes):
|
||||
return False
|
||||
|
||||
# Check for Z-Image specific key prefixes
|
||||
# Second pass: check for Z-Image specific key parts
|
||||
for key in state_dict.keys():
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
# Handle both direct keys (cap_embedder.0.weight) and
|
||||
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
|
||||
key_parts = key.split(".")
|
||||
|
||||
@@ -178,12 +178,43 @@ class Flux2VAELoader(ModelLoader):
|
||||
if is_bfl_format:
|
||||
sd = self._convert_flux2_vae_bfl_to_diffusers(sd)
|
||||
|
||||
# FLUX.2 VAE configuration (32 latent channels)
|
||||
# Based on the official FLUX.2 VAE architecture
|
||||
# Use default config - AutoencoderKLFlux2 has built-in defaults
|
||||
# FLUX.2 VAE configuration (32 latent channels).
|
||||
# The standard FLUX.2 VAE uses block_out_channels=(128,256,512,512) for both
|
||||
# encoder and decoder. The "small decoder" variant from
|
||||
# black-forest-labs/FLUX.2-small-decoder keeps the full encoder but uses a
|
||||
# narrower decoder with channels (96,192,384,384). AutoencoderKLFlux2 only
|
||||
# exposes a single block_out_channels, so we build the model with the
|
||||
# encoder's channels and, if the decoder differs, replace just the decoder
|
||||
# submodule with a matching one before loading the state dict.
|
||||
encoder_block_out_channels = (128, 256, 512, 512)
|
||||
decoder_block_out_channels = encoder_block_out_channels
|
||||
if "encoder.conv_in.weight" in sd and "encoder.conv_norm_out.weight" in sd:
|
||||
enc_last = int(sd["encoder.conv_norm_out.weight"].shape[0])
|
||||
enc_first = int(sd["encoder.conv_in.weight"].shape[0])
|
||||
encoder_block_out_channels = (enc_first, enc_first * 2, enc_last, enc_last)
|
||||
if "decoder.conv_in.weight" in sd and "decoder.conv_norm_out.weight" in sd:
|
||||
dec_last = int(sd["decoder.conv_in.weight"].shape[0])
|
||||
dec_first = int(sd["decoder.conv_norm_out.weight"].shape[0])
|
||||
decoder_block_out_channels = (dec_first, dec_first * 2, dec_last, dec_last)
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoencoderKLFlux2()
|
||||
model = AutoencoderKLFlux2(block_out_channels=encoder_block_out_channels)
|
||||
if decoder_block_out_channels != encoder_block_out_channels:
|
||||
# Rebuild the decoder with the smaller channel widths.
|
||||
from diffusers.models.autoencoders.vae import Decoder
|
||||
|
||||
cfg = model.config
|
||||
model.decoder = Decoder(
|
||||
in_channels=cfg.latent_channels,
|
||||
out_channels=cfg.out_channels,
|
||||
up_block_types=cfg.up_block_types,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
layers_per_block=cfg.layers_per_block,
|
||||
norm_num_groups=cfg.norm_num_groups,
|
||||
act_fn=cfg.act_fn,
|
||||
mid_block_add_attention=cfg.mid_block_add_attention,
|
||||
)
|
||||
|
||||
# Convert to bfloat16 and load
|
||||
for k in sd.keys():
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Z-Image LoRA conversion utilities.
|
||||
|
||||
Z-Image uses S3-DiT transformer architecture with Qwen3 text encoder.
|
||||
LoRAs for Z-Image typically follow the diffusers PEFT format.
|
||||
LoRAs for Z-Image typically follow the diffusers PEFT format or Kohya format.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
@@ -16,6 +17,29 @@ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import (
|
||||
)
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
# Regex for Kohya-format Z-Image transformer keys.
|
||||
# Example keys:
|
||||
# lora_unet__layers_0_attention_to_k.alpha
|
||||
# lora_unet__layers_0_attention_to_k.lora_down.weight
|
||||
# lora_unet__context_refiner_0_feed_forward_w1.lora_up.weight
|
||||
# lora_unet__noise_refiner_1_attention_to_v.lora_down.weight
|
||||
Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX = (
|
||||
r"lora_unet__(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)"
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_z_image_kohya_lora(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely a Z-Image LoRA in Kohya format.
|
||||
|
||||
Kohya Z-Image LoRAs have keys like:
|
||||
- lora_unet__layers_0_attention_to_k.lora_down.weight
|
||||
- lora_unet__context_refiner_0_feed_forward_w1.alpha
|
||||
- lora_unet__noise_refiner_1_attention_to_v.lora_up.weight
|
||||
"""
|
||||
return any(
|
||||
isinstance(k, str) and re.match(Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX, k.split(".")[0]) for k in state_dict.keys()
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) -> bool:
|
||||
"""Checks if the provided state dict is likely a Z-Image LoRA.
|
||||
@@ -23,6 +47,9 @@ def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor])
|
||||
Z-Image LoRAs can have keys for transformer and/or Qwen3 text encoder.
|
||||
They may use various prefixes depending on the training framework.
|
||||
"""
|
||||
if is_state_dict_likely_z_image_kohya_lora(state_dict):
|
||||
return True
|
||||
|
||||
str_keys = [k for k in state_dict.keys() if isinstance(k, str)]
|
||||
|
||||
# Check for Z-Image transformer keys (S3-DiT architecture)
|
||||
@@ -57,6 +84,7 @@ def lora_model_from_z_image_state_dict(
|
||||
- "transformer." or "base_model.model.transformer." for diffusers PEFT format
|
||||
- "diffusion_model." for some training frameworks
|
||||
- "text_encoder." or "base_model.model.text_encoder." for Qwen3 encoder
|
||||
- "lora_unet__" for Kohya format (underscores instead of dots)
|
||||
|
||||
Args:
|
||||
state_dict: The LoRA state dict
|
||||
@@ -65,6 +93,10 @@ def lora_model_from_z_image_state_dict(
|
||||
Returns:
|
||||
A ModelPatchRaw containing the LoRA layers
|
||||
"""
|
||||
# If Kohya format, convert keys first then process normally
|
||||
if is_state_dict_likely_z_image_kohya_lora(state_dict):
|
||||
state_dict = _convert_z_image_kohya_state_dict(state_dict)
|
||||
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
|
||||
# Group keys by layer
|
||||
@@ -120,6 +152,45 @@ def lora_model_from_z_image_state_dict(
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
def _convert_z_image_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Converts a Kohya-format Z-Image LoRA state dict to diffusion_model dot-notation.
|
||||
|
||||
Example key conversions:
|
||||
- lora_unet__layers_0_attention_to_k.lora_down.weight -> diffusion_model.layers.0.attention.to_k.lora_down.weight
|
||||
- lora_unet__context_refiner_0_feed_forward_w1.alpha -> diffusion_model.context_refiner.0.feed_forward.w1.alpha
|
||||
- lora_unet__noise_refiner_1_attention_to_v.lora_up.weight -> diffusion_model.noise_refiner.1.attention.to_v.lora_up.weight
|
||||
"""
|
||||
converted: Dict[str, torch.Tensor] = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(key, str) or not key.startswith("lora_unet__"):
|
||||
converted[key] = value
|
||||
continue
|
||||
|
||||
# Split into layer name and param suffix (e.g. "lora_down.weight", "alpha")
|
||||
layer_name, _, param_suffix = key.partition(".")
|
||||
|
||||
# Strip lora_unet__ prefix
|
||||
remainder = layer_name[len("lora_unet__") :]
|
||||
|
||||
# Convert Kohya underscore format to dot-notation using the known structure
|
||||
match = re.match(
|
||||
r"(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)$",
|
||||
remainder,
|
||||
)
|
||||
if match:
|
||||
block, idx, submodule, param = match.groups()
|
||||
new_layer = f"diffusion_model.{block}.{idx}.{submodule}.{param}"
|
||||
else:
|
||||
# Fallback: keep original key for unrecognized patterns
|
||||
converted[key] = value
|
||||
continue
|
||||
|
||||
new_key = f"{new_layer}.{param_suffix}" if param_suffix else new_layer
|
||||
converted[new_key] = value
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]:
|
||||
"""Convert layer dict keys from PEFT format to internal format."""
|
||||
if "lora_A.weight" in layer_dict:
|
||||
|
||||
@@ -40,11 +40,9 @@ def directory_size(directory: Path) -> int:
|
||||
Return the aggregate size of all files in a directory (bytes).
|
||||
"""
|
||||
sum = 0
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for root, _, files in os.walk(directory):
|
||||
for f in files:
|
||||
sum += Path(root, f).stat().st_size
|
||||
for d in dirs:
|
||||
sum += Path(root, d).stat().st_size
|
||||
return sum
|
||||
|
||||
|
||||
|
||||
@@ -1203,6 +1203,15 @@
|
||||
"modelType": "Model Type",
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelUpdateFailed": "Model Update Failed",
|
||||
"sortByName": "Name",
|
||||
"sortByBase": "Base",
|
||||
"sortBySize": "Size",
|
||||
"sortByDateAdded": "Date Added",
|
||||
"sortByDateModified": "Date Modified",
|
||||
"sortByPath": "Path",
|
||||
"sortByType": "Type",
|
||||
"sortByFormat": "Format",
|
||||
"sortDefault": "Default",
|
||||
"name": "Name",
|
||||
"externalProvider": "External Provider",
|
||||
"externalCapabilities": "External Capabilities",
|
||||
@@ -1316,9 +1325,11 @@
|
||||
"zImageQwen3Source": "Qwen3 & VAE Source Model",
|
||||
"zImageQwen3SourcePlaceholder": "Required if VAE/Encoder empty",
|
||||
"flux2KleinVae": "VAE (optional)",
|
||||
"flux2KleinVaePlaceholder": "From main model",
|
||||
"flux2KleinVaePlaceholder": "From diffusers model",
|
||||
"flux2KleinVaeNoModelPlaceholder": "No diffusers model available",
|
||||
"flux2KleinQwen3Encoder": "Qwen3 Encoder (optional)",
|
||||
"flux2KleinQwen3EncoderPlaceholder": "From main model",
|
||||
"flux2KleinQwen3EncoderPlaceholder": "From diffusers model",
|
||||
"flux2KleinQwen3EncoderNoModelPlaceholder": "No diffusers model available",
|
||||
"qwenImageComponentSource": "VAE/Encoder Source (Diffusers)",
|
||||
"qwenImageComponentSourcePlaceholder": "Required for GGUF models",
|
||||
"qwenImageQuantization": "Encoder Quantization",
|
||||
@@ -1623,6 +1634,8 @@
|
||||
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
|
||||
"noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation",
|
||||
"noQwen3EncoderModelSelected": "No Qwen3 Encoder model selected for FLUX2 Klein generation",
|
||||
"noFlux2KleinVaeModelSelected": "No VAE selected. Non-diffusers FLUX.2 Klein models require a standalone VAE",
|
||||
"noFlux2KleinQwen3EncoderModelSelected": "No Qwen3 Encoder selected. Non-diffusers FLUX.2 Klein models require a standalone Qwen3 Encoder",
|
||||
"noQwenImageComponentSourceSelected": "GGUF Qwen Image models require a Diffusers Component Source for VAE/encoder",
|
||||
"noZImageVaeSourceSelected": "No VAE source: Select VAE (FLUX) or Qwen3 Source model",
|
||||
"noZImageQwen3EncoderSourceSelected": "No Qwen3 Encoder source: Select Qwen3 Encoder or Qwen3 Source model",
|
||||
|
||||
@@ -2660,7 +2660,9 @@
|
||||
"fitModeCover": "Copri",
|
||||
"smoothingMode": "Modalità di ricampionamento",
|
||||
"smoothingDesc": "Applica un ricampionamento di alta qualità lato backend alla conferma delle trasformazioni.",
|
||||
"smoothing": "Smussamento"
|
||||
"smoothing": "Smussamento",
|
||||
"smoothingModeBilinear": "Bilineare",
|
||||
"smoothingModeBicubic": "Bicubico"
|
||||
},
|
||||
"stagingArea": {
|
||||
"next": "Successiva",
|
||||
|
||||
@@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => {
|
||||
};
|
||||
};
|
||||
|
||||
export const SubMenuButtonContent = ({ label }: { label: string }) => {
|
||||
export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => {
|
||||
return (
|
||||
<Flex w="full" h="full" flexDir="row" justifyContent="space-between" alignItems="center">
|
||||
<Text>{label}</Text>
|
||||
<Icon as={PiCaretRightBold} />
|
||||
<Flex alignItems="center" gap={2}>
|
||||
{value !== undefined && (
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{value}
|
||||
</Text>
|
||||
)}
|
||||
<Icon as={PiCaretRightBold} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -35,6 +35,15 @@ type PayloadActionWithId<T = void> = T extends void
|
||||
} & T
|
||||
>;
|
||||
|
||||
/** Fingerprint used to match the same reference image entry after recall when ids are regenerated. */
|
||||
/** Empty configs of the same type may collide; the worst case is selecting an equivalent empty entity. */
|
||||
const getRefImageRecallMatchKey = (entity: RefImageState): string => {
|
||||
const { config } = entity;
|
||||
const imageName = config.image?.original.image.image_name ?? '';
|
||||
const modelKey = 'model' in config && config.model ? config.model.key : '';
|
||||
return `${config.type}\0${modelKey}\0${imageName}`;
|
||||
};
|
||||
|
||||
const slice = createSlice({
|
||||
name: 'refImages',
|
||||
initialState: getInitialRefImagesState(),
|
||||
@@ -54,13 +63,41 @@ const slice = createSlice({
|
||||
},
|
||||
refImagesRecalled: (state, action: PayloadAction<{ entities: RefImageState[]; replace: boolean }>) => {
|
||||
const { entities, replace } = action.payload;
|
||||
if (replace) {
|
||||
state.entities = entities;
|
||||
state.isPanelOpen = false;
|
||||
state.selectedEntityId = null;
|
||||
} else {
|
||||
if (!replace) {
|
||||
state.entities.push(...entities);
|
||||
return;
|
||||
}
|
||||
const wasPanelOpen = state.isPanelOpen;
|
||||
const previousSelectedId = state.selectedEntityId;
|
||||
let previousEntity: RefImageState | null = null;
|
||||
if (previousSelectedId !== null) {
|
||||
previousEntity = state.entities.find((e) => e.id === previousSelectedId) ?? null;
|
||||
}
|
||||
state.entities = entities;
|
||||
if (entities.length === 0) {
|
||||
state.selectedEntityId = null;
|
||||
state.isPanelOpen = false;
|
||||
return;
|
||||
}
|
||||
if (!wasPanelOpen) {
|
||||
state.selectedEntityId = null;
|
||||
return;
|
||||
}
|
||||
const firstEntity = entities[0];
|
||||
assert(firstEntity);
|
||||
if (previousSelectedId === null) {
|
||||
// Open panel must have a selection; otherwise, fall back to the first entity.
|
||||
state.selectedEntityId = firstEntity.id;
|
||||
return;
|
||||
}
|
||||
if (previousSelectedId !== null && entities.some((e) => e.id === previousSelectedId)) {
|
||||
state.selectedEntityId = previousSelectedId;
|
||||
return;
|
||||
}
|
||||
const previousKey = previousEntity ? getRefImageRecallMatchKey(previousEntity) : null;
|
||||
const matched =
|
||||
previousKey !== null ? entities.find((e) => getRefImageRecallMatchKey(e) === previousKey) : undefined;
|
||||
state.selectedEntityId = matched?.id ?? firstEntity.id;
|
||||
},
|
||||
refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => {
|
||||
const { id, croppableImage } = action.payload;
|
||||
|
||||
@@ -31,7 +31,7 @@ export type ModelCategoryData = {
|
||||
filter: (config: AnyModelConfig) => boolean;
|
||||
};
|
||||
|
||||
export const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
|
||||
const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
|
||||
unknown: {
|
||||
category: 'unknown',
|
||||
i18nKey: 'common.unknown',
|
||||
|
||||
@@ -25,6 +25,10 @@ const zModelManagerState = z.object({
|
||||
scanPath: z.string().optional(),
|
||||
shouldInstallInPlace: z.boolean(),
|
||||
selectedModelKeys: z.array(z.string()),
|
||||
orderBy: z
|
||||
.enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format'])
|
||||
.default('name'),
|
||||
sortDirection: z.enum(['asc', 'desc']).default('asc'),
|
||||
});
|
||||
|
||||
type ModelManagerState = z.infer<typeof zModelManagerState>;
|
||||
@@ -38,6 +42,8 @@ const getInitialState = (): ModelManagerState => ({
|
||||
scanPath: undefined,
|
||||
shouldInstallInPlace: true,
|
||||
selectedModelKeys: [],
|
||||
orderBy: 'name',
|
||||
sortDirection: 'asc',
|
||||
});
|
||||
|
||||
const slice = createSlice({
|
||||
@@ -77,6 +83,12 @@ const slice = createSlice({
|
||||
clearModelSelection: (state) => {
|
||||
state.selectedModelKeys = [];
|
||||
},
|
||||
setOrderBy: (state, action: PayloadAction<ModelManagerState['orderBy']>) => {
|
||||
state.orderBy = action.payload;
|
||||
},
|
||||
setSortDirection: (state, action: PayloadAction<ModelManagerState['sortDirection']>) => {
|
||||
state.sortDirection = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -90,6 +102,8 @@ export const {
|
||||
modelSelectionChanged,
|
||||
toggleModelSelection,
|
||||
clearModelSelection,
|
||||
setOrderBy,
|
||||
setSortDirection,
|
||||
} = slice.actions;
|
||||
|
||||
export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
|
||||
@@ -119,3 +133,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm
|
||||
export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType);
|
||||
export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace);
|
||||
export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys);
|
||||
export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy);
|
||||
export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection);
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import type { ModelCategoryData } from 'features/modelManagerV2/models';
|
||||
import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
|
||||
import {
|
||||
selectFilteredModelType,
|
||||
selectOrderBy,
|
||||
selectSortDirection,
|
||||
setFilteredModelType,
|
||||
setOrderBy,
|
||||
setSortDirection,
|
||||
} from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
PiCheckBold,
|
||||
PiFunnelBold,
|
||||
PiListBold,
|
||||
PiSortAscendingBold,
|
||||
PiSortDescendingBold,
|
||||
PiWarningBold,
|
||||
} from 'react-icons/pi';
|
||||
|
||||
type OrderBy = 'default' | 'name' | 'type' | 'base' | 'size' | 'created_at' | 'updated_at' | 'path' | 'format';
|
||||
|
||||
const ORDER_BY_OPTIONS: { key: OrderBy; i18nKey: string }[] = [
|
||||
{ key: 'default', i18nKey: 'modelManager.sortDefault' },
|
||||
{ key: 'name', i18nKey: 'modelManager.sortByName' },
|
||||
{ key: 'base', i18nKey: 'modelManager.sortByBase' },
|
||||
{ key: 'size', i18nKey: 'modelManager.sortBySize' },
|
||||
{ key: 'created_at', i18nKey: 'modelManager.sortByDateAdded' },
|
||||
{ key: 'updated_at', i18nKey: 'modelManager.sortByDateModified' },
|
||||
{ key: 'path', i18nKey: 'modelManager.sortByPath' },
|
||||
{ key: 'type', i18nKey: 'modelManager.sortByType' },
|
||||
{ key: 'format', i18nKey: 'modelManager.sortByFormat' },
|
||||
];
|
||||
|
||||
const SortByMenuItem = memo(({ option, label }: { option: OrderBy; label: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const orderBy = useAppSelector(selectOrderBy);
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(setOrderBy(option));
|
||||
}, [dispatch, option]);
|
||||
|
||||
return (
|
||||
<MenuItem
|
||||
onClick={onClick}
|
||||
bg={orderBy === option ? 'base.700' : 'transparent'}
|
||||
icon={orderBy === option ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
|
||||
>
|
||||
{label}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
SortByMenuItem.displayName = 'SortByMenuItem';
|
||||
|
||||
const SortBySubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
const orderBy = useAppSelector(selectOrderBy);
|
||||
|
||||
const currentSortLabel = useMemo(() => {
|
||||
const option = ORDER_BY_OPTIONS.find((o) => o.key === orderBy);
|
||||
if (!option) {
|
||||
return '';
|
||||
}
|
||||
return t(option.i18nKey);
|
||||
}, [orderBy, t]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiListBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('modelManager.sortBy', 'Sort By')} value={currentSortLabel} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
{ORDER_BY_OPTIONS.map(({ key, i18nKey }) => (
|
||||
<SortByMenuItem key={key} option={key} label={t(i18nKey)} />
|
||||
))}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
SortBySubMenu.displayName = 'SortBySubMenu';
|
||||
|
||||
const DirectionSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const direction = useAppSelector(selectSortDirection);
|
||||
const subMenu = useSubMenu();
|
||||
|
||||
const setDirectionAsc = useCallback(() => {
|
||||
dispatch(setSortDirection('asc'));
|
||||
}, [dispatch]);
|
||||
|
||||
const setDirectionDesc = useCallback(() => {
|
||||
dispatch(setSortDirection('desc'));
|
||||
}, [dispatch]);
|
||||
|
||||
const currentValue = direction === 'asc' ? t('common.ascending', 'Ascending') : t('common.descending', 'Descending');
|
||||
|
||||
return (
|
||||
<MenuItem
|
||||
{...subMenu.parentMenuItemProps}
|
||||
icon={direction === 'asc' ? <PiSortAscendingBold /> : <PiSortDescendingBold />}
|
||||
>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('common.direction', 'Direction')} value={currentValue} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<MenuItem
|
||||
onClick={setDirectionAsc}
|
||||
bg={direction === 'asc' ? 'base.700' : 'transparent'}
|
||||
icon={direction === 'asc' ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
|
||||
>
|
||||
{t('common.ascending', 'Ascending')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClick={setDirectionDesc}
|
||||
bg={direction === 'desc' ? 'base.700' : 'transparent'}
|
||||
icon={direction === 'desc' ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
|
||||
>
|
||||
{t('common.descending', 'Descending')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
DirectionSubMenu.displayName = 'DirectionSubMenu';
|
||||
|
||||
const ModelTypeSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const filteredModelType = useAppSelector(selectFilteredModelType);
|
||||
const subMenu = useSubMenu();
|
||||
|
||||
const clearModelType = useCallback(() => {
|
||||
dispatch(setFilteredModelType(null));
|
||||
}, [dispatch]);
|
||||
|
||||
const setMissingFilter = useCallback(() => {
|
||||
dispatch(setFilteredModelType('missing'));
|
||||
}, [dispatch]);
|
||||
|
||||
const currentValue = useMemo(() => {
|
||||
if (filteredModelType === null) {
|
||||
return t('modelManager.allModels');
|
||||
}
|
||||
if (filteredModelType === 'missing') {
|
||||
return t('modelManager.missingFiles');
|
||||
}
|
||||
const categoryData = MODEL_CATEGORIES_AS_LIST.find((data) => data.category === filteredModelType);
|
||||
return categoryData ? t(categoryData.i18nKey) : '';
|
||||
}, [filteredModelType, t]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiFunnelBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('modelManager.modelType', 'Model Type')} value={currentValue} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<MenuItem
|
||||
onClick={clearModelType}
|
||||
bg={filteredModelType === null ? 'base.700' : 'transparent'}
|
||||
icon={filteredModelType === null ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
|
||||
>
|
||||
{t('modelManager.allModels')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClick={setMissingFilter}
|
||||
bg={filteredModelType === 'missing' ? 'base.700' : 'transparent'}
|
||||
color="warning.300"
|
||||
icon={filteredModelType === 'missing' ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
|
||||
>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
{filteredModelType !== 'missing' && <PiWarningBold />}
|
||||
{t('modelManager.missingFiles')}
|
||||
</Flex>
|
||||
</MenuItem>
|
||||
{MODEL_CATEGORIES_AS_LIST.map((data) => (
|
||||
<ModelMenuItem key={data.category} data={data} />
|
||||
))}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
ModelTypeSubMenu.displayName = 'ModelTypeSubMenu';
|
||||
|
||||
const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const filteredModelType = useAppSelector(selectFilteredModelType);
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(setFilteredModelType(data.category));
|
||||
}, [data.category, dispatch]);
|
||||
return (
|
||||
<MenuItem
|
||||
bg={filteredModelType === data.category ? 'base.700' : 'transparent'}
|
||||
onClick={onClick}
|
||||
icon={filteredModelType === data.category ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
|
||||
>
|
||||
{t(data.i18nKey)}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
ModelMenuItem.displayName = 'ModelMenuItem';
|
||||
|
||||
export const ModelFilterMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Menu placement="bottom-end">
|
||||
<MenuButton as={Button} size="sm" rightIcon={<PiFunnelBold />}>
|
||||
{t('common.filtering', 'Filtering')}
|
||||
</MenuButton>
|
||||
<MenuList>
|
||||
<DirectionSubMenu />
|
||||
<SortBySubMenu />
|
||||
<ModelTypeSubMenu />
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
ModelFilterMenu.displayName = 'ModelFilterMenu';
|
||||
@@ -8,8 +8,10 @@ import {
|
||||
clearModelSelection,
|
||||
type FilterableModelType,
|
||||
selectFilteredModelType,
|
||||
selectOrderBy,
|
||||
selectSearchTerm,
|
||||
selectSelectedModelKeys,
|
||||
selectSortDirection,
|
||||
setSelectedModelKey,
|
||||
} from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
@@ -39,6 +41,8 @@ const ModelList = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const filteredModelType = useAppSelector(selectFilteredModelType);
|
||||
const searchTerm = useAppSelector(selectSearchTerm);
|
||||
const orderBy = useAppSelector(selectOrderBy);
|
||||
const direction = useAppSelector(selectSortDirection);
|
||||
const selectedModelKeys = useAppSelector(selectSelectedModelKeys);
|
||||
const { t } = useTranslation();
|
||||
const toast = useToast();
|
||||
@@ -47,7 +51,8 @@ const ModelList = () => {
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [isReidentifying, setIsReidentifying] = useState(false);
|
||||
|
||||
const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery();
|
||||
const queryArgs = useMemo(() => ({ order_by: orderBy, direction: direction.toUpperCase() }), [orderBy, direction]);
|
||||
const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(queryArgs);
|
||||
const { data: missingModelsData, isLoading: isLoadingMissing } = useGetMissingModelsQuery();
|
||||
const [bulkDeleteModels] = useBulkDeleteModelsMutation();
|
||||
const [bulkReidentifyModels] = useBulkReidentifyModelsMutation();
|
||||
|
||||
@@ -6,8 +6,8 @@ import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import { ModelFilterMenu } from './ModelFilterMenu';
|
||||
import { ModelListBulkActions } from './ModelListBulkActions';
|
||||
import { ModelTypeFilter } from './ModelTypeFilter';
|
||||
|
||||
export const ModelListNavigation = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -50,7 +50,7 @@ export const ModelListNavigation = memo(() => {
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
<Flex shrink={0}>
|
||||
<ModelTypeFilter />
|
||||
<ModelFilterMenu />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<ModelListBulkActions />
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ModelCategoryData } from 'features/modelManagerV2/models';
|
||||
import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
|
||||
import type { ModelCategoryType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFunnelBold, PiWarningBold } from 'react-icons/pi';
|
||||
|
||||
const isModelCategoryType = (type: string): type is ModelCategoryType => {
|
||||
return type in MODEL_CATEGORIES;
|
||||
};
|
||||
|
||||
export const ModelTypeFilter = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const filteredModelType = useAppSelector(selectFilteredModelType);
|
||||
|
||||
const clearModelType = useCallback(() => {
|
||||
dispatch(setFilteredModelType(null));
|
||||
}, [dispatch]);
|
||||
|
||||
const setMissingFilter = useCallback(() => {
|
||||
dispatch(setFilteredModelType('missing'));
|
||||
}, [dispatch]);
|
||||
|
||||
const getButtonLabel = () => {
|
||||
if (filteredModelType === 'missing') {
|
||||
return t('modelManager.missingFiles');
|
||||
}
|
||||
if (filteredModelType && isModelCategoryType(filteredModelType)) {
|
||||
return t(MODEL_CATEGORIES[filteredModelType].i18nKey);
|
||||
}
|
||||
return t('modelManager.allModels');
|
||||
};
|
||||
|
||||
return (
|
||||
<Menu placement="bottom-end">
|
||||
<MenuButton as={Button} size="sm" rightIcon={<PiFunnelBold />}>
|
||||
{getButtonLabel()}
|
||||
</MenuButton>
|
||||
<MenuList>
|
||||
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
|
||||
<MenuItem
|
||||
onClick={setMissingFilter}
|
||||
bg={filteredModelType === 'missing' ? 'base.700' : 'transparent'}
|
||||
color="warning.300"
|
||||
>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<PiWarningBold />
|
||||
{t('modelManager.missingFiles')}
|
||||
</Flex>
|
||||
</MenuItem>
|
||||
{MODEL_CATEGORIES_AS_LIST.map((data) => (
|
||||
<ModelMenuItem key={data.category} data={data} />
|
||||
))}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
ModelTypeFilter.displayName = 'ModelTypeFilter';
|
||||
|
||||
const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const filteredModelType = useAppSelector(selectFilteredModelType);
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(setFilteredModelType(data.category));
|
||||
}, [data.category, dispatch]);
|
||||
return (
|
||||
<MenuItem bg={filteredModelType === data.category ? 'base.700' : 'transparent'} onClick={onClick}>
|
||||
{t(data.i18nKey)}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
ModelMenuItem.displayName = 'ModelMenuItem';
|
||||
@@ -0,0 +1,370 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
vi.mock('app/logging/logger', () => ({
|
||||
logger: () => ({
|
||||
debug: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
let nextId = 0;
|
||||
vi.mock('features/controlLayers/konva/util', () => ({
|
||||
getPrefixedId: (prefix: string) => `${prefix}:${nextId++}`,
|
||||
}));
|
||||
|
||||
// --- Flux2 Klein model fixtures ---
|
||||
|
||||
const flux2DiffusersModel = {
|
||||
key: 'flux2-klein-diffusers',
|
||||
hash: 'flux2-diff-hash',
|
||||
name: 'FLUX.2 Klein 4B',
|
||||
base: 'flux2',
|
||||
type: 'main',
|
||||
format: 'diffusers',
|
||||
variant: 'klein_4b',
|
||||
};
|
||||
|
||||
const flux2GGUFModel = {
|
||||
key: 'flux2-klein-gguf',
|
||||
hash: 'flux2-gguf-hash',
|
||||
name: 'FLUX.2 Klein 4B GGUF',
|
||||
base: 'flux2',
|
||||
type: 'main',
|
||||
format: 'gguf_quantized',
|
||||
variant: 'klein_4b',
|
||||
};
|
||||
|
||||
const kleinVaeModelFixture = { key: 'klein-vae', name: 'Klein VAE', base: 'flux2', type: 'vae' };
|
||||
const kleinQwen3EncoderModelFixture = {
|
||||
key: 'klein-qwen3',
|
||||
name: 'Qwen3 4B',
|
||||
base: 'flux2',
|
||||
type: 'qwen3_encoder',
|
||||
};
|
||||
|
||||
const flux2GGUF9BModel = {
|
||||
key: 'flux2-klein-gguf-9b',
|
||||
hash: 'flux2-gguf-9b-hash',
|
||||
name: 'FLUX.2 Klein 9B GGUF',
|
||||
base: 'flux2',
|
||||
type: 'main',
|
||||
format: 'gguf_quantized',
|
||||
variant: 'klein_9b',
|
||||
};
|
||||
|
||||
const diffusersSourceModelFixture = {
|
||||
key: 'flux2-source-diffusers',
|
||||
hash: 'flux2-src-hash',
|
||||
name: 'FLUX.2 Klein 4B Source',
|
||||
base: 'flux2',
|
||||
type: 'main',
|
||||
format: 'diffusers',
|
||||
variant: 'klein_4b',
|
||||
};
|
||||
|
||||
const diffusers9BSourceModelFixture = {
|
||||
key: 'flux2-source-diffusers-9b',
|
||||
hash: 'flux2-src-9b-hash',
|
||||
name: 'FLUX.2 Klein 9B Source',
|
||||
base: 'flux2',
|
||||
type: 'main',
|
||||
format: 'diffusers',
|
||||
variant: 'klein_9b',
|
||||
};
|
||||
|
||||
// --- Mutable state ---
|
||||
|
||||
let model: Record<string, unknown> = { ...flux2DiffusersModel };
|
||||
let kleinVaeModel: Record<string, unknown> | null = null;
|
||||
let kleinQwen3EncoderModel: Record<string, unknown> | null = null;
|
||||
let diffusersModels: Record<string, unknown>[] = [];
|
||||
|
||||
vi.mock('features/controlLayers/store/paramsSlice', () => ({
|
||||
selectMainModelConfig: vi.fn(() => model),
|
||||
selectParamsSlice: vi.fn(() => ({
|
||||
guidance: 4,
|
||||
steps: 20,
|
||||
fluxScheduler: 'euler',
|
||||
fluxDypePreset: 'off',
|
||||
fluxDypeScale: 2.0,
|
||||
fluxDypeExponent: 2.0,
|
||||
fluxVAE: null,
|
||||
t5EncoderModel: null,
|
||||
clipEmbedModel: null,
|
||||
})),
|
||||
selectKleinVaeModel: vi.fn(() => kleinVaeModel),
|
||||
selectKleinQwen3EncoderModel: vi.fn(() => kleinQwen3EncoderModel),
|
||||
}));
|
||||
|
||||
vi.mock('features/controlLayers/store/refImagesSlice', () => ({
|
||||
selectRefImagesSlice: vi.fn(() => ({
|
||||
entities: [],
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('features/controlLayers/store/selectors', () => ({
|
||||
selectCanvasMetadata: vi.fn(() => ({})),
|
||||
selectCanvasSlice: vi.fn(() => ({})),
|
||||
}));
|
||||
|
||||
vi.mock('features/controlLayers/store/types', () => ({
|
||||
isFlux2ReferenceImageConfig: vi.fn(() => false),
|
||||
isFluxKontextReferenceImageConfig: vi.fn(() => false),
|
||||
}));
|
||||
|
||||
vi.mock('features/controlLayers/store/validators', () => ({
|
||||
getGlobalReferenceImageWarnings: vi.fn(() => []),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addFlux2KleinLoRAs', () => ({
|
||||
addFlux2KleinLoRAs: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addFLUXFill', () => ({
|
||||
addFLUXFill: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addFLUXLoRAs', () => ({
|
||||
addFLUXLoRAs: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addFLUXRedux', () => ({
|
||||
addFLUXReduxes: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({
|
||||
addImageToImage: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({
|
||||
addInpaint: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({
|
||||
addNSFWChecker: vi.fn((_g, node) => node),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({
|
||||
addOutpaint: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addRegions', () => ({
|
||||
addRegions: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({
|
||||
addTextToImage: vi.fn(({ l2i }) => l2i),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({
|
||||
addWatermarker: vi.fn((_g, node) => node),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addControlAdapters', () => ({
|
||||
addControlLoRA: vi.fn(),
|
||||
addControlNets: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/addIPAdapters', () => ({
|
||||
addIPAdapters: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({
|
||||
selectCanvasOutputFields: vi.fn(() => ({})),
|
||||
selectPresetModifiedPrompts: vi.fn(() => ({
|
||||
positive: 'a prompt',
|
||||
negative: '',
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('features/ui/store/uiSelectors', () => ({
|
||||
selectActiveTab: vi.fn(() => 'generation'),
|
||||
}));
|
||||
|
||||
vi.mock('services/api/hooks/modelsByType', () => ({
|
||||
selectFlux2DiffusersModels: vi.fn(() => diffusersModels),
|
||||
}));
|
||||
|
||||
vi.mock('services/api/types', async () => {
|
||||
const actual = await vi.importActual('services/api/types');
|
||||
return {
|
||||
...actual,
|
||||
isNonRefinerMainModelConfig: vi.fn(() => true),
|
||||
};
|
||||
});
|
||||
|
||||
import { buildFLUXGraph } from './buildFLUXGraph';
|
||||
|
||||
const buildGraphArg = () => ({
|
||||
generationMode: 'txt2img' as const,
|
||||
manager: null,
|
||||
state: {
|
||||
system: {
|
||||
shouldUseNSFWChecker: false,
|
||||
shouldUseWatermarker: false,
|
||||
},
|
||||
} as never,
|
||||
});
|
||||
|
||||
/** Find the flux2_klein_model_loader node in the graph. */
|
||||
const getLoaderNode = async () => {
|
||||
const { g } = await buildFLUXGraph(buildGraphArg());
|
||||
const graph = g.getGraph();
|
||||
const loaderEntry = Object.entries(graph.nodes).find(([id]) => id.startsWith('flux2_klein_model_loader:'));
|
||||
return loaderEntry?.[1] as Record<string, unknown> | undefined;
|
||||
};
|
||||
|
||||
describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => {
|
||||
afterEach(() => {
|
||||
nextId = 0;
|
||||
model = { ...flux2DiffusersModel };
|
||||
kleinVaeModel = null;
|
||||
kleinQwen3EncoderModel = null;
|
||||
diffusersModels = [];
|
||||
});
|
||||
|
||||
it('does not set qwen3_source_model when main model is diffusers', async () => {
|
||||
model = { ...flux2DiffusersModel };
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
expect(loader!.qwen3_source_model).toBeUndefined();
|
||||
});
|
||||
|
||||
it('sets qwen3_source_model when main model is GGUF and a diffusers model is available', async () => {
|
||||
model = { ...flux2GGUFModel };
|
||||
diffusersModels = [diffusersSourceModelFixture];
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
expect(loader!.qwen3_source_model).toEqual({
|
||||
key: diffusersSourceModelFixture.key,
|
||||
hash: diffusersSourceModelFixture.hash,
|
||||
name: diffusersSourceModelFixture.name,
|
||||
base: diffusersSourceModelFixture.base,
|
||||
type: diffusersSourceModelFixture.type,
|
||||
});
|
||||
});
|
||||
|
||||
it('does not set qwen3_source_model when main model is GGUF but standalone VAE and Qwen3 are both selected', async () => {
|
||||
model = { ...flux2GGUFModel };
|
||||
kleinVaeModel = kleinVaeModelFixture;
|
||||
kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture;
|
||||
diffusersModels = [diffusersSourceModelFixture];
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
expect(loader!.qwen3_source_model).toBeUndefined();
|
||||
});
|
||||
|
||||
it('does not set qwen3_source_model when main model is GGUF and no diffusers model is available', async () => {
|
||||
model = { ...flux2GGUFModel };
|
||||
diffusersModels = [];
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
expect(loader!.qwen3_source_model).toBeUndefined();
|
||||
});
|
||||
|
||||
it('sets qwen3_source_model when only VAE is selected but Qwen3 is missing', async () => {
|
||||
model = { ...flux2GGUFModel };
|
||||
kleinVaeModel = kleinVaeModelFixture;
|
||||
kleinQwen3EncoderModel = null;
|
||||
diffusersModels = [diffusersSourceModelFixture];
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
expect(loader!.qwen3_source_model).toBeDefined();
|
||||
});
|
||||
|
||||
it('sets qwen3_source_model when only Qwen3 is selected but VAE is missing', async () => {
|
||||
model = { ...flux2GGUFModel };
|
||||
kleinVaeModel = null;
|
||||
kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture;
|
||||
diffusersModels = [diffusersSourceModelFixture];
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
expect(loader!.qwen3_source_model).toBeDefined();
|
||||
});
|
||||
|
||||
it('passes standalone vae_model and qwen3_encoder_model when selected', async () => {
|
||||
model = { ...flux2DiffusersModel };
|
||||
kleinVaeModel = kleinVaeModelFixture;
|
||||
kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture;
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
expect(loader!.vae_model).toEqual(kleinVaeModelFixture);
|
||||
expect(loader!.qwen3_encoder_model).toEqual(kleinQwen3EncoderModelFixture);
|
||||
expect(loader!.qwen3_source_model).toBeUndefined();
|
||||
});
|
||||
|
||||
describe('variant matching', () => {
|
||||
it('selects a variant-matching diffusers model when multiple are available', async () => {
|
||||
model = { ...flux2GGUF9BModel };
|
||||
diffusersModels = [diffusersSourceModelFixture, diffusers9BSourceModelFixture];
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
// Should pick the 9B diffusers model, not the 4B
|
||||
expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusers9BSourceModelFixture.key }));
|
||||
});
|
||||
|
||||
it('falls back to any diffusers model for VAE when standalone Qwen3 is selected but no variant match', async () => {
|
||||
model = { ...flux2GGUF9BModel };
|
||||
kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture;
|
||||
// Only 4B diffusers available, no 9B — but Qwen3 is already provided standalone
|
||||
diffusersModels = [diffusersSourceModelFixture];
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
// Should use the 4B diffusers model just for VAE extraction
|
||||
expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusersSourceModelFixture.key }));
|
||||
});
|
||||
|
||||
it('does not set qwen3_source_model when GGUF 9B with only 4B diffusers available and no standalone Qwen3', async () => {
|
||||
model = { ...flux2GGUF9BModel };
|
||||
kleinQwen3EncoderModel = null;
|
||||
// Only 4B diffusers available — wrong variant for Qwen3, no standalone Qwen3 selected
|
||||
diffusersModels = [diffusersSourceModelFixture];
|
||||
|
||||
const loader = await getLoaderNode();
|
||||
expect(loader).toBeDefined();
|
||||
// Should NOT use the 4B diffusers since it has the wrong Qwen3 encoder
|
||||
expect(loader!.qwen3_source_model).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('graph structure', () => {
|
||||
it('uses flux2_klein_model_loader for flux2 models', async () => {
|
||||
model = { ...flux2DiffusersModel };
|
||||
const { g } = await buildFLUXGraph(buildGraphArg());
|
||||
const graph = g.getGraph();
|
||||
const nodeIds = Object.keys(graph.nodes);
|
||||
expect(nodeIds.some((id) => id.startsWith('flux2_klein_model_loader:'))).toBe(true);
|
||||
});
|
||||
|
||||
it('uses flux2_vae_decode for flux2 models', async () => {
|
||||
model = { ...flux2DiffusersModel };
|
||||
const { g } = await buildFLUXGraph(buildGraphArg());
|
||||
const graph = g.getGraph();
|
||||
const nodeIds = Object.keys(graph.nodes);
|
||||
expect(nodeIds.some((id) => id.startsWith('flux2_vae_decode:'))).toBe(true);
|
||||
});
|
||||
|
||||
it('uses flux2_klein_text_encoder for flux2 models', async () => {
|
||||
model = { ...flux2DiffusersModel };
|
||||
const { g } = await buildFLUXGraph(buildGraphArg());
|
||||
const graph = g.getGraph();
|
||||
const nodeIds = Object.keys(graph.nodes);
|
||||
expect(nodeIds.some((id) => id.startsWith('flux2_klein_text_encoder:'))).toBe(true);
|
||||
});
|
||||
|
||||
it('uses flux2_denoise for flux2 models', async () => {
|
||||
model = { ...flux2DiffusersModel };
|
||||
const { g } = await buildFLUXGraph(buildGraphArg());
|
||||
const graph = g.getGraph();
|
||||
const nodeTypes = Object.values(graph.nodes).map((n) => n.type);
|
||||
expect(nodeTypes).toContain('flux2_denoise');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -10,7 +10,8 @@ import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlic
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isFlux2ReferenceImageConfig, isFluxKontextReferenceImageConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import { zImageField } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { zImageField, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { addFlux2KleinLoRAs } from 'features/nodes/util/graph/generation/addFlux2KleinLoRAs';
|
||||
import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill';
|
||||
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
|
||||
@@ -26,8 +27,10 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { t } from 'i18next';
|
||||
import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -141,7 +144,23 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
|
||||
|
||||
if (isFlux2) {
|
||||
// Flux2 Klein: Use Qwen3-based model loader, text encoder, and dedicated denoise node
|
||||
// VAE and Qwen3 encoder can be extracted from the main Diffusers model or selected separately
|
||||
// VAE and Qwen3 encoder can be extracted from the main Diffusers model or selected separately.
|
||||
// For non-diffusers main models, find a diffusers flux2 model to use as the source for VAE/encoder.
|
||||
let qwen3SourceModel: ModelIdentifierField | undefined;
|
||||
if (model.format !== 'diffusers' && (!kleinVaeModel || !kleinQwen3EncoderModel)) {
|
||||
const diffusersModels = selectFlux2DiffusersModels(state);
|
||||
// Prefer a diffusers model that shares the same Qwen3 encoder (e.g. klein_9b and klein_9b_base both use qwen3_8b).
|
||||
// Fall back to any diffusers model if only the VAE is needed.
|
||||
const modelVariant = 'variant' in model ? model.variant : undefined;
|
||||
const variantMatch = diffusersModels.find(
|
||||
(m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant)
|
||||
);
|
||||
const sourceModel = variantMatch ?? (kleinQwen3EncoderModel ? diffusersModels[0] : undefined);
|
||||
if (sourceModel) {
|
||||
qwen3SourceModel = zModelIdentifierField.parse(sourceModel);
|
||||
}
|
||||
}
|
||||
|
||||
modelLoader = g.addNode({
|
||||
type: 'flux2_klein_model_loader',
|
||||
id: getPrefixedId('flux2_klein_model_loader'),
|
||||
@@ -149,6 +168,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
|
||||
// Optional: Use separately selected VAE and Qwen3 encoder models
|
||||
vae_model: kleinVaeModel ?? undefined,
|
||||
qwen3_encoder_model: kleinQwen3EncoderModel ?? undefined,
|
||||
qwen3_source_model: qwen3SourceModel ?? undefined,
|
||||
});
|
||||
|
||||
posCond = g.addNode({
|
||||
|
||||
@@ -9,9 +9,10 @@ import {
|
||||
selectMainModelConfig,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { isFlux2KleinQwen3Compatible, KLEIN_TO_QWEN3_VARIANT_MAP } from 'features/parameters/util/flux2Klein';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFlux2VAEModels, useQwen3EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import { useFlux2DiffusersModels, useFlux2VAEModels, useQwen3EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { Qwen3EncoderModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
/**
|
||||
@@ -22,7 +23,9 @@ const ParamFlux2KleinVaeModelSelect = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const kleinVaeModel = useAppSelector(selectKleinVaeModel);
|
||||
const mainModelConfig = useAppSelector(selectMainModelConfig);
|
||||
const [modelConfigs, { isLoading }] = useFlux2VAEModels();
|
||||
const [diffusersModels] = useFlux2DiffusersModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: VAEModelConfig | null) => {
|
||||
@@ -42,6 +45,11 @@ const ParamFlux2KleinVaeModelSelect = memo(() => {
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const hasDiffusersSource = mainModelConfig?.format === 'diffusers' || diffusersModels.length > 0;
|
||||
const placeholder = hasDiffusersSource
|
||||
? t('modelManager.flux2KleinVaePlaceholder')
|
||||
: t('modelManager.flux2KleinVaeNoModelPlaceholder');
|
||||
|
||||
return (
|
||||
<FormControl minW={0} flexGrow={1} gap={2}>
|
||||
<FormLabel m={0}>{t('modelManager.flux2KleinVae')}</FormLabel>
|
||||
@@ -51,7 +59,7 @@ const ParamFlux2KleinVaeModelSelect = memo(() => {
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
isClearable
|
||||
placeholder={t('modelManager.flux2KleinVaePlaceholder')}
|
||||
placeholder={placeholder}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
@@ -59,15 +67,6 @@ const ParamFlux2KleinVaeModelSelect = memo(() => {
|
||||
|
||||
ParamFlux2KleinVaeModelSelect.displayName = 'ParamFlux2KleinVaeModelSelect';
|
||||
|
||||
/**
|
||||
* Maps FLUX.2 Klein variants to compatible Qwen3 encoder variants
|
||||
*/
|
||||
const KLEIN_TO_QWEN3_VARIANT_MAP: Record<string, string> = {
|
||||
klein_4b: 'qwen3_4b',
|
||||
klein_9b: 'qwen3_8b',
|
||||
klein_9b_base: 'qwen3_8b',
|
||||
};
|
||||
|
||||
/**
|
||||
* FLUX.2 Klein Qwen3 Encoder Model Select
|
||||
* Selects a Qwen3 text encoder model for FLUX.2 Klein
|
||||
@@ -79,6 +78,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => {
|
||||
const kleinQwen3EncoderModel = useAppSelector(selectKleinQwen3EncoderModel);
|
||||
const mainModelConfig = useAppSelector(selectMainModelConfig);
|
||||
const [allModelConfigs, { isLoading }] = useQwen3EncoderModels();
|
||||
const [diffusersModels] = useFlux2DiffusersModels();
|
||||
|
||||
// Filter Qwen3 encoders based on the main model's variant
|
||||
const modelConfigs = useMemo(() => {
|
||||
@@ -112,6 +112,20 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => {
|
||||
isLoading,
|
||||
});
|
||||
|
||||
// Qwen3 encoder requires a Qwen3-compatible diffusers model (variants that share the same Qwen3 encoder).
|
||||
const hasMatchingDiffusersSource =
|
||||
mainModelConfig?.format === 'diffusers' ||
|
||||
diffusersModels.some(
|
||||
(m) =>
|
||||
'variant' in m &&
|
||||
mainModelConfig &&
|
||||
'variant' in mainModelConfig &&
|
||||
isFlux2KleinQwen3Compatible(m.variant, mainModelConfig.variant)
|
||||
);
|
||||
const placeholder = hasMatchingDiffusersSource
|
||||
? t('modelManager.flux2KleinQwen3EncoderPlaceholder')
|
||||
: t('modelManager.flux2KleinQwen3EncoderNoModelPlaceholder');
|
||||
|
||||
return (
|
||||
<FormControl minW={0} flexGrow={1} gap={2}>
|
||||
<FormLabel m={0}>{t('modelManager.flux2KleinQwen3Encoder')}</FormLabel>
|
||||
@@ -121,7 +135,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => {
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
isClearable
|
||||
placeholder={t('modelManager.flux2KleinQwen3EncoderPlaceholder')}
|
||||
placeholder={placeholder}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
/**
|
||||
* Maps a FLUX.2 Klein main-model variant to the Qwen3 encoder variant it uses.
|
||||
* Multiple Klein variants can share the same Qwen3 variant (e.g. `klein_9b` and
|
||||
* `klein_9b_base` both use `qwen3_8b`), so two different Klein variants can be
|
||||
* Qwen3-compatible sources for each other.
|
||||
*/
|
||||
export const KLEIN_TO_QWEN3_VARIANT_MAP: Record<string, string> = {
|
||||
klein_4b: 'qwen3_4b',
|
||||
klein_9b: 'qwen3_8b',
|
||||
klein_9b_base: 'qwen3_8b',
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns true if two Klein variants share the same Qwen3 encoder and can therefore
|
||||
* be used as a Qwen3 source for each other.
|
||||
*/
|
||||
export const isFlux2KleinQwen3Compatible = (variantA: unknown, variantB: unknown): boolean => {
|
||||
if (typeof variantA !== 'string' || typeof variantB !== 'string') {
|
||||
return false;
|
||||
}
|
||||
const qwen3A = KLEIN_TO_QWEN3_VARIANT_MAP[variantA];
|
||||
const qwen3B = KLEIN_TO_QWEN3_VARIANT_MAP[variantB];
|
||||
return qwen3A !== undefined && qwen3A === qwen3B;
|
||||
};
|
||||
272
invokeai/frontend/web/src/features/queue/store/readiness.test.ts
Normal file
272
invokeai/frontend/web/src/features/queue/store/readiness.test.ts
Normal file
@@ -0,0 +1,272 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
vi.mock('features/dynamicPrompts/util/getShouldProcessPrompt', () => ({
|
||||
getShouldProcessPrompt: vi.fn(() => false),
|
||||
}));
|
||||
|
||||
vi.mock('i18next', () => ({
|
||||
default: {
|
||||
t: (key: string) => key,
|
||||
},
|
||||
}));
|
||||
|
||||
import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types';
|
||||
import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import { getReasonsWhyCannotEnqueueCanvasTab, getReasonsWhyCannotEnqueueGenerateTab } from './readiness';
|
||||
|
||||
// --- Fixtures ---
|
||||
|
||||
const flux2DiffusersModel = {
|
||||
key: 'flux2-diff',
|
||||
hash: 'h',
|
||||
name: 'FLUX.2 Klein 4B',
|
||||
base: 'flux2',
|
||||
type: 'main',
|
||||
format: 'diffusers',
|
||||
variant: 'klein_4b',
|
||||
} as unknown as MainModelConfig;
|
||||
|
||||
const flux2GGUF4BModel = {
|
||||
key: 'flux2-gguf-4b',
|
||||
hash: 'h',
|
||||
name: 'FLUX.2 Klein 4B GGUF',
|
||||
base: 'flux2',
|
||||
type: 'main',
|
||||
format: 'gguf_quantized',
|
||||
variant: 'klein_4b',
|
||||
} as unknown as MainModelConfig;
|
||||
|
||||
const flux2GGUF9BModel = {
|
||||
key: 'flux2-gguf-9b',
|
||||
hash: 'h',
|
||||
name: 'FLUX.2 Klein 9B GGUF',
|
||||
base: 'flux2',
|
||||
type: 'main',
|
||||
format: 'gguf_quantized',
|
||||
variant: 'klein_9b',
|
||||
} as unknown as MainModelConfig;
|
||||
|
||||
const kleinVaeModel = { key: 'vae', name: 'VAE', base: 'flux2', type: 'vae' };
|
||||
const kleinQwen3Model = { key: 'qwen3', name: 'Qwen3', base: 'flux2', type: 'qwen3_encoder' };
|
||||
|
||||
const baseDynamicPrompts: DynamicPromptsState = {
|
||||
_version: 1,
|
||||
maxPrompts: 100,
|
||||
combinatorial: false,
|
||||
prompts: ['test prompt'],
|
||||
parsingError: undefined,
|
||||
isError: false,
|
||||
isLoading: false,
|
||||
seedBehaviour: 'PER_PROMPT',
|
||||
};
|
||||
|
||||
const baseRefImages: RefImagesState = {
|
||||
entities: [],
|
||||
ipAdapters: { entities: [], ids: [] },
|
||||
} as unknown as RefImagesState;
|
||||
|
||||
const baseParams = {
|
||||
positivePrompt: 'test',
|
||||
kleinVaeModel: null,
|
||||
kleinQwen3EncoderModel: null,
|
||||
} as unknown as ParamsState;
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
const buildGenerateTabArg = (overrides: {
|
||||
model?: MainModelConfig | null;
|
||||
kleinVaeModel?: unknown;
|
||||
kleinQwen3EncoderModel?: unknown;
|
||||
hasFlux2DiffusersVaeSource?: boolean;
|
||||
hasFlux2DiffusersQwen3Source?: boolean;
|
||||
}) => ({
|
||||
isConnected: true,
|
||||
model: overrides.model ?? flux2DiffusersModel,
|
||||
params: {
|
||||
...baseParams,
|
||||
kleinVaeModel: overrides.kleinVaeModel ?? null,
|
||||
kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null,
|
||||
} as unknown as ParamsState,
|
||||
refImages: baseRefImages,
|
||||
loras: [],
|
||||
dynamicPrompts: baseDynamicPrompts,
|
||||
hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false,
|
||||
hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false,
|
||||
});
|
||||
|
||||
const buildCanvasTabArg = (overrides: {
|
||||
model?: MainModelConfig | null;
|
||||
kleinVaeModel?: unknown;
|
||||
kleinQwen3EncoderModel?: unknown;
|
||||
hasFlux2DiffusersVaeSource?: boolean;
|
||||
hasFlux2DiffusersQwen3Source?: boolean;
|
||||
}) => ({
|
||||
isConnected: true,
|
||||
model: overrides.model ?? flux2DiffusersModel,
|
||||
canvas: {
|
||||
bbox: {
|
||||
scaleMethod: 'none',
|
||||
rect: { width: 1024, height: 1024 },
|
||||
scaledSize: { width: 1024, height: 1024 },
|
||||
},
|
||||
controlLayers: { entities: [] },
|
||||
regionalGuidance: { entities: [] },
|
||||
rasterLayers: { entities: [] },
|
||||
inpaintMasks: { entities: [] },
|
||||
},
|
||||
params: {
|
||||
...baseParams,
|
||||
kleinVaeModel: overrides.kleinVaeModel ?? null,
|
||||
kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null,
|
||||
} as unknown as ParamsState,
|
||||
refImages: baseRefImages,
|
||||
loras: [],
|
||||
dynamicPrompts: baseDynamicPrompts,
|
||||
canvasIsFiltering: false,
|
||||
canvasIsTransforming: false,
|
||||
canvasIsRasterizing: false,
|
||||
canvasIsCompositing: false,
|
||||
canvasIsSelectingObject: false,
|
||||
hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false,
|
||||
hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false,
|
||||
});
|
||||
|
||||
const hasFlux2VaeReason = (reasons: { content: string }[]) =>
|
||||
reasons.some((r) => r.content.includes('noFlux2KleinVaeModelSelected'));
|
||||
|
||||
const hasFlux2Qwen3Reason = (reasons: { content: string }[]) =>
|
||||
reasons.some((r) => r.content.includes('noFlux2KleinQwen3EncoderModelSelected'));
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
describe('FLUX.2 Klein readiness checks – generate tab', () => {
|
||||
it('no errors when main model is diffusers (VAE/Qwen3 extracted from it)', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2DiffusersModel }));
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(false);
|
||||
});
|
||||
|
||||
it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueGenerateTab(
|
||||
buildGenerateTabArg({
|
||||
model: flux2GGUF4BModel,
|
||||
hasFlux2DiffusersVaeSource: true,
|
||||
hasFlux2DiffusersQwen3Source: true,
|
||||
})
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(false);
|
||||
});
|
||||
|
||||
it('errors for both VAE and Qwen3 when GGUF model with no diffusers source and no standalone models', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2GGUF4BModel }));
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(true);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(true);
|
||||
});
|
||||
|
||||
it('errors only for Qwen3 when GGUF model with standalone VAE but no Qwen3 and no diffusers source', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueGenerateTab(
|
||||
buildGenerateTabArg({ model: flux2GGUF4BModel, kleinVaeModel: kleinVaeModel })
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(true);
|
||||
});
|
||||
|
||||
it('errors only for VAE when GGUF model with standalone Qwen3 but no VAE and no diffusers source', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueGenerateTab(
|
||||
buildGenerateTabArg({ model: flux2GGUF4BModel, kleinQwen3EncoderModel: kleinQwen3Model })
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(true);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(false);
|
||||
});
|
||||
|
||||
it('no errors when GGUF model with both standalone VAE and Qwen3', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueGenerateTab(
|
||||
buildGenerateTabArg({
|
||||
model: flux2GGUF4BModel,
|
||||
kleinVaeModel: kleinVaeModel,
|
||||
kleinQwen3EncoderModel: kleinQwen3Model,
|
||||
})
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(false);
|
||||
});
|
||||
|
||||
it('VAE ok but Qwen3 errors when GGUF 9B model with only a 4B diffusers source (variant mismatch)', () => {
|
||||
// User has Klein 9B GGUF selected, only a 4B diffusers model installed.
|
||||
// VAE is shared across variants so it's ok. Qwen3 encoder differs, so it's not ok.
|
||||
const reasons = getReasonsWhyCannotEnqueueGenerateTab(
|
||||
buildGenerateTabArg({
|
||||
model: flux2GGUF9BModel,
|
||||
hasFlux2DiffusersVaeSource: true,
|
||||
hasFlux2DiffusersQwen3Source: false,
|
||||
})
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(true);
|
||||
});
|
||||
|
||||
it('no errors when GGUF 9B model with variant-matching diffusers source', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueGenerateTab(
|
||||
buildGenerateTabArg({
|
||||
model: flux2GGUF9BModel,
|
||||
hasFlux2DiffusersVaeSource: true,
|
||||
hasFlux2DiffusersQwen3Source: true,
|
||||
})
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('FLUX.2 Klein readiness checks – canvas tab', () => {
|
||||
it('no errors when main model is diffusers', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2DiffusersModel }) as never);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(false);
|
||||
});
|
||||
|
||||
it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueCanvasTab(
|
||||
buildCanvasTabArg({
|
||||
model: flux2GGUF4BModel,
|
||||
hasFlux2DiffusersVaeSource: true,
|
||||
hasFlux2DiffusersQwen3Source: true,
|
||||
}) as never
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(false);
|
||||
});
|
||||
|
||||
it('errors for both VAE and Qwen3 when GGUF model with no sources', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2GGUF4BModel }) as never);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(true);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(true);
|
||||
});
|
||||
|
||||
it('no errors when GGUF model with both standalone VAE and Qwen3', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueCanvasTab(
|
||||
buildCanvasTabArg({
|
||||
model: flux2GGUF4BModel,
|
||||
kleinVaeModel: kleinVaeModel,
|
||||
kleinQwen3EncoderModel: kleinQwen3Model,
|
||||
}) as never
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(false);
|
||||
});
|
||||
|
||||
it('VAE ok but Qwen3 errors when GGUF 9B with variant-mismatched diffusers source', () => {
|
||||
const reasons = getReasonsWhyCannotEnqueueCanvasTab(
|
||||
buildCanvasTabArg({
|
||||
model: flux2GGUF9BModel,
|
||||
hasFlux2DiffusersVaeSource: true,
|
||||
hasFlux2DiffusersQwen3Source: false,
|
||||
}) as never
|
||||
);
|
||||
expect(hasFlux2VaeReason(reasons)).toBe(false);
|
||||
expect(hasFlux2Qwen3Reason(reasons)).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -33,12 +33,14 @@ import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/
|
||||
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
|
||||
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein';
|
||||
import { getGridSize } from 'features/parameters/util/optimalDimension';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
import i18n from 'i18next';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import { useEffect } from 'react';
|
||||
import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainOrExternalModelConfig } from 'services/api/types';
|
||||
import { isExternalApiModelConfig } from 'services/api/types';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
@@ -109,6 +111,12 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => {
|
||||
} = arg;
|
||||
if (tab === 'generate') {
|
||||
const model = selectMainModelConfig(store.getState());
|
||||
const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState());
|
||||
const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0;
|
||||
const modelVariant = model && 'variant' in model ? model.variant : undefined;
|
||||
const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some(
|
||||
(m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant)
|
||||
);
|
||||
const reasons = await getReasonsWhyCannotEnqueueGenerateTab({
|
||||
isConnected,
|
||||
model,
|
||||
@@ -116,10 +124,18 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => {
|
||||
refImages,
|
||||
dynamicPrompts,
|
||||
loras,
|
||||
hasFlux2DiffusersVaeSource,
|
||||
hasFlux2DiffusersQwen3Source,
|
||||
});
|
||||
$reasonsWhyCannotEnqueue.set(reasons);
|
||||
} else if (tab === 'canvas') {
|
||||
const model = selectMainModelConfig(store.getState());
|
||||
const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState());
|
||||
const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0;
|
||||
const modelVariant = model && 'variant' in model ? model.variant : undefined;
|
||||
const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some(
|
||||
(m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant)
|
||||
);
|
||||
const reasons = await getReasonsWhyCannotEnqueueCanvasTab({
|
||||
isConnected,
|
||||
model,
|
||||
@@ -133,6 +149,8 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => {
|
||||
canvasIsCompositing,
|
||||
canvasIsSelectingObject,
|
||||
loras,
|
||||
hasFlux2DiffusersVaeSource,
|
||||
hasFlux2DiffusersQwen3Source,
|
||||
});
|
||||
$reasonsWhyCannotEnqueue.set(reasons);
|
||||
} else if (tab === 'workflows') {
|
||||
@@ -220,15 +238,26 @@ export const useReadinessWatcher = () => {
|
||||
|
||||
const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') });
|
||||
|
||||
const getReasonsWhyCannotEnqueueGenerateTab = (arg: {
|
||||
export const getReasonsWhyCannotEnqueueGenerateTab = (arg: {
|
||||
isConnected: boolean;
|
||||
model: MainOrExternalModelConfig | null | undefined;
|
||||
params: ParamsState;
|
||||
refImages: RefImagesState;
|
||||
loras: LoRA[];
|
||||
dynamicPrompts: DynamicPromptsState;
|
||||
hasFlux2DiffusersVaeSource: boolean;
|
||||
hasFlux2DiffusersQwen3Source: boolean;
|
||||
}) => {
|
||||
const { isConnected, model, params, refImages, loras, dynamicPrompts } = arg;
|
||||
const {
|
||||
isConnected,
|
||||
model,
|
||||
params,
|
||||
refImages,
|
||||
loras,
|
||||
dynamicPrompts,
|
||||
hasFlux2DiffusersVaeSource,
|
||||
hasFlux2DiffusersQwen3Source,
|
||||
} = arg;
|
||||
const { positivePrompt } = params;
|
||||
const reasons: Reason[] = [];
|
||||
|
||||
@@ -260,7 +289,17 @@ const getReasonsWhyCannotEnqueueGenerateTab = (arg: {
|
||||
}
|
||||
}
|
||||
|
||||
// FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed
|
||||
if (model?.base === 'flux2' && model.format !== 'diffusers') {
|
||||
// Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder
|
||||
// unless a diffusers flux2 model is available to extract them from.
|
||||
// VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model.
|
||||
if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') });
|
||||
}
|
||||
if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') });
|
||||
}
|
||||
}
|
||||
|
||||
if (model?.base === 'qwen-image' && model.format === 'gguf_quantized') {
|
||||
if (!params.qwenImageComponentSource) {
|
||||
@@ -452,7 +491,7 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: {
|
||||
return reasons;
|
||||
};
|
||||
|
||||
const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
export const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
isConnected: boolean;
|
||||
model: MainOrExternalModelConfig | null | undefined;
|
||||
canvas: CanvasState;
|
||||
@@ -465,6 +504,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
canvasIsRasterizing: boolean;
|
||||
canvasIsCompositing: boolean;
|
||||
canvasIsSelectingObject: boolean;
|
||||
hasFlux2DiffusersVaeSource: boolean;
|
||||
hasFlux2DiffusersQwen3Source: boolean;
|
||||
}) => {
|
||||
const {
|
||||
isConnected,
|
||||
@@ -479,6 +520,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
canvasIsRasterizing,
|
||||
canvasIsCompositing,
|
||||
canvasIsSelectingObject,
|
||||
hasFlux2DiffusersVaeSource,
|
||||
hasFlux2DiffusersQwen3Source,
|
||||
} = arg;
|
||||
const { positivePrompt } = params;
|
||||
const reasons: Reason[] = [];
|
||||
@@ -571,7 +614,17 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
|
||||
}
|
||||
|
||||
if (model?.base === 'flux2') {
|
||||
// FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed
|
||||
// Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder
|
||||
// unless a diffusers flux2 model is available to extract them from.
|
||||
// VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model.
|
||||
if (model.format !== 'diffusers') {
|
||||
if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') });
|
||||
}
|
||||
if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') });
|
||||
}
|
||||
}
|
||||
|
||||
const { bbox } = canvas;
|
||||
const gridSize = getGridSize('flux'); // FLUX.2 uses same grid size as FLUX.1
|
||||
|
||||
@@ -111,9 +111,13 @@ type DeleteOrphanedModelsResponse = {
|
||||
errors: Record<string, string>;
|
||||
};
|
||||
|
||||
type GetModelConfigsArg = {
|
||||
order_by?: string;
|
||||
direction?: string;
|
||||
} | void;
|
||||
|
||||
const modelConfigsAdapter = createEntityAdapter<AnyModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
@@ -338,8 +342,11 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['ModelInstalls'],
|
||||
}),
|
||||
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl() }),
|
||||
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, GetModelConfigsArg>({
|
||||
query: (arg) => {
|
||||
const queryStr = arg ? `?${queryString.stringify(arg)}` : '';
|
||||
return { url: buildModelsUrl(queryStr) };
|
||||
},
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
|
||||
if (result) {
|
||||
@@ -498,5 +505,5 @@ export const {
|
||||
useDeleteOrphanedModelsMutation,
|
||||
} = modelsApi;
|
||||
|
||||
export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select();
|
||||
export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(undefined);
|
||||
export const selectMissingModelsQuery = modelsApi.endpoints.getMissingModels.select();
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
isControlNetModelConfig,
|
||||
isExternalApiModelConfig,
|
||||
isFlux1VAEModelConfig,
|
||||
isFlux2DiffusersMainModelConfig,
|
||||
isFlux2VAEModelConfig,
|
||||
isFluxKontextModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
@@ -101,6 +102,7 @@ export const useFlux2VAEModels = () => buildModelsHook(isFlux2VAEModelConfig)();
|
||||
export const useAnimaVAEModels = () => buildModelsHook(isAnimaVAEModelConfig)();
|
||||
export const useAnimaQwen3EncoderModels = () => buildModelsHook(isAnimaQwen3EncoderModelConfig)();
|
||||
export const useZImageDiffusersModels = () => buildModelsHook(isZImageDiffusersMainModelConfig)();
|
||||
export const useFlux2DiffusersModels = () => buildModelsHook(isFlux2DiffusersMainModelConfig)();
|
||||
export const useQwenImageDiffusersModels = () => buildModelsHook(isQwenImageDiffusersMainModelConfig)();
|
||||
export const useQwen3EncoderModels = () => buildModelsHook(isQwen3EncoderModelConfig)();
|
||||
export const useGlobalReferenceImageModels = buildModelsHook(
|
||||
@@ -140,6 +142,7 @@ export const selectAnimaQwen3EncoderModels = buildModelsSelector(isAnimaQwen3Enc
|
||||
export const selectQwen3EncoderModels = buildModelsSelector(isQwen3EncoderModelConfig);
|
||||
export const selectQwenImageDiffusersModels = buildModelsSelector(isQwenImageDiffusersMainModelConfig);
|
||||
export const selectZImageDiffusersModels = buildModelsSelector(isZImageDiffusersMainModelConfig);
|
||||
export const selectFlux2DiffusersModels = buildModelsSelector(isFlux2DiffusersMainModelConfig);
|
||||
export const selectFluxVAEModels = buildModelsSelector(isFluxVAEModelConfig);
|
||||
export const selectAnimaVAEModels = buildModelsSelector(isAnimaVAEModelConfig);
|
||||
export const selectT5EncoderModels = buildModelsSelector(isT5EncoderModelConfigOrSubmodel);
|
||||
|
||||
@@ -22999,6 +22999,12 @@ export type components = {
|
||||
*/
|
||||
config_path?: string | null;
|
||||
};
|
||||
/**
|
||||
* ModelRecordOrderBy
|
||||
* @description The order in which to return model summaries.
|
||||
* @enum {string}
|
||||
*/
|
||||
ModelRecordOrderBy: "default" | "type" | "base" | "name" | "format" | "size" | "created_at" | "updated_at" | "path";
|
||||
/** ModelRelationshipBatchRequest */
|
||||
ModelRelationshipBatchRequest: {
|
||||
/**
|
||||
@@ -31550,6 +31556,10 @@ export interface operations {
|
||||
model_name?: string | null;
|
||||
/** @description Exact match on the format of the model (e.g. 'diffusers') */
|
||||
model_format?: components["schemas"]["ModelFormat"] | null;
|
||||
/** @description The field to order by */
|
||||
order_by?: components["schemas"]["ModelRecordOrderBy"];
|
||||
/** @description The direction to order by */
|
||||
direction?: components["schemas"]["SQLiteDirection"];
|
||||
};
|
||||
header?: never;
|
||||
path?: never;
|
||||
|
||||
@@ -457,6 +457,10 @@ export const isZImageDiffusersMainModelConfig = (config: AnyModelConfig): config
|
||||
return config.type === 'main' && config.base === 'z-image' && config.format === 'diffusers';
|
||||
};
|
||||
|
||||
export const isFlux2DiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'flux2' && config.format === 'diffusers';
|
||||
};
|
||||
|
||||
export const isQwenImageDiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'qwen-image' && config.format === 'diffusers';
|
||||
};
|
||||
|
||||
@@ -11,11 +11,13 @@ from pydantic import ValidationError
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
|
||||
from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config
|
||||
from invokeai.backend.model_manager.configs.main import (
|
||||
@@ -364,6 +366,73 @@ def test_filter_2(store: ModelRecordServiceBase):
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_search_by_attr_sorting(store: ModelRecordServiceSQL):
|
||||
config1 = Main_Diffusers_SD1_Config(
|
||||
path="/tmp/config1",
|
||||
name="alpha",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
hash="CONFIG1HASH",
|
||||
file_size=1000,
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
variant=ModelVariantType.Normal,
|
||||
prediction_type=SchedulerPredictionType.Epsilon,
|
||||
repo_variant=ModelRepoVariant.Default,
|
||||
)
|
||||
config2 = Main_Diffusers_SD2_Config(
|
||||
path="/tmp/config2",
|
||||
name="beta",
|
||||
base=BaseModelType.StableDiffusion2,
|
||||
type=ModelType.Main,
|
||||
hash="CONFIG2HASH",
|
||||
file_size=2000,
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
variant=ModelVariantType.Normal,
|
||||
prediction_type=SchedulerPredictionType.Epsilon,
|
||||
repo_variant=ModelRepoVariant.Default,
|
||||
)
|
||||
config3 = VAE_Diffusers_SD1_Config(
|
||||
path="/tmp/config3",
|
||||
name="gamma",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.VAE,
|
||||
hash="CONFIG3HASH",
|
||||
file_size=500,
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
repo_variant=ModelRepoVariant.Default,
|
||||
)
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(c)
|
||||
|
||||
# Test sorting by Name Ascending
|
||||
matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Ascending)
|
||||
assert len(matches) == 3
|
||||
assert matches[0].name == "alpha"
|
||||
assert matches[1].name == "beta"
|
||||
assert matches[2].name == "gamma"
|
||||
|
||||
# Test sorting by Name Descending
|
||||
matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Descending)
|
||||
assert matches[0].name == "gamma"
|
||||
assert matches[1].name == "beta"
|
||||
assert matches[2].name == "alpha"
|
||||
|
||||
# Test sorting by Size Ascending
|
||||
matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Ascending)
|
||||
assert matches[0].name == "gamma" # 500
|
||||
assert matches[1].name == "alpha" # 1000
|
||||
assert matches[2].name == "beta" # 2000
|
||||
|
||||
# Test sorting by Size Descending
|
||||
matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Descending)
|
||||
assert matches[0].name == "beta" # 2000
|
||||
assert matches[1].name == "alpha" # 1000
|
||||
assert matches[2].name == "gamma" # 500
|
||||
|
||||
|
||||
def test_model_record_changes():
|
||||
# This test guards against some unexpected behaviours from pydantic's union evaluation. See #6035
|
||||
changes = ModelRecordChanges.model_validate({"default_settings": {"preprocessor": "value"}})
|
||||
|
||||
Reference in New Issue
Block a user