diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index f351be11ad..40d4f48b63 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -26,9 +26,11 @@ from invokeai.app.services.model_install.model_install_common import ModelInstal from invokeai.app.services.model_records import ( InvalidModelException, ModelRecordChanges, + ModelRecordOrderBy, UnknownModelException, ) from invokeai.app.services.orphaned_models import OrphanedModelInfo +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.suppress_output import SuppressOutput from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory @@ -159,6 +161,8 @@ async def list_model_records( model_format: Optional[ModelFormat] = Query( default=None, description="Exact match on the format of the model (e.g. 'diffusers')" ), + order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"), + direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_manager.store @@ -167,12 +171,23 @@ async def list_model_records( for base_model in base_models: found_models.extend( record_store.search_by_attr( - base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format + base_model=base_model, + model_type=model_type, + model_name=model_name, + model_format=model_format, + order_by=order_by, + direction=direction, ) ) else: found_models.extend( - record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) + record_store.search_by_attr( + model_type=model_type, + model_name=model_name, + model_format=model_format, + order_by=order_by, + direction=direction, + ) ) for index, model in enumerate(found_models): found_models[index] = prepare_model_config_for_response(model, ApiDependencies) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 6420949c29..31fbadb3cb 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,6 +11,7 @@ from typing import List, Optional, Set, Union from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.external_api import ( @@ -60,6 +61,10 @@ class ModelRecordOrderBy(str, Enum): Base = "base" Name = "name" Format = "format" + Size = "size" + DateAdded = "created_at" + DateModified = "updated_at" + Path = "path" class ModelSummary(BaseModel): @@ -200,7 +205,11 @@ class ModelRecordServiceBase(ABC): @abstractmethod def list_models( - self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default + self, + page: int = 0, + per_page: int = 10, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" pass @@ -237,6 +246,8 @@ class ModelRecordServiceBase(ABC): base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, model_format: Optional[ModelFormat] = None, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index edcbba2acd..f104c3855e 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -57,6 +57,7 @@ from invokeai.app.services.model_records.model_records_base import ( UnknownModelException, ) from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType @@ -257,6 +258,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): model_type: Optional[ModelType] = None, model_format: Optional[ModelFormat] = None, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -266,18 +268,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): :param model_type: Filter by type of model (optional) :param model_format: Filter by model format (e.g. "diffusers") (optional) :param order_by: Result order + :param direction: Result direction If none of the optional filters are passed, will return all models in the database. """ with self._db.transaction() as cursor: assert isinstance(order_by, ModelRecordOrderBy) + order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC" ordering = { - ModelRecordOrderBy.Default: "type, base, name, format", + ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format", ModelRecordOrderBy.Type: "type", - ModelRecordOrderBy.Base: "base", - ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Base: "base COLLATE NOCASE", + ModelRecordOrderBy.Name: "name COLLATE NOCASE", ModelRecordOrderBy.Format: "format", + ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)", + ModelRecordOrderBy.DateAdded: "created_at", + ModelRecordOrderBy.DateModified: "updated_at", + ModelRecordOrderBy.Path: "path", } where_clause: list[str] = [] @@ -301,7 +309,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): SELECT config FROM models {where} - ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason; + ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason; """, tuple(bindings), ) @@ -357,17 +365,26 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): return results def list_models( - self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default + self, + page: int = 0, + per_page: int = 10, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" with self._db.transaction() as cursor: assert isinstance(order_by, ModelRecordOrderBy) + order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC" ordering = { - ModelRecordOrderBy.Default: "type, base, name, format", + ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format", ModelRecordOrderBy.Type: "type", - ModelRecordOrderBy.Base: "base", - ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Base: "base COLLATE NOCASE", + ModelRecordOrderBy.Name: "name COLLATE NOCASE", ModelRecordOrderBy.Format: "format", + ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)", + ModelRecordOrderBy.DateAdded: "created_at", + ModelRecordOrderBy.DateModified: "updated_at", + ModelRecordOrderBy.Path: "path", } # Lock so that the database isn't updated while we're doing the two queries. @@ -385,7 +402,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): f"""--sql SELECT config FROM models - ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason + ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason LIMIT ? OFFSET ?; """, diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 88f917d0d3..46606a3c0d 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -714,14 +714,25 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): - diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format) - diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format) - diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale) + - lora_unet__layers_X_attention_to_k.lora_down.weight (Kohya format) """ + from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import ( + is_state_dict_likely_z_image_kohya_lora, + ) + state_dict = mod.load_state_dict() - # Check for Z-Image specific LoRA patterns + # Check for Kohya format first + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return + + # Check for Z-Image specific LoRA patterns (dot-notation formats) has_z_image_lora_keys = state_dict_has_any_keys_starting_with( state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "diffusion_model.context_refiner.", + "diffusion_model.noise_refiner.", "transformer.layers.", # OneTrainer/diffusers prefix variant "base_model.model.transformer.layers.", # PEFT-wrapped variant }, @@ -751,15 +762,26 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): Z-Image uses S3-DiT architecture with layer names like: - diffusion_model.layers.0.attention.to_k.lora_A.weight - diffusion_model.layers.0.feed_forward.w1.lora_A.weight + - lora_unet__layers_0_attention_to_k.lora_down.weight (Kohya format) """ + from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import ( + is_state_dict_likely_z_image_kohya_lora, + ) + state_dict = mod.load_state_dict() - # Check for Z-Image transformer layer patterns + # Check for Kohya format + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return BaseModelType.ZImage + + # Check for Z-Image transformer layer patterns (dot-notation formats) # Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks) has_z_image_keys = state_dict_has_any_keys_starting_with( state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "diffusion_model.context_refiner.", + "diffusion_model.noise_refiner.", "transformer.layers.", # OneTrainer/diffusers prefix variant "base_model.model.transformer.layers.", # PEFT-wrapped variant }, diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index 1be349f394..a2f008f41e 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -160,17 +160,20 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool: ".lora_A.weight", ".lora_B.weight", ".dora_scale", + ".alpha", ) + # First pass: check if any key has LoRA suffixes - if so, this is a LoRA not a main model for key in state_dict.keys(): if isinstance(key, int): continue - - # If we find any LoRA-specific keys, this is not a main model if key.endswith(lora_suffixes): return False - # Check for Z-Image specific key prefixes + # Second pass: check for Z-Image specific key parts + for key in state_dict.keys(): + if isinstance(key, int): + continue # Handle both direct keys (cap_embedder.0.weight) and # ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight) key_parts = key.split(".") diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 2de51a8aca..c802154797 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -178,12 +178,43 @@ class Flux2VAELoader(ModelLoader): if is_bfl_format: sd = self._convert_flux2_vae_bfl_to_diffusers(sd) - # FLUX.2 VAE configuration (32 latent channels) - # Based on the official FLUX.2 VAE architecture - # Use default config - AutoencoderKLFlux2 has built-in defaults + # FLUX.2 VAE configuration (32 latent channels). + # The standard FLUX.2 VAE uses block_out_channels=(128,256,512,512) for both + # encoder and decoder. The "small decoder" variant from + # black-forest-labs/FLUX.2-small-decoder keeps the full encoder but uses a + # narrower decoder with channels (96,192,384,384). AutoencoderKLFlux2 only + # exposes a single block_out_channels, so we build the model with the + # encoder's channels and, if the decoder differs, replace just the decoder + # submodule with a matching one before loading the state dict. + encoder_block_out_channels = (128, 256, 512, 512) + decoder_block_out_channels = encoder_block_out_channels + if "encoder.conv_in.weight" in sd and "encoder.conv_norm_out.weight" in sd: + enc_last = int(sd["encoder.conv_norm_out.weight"].shape[0]) + enc_first = int(sd["encoder.conv_in.weight"].shape[0]) + encoder_block_out_channels = (enc_first, enc_first * 2, enc_last, enc_last) + if "decoder.conv_in.weight" in sd and "decoder.conv_norm_out.weight" in sd: + dec_last = int(sd["decoder.conv_in.weight"].shape[0]) + dec_first = int(sd["decoder.conv_norm_out.weight"].shape[0]) + decoder_block_out_channels = (dec_first, dec_first * 2, dec_last, dec_last) + with SilenceWarnings(): with accelerate.init_empty_weights(): - model = AutoencoderKLFlux2() + model = AutoencoderKLFlux2(block_out_channels=encoder_block_out_channels) + if decoder_block_out_channels != encoder_block_out_channels: + # Rebuild the decoder with the smaller channel widths. + from diffusers.models.autoencoders.vae import Decoder + + cfg = model.config + model.decoder = Decoder( + in_channels=cfg.latent_channels, + out_channels=cfg.out_channels, + up_block_types=cfg.up_block_types, + block_out_channels=decoder_block_out_channels, + layers_per_block=cfg.layers_per_block, + norm_num_groups=cfg.norm_num_groups, + act_fn=cfg.act_fn, + mid_block_add_attention=cfg.mid_block_add_attention, + ) # Convert to bfloat16 and load for k in sd.keys(): diff --git a/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py index e248f9cfc4..70b10de50d 100644 --- a/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py @@ -1,10 +1,11 @@ """Z-Image LoRA conversion utilities. Z-Image uses S3-DiT transformer architecture with Qwen3 text encoder. -LoRAs for Z-Image typically follow the diffusers PEFT format. +LoRAs for Z-Image typically follow the diffusers PEFT format or Kohya format. """ -from typing import Dict +import re +from typing import Any, Dict import torch @@ -16,6 +17,29 @@ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import ( ) from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +# Regex for Kohya-format Z-Image transformer keys. +# Example keys: +# lora_unet__layers_0_attention_to_k.alpha +# lora_unet__layers_0_attention_to_k.lora_down.weight +# lora_unet__context_refiner_0_feed_forward_w1.lora_up.weight +# lora_unet__noise_refiner_1_attention_to_v.lora_down.weight +Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX = ( + r"lora_unet__(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)" +) + + +def is_state_dict_likely_z_image_kohya_lora(state_dict: dict[str | int, Any]) -> bool: + """Checks if the provided state dict is likely a Z-Image LoRA in Kohya format. + + Kohya Z-Image LoRAs have keys like: + - lora_unet__layers_0_attention_to_k.lora_down.weight + - lora_unet__context_refiner_0_feed_forward_w1.alpha + - lora_unet__noise_refiner_1_attention_to_v.lora_up.weight + """ + return any( + isinstance(k, str) and re.match(Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX, k.split(".")[0]) for k in state_dict.keys() + ) + def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) -> bool: """Checks if the provided state dict is likely a Z-Image LoRA. @@ -23,6 +47,9 @@ def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) Z-Image LoRAs can have keys for transformer and/or Qwen3 text encoder. They may use various prefixes depending on the training framework. """ + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return True + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] # Check for Z-Image transformer keys (S3-DiT architecture) @@ -57,6 +84,7 @@ def lora_model_from_z_image_state_dict( - "transformer." or "base_model.model.transformer." for diffusers PEFT format - "diffusion_model." for some training frameworks - "text_encoder." or "base_model.model.text_encoder." for Qwen3 encoder + - "lora_unet__" for Kohya format (underscores instead of dots) Args: state_dict: The LoRA state dict @@ -65,6 +93,10 @@ def lora_model_from_z_image_state_dict( Returns: A ModelPatchRaw containing the LoRA layers """ + # If Kohya format, convert keys first then process normally + if is_state_dict_likely_z_image_kohya_lora(state_dict): + state_dict = _convert_z_image_kohya_state_dict(state_dict) + layers: dict[str, BaseLayerPatch] = {} # Group keys by layer @@ -120,6 +152,45 @@ def lora_model_from_z_image_state_dict( return ModelPatchRaw(layers=layers) +def _convert_z_image_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Converts a Kohya-format Z-Image LoRA state dict to diffusion_model dot-notation. + + Example key conversions: + - lora_unet__layers_0_attention_to_k.lora_down.weight -> diffusion_model.layers.0.attention.to_k.lora_down.weight + - lora_unet__context_refiner_0_feed_forward_w1.alpha -> diffusion_model.context_refiner.0.feed_forward.w1.alpha + - lora_unet__noise_refiner_1_attention_to_v.lora_up.weight -> diffusion_model.noise_refiner.1.attention.to_v.lora_up.weight + """ + converted: Dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + if not isinstance(key, str) or not key.startswith("lora_unet__"): + converted[key] = value + continue + + # Split into layer name and param suffix (e.g. "lora_down.weight", "alpha") + layer_name, _, param_suffix = key.partition(".") + + # Strip lora_unet__ prefix + remainder = layer_name[len("lora_unet__") :] + + # Convert Kohya underscore format to dot-notation using the known structure + match = re.match( + r"(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)$", + remainder, + ) + if match: + block, idx, submodule, param = match.groups() + new_layer = f"diffusion_model.{block}.{idx}.{submodule}.{param}" + else: + # Fallback: keep original key for unrecognized patterns + converted[key] = value + continue + + new_key = f"{new_layer}.{param_suffix}" if param_suffix else new_layer + converted[new_key] = value + + return converted + + def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]: """Convert layer dict keys from PEFT format to internal format.""" if "lora_A.weight" in layer_dict: diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index cc654e4d39..fb8671cec2 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -40,11 +40,9 @@ def directory_size(directory: Path) -> int: Return the aggregate size of all files in a directory (bytes). """ sum = 0 - for root, dirs, files in os.walk(directory): + for root, _, files in os.walk(directory): for f in files: sum += Path(root, f).stat().st_size - for d in dirs: - sum += Path(root, d).stat().st_size return sum diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 3e7e742934..75c5ad6671 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1203,6 +1203,15 @@ "modelType": "Model Type", "modelUpdated": "Model Updated", "modelUpdateFailed": "Model Update Failed", + "sortByName": "Name", + "sortByBase": "Base", + "sortBySize": "Size", + "sortByDateAdded": "Date Added", + "sortByDateModified": "Date Modified", + "sortByPath": "Path", + "sortByType": "Type", + "sortByFormat": "Format", + "sortDefault": "Default", "name": "Name", "externalProvider": "External Provider", "externalCapabilities": "External Capabilities", diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index db0a9a11a6..d823258dbf 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -2660,7 +2660,9 @@ "fitModeCover": "Copri", "smoothingMode": "Modalità di ricampionamento", "smoothingDesc": "Applica un ricampionamento di alta qualità lato backend alla conferma delle trasformazioni.", - "smoothing": "Smussamento" + "smoothing": "Smussamento", + "smoothingModeBilinear": "Bilineare", + "smoothingModeBicubic": "Bicubico" }, "stagingArea": { "next": "Successiva", diff --git a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx index f8ea01909a..4c1bc56e49 100644 --- a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx +++ b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx @@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => { }; }; -export const SubMenuButtonContent = ({ label }: { label: string }) => { +export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => { return ( {label} - + + {value !== undefined && ( + + {value} + + )} + + ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 9cc4ed24d9..7cdba474bb 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -31,7 +31,7 @@ export type ModelCategoryData = { filter: (config: AnyModelConfig) => boolean; }; -export const MODEL_CATEGORIES: Record = { +const MODEL_CATEGORIES: Record = { unknown: { category: 'unknown', i18nKey: 'common.unknown', diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 44df38d911..91fb1afd4d 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -25,6 +25,10 @@ const zModelManagerState = z.object({ scanPath: z.string().optional(), shouldInstallInPlace: z.boolean(), selectedModelKeys: z.array(z.string()), + orderBy: z + .enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format']) + .default('name'), + sortDirection: z.enum(['asc', 'desc']).default('asc'), }); type ModelManagerState = z.infer; @@ -38,6 +42,8 @@ const getInitialState = (): ModelManagerState => ({ scanPath: undefined, shouldInstallInPlace: true, selectedModelKeys: [], + orderBy: 'name', + sortDirection: 'asc', }); const slice = createSlice({ @@ -77,6 +83,12 @@ const slice = createSlice({ clearModelSelection: (state) => { state.selectedModelKeys = []; }, + setOrderBy: (state, action: PayloadAction) => { + state.orderBy = action.payload; + }, + setSortDirection: (state, action: PayloadAction) => { + state.sortDirection = action.payload; + }, }, }); @@ -90,6 +102,8 @@ export const { modelSelectionChanged, toggleModelSelection, clearModelSelection, + setOrderBy, + setSortDirection, } = slice.actions; export const modelManagerSliceConfig: SliceConfig = { @@ -119,3 +133,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType); export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace); export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys); +export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy); +export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx new file mode 100644 index 0000000000..57dad58f2c --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx @@ -0,0 +1,231 @@ +import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu'; +import type { ModelCategoryData } from 'features/modelManagerV2/models'; +import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; +import { + selectFilteredModelType, + selectOrderBy, + selectSortDirection, + setFilteredModelType, + setOrderBy, + setSortDirection, +} from 'features/modelManagerV2/store/modelManagerV2Slice'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + PiCheckBold, + PiFunnelBold, + PiListBold, + PiSortAscendingBold, + PiSortDescendingBold, + PiWarningBold, +} from 'react-icons/pi'; + +type OrderBy = 'default' | 'name' | 'type' | 'base' | 'size' | 'created_at' | 'updated_at' | 'path' | 'format'; + +const ORDER_BY_OPTIONS: { key: OrderBy; i18nKey: string }[] = [ + { key: 'default', i18nKey: 'modelManager.sortDefault' }, + { key: 'name', i18nKey: 'modelManager.sortByName' }, + { key: 'base', i18nKey: 'modelManager.sortByBase' }, + { key: 'size', i18nKey: 'modelManager.sortBySize' }, + { key: 'created_at', i18nKey: 'modelManager.sortByDateAdded' }, + { key: 'updated_at', i18nKey: 'modelManager.sortByDateModified' }, + { key: 'path', i18nKey: 'modelManager.sortByPath' }, + { key: 'type', i18nKey: 'modelManager.sortByType' }, + { key: 'format', i18nKey: 'modelManager.sortByFormat' }, +]; + +const SortByMenuItem = memo(({ option, label }: { option: OrderBy; label: string }) => { + const dispatch = useAppDispatch(); + const orderBy = useAppSelector(selectOrderBy); + const onClick = useCallback(() => { + dispatch(setOrderBy(option)); + }, [dispatch, option]); + + return ( + : } + > + {label} + + ); +}); +SortByMenuItem.displayName = 'SortByMenuItem'; + +const SortBySubMenu = memo(() => { + const { t } = useTranslation(); + const subMenu = useSubMenu(); + const orderBy = useAppSelector(selectOrderBy); + + const currentSortLabel = useMemo(() => { + const option = ORDER_BY_OPTIONS.find((o) => o.key === orderBy); + if (!option) { + return ''; + } + return t(option.i18nKey); + }, [orderBy, t]); + + return ( + }> + + + + + + {ORDER_BY_OPTIONS.map(({ key, i18nKey }) => ( + + ))} + + + + ); +}); +SortBySubMenu.displayName = 'SortBySubMenu'; + +const DirectionSubMenu = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const direction = useAppSelector(selectSortDirection); + const subMenu = useSubMenu(); + + const setDirectionAsc = useCallback(() => { + dispatch(setSortDirection('asc')); + }, [dispatch]); + + const setDirectionDesc = useCallback(() => { + dispatch(setSortDirection('desc')); + }, [dispatch]); + + const currentValue = direction === 'asc' ? t('common.ascending', 'Ascending') : t('common.descending', 'Descending'); + + return ( + : } + > + + + + + + : } + > + {t('common.ascending', 'Ascending')} + + : } + > + {t('common.descending', 'Descending')} + + + + + ); +}); +DirectionSubMenu.displayName = 'DirectionSubMenu'; + +const ModelTypeSubMenu = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const subMenu = useSubMenu(); + + const clearModelType = useCallback(() => { + dispatch(setFilteredModelType(null)); + }, [dispatch]); + + const setMissingFilter = useCallback(() => { + dispatch(setFilteredModelType('missing')); + }, [dispatch]); + + const currentValue = useMemo(() => { + if (filteredModelType === null) { + return t('modelManager.allModels'); + } + if (filteredModelType === 'missing') { + return t('modelManager.missingFiles'); + } + const categoryData = MODEL_CATEGORIES_AS_LIST.find((data) => data.category === filteredModelType); + return categoryData ? t(categoryData.i18nKey) : ''; + }, [filteredModelType, t]); + + return ( + }> + + + + + + : } + > + {t('modelManager.allModels')} + + : } + > + + {filteredModelType !== 'missing' && } + {t('modelManager.missingFiles')} + + + {MODEL_CATEGORIES_AS_LIST.map((data) => ( + + ))} + + + + ); +}); +ModelTypeSubMenu.displayName = 'ModelTypeSubMenu'; + +const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const onClick = useCallback(() => { + dispatch(setFilteredModelType(data.category)); + }, [data.category, dispatch]); + return ( + : } + > + {t(data.i18nKey)} + + ); +}); +ModelMenuItem.displayName = 'ModelMenuItem'; + +export const ModelFilterMenu = memo(() => { + const { t } = useTranslation(); + + return ( + + }> + {t('common.filtering', 'Filtering')} + + + + + + + + ); +}); + +ModelFilterMenu.displayName = 'ModelFilterMenu'; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index ed49fa2870..033a439bfc 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -8,8 +8,10 @@ import { clearModelSelection, type FilterableModelType, selectFilteredModelType, + selectOrderBy, selectSearchTerm, selectSelectedModelKeys, + selectSortDirection, setSelectedModelKey, } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { memo, useCallback, useMemo, useState } from 'react'; @@ -39,6 +41,8 @@ const ModelList = () => { const dispatch = useAppDispatch(); const filteredModelType = useAppSelector(selectFilteredModelType); const searchTerm = useAppSelector(selectSearchTerm); + const orderBy = useAppSelector(selectOrderBy); + const direction = useAppSelector(selectSortDirection); const selectedModelKeys = useAppSelector(selectSelectedModelKeys); const { t } = useTranslation(); const toast = useToast(); @@ -47,7 +51,8 @@ const ModelList = () => { const [isDeleting, setIsDeleting] = useState(false); const [isReidentifying, setIsReidentifying] = useState(false); - const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(); + const queryArgs = useMemo(() => ({ order_by: orderBy, direction: direction.toUpperCase() }), [orderBy, direction]); + const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(queryArgs); const { data: missingModelsData, isLoading: isLoadingMissing } = useGetMissingModelsQuery(); const [bulkDeleteModels] = useBulkDeleteModelsMutation(); const [bulkReidentifyModels] = useBulkReidentifyModelsMutation(); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx index 78bed8ab83..bbfb88df5c 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx @@ -6,8 +6,8 @@ import type { ChangeEventHandler } from 'react'; import { memo, useCallback } from 'react'; import { PiXBold } from 'react-icons/pi'; +import { ModelFilterMenu } from './ModelFilterMenu'; import { ModelListBulkActions } from './ModelListBulkActions'; -import { ModelTypeFilter } from './ModelTypeFilter'; export const ModelListNavigation = memo(() => { const dispatch = useAppDispatch(); @@ -50,7 +50,7 @@ export const ModelListNavigation = memo(() => { - + diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx deleted file mode 100644 index 5aa8e62886..0000000000 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ /dev/null @@ -1,78 +0,0 @@ -import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import type { ModelCategoryData } from 'features/modelManagerV2/models'; -import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; -import type { ModelCategoryType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { memo, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import { PiFunnelBold, PiWarningBold } from 'react-icons/pi'; - -const isModelCategoryType = (type: string): type is ModelCategoryType => { - return type in MODEL_CATEGORIES; -}; - -export const ModelTypeFilter = memo(() => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const filteredModelType = useAppSelector(selectFilteredModelType); - - const clearModelType = useCallback(() => { - dispatch(setFilteredModelType(null)); - }, [dispatch]); - - const setMissingFilter = useCallback(() => { - dispatch(setFilteredModelType('missing')); - }, [dispatch]); - - const getButtonLabel = () => { - if (filteredModelType === 'missing') { - return t('modelManager.missingFiles'); - } - if (filteredModelType && isModelCategoryType(filteredModelType)) { - return t(MODEL_CATEGORIES[filteredModelType].i18nKey); - } - return t('modelManager.allModels'); - }; - - return ( - - }> - {getButtonLabel()} - - - {t('modelManager.allModels')} - - - - {t('modelManager.missingFiles')} - - - {MODEL_CATEGORIES_AS_LIST.map((data) => ( - - ))} - - - ); -}); - -ModelTypeFilter.displayName = 'ModelTypeFilter'; - -const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const filteredModelType = useAppSelector(selectFilteredModelType); - const onClick = useCallback(() => { - dispatch(setFilteredModelType(data.category)); - }, [data.category, dispatch]); - return ( - - {t(data.i18nKey)} - - ); -}); -ModelMenuItem.displayName = 'ModelMenuItem'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts index 892b47a408..28aa74db5e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts @@ -260,17 +260,7 @@ export const getDenoisingStartAndEnd = (state: RootState): { denoising_start: nu }; } } - case 'anima': { - // Anima uses a fixed shift=3.0 which makes the sigma schedule highly non-linear. - // Without rescaling, most of the visual 'change' is concentrated in the high denoise - // strength range (>0.8). The exponent 0.2 spreads the effective range more evenly, - // matching the approach used for FLUX and SD3. - const animaExponent = optimizedDenoisingEnabled ? 0.2 : 1; - return { - denoising_start: 1 - denoisingStrength ** animaExponent, - denoising_end: 1, - }; - } + case 'anima': case 'sd-1': case 'sd-2': case 'cogview4': diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 2bf8d8643c..5a4abdd4d2 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -4,13 +4,14 @@ import { useStore } from '@nanostores/react'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; -import { $isConnected, $lastProgressEvent } from 'services/events/stores'; +import { $isConnected, $lastProgressEvent, $loadingModelsCount } from 'services/events/stores'; const ProgressBar = (props: ProgressProps) => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); const isConnected = useStore($isConnected); const lastProgressEvent = useStore($lastProgressEvent); + const loadingModelsCount = useStore($loadingModelsCount); const value = useMemo(() => { if (!lastProgressEvent) { return 0; @@ -23,6 +24,10 @@ const ProgressBar = (props: ProgressProps) => { return false; } + if (loadingModelsCount > 0) { + return true; + } + if (!queueStatus?.queue.in_progress) { return false; } @@ -40,7 +45,7 @@ const ProgressBar = (props: ProgressProps) => { } return false; - }, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress]); + }, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress, loadingModelsCount]); return ( ; }; +type GetModelConfigsArg = { + order_by?: string; + direction?: string; +} | void; + const modelConfigsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions); @@ -338,8 +342,11 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['ModelInstalls'], }), - getModelConfigs: build.query, void>({ - query: () => ({ url: buildModelsUrl() }), + getModelConfigs: build.query, GetModelConfigsArg>({ + query: (arg) => { + const queryStr = arg ? `?${queryString.stringify(arg)}` : ''; + return { url: buildModelsUrl(queryStr) }; + }, providesTags: (result) => { const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }]; if (result) { @@ -498,5 +505,5 @@ export const { useDeleteOrphanedModelsMutation, } = modelsApi; -export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(); +export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(undefined); export const selectMissingModelsQuery = modelsApi.endpoints.getMissingModels.select(); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 6c6e259fd1..e1dd2ad361 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -22997,6 +22997,12 @@ export type components = { */ config_path?: string | null; }; + /** + * ModelRecordOrderBy + * @description The order in which to return model summaries. + * @enum {string} + */ + ModelRecordOrderBy: "default" | "type" | "base" | "name" | "format" | "size" | "created_at" | "updated_at" | "path"; /** ModelRelationshipBatchRequest */ ModelRelationshipBatchRequest: { /** @@ -31525,6 +31531,10 @@ export interface operations { model_name?: string | null; /** @description Exact match on the format of the model (e.g. 'diffusers') */ model_format?: components["schemas"]["ModelFormat"] | null; + /** @description The field to order by */ + order_by?: components["schemas"]["ModelRecordOrderBy"]; + /** @description The direction to order by */ + direction?: components["schemas"]["SQLiteDirection"]; }; header?: never; path?: never; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 774acd3f93..fb08fc08dd 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -43,7 +43,7 @@ import type { ClientToServerEvents, ServerToClientEvents } from 'services/events import type { Socket } from 'socket.io-client'; import type { JsonObject } from 'type-fest'; -import { $lastProgressEvent } from './stores'; +import { $lastProgressEvent, $loadingModelsCount } from './stores'; const log = logger('events'); @@ -73,12 +73,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.emit('subscribe_queue', { queue_id: 'default' }); socket.emit('subscribe_bulk_download', { bulk_download_id: 'default' }); $lastProgressEvent.set(null); + $loadingModelsCount.set(0); }); socket.on('connect_error', (error) => { log.debug('Connect error'); setIsConnected(false); $lastProgressEvent.set(null); + $loadingModelsCount.set(0); if (error && error.message) { const data: string | undefined = (error as unknown as { data: string | undefined }).data; if (data === 'ERR_UNAUTHENTICATED') { @@ -95,6 +97,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.on('disconnect', () => { log.debug('Disconnected'); $lastProgressEvent.set(null); + $loadingModelsCount.set(0); setIsConnected(false); }); @@ -183,6 +186,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis const message = `Model load started: ${name} (${extras.join(', ')})`; log.debug({ data }, message); + $loadingModelsCount.set($loadingModelsCount.get() + 1); }); socket.on('model_load_complete', (data) => { @@ -197,6 +201,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis const message = `Model load complete: ${name} (${extras.join(', ')})`; log.debug({ data }, message); + $loadingModelsCount.set(Math.max(0, $loadingModelsCount.get() - 1)); }); socket.on('download_started', (data) => { diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index 720ba920cf..180f4a3a63 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -6,6 +6,7 @@ import type { AppSocket } from 'services/events/types'; export const $socket = atom(null); export const $isConnected = atom(false); export const $lastProgressEvent = atom(null); +export const $loadingModelsCount = atom(0); export const $lastProgressMessage = computed($lastProgressEvent, (val) => { if (!val) { diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 2b6c54d5b0..19a1b74e73 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -11,11 +11,13 @@ from pydantic import ValidationError from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.model_records import ( DuplicateModelException, + ModelRecordOrderBy, ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException, ) from invokeai.app.services.model_records.model_records_base import ModelRecordChanges +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config from invokeai.backend.model_manager.configs.main import ( @@ -364,6 +366,73 @@ def test_filter_2(store: ModelRecordServiceBase): assert len(matches) == 1 +def test_search_by_attr_sorting(store: ModelRecordServiceSQL): + config1 = Main_Diffusers_SD1_Config( + path="/tmp/config1", + name="alpha", + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, + hash="CONFIG1HASH", + file_size=1000, + source="test/source/", + source_type=ModelSourceType.Path, + variant=ModelVariantType.Normal, + prediction_type=SchedulerPredictionType.Epsilon, + repo_variant=ModelRepoVariant.Default, + ) + config2 = Main_Diffusers_SD2_Config( + path="/tmp/config2", + name="beta", + base=BaseModelType.StableDiffusion2, + type=ModelType.Main, + hash="CONFIG2HASH", + file_size=2000, + source="test/source/", + source_type=ModelSourceType.Path, + variant=ModelVariantType.Normal, + prediction_type=SchedulerPredictionType.Epsilon, + repo_variant=ModelRepoVariant.Default, + ) + config3 = VAE_Diffusers_SD1_Config( + path="/tmp/config3", + name="gamma", + base=BaseModelType.StableDiffusion1, + type=ModelType.VAE, + hash="CONFIG3HASH", + file_size=500, + source="test/source/", + source_type=ModelSourceType.Path, + repo_variant=ModelRepoVariant.Default, + ) + for c in config1, config2, config3: + store.add_model(c) + + # Test sorting by Name Ascending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Ascending) + assert len(matches) == 3 + assert matches[0].name == "alpha" + assert matches[1].name == "beta" + assert matches[2].name == "gamma" + + # Test sorting by Name Descending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Descending) + assert matches[0].name == "gamma" + assert matches[1].name == "beta" + assert matches[2].name == "alpha" + + # Test sorting by Size Ascending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Ascending) + assert matches[0].name == "gamma" # 500 + assert matches[1].name == "alpha" # 1000 + assert matches[2].name == "beta" # 2000 + + # Test sorting by Size Descending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Descending) + assert matches[0].name == "beta" # 2000 + assert matches[1].name == "alpha" # 1000 + assert matches[2].name == "gamma" # 500 + + def test_model_record_changes(): # This test guards against some unexpected behaviours from pydantic's union evaluation. See #6035 changes = ModelRecordChanges.model_validate({"default_settings": {"preprocessor": "value"}})