Merge branch 'main' into feature/sqlmodel-migration

This commit is contained in:
Alexander Eichhorn
2026-04-22 03:24:43 +02:00
committed by GitHub
24 changed files with 576 additions and 129 deletions

View File

@@ -26,9 +26,11 @@ from invokeai.app.services.model_install.model_install_common import ModelInstal
from invokeai.app.services.model_records import (
InvalidModelException,
ModelRecordChanges,
ModelRecordOrderBy,
UnknownModelException,
)
from invokeai.app.services.orphaned_models import OrphanedModelInfo
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
@@ -159,6 +161,8 @@ async def list_model_records(
model_format: Optional[ModelFormat] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"),
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
@@ -167,12 +171,23 @@ async def list_model_records(
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
base_model=base_model,
model_type=model_type,
model_name=model_name,
model_format=model_format,
order_by=order_by,
direction=direction,
)
)
else:
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
record_store.search_by_attr(
model_type=model_type,
model_name=model_name,
model_format=model_format,
order_by=order_by,
direction=direction,
)
)
for index, model in enumerate(found_models):
found_models[index] = prepare_model_config_for_response(model, ApiDependencies)

View File

@@ -11,6 +11,7 @@ from typing import List, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.external_api import (
@@ -60,6 +61,10 @@ class ModelRecordOrderBy(str, Enum):
Base = "base"
Name = "name"
Format = "format"
Size = "size"
DateAdded = "created_at"
DateModified = "updated_at"
Path = "path"
class ModelSummary(BaseModel):
@@ -200,7 +205,11 @@ class ModelRecordServiceBase(ABC):
@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
self,
page: int = 0,
per_page: int = 10,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
pass
@@ -237,6 +246,8 @@ class ModelRecordServiceBase(ABC):
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.

View File

@@ -57,6 +57,7 @@ from invokeai.app.services.model_records.model_records_base import (
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
@@ -257,6 +258,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@@ -266,18 +268,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
:param order_by: Result order
:param direction: Result direction
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Base: "base COLLATE NOCASE",
ModelRecordOrderBy.Name: "name COLLATE NOCASE",
ModelRecordOrderBy.Format: "format",
ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
ModelRecordOrderBy.DateAdded: "created_at",
ModelRecordOrderBy.DateModified: "updated_at",
ModelRecordOrderBy.Path: "path",
}
where_clause: list[str] = []
@@ -301,7 +309,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
SELECT config
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
@@ -357,17 +365,26 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return results
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
self,
page: int = 0,
per_page: int = 10,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
direction: SQLiteDirection = SQLiteDirection.Ascending,
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC"
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Base: "base COLLATE NOCASE",
ModelRecordOrderBy.Name: "name COLLATE NOCASE",
ModelRecordOrderBy.Format: "format",
ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)",
ModelRecordOrderBy.DateAdded: "created_at",
ModelRecordOrderBy.DateModified: "updated_at",
ModelRecordOrderBy.Path: "path",
}
# Lock so that the database isn't updated while we're doing the two queries.
@@ -385,7 +402,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",

View File

@@ -714,14 +714,25 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
- diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format)
- diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format)
- diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale)
- lora_unet__layers_X_attention_to_k.lora_down.weight (Kohya format)
"""
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import (
is_state_dict_likely_z_image_kohya_lora,
)
state_dict = mod.load_state_dict()
# Check for Z-Image specific LoRA patterns
# Check for Kohya format first
if is_state_dict_likely_z_image_kohya_lora(state_dict):
return
# Check for Z-Image specific LoRA patterns (dot-notation formats)
has_z_image_lora_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
"diffusion_model.context_refiner.",
"diffusion_model.noise_refiner.",
"transformer.layers.", # OneTrainer/diffusers prefix variant
"base_model.model.transformer.layers.", # PEFT-wrapped variant
},
@@ -751,15 +762,26 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
Z-Image uses S3-DiT architecture with layer names like:
- diffusion_model.layers.0.attention.to_k.lora_A.weight
- diffusion_model.layers.0.feed_forward.w1.lora_A.weight
- lora_unet__layers_0_attention_to_k.lora_down.weight (Kohya format)
"""
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import (
is_state_dict_likely_z_image_kohya_lora,
)
state_dict = mod.load_state_dict()
# Check for Z-Image transformer layer patterns
# Check for Kohya format
if is_state_dict_likely_z_image_kohya_lora(state_dict):
return BaseModelType.ZImage
# Check for Z-Image transformer layer patterns (dot-notation formats)
# Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks)
has_z_image_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
"diffusion_model.context_refiner.",
"diffusion_model.noise_refiner.",
"transformer.layers.", # OneTrainer/diffusers prefix variant
"base_model.model.transformer.layers.", # PEFT-wrapped variant
},

View File

@@ -160,17 +160,20 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
".lora_A.weight",
".lora_B.weight",
".dora_scale",
".alpha",
)
# First pass: check if any key has LoRA suffixes - if so, this is a LoRA not a main model
for key in state_dict.keys():
if isinstance(key, int):
continue
# If we find any LoRA-specific keys, this is not a main model
if key.endswith(lora_suffixes):
return False
# Check for Z-Image specific key prefixes
# Second pass: check for Z-Image specific key parts
for key in state_dict.keys():
if isinstance(key, int):
continue
# Handle both direct keys (cap_embedder.0.weight) and
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
key_parts = key.split(".")

View File

@@ -178,12 +178,43 @@ class Flux2VAELoader(ModelLoader):
if is_bfl_format:
sd = self._convert_flux2_vae_bfl_to_diffusers(sd)
# FLUX.2 VAE configuration (32 latent channels)
# Based on the official FLUX.2 VAE architecture
# Use default config - AutoencoderKLFlux2 has built-in defaults
# FLUX.2 VAE configuration (32 latent channels).
# The standard FLUX.2 VAE uses block_out_channels=(128,256,512,512) for both
# encoder and decoder. The "small decoder" variant from
# black-forest-labs/FLUX.2-small-decoder keeps the full encoder but uses a
# narrower decoder with channels (96,192,384,384). AutoencoderKLFlux2 only
# exposes a single block_out_channels, so we build the model with the
# encoder's channels and, if the decoder differs, replace just the decoder
# submodule with a matching one before loading the state dict.
encoder_block_out_channels = (128, 256, 512, 512)
decoder_block_out_channels = encoder_block_out_channels
if "encoder.conv_in.weight" in sd and "encoder.conv_norm_out.weight" in sd:
enc_last = int(sd["encoder.conv_norm_out.weight"].shape[0])
enc_first = int(sd["encoder.conv_in.weight"].shape[0])
encoder_block_out_channels = (enc_first, enc_first * 2, enc_last, enc_last)
if "decoder.conv_in.weight" in sd and "decoder.conv_norm_out.weight" in sd:
dec_last = int(sd["decoder.conv_in.weight"].shape[0])
dec_first = int(sd["decoder.conv_norm_out.weight"].shape[0])
decoder_block_out_channels = (dec_first, dec_first * 2, dec_last, dec_last)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = AutoencoderKLFlux2()
model = AutoencoderKLFlux2(block_out_channels=encoder_block_out_channels)
if decoder_block_out_channels != encoder_block_out_channels:
# Rebuild the decoder with the smaller channel widths.
from diffusers.models.autoencoders.vae import Decoder
cfg = model.config
model.decoder = Decoder(
in_channels=cfg.latent_channels,
out_channels=cfg.out_channels,
up_block_types=cfg.up_block_types,
block_out_channels=decoder_block_out_channels,
layers_per_block=cfg.layers_per_block,
norm_num_groups=cfg.norm_num_groups,
act_fn=cfg.act_fn,
mid_block_add_attention=cfg.mid_block_add_attention,
)
# Convert to bfloat16 and load
for k in sd.keys():

View File

@@ -1,10 +1,11 @@
"""Z-Image LoRA conversion utilities.
Z-Image uses S3-DiT transformer architecture with Qwen3 text encoder.
LoRAs for Z-Image typically follow the diffusers PEFT format.
LoRAs for Z-Image typically follow the diffusers PEFT format or Kohya format.
"""
from typing import Dict
import re
from typing import Any, Dict
import torch
@@ -16,6 +17,29 @@ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import (
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
# Regex for Kohya-format Z-Image transformer keys.
# Example keys:
# lora_unet__layers_0_attention_to_k.alpha
# lora_unet__layers_0_attention_to_k.lora_down.weight
# lora_unet__context_refiner_0_feed_forward_w1.lora_up.weight
# lora_unet__noise_refiner_1_attention_to_v.lora_down.weight
Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX = (
r"lora_unet__(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)"
)
def is_state_dict_likely_z_image_kohya_lora(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely a Z-Image LoRA in Kohya format.
Kohya Z-Image LoRAs have keys like:
- lora_unet__layers_0_attention_to_k.lora_down.weight
- lora_unet__context_refiner_0_feed_forward_w1.alpha
- lora_unet__noise_refiner_1_attention_to_v.lora_up.weight
"""
return any(
isinstance(k, str) and re.match(Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX, k.split(".")[0]) for k in state_dict.keys()
)
def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely a Z-Image LoRA.
@@ -23,6 +47,9 @@ def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor])
Z-Image LoRAs can have keys for transformer and/or Qwen3 text encoder.
They may use various prefixes depending on the training framework.
"""
if is_state_dict_likely_z_image_kohya_lora(state_dict):
return True
str_keys = [k for k in state_dict.keys() if isinstance(k, str)]
# Check for Z-Image transformer keys (S3-DiT architecture)
@@ -57,6 +84,7 @@ def lora_model_from_z_image_state_dict(
- "transformer." or "base_model.model.transformer." for diffusers PEFT format
- "diffusion_model." for some training frameworks
- "text_encoder." or "base_model.model.text_encoder." for Qwen3 encoder
- "lora_unet__" for Kohya format (underscores instead of dots)
Args:
state_dict: The LoRA state dict
@@ -65,6 +93,10 @@ def lora_model_from_z_image_state_dict(
Returns:
A ModelPatchRaw containing the LoRA layers
"""
# If Kohya format, convert keys first then process normally
if is_state_dict_likely_z_image_kohya_lora(state_dict):
state_dict = _convert_z_image_kohya_state_dict(state_dict)
layers: dict[str, BaseLayerPatch] = {}
# Group keys by layer
@@ -120,6 +152,45 @@ def lora_model_from_z_image_state_dict(
return ModelPatchRaw(layers=layers)
def _convert_z_image_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Converts a Kohya-format Z-Image LoRA state dict to diffusion_model dot-notation.
Example key conversions:
- lora_unet__layers_0_attention_to_k.lora_down.weight -> diffusion_model.layers.0.attention.to_k.lora_down.weight
- lora_unet__context_refiner_0_feed_forward_w1.alpha -> diffusion_model.context_refiner.0.feed_forward.w1.alpha
- lora_unet__noise_refiner_1_attention_to_v.lora_up.weight -> diffusion_model.noise_refiner.1.attention.to_v.lora_up.weight
"""
converted: Dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
if not isinstance(key, str) or not key.startswith("lora_unet__"):
converted[key] = value
continue
# Split into layer name and param suffix (e.g. "lora_down.weight", "alpha")
layer_name, _, param_suffix = key.partition(".")
# Strip lora_unet__ prefix
remainder = layer_name[len("lora_unet__") :]
# Convert Kohya underscore format to dot-notation using the known structure
match = re.match(
r"(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)$",
remainder,
)
if match:
block, idx, submodule, param = match.groups()
new_layer = f"diffusion_model.{block}.{idx}.{submodule}.{param}"
else:
# Fallback: keep original key for unrecognized patterns
converted[key] = value
continue
new_key = f"{new_layer}.{param_suffix}" if param_suffix else new_layer
converted[new_key] = value
return converted
def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]:
"""Convert layer dict keys from PEFT format to internal format."""
if "lora_A.weight" in layer_dict:

View File

@@ -40,11 +40,9 @@ def directory_size(directory: Path) -> int:
Return the aggregate size of all files in a directory (bytes).
"""
sum = 0
for root, dirs, files in os.walk(directory):
for root, _, files in os.walk(directory):
for f in files:
sum += Path(root, f).stat().st_size
for d in dirs:
sum += Path(root, d).stat().st_size
return sum

View File

@@ -1203,6 +1203,15 @@
"modelType": "Model Type",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"sortByName": "Name",
"sortByBase": "Base",
"sortBySize": "Size",
"sortByDateAdded": "Date Added",
"sortByDateModified": "Date Modified",
"sortByPath": "Path",
"sortByType": "Type",
"sortByFormat": "Format",
"sortDefault": "Default",
"name": "Name",
"externalProvider": "External Provider",
"externalCapabilities": "External Capabilities",

View File

@@ -2660,7 +2660,9 @@
"fitModeCover": "Copri",
"smoothingMode": "Modalità di ricampionamento",
"smoothingDesc": "Applica un ricampionamento di alta qualità lato backend alla conferma delle trasformazioni.",
"smoothing": "Smussamento"
"smoothing": "Smussamento",
"smoothingModeBilinear": "Bilineare",
"smoothingModeBicubic": "Bicubico"
},
"stagingArea": {
"next": "Successiva",

View File

@@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => {
};
};
export const SubMenuButtonContent = ({ label }: { label: string }) => {
export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => {
return (
<Flex w="full" h="full" flexDir="row" justifyContent="space-between" alignItems="center">
<Text>{label}</Text>
<Icon as={PiCaretRightBold} />
<Flex alignItems="center" gap={2}>
{value !== undefined && (
<Text fontSize="sm" color="base.400">
{value}
</Text>
)}
<Icon as={PiCaretRightBold} />
</Flex>
</Flex>
);
};

View File

@@ -31,7 +31,7 @@ export type ModelCategoryData = {
filter: (config: AnyModelConfig) => boolean;
};
export const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
unknown: {
category: 'unknown',
i18nKey: 'common.unknown',

View File

@@ -25,6 +25,10 @@ const zModelManagerState = z.object({
scanPath: z.string().optional(),
shouldInstallInPlace: z.boolean(),
selectedModelKeys: z.array(z.string()),
orderBy: z
.enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format'])
.default('name'),
sortDirection: z.enum(['asc', 'desc']).default('asc'),
});
type ModelManagerState = z.infer<typeof zModelManagerState>;
@@ -38,6 +42,8 @@ const getInitialState = (): ModelManagerState => ({
scanPath: undefined,
shouldInstallInPlace: true,
selectedModelKeys: [],
orderBy: 'name',
sortDirection: 'asc',
});
const slice = createSlice({
@@ -77,6 +83,12 @@ const slice = createSlice({
clearModelSelection: (state) => {
state.selectedModelKeys = [];
},
setOrderBy: (state, action: PayloadAction<ModelManagerState['orderBy']>) => {
state.orderBy = action.payload;
},
setSortDirection: (state, action: PayloadAction<ModelManagerState['sortDirection']>) => {
state.sortDirection = action.payload;
},
},
});
@@ -90,6 +102,8 @@ export const {
modelSelectionChanged,
toggleModelSelection,
clearModelSelection,
setOrderBy,
setSortDirection,
} = slice.actions;
export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
@@ -119,3 +133,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm
export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType);
export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace);
export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys);
export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy);
export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection);

View File

@@ -0,0 +1,231 @@
import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import type { ModelCategoryData } from 'features/modelManagerV2/models';
import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
import {
selectFilteredModelType,
selectOrderBy,
selectSortDirection,
setFilteredModelType,
setOrderBy,
setSortDirection,
} from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
PiCheckBold,
PiFunnelBold,
PiListBold,
PiSortAscendingBold,
PiSortDescendingBold,
PiWarningBold,
} from 'react-icons/pi';
type OrderBy = 'default' | 'name' | 'type' | 'base' | 'size' | 'created_at' | 'updated_at' | 'path' | 'format';
const ORDER_BY_OPTIONS: { key: OrderBy; i18nKey: string }[] = [
{ key: 'default', i18nKey: 'modelManager.sortDefault' },
{ key: 'name', i18nKey: 'modelManager.sortByName' },
{ key: 'base', i18nKey: 'modelManager.sortByBase' },
{ key: 'size', i18nKey: 'modelManager.sortBySize' },
{ key: 'created_at', i18nKey: 'modelManager.sortByDateAdded' },
{ key: 'updated_at', i18nKey: 'modelManager.sortByDateModified' },
{ key: 'path', i18nKey: 'modelManager.sortByPath' },
{ key: 'type', i18nKey: 'modelManager.sortByType' },
{ key: 'format', i18nKey: 'modelManager.sortByFormat' },
];
const SortByMenuItem = memo(({ option, label }: { option: OrderBy; label: string }) => {
const dispatch = useAppDispatch();
const orderBy = useAppSelector(selectOrderBy);
const onClick = useCallback(() => {
dispatch(setOrderBy(option));
}, [dispatch, option]);
return (
<MenuItem
onClick={onClick}
bg={orderBy === option ? 'base.700' : 'transparent'}
icon={orderBy === option ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
>
{label}
</MenuItem>
);
});
SortByMenuItem.displayName = 'SortByMenuItem';
const SortBySubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const orderBy = useAppSelector(selectOrderBy);
const currentSortLabel = useMemo(() => {
const option = ORDER_BY_OPTIONS.find((o) => o.key === orderBy);
if (!option) {
return '';
}
return t(option.i18nKey);
}, [orderBy, t]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiListBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('modelManager.sortBy', 'Sort By')} value={currentSortLabel} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
{ORDER_BY_OPTIONS.map(({ key, i18nKey }) => (
<SortByMenuItem key={key} option={key} label={t(i18nKey)} />
))}
</MenuList>
</Menu>
</MenuItem>
);
});
SortBySubMenu.displayName = 'SortBySubMenu';
const DirectionSubMenu = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const direction = useAppSelector(selectSortDirection);
const subMenu = useSubMenu();
const setDirectionAsc = useCallback(() => {
dispatch(setSortDirection('asc'));
}, [dispatch]);
const setDirectionDesc = useCallback(() => {
dispatch(setSortDirection('desc'));
}, [dispatch]);
const currentValue = direction === 'asc' ? t('common.ascending', 'Ascending') : t('common.descending', 'Descending');
return (
<MenuItem
{...subMenu.parentMenuItemProps}
icon={direction === 'asc' ? <PiSortAscendingBold /> : <PiSortDescendingBold />}
>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('common.direction', 'Direction')} value={currentValue} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem
onClick={setDirectionAsc}
bg={direction === 'asc' ? 'base.700' : 'transparent'}
icon={direction === 'asc' ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
>
{t('common.ascending', 'Ascending')}
</MenuItem>
<MenuItem
onClick={setDirectionDesc}
bg={direction === 'desc' ? 'base.700' : 'transparent'}
icon={direction === 'desc' ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
>
{t('common.descending', 'Descending')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
DirectionSubMenu.displayName = 'DirectionSubMenu';
const ModelTypeSubMenu = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const filteredModelType = useAppSelector(selectFilteredModelType);
const subMenu = useSubMenu();
const clearModelType = useCallback(() => {
dispatch(setFilteredModelType(null));
}, [dispatch]);
const setMissingFilter = useCallback(() => {
dispatch(setFilteredModelType('missing'));
}, [dispatch]);
const currentValue = useMemo(() => {
if (filteredModelType === null) {
return t('modelManager.allModels');
}
if (filteredModelType === 'missing') {
return t('modelManager.missingFiles');
}
const categoryData = MODEL_CATEGORIES_AS_LIST.find((data) => data.category === filteredModelType);
return categoryData ? t(categoryData.i18nKey) : '';
}, [filteredModelType, t]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiFunnelBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('modelManager.modelType', 'Model Type')} value={currentValue} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem
onClick={clearModelType}
bg={filteredModelType === null ? 'base.700' : 'transparent'}
icon={filteredModelType === null ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
>
{t('modelManager.allModels')}
</MenuItem>
<MenuItem
onClick={setMissingFilter}
bg={filteredModelType === 'missing' ? 'base.700' : 'transparent'}
color="warning.300"
icon={filteredModelType === 'missing' ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
>
<Flex alignItems="center" gap={2}>
{filteredModelType !== 'missing' && <PiWarningBold />}
{t('modelManager.missingFiles')}
</Flex>
</MenuItem>
{MODEL_CATEGORIES_AS_LIST.map((data) => (
<ModelMenuItem key={data.category} data={data} />
))}
</MenuList>
</Menu>
</MenuItem>
);
});
ModelTypeSubMenu.displayName = 'ModelTypeSubMenu';
const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const filteredModelType = useAppSelector(selectFilteredModelType);
const onClick = useCallback(() => {
dispatch(setFilteredModelType(data.category));
}, [data.category, dispatch]);
return (
<MenuItem
bg={filteredModelType === data.category ? 'base.700' : 'transparent'}
onClick={onClick}
icon={filteredModelType === data.category ? <PiCheckBold /> : <PiCheckBold visibility="hidden" />}
>
{t(data.i18nKey)}
</MenuItem>
);
});
ModelMenuItem.displayName = 'ModelMenuItem';
export const ModelFilterMenu = memo(() => {
const { t } = useTranslation();
return (
<Menu placement="bottom-end">
<MenuButton as={Button} size="sm" rightIcon={<PiFunnelBold />}>
{t('common.filtering', 'Filtering')}
</MenuButton>
<MenuList>
<DirectionSubMenu />
<SortBySubMenu />
<ModelTypeSubMenu />
</MenuList>
</Menu>
);
});
ModelFilterMenu.displayName = 'ModelFilterMenu';

View File

@@ -8,8 +8,10 @@ import {
clearModelSelection,
type FilterableModelType,
selectFilteredModelType,
selectOrderBy,
selectSearchTerm,
selectSelectedModelKeys,
selectSortDirection,
setSelectedModelKey,
} from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useCallback, useMemo, useState } from 'react';
@@ -39,6 +41,8 @@ const ModelList = () => {
const dispatch = useAppDispatch();
const filteredModelType = useAppSelector(selectFilteredModelType);
const searchTerm = useAppSelector(selectSearchTerm);
const orderBy = useAppSelector(selectOrderBy);
const direction = useAppSelector(selectSortDirection);
const selectedModelKeys = useAppSelector(selectSelectedModelKeys);
const { t } = useTranslation();
const toast = useToast();
@@ -47,7 +51,8 @@ const ModelList = () => {
const [isDeleting, setIsDeleting] = useState(false);
const [isReidentifying, setIsReidentifying] = useState(false);
const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery();
const queryArgs = useMemo(() => ({ order_by: orderBy, direction: direction.toUpperCase() }), [orderBy, direction]);
const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(queryArgs);
const { data: missingModelsData, isLoading: isLoadingMissing } = useGetMissingModelsQuery();
const [bulkDeleteModels] = useBulkDeleteModelsMutation();
const [bulkReidentifyModels] = useBulkReidentifyModelsMutation();

View File

@@ -6,8 +6,8 @@ import type { ChangeEventHandler } from 'react';
import { memo, useCallback } from 'react';
import { PiXBold } from 'react-icons/pi';
import { ModelFilterMenu } from './ModelFilterMenu';
import { ModelListBulkActions } from './ModelListBulkActions';
import { ModelTypeFilter } from './ModelTypeFilter';
export const ModelListNavigation = memo(() => {
const dispatch = useAppDispatch();
@@ -50,7 +50,7 @@ export const ModelListNavigation = memo(() => {
</InputGroup>
</Flex>
<Flex shrink={0}>
<ModelTypeFilter />
<ModelFilterMenu />
</Flex>
</Flex>
<ModelListBulkActions />

View File

@@ -1,78 +0,0 @@
import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { ModelCategoryData } from 'features/modelManagerV2/models';
import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models';
import type { ModelCategoryType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFunnelBold, PiWarningBold } from 'react-icons/pi';
const isModelCategoryType = (type: string): type is ModelCategoryType => {
return type in MODEL_CATEGORIES;
};
export const ModelTypeFilter = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const filteredModelType = useAppSelector(selectFilteredModelType);
const clearModelType = useCallback(() => {
dispatch(setFilteredModelType(null));
}, [dispatch]);
const setMissingFilter = useCallback(() => {
dispatch(setFilteredModelType('missing'));
}, [dispatch]);
const getButtonLabel = () => {
if (filteredModelType === 'missing') {
return t('modelManager.missingFiles');
}
if (filteredModelType && isModelCategoryType(filteredModelType)) {
return t(MODEL_CATEGORIES[filteredModelType].i18nKey);
}
return t('modelManager.allModels');
};
return (
<Menu placement="bottom-end">
<MenuButton as={Button} size="sm" rightIcon={<PiFunnelBold />}>
{getButtonLabel()}
</MenuButton>
<MenuList>
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
<MenuItem
onClick={setMissingFilter}
bg={filteredModelType === 'missing' ? 'base.700' : 'transparent'}
color="warning.300"
>
<Flex alignItems="center" gap={2}>
<PiWarningBold />
{t('modelManager.missingFiles')}
</Flex>
</MenuItem>
{MODEL_CATEGORIES_AS_LIST.map((data) => (
<ModelMenuItem key={data.category} data={data} />
))}
</MenuList>
</Menu>
);
});
ModelTypeFilter.displayName = 'ModelTypeFilter';
const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const filteredModelType = useAppSelector(selectFilteredModelType);
const onClick = useCallback(() => {
dispatch(setFilteredModelType(data.category));
}, [data.category, dispatch]);
return (
<MenuItem bg={filteredModelType === data.category ? 'base.700' : 'transparent'} onClick={onClick}>
{t(data.i18nKey)}
</MenuItem>
);
});
ModelMenuItem.displayName = 'ModelMenuItem';

View File

@@ -260,17 +260,7 @@ export const getDenoisingStartAndEnd = (state: RootState): { denoising_start: nu
};
}
}
case 'anima': {
// Anima uses a fixed shift=3.0 which makes the sigma schedule highly non-linear.
// Without rescaling, most of the visual 'change' is concentrated in the high denoise
// strength range (>0.8). The exponent 0.2 spreads the effective range more evenly,
// matching the approach used for FLUX and SD3.
const animaExponent = optimizedDenoisingEnabled ? 0.2 : 1;
return {
denoising_start: 1 - denoisingStrength ** animaExponent,
denoising_end: 1,
};
}
case 'anima':
case 'sd-1':
case 'sd-2':
case 'cogview4':

View File

@@ -4,13 +4,14 @@ import { useStore } from '@nanostores/react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { $isConnected, $lastProgressEvent } from 'services/events/stores';
import { $isConnected, $lastProgressEvent, $loadingModelsCount } from 'services/events/stores';
const ProgressBar = (props: ProgressProps) => {
const { t } = useTranslation();
const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useStore($isConnected);
const lastProgressEvent = useStore($lastProgressEvent);
const loadingModelsCount = useStore($loadingModelsCount);
const value = useMemo(() => {
if (!lastProgressEvent) {
return 0;
@@ -23,6 +24,10 @@ const ProgressBar = (props: ProgressProps) => {
return false;
}
if (loadingModelsCount > 0) {
return true;
}
if (!queueStatus?.queue.in_progress) {
return false;
}
@@ -40,7 +45,7 @@ const ProgressBar = (props: ProgressProps) => {
}
return false;
}, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress]);
}, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress, loadingModelsCount]);
return (
<Progress

View File

@@ -111,9 +111,13 @@ type DeleteOrphanedModelsResponse = {
errors: Record<string, string>;
};
type GetModelConfigsArg = {
order_by?: string;
direction?: string;
} | void;
const modelConfigsAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
@@ -338,8 +342,11 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['ModelInstalls'],
}),
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl() }),
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, GetModelConfigsArg>({
query: (arg) => {
const queryStr = arg ? `?${queryString.stringify(arg)}` : '';
return { url: buildModelsUrl(queryStr) };
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
if (result) {
@@ -498,5 +505,5 @@ export const {
useDeleteOrphanedModelsMutation,
} = modelsApi;
export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select();
export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(undefined);
export const selectMissingModelsQuery = modelsApi.endpoints.getMissingModels.select();

View File

@@ -22997,6 +22997,12 @@ export type components = {
*/
config_path?: string | null;
};
/**
* ModelRecordOrderBy
* @description The order in which to return model summaries.
* @enum {string}
*/
ModelRecordOrderBy: "default" | "type" | "base" | "name" | "format" | "size" | "created_at" | "updated_at" | "path";
/** ModelRelationshipBatchRequest */
ModelRelationshipBatchRequest: {
/**
@@ -31525,6 +31531,10 @@ export interface operations {
model_name?: string | null;
/** @description Exact match on the format of the model (e.g. 'diffusers') */
model_format?: components["schemas"]["ModelFormat"] | null;
/** @description The field to order by */
order_by?: components["schemas"]["ModelRecordOrderBy"];
/** @description The direction to order by */
direction?: components["schemas"]["SQLiteDirection"];
};
header?: never;
path?: never;

View File

@@ -43,7 +43,7 @@ import type { ClientToServerEvents, ServerToClientEvents } from 'services/events
import type { Socket } from 'socket.io-client';
import type { JsonObject } from 'type-fest';
import { $lastProgressEvent } from './stores';
import { $lastProgressEvent, $loadingModelsCount } from './stores';
const log = logger('events');
@@ -73,12 +73,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
socket.emit('subscribe_queue', { queue_id: 'default' });
socket.emit('subscribe_bulk_download', { bulk_download_id: 'default' });
$lastProgressEvent.set(null);
$loadingModelsCount.set(0);
});
socket.on('connect_error', (error) => {
log.debug('Connect error');
setIsConnected(false);
$lastProgressEvent.set(null);
$loadingModelsCount.set(0);
if (error && error.message) {
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
if (data === 'ERR_UNAUTHENTICATED') {
@@ -95,6 +97,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
socket.on('disconnect', () => {
log.debug('Disconnected');
$lastProgressEvent.set(null);
$loadingModelsCount.set(0);
setIsConnected(false);
});
@@ -183,6 +186,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
const message = `Model load started: ${name} (${extras.join(', ')})`;
log.debug({ data }, message);
$loadingModelsCount.set($loadingModelsCount.get() + 1);
});
socket.on('model_load_complete', (data) => {
@@ -197,6 +201,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
const message = `Model load complete: ${name} (${extras.join(', ')})`;
log.debug({ data }, message);
$loadingModelsCount.set(Math.max(0, $loadingModelsCount.get() - 1));
});
socket.on('download_started', (data) => {

View File

@@ -6,6 +6,7 @@ import type { AppSocket } from 'services/events/types';
export const $socket = atom<AppSocket | null>(null);
export const $isConnected = atom<boolean>(false);
export const $lastProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
export const $loadingModelsCount = atom<number>(0);
export const $lastProgressMessage = computed($lastProgressEvent, (val) => {
if (!val) {

View File

@@ -11,11 +11,13 @@ from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config
from invokeai.backend.model_manager.configs.main import (
@@ -364,6 +366,73 @@ def test_filter_2(store: ModelRecordServiceBase):
assert len(matches) == 1
def test_search_by_attr_sorting(store: ModelRecordServiceSQL):
config1 = Main_Diffusers_SD1_Config(
path="/tmp/config1",
name="alpha",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
hash="CONFIG1HASH",
file_size=1000,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config2 = Main_Diffusers_SD2_Config(
path="/tmp/config2",
name="beta",
base=BaseModelType.StableDiffusion2,
type=ModelType.Main,
hash="CONFIG2HASH",
file_size=2000,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config3 = VAE_Diffusers_SD1_Config(
path="/tmp/config3",
name="gamma",
base=BaseModelType.StableDiffusion1,
type=ModelType.VAE,
hash="CONFIG3HASH",
file_size=500,
source="test/source/",
source_type=ModelSourceType.Path,
repo_variant=ModelRepoVariant.Default,
)
for c in config1, config2, config3:
store.add_model(c)
# Test sorting by Name Ascending
matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Ascending)
assert len(matches) == 3
assert matches[0].name == "alpha"
assert matches[1].name == "beta"
assert matches[2].name == "gamma"
# Test sorting by Name Descending
matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Descending)
assert matches[0].name == "gamma"
assert matches[1].name == "beta"
assert matches[2].name == "alpha"
# Test sorting by Size Ascending
matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Ascending)
assert matches[0].name == "gamma" # 500
assert matches[1].name == "alpha" # 1000
assert matches[2].name == "beta" # 2000
# Test sorting by Size Descending
matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Descending)
assert matches[0].name == "beta" # 2000
assert matches[1].name == "alpha" # 1000
assert matches[2].name == "gamma" # 500
def test_model_record_changes():
# This test guards against some unexpected behaviours from pydantic's union evaluation. See #6035
changes = ModelRecordChanges.model_validate({"default_settings": {"preprocessor": "value"}})