Compare commits

..

5 Commits

Author SHA1 Message Date
Mary Hipp Rogers
649596cec5 fix 1:1 ratio (#8127)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-06-25 19:39:56 -04:00
psychedelicious
45aa84c01a feat: add user_label to FieldIdentifier (#8126)
Co-authored-by: Mary Hipp Rogers <maryhipp@gmail.com>
2025-06-25 09:48:15 -04:00
Mary Hipp Rogers
064d5787c9 Flux Kontext UI support (#8111)
* add support for flux-kontext models in nodes

* flux kontext in canvas

* add aspect ratio support

* lint

* restore aspect ratio logic

* more linting

* typegen

* fix typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-06-25 09:46:58 -04:00
psychedelicious
d81b23adff fix(nodes): ensure each invocation overrides _original_model_fields with own field data 2025-06-19 09:57:11 -04:00
psychedelicious
c72480fd1b fix: opencv dependency conflict (#8095)
* build: prevent `opencv-python` from being installed

Fixes this error: `AttributeError: module 'cv2.ximgproc' has no attribute 'thinning'`

`opencv-contrib-python` supersedes `opencv-python`, providing the same API + additional features. The two packages should not be installed at the same time to avoid conflicts and/or errors.

The `invisible-watermark` package requires `opencv-python`, but we require the contrib variant.

This change updates `pyproject.toml` to prevent `opencv-python` from ever being installed using a `uv` features called dependency overrides.

* feat(ui): data viewer supports disabling wrap

* feat(api): list _all_ pkgs in app deps endpoint

* chore(ui): typegen

* feat(ui): update about modal to display new full deps list

* chore: uv lock
2025-06-10 08:34:00 -04:00
56 changed files with 459 additions and 738 deletions

View File

@@ -99,9 +99,7 @@ async def upload_image(
raise HTTPException(status_code=400, detail="Invalid resize_to format or size")
try:
# heuristic_resize_fast expects an RGB or RGBA image
pil_rgba = pil_image.convert("RGBA")
np_image = pil_to_np(pil_rgba)
np_image = pil_to_np(pil_image)
np_image = heuristic_resize_fast(np_image, (resize_dims.width, resize_dims.height))
pil_image = np_to_pil(np_image)
except Exception:

View File

@@ -158,7 +158,7 @@ web_root_path = Path(list(web_dir.__path__)[0])
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warning(f"No UI found at {web_root_path}/dist, skipping UI mount")
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here

View File

@@ -499,7 +499,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
field.json_schema_extra.pop("ui_type")
return None
@@ -582,6 +582,8 @@ def invocation(
fields: dict[str, tuple[Any, FieldInfo]] = {}
original_model_fields: dict[str, OriginalModelField] = {}
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
@@ -589,7 +591,7 @@ def invocation(
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
@@ -613,7 +615,7 @@ def invocation(
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
uiconfig["version"] = version
else:
logger.warning(f'No version specified for node "{invocation_type}", using "1.0.0"')
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
@@ -676,6 +678,7 @@ def invocation(
docstring = cls.__doc__
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore
new_class.__doc__ = docstring
new_class._original_model_fields = original_model_fields
InvocationRegistry.register_invocation(new_class)

View File

@@ -114,13 +114,6 @@ class CompelInvocation(BaseInvocation):
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info
c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
@@ -229,10 +222,7 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info

View File

@@ -64,6 +64,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
FluxKontextModel = "FluxKontextModelField"
# endregion
# region Misc Field Types
@@ -437,7 +438,7 @@ class WithWorkflow:
workflow = None
def __init_subclass__(cls) -> None:
logger.warning(
logger.warn(
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
)
super().__init_subclass__()
@@ -578,7 +579,7 @@ def InputField(
if default_factory is not _Unset and default_factory is not None:
default = default_factory()
logger.warning('"default_factory" is not supported, calling it now to set "default"')
logger.warn('"default_factory" is not supported, calling it now to set "default"')
# These are the args we may wish pass to the pydantic `Field()` function
field_args = {

View File

@@ -24,6 +24,7 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
@@ -92,7 +93,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
@@ -175,7 +176,7 @@ class InvokeAIAppConfig(BaseSettings):
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
# DEVICE
device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$")
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION

View File

@@ -196,13 +196,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql

View File

@@ -78,7 +78,7 @@ class ImageService(ImageServiceABC):
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warning(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)

View File

@@ -148,7 +148,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
self._logger.warning(f"Cancelling job {job.id}")
self._logger.warning("Cancelling job {job.id}")
self.cancel_job(job)
while True:
try:

View File

@@ -1,4 +1,3 @@
import gc
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
@@ -440,12 +439,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
# GC-ing here can reduce peak memory usage of the invoke process by freeing allocated memory blocks.
# Most queue items take seconds to execute, so the relative cost of a GC is very small.
# Python will never cede allocated memory back to the OS, so anything we can do to reduce the peak
# allocation is well worth it.
gc.collect()
self._invoker.services.logger.info(
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
)

View File

@@ -205,6 +205,7 @@ class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
user_label: str | None = Field(description="The user label of the field, if any")
class SessionQueueItemWithoutGraph(BaseModel):

View File

@@ -104,7 +104,11 @@ class SqliteSessionQueue(SessionQueueBase):
return cast(Union[int, None], cursor.fetchone()[0]) or 0
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
cursor = self._conn.cursor()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
@@ -114,12 +118,8 @@ class SqliteSessionQueue(SessionQueueBase):
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,
priority=priority,
@@ -127,16 +127,19 @@ class SqliteSessionQueue(SessionQueueBase):
)
enqueued_count = len(values_to_insert)
with self._conn:
cursor = self._conn.cursor()
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,

View File

@@ -42,5 +42,4 @@ IP-Adapters:
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
- [InvokeAI/ip-adapter-plus_sdxl_vit-h](https://huggingface.co/InvokeAI/ip-adapter-plus_sdxl_vit-h)
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)

View File

@@ -296,7 +296,7 @@ class LoRAConfigBase(ABC, BaseModel):
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd, mod.metadata())
value = flux_format_from_state_dict(sd)
mod.cache[key] = value
return value

View File

@@ -20,10 +20,6 @@ from invokeai.backend.model_manager.taxonomy import (
ModelType,
SubModelType,
)
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
lora_model_from_flux_aitoolkit_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
@@ -96,8 +92,6 @@ class LoRALoader(ModelLoader):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:

View File

@@ -297,15 +297,6 @@ ip_adapter_sdxl = StarterModel(
dependencies=[ip_adapter_sdxl_image_encoder],
previous_names=["IP Adapter SDXL"],
)
ip_adapter_plus_sdxl = StarterModel(
name="Precise Reference (IP Adapter Plus ViT-H)",
base=BaseModelType.StableDiffusionXL,
source="https://huggingface.co/InvokeAI/ip-adapter-plus_sdxl_vit-h/resolve/main/ip-adapter-plus_sdxl_vit-h.safetensors",
description="References images with a higher degree of precision.",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sdxl_image_encoder],
previous_names=["IP Adapter Plus SDXL"],
)
ip_adapter_flux = StarterModel(
name="Standard Reference (XLabs FLUX IP-Adapter v2)",
base=BaseModelType.Flux,
@@ -681,7 +672,6 @@ STARTER_MODELS: list[StarterModel] = [
ip_adapter_plus_sd1,
ip_adapter_plus_face_sd1,
ip_adapter_sdxl,
ip_adapter_plus_sdxl,
ip_adapter_flux,
qr_code_cnet_sd1,
qr_code_cnet_sdxl,
@@ -754,7 +744,6 @@ sdxl_bundle: list[StarterModel] = [
juggernaut_sdxl,
sdxl_fp16_vae_fix,
ip_adapter_sdxl,
ip_adapter_plus_sdxl,
canny_sdxl,
depth_sdxl,
softedge_sdxl,

View File

@@ -29,6 +29,7 @@ class BaseModelType(str, Enum):
Imagen3 = "imagen3"
Imagen4 = "imagen4"
ChatGPT4o = "chatgpt-4o"
FluxKontext = "flux-kontext"
class ModelType(str, Enum):
@@ -137,7 +138,6 @@ class FluxLoRAFormat(str, Enum):
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AIToolkit = "flux.aitoolkit"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]

View File

@@ -46,10 +46,6 @@ class ModelPatcher:
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
if len(ti_list) == 0:
yield tokenizer, TextualInversionManager(tokenizer)
return
init_tokens_count = None
new_tokens_added = None

View File

@@ -1,63 +0,0 @@
import json
from dataclasses import dataclass, field
from typing import Any
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import _group_by_layer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.util import InvokeAILogger
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
if metadata:
try:
software = json.loads(metadata.get("software", "{}"))
except json.JSONDecodeError:
return False
return software.get("name") == "ai-toolkit"
# metadata got lost somewhere
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
@dataclass
class GroupedStateDict:
transformer: dict[str, Any] = field(default_factory=dict)
# might also grow CLIP and T5 submodels
def _group_state_by_submodel(state_dict: dict[str, Any]) -> GroupedStateDict:
logger = InvokeAILogger.get_logger()
grouped = GroupedStateDict()
for key, value in state_dict.items():
submodel_name, param_name = key.split(".", 1)
match submodel_name:
case "diffusion_model":
grouped.transformer[param_name] = value
case _:
logger.warning(f"Unexpected submodel name: {submodel_name}")
return grouped
def _rename_peft_lora_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Renames keys from the PEFT LoRA format to the InvokeAI format."""
renamed_state_dict = {}
for key, value in state_dict.items():
renamed_key = key.replace(".lora_A.", ".lora_down.").replace(".lora_B.", ".lora_up.")
renamed_state_dict[renamed_key] = value
return renamed_state_dict
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
state_dict = _rename_peft_lora_keys(state_dict)
by_layer = _group_by_layer(state_dict)
by_model = _group_state_by_submodel(by_layer)
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in by_model.transformer.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
return ModelPatchRaw(layers=layers)

View File

@@ -1,7 +1,4 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
@@ -14,7 +11,7 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
)
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
def flux_format_from_state_dict(state_dict):
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
@@ -23,7 +20,5 @@ def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None)
return FluxLoRAFormat.Diffusers
elif is_state_dict_likely_flux_control(state_dict):
return FluxLoRAFormat.Control
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
return FluxLoRAFormat.AIToolkit
else:
return None

View File

@@ -68,7 +68,7 @@
"cmdk": "^1.1.1",
"compare-versions": "^6.1.1",
"filesize": "^10.1.6",
"fracturedjsonjs": "^4.1.0",
"fracturedjsonjs": "^4.0.2",
"framer-motion": "^11.10.0",
"i18next": "^25.0.1",
"i18next-http-backend": "^3.0.2",

View File

@@ -54,8 +54,8 @@ dependencies:
specifier: ^10.1.6
version: 10.1.6
fracturedjsonjs:
specifier: ^4.1.0
version: 4.1.0
specifier: ^4.0.2
version: 4.0.2
framer-motion:
specifier: ^11.10.0
version: 11.10.0(react-dom@18.3.1)(react@18.3.1)
@@ -5280,8 +5280,8 @@ packages:
signal-exit: 4.1.0
dev: true
/fracturedjsonjs@4.1.0:
resolution: {integrity: sha512-qy6LPA8OOiiyRHt5/sNKDayD7h5r3uHmHxSOLbBsgtU/hkt5vOVWOR51MdfDbeCNfj7k/dKCRbXYm8FBAJcgWQ==}
/fracturedjsonjs@4.0.2:
resolution: {integrity: sha512-+vGJH9wK0EEhbbn50V2sOebLRaar1VL3EXr02kxchIwpkhQk0ItrPjIOtYPYuU9hNFpVzxjrPgzjtMJih+ae4A==}
dev: false
/framer-motion@10.18.0(react-dom@18.3.1)(react@18.3.1):

View File

@@ -1147,6 +1147,7 @@
"modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}",
"modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext",
"canvasIsFiltering": "Canvas is busy (filtering)",
"canvasIsTransforming": "Canvas is busy (transforming)",
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
@@ -1337,6 +1338,7 @@
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "Flux Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished"

View File

@@ -10,6 +10,7 @@ import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatch
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
@@ -59,6 +60,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return await buildImagen4Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
case 'flux-kontext':
return await buildFluxKontextGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}

View File

@@ -29,6 +29,7 @@ import type {
import {
initialChatGPT4oReferenceImage,
initialControlNet,
initialFluxKontextReferenceImage,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
@@ -87,6 +88,12 @@ export const selectDefaultRefImageConfig = createSelector(
return referenceImage;
}
if (selectedMainModel?.base === 'flux-kontext') {
const referenceImage = deepClone(initialFluxKontextReferenceImage);
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
return referenceImage;
}
const { data } = query;
let model: IPAdapterModelConfig | null = null;
if (data) {

View File

@@ -2,10 +2,12 @@ import { useAppSelector } from 'app/store/storeHooks';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsFluxKontext,
selectIsImagen3,
selectIsImagen4,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import { selectActiveReferenceImageEntities } from 'features/controlLayers/store/selectors';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
import type { Equals } from 'tsafe';
@@ -17,23 +19,28 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isFluxKontext = useAppSelector(selectIsFluxKontext);
const activeReferenceImageEntities = useAppSelector(selectActiveReferenceImageEntities);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
if (isFluxKontext) {
return activeReferenceImageEntities.length === 0;
}
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'inpaint_mask':
return !isImagen3 && !isImagen4 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'raster_layer':
return !isImagen3 && !isImagen4 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isChatGPT4o]);
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isFluxKontext, isChatGPT4o, activeReferenceImageEntities]);
return isEntityTypeEnabled;
};

View File

@@ -69,7 +69,13 @@ import type {
IPMethodV2,
T2IAdapterConfig,
} from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagenAspectRatioID, isRenderableEntity } from './types';
import {
getEntityIdentifier,
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isImagenAspectRatioID,
isRenderableEntity,
} from './types';
import {
converters,
getControlLayerState,
@@ -81,6 +87,7 @@ import {
initialChatGPT4oReferenceImage,
initialControlLoRA,
initialControlNet,
initialFluxKontextReferenceImage,
initialFLUXRedux,
initialIPAdapter,
initialT2IAdapter,
@@ -686,6 +693,16 @@ export const canvasSlice = createSlice({
return;
}
if (entity.ipAdapter.model.base === 'flux-kontext') {
// Switching to flux-kontext
entity.ipAdapter = {
...initialFluxKontextReferenceImage,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.type === 'flux_redux') {
// Switching to flux_redux
entity.ipAdapter = {
@@ -1322,6 +1339,31 @@ export const canvasSlice = createSlice({
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'flux-kontext' && isFluxKontextAspectRatioID(id)) {
if (id === '3:4') {
state.bbox.rect.width = 880;
state.bbox.rect.height = 1184;
} else if (id === '4:3') {
state.bbox.rect.width = 1184;
state.bbox.rect.height = 880;
} else if (id === '9:16') {
state.bbox.rect.width = 752;
state.bbox.rect.height = 1392;
} else if (id === '16:9') {
state.bbox.rect.width = 1392;
state.bbox.rect.height = 752;
} else if (id === '21:9') {
state.bbox.rect.width = 1568;
state.bbox.rect.height = 672;
} else if (id === '9:21') {
state.bbox.rect.width = 672;
state.bbox.rect.height = 1568;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;

View File

@@ -383,6 +383,7 @@ export const selectIsCogView4 = createParamsSelector((params) => params.model?.b
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectIsFluxKontext = createParamsSelector((params) => params.model?.base === 'flux-kontext');
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);

View File

@@ -258,6 +258,13 @@ const zChatGPT4oReferenceImageConfig = z.object({
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
const zCanvasEntityBase = z.object({
id: zId,
name: zName,
@@ -268,7 +275,12 @@ const zCanvasEntityBase = z.object({
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
type: z.literal('reference_image'),
// This should be named `referenceImage` but we need to keep it as `ipAdapter` for backwards compatibility
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
ipAdapter: z.discriminatedUnion('type', [
zIPAdapterConfig,
zFLUXReduxConfig,
zChatGPT4oReferenceImageConfig,
zFluxKontextReferenceImageConfig,
]),
});
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
@@ -280,6 +292,9 @@ export const isFLUXReduxConfig = (config: CanvasReferenceImageState['ipAdapter']
export const isChatGPT4oReferenceImageConfig = (
config: CanvasReferenceImageState['ipAdapter']
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
export const isFluxKontextReferenceImageConfig = (
config: CanvasReferenceImageState['ipAdapter']
): config is FluxKontextReferenceImageConfig => config.type === 'flux_kontext_reference_image';
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;
@@ -406,7 +421,7 @@ export type StagingAreaImage = {
offsetY: number;
};
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zAspectRatioID = z.enum(['Free', '21:9', '9:21', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
export const isImagenAspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
@@ -416,6 +431,10 @@ export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
zChatGPT4oAspectRatioID.safeParse(v).success;
export const zFluxKontextAspectRatioID = z.enum(['21:9', '4:3', '1:1', '3:4', '9:21', '16:9', '9:16']);
export const isFluxKontextAspectRatioID = (v: unknown): v is z.infer<typeof zFluxKontextAspectRatioID> =>
zFluxKontextAspectRatioID.safeParse(v).success;
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;

View File

@@ -10,6 +10,7 @@ import type {
ChatGPT4oReferenceImageConfig,
ControlLoRAConfig,
ControlNetConfig,
FluxKontextReferenceImageConfig,
FLUXReduxConfig,
ImageWithDims,
IPAdapterConfig,
@@ -83,6 +84,11 @@ export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
image: null,
model: null,
};
export const initialFluxKontextReferenceImage: FluxKontextReferenceImageConfig = {
type: 'flux_kontext_reference_image',
image: null,
model: null,
};
export const initialT2IAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
model: null,

View File

@@ -2,7 +2,7 @@ import type { FlexProps } from '@invoke-ai/ui-library';
import { Box, chakra, Flex, IconButton, Tooltip, useShiftModifier } from '@invoke-ai/ui-library';
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import { useClipboard } from 'common/hooks/useClipboard';
import { Formatter, TableCommaPlacement } from 'fracturedjsonjs';
import { Formatter } from 'fracturedjsonjs';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
@@ -11,8 +11,6 @@ import { useTranslation } from 'react-i18next';
import { PiCopyBold, PiDownloadSimpleBold } from 'react-icons/pi';
const formatter = new Formatter();
formatter.Options.TableCommaPlacement = TableCommaPlacement.BeforePadding;
formatter.Options.OmitTrailingWhitespace = true;
type Props = {
label: string;

View File

@@ -19,6 +19,7 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
imagen3: 'pink',
imagen4: 'pink',
'chatgpt-4o': 'pink',
'flux-kontext': 'pink',
};
const ModelBaseBadge = ({ base }: Props) => {

View File

@@ -4,6 +4,7 @@ import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocatio
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import FluxKontextModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxKontextModelFieldInputComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
@@ -50,6 +51,8 @@ import {
isFloatFieldInputTemplate,
isFloatGeneratorFieldInputInstance,
isFloatGeneratorFieldInputTemplate,
isFluxKontextModelFieldInputInstance,
isFluxKontextModelFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxReduxModelFieldInputInstance,
@@ -417,6 +420,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxKontextModelFieldInputTemplate(template)) {
if (!isFluxKontextModelFieldInputInstance(field)) {
return null;
}
return <FluxKontextModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,49 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxKontextModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
FluxKontextModelFieldInputInstance,
FluxKontextModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxKontextModels } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const FluxKontextModelFieldInputComponent = (
props: FieldComponentProps<FluxKontextModelFieldInputInstance, FluxKontextModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxKontextModels();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxKontextModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxKontextModelFieldInputComponent);

View File

@@ -22,7 +22,6 @@ import { NodeFieldElementOverlay } from 'features/nodes/components/sidePanel/bui
import { useDoesWorkflowHaveUnsavedChanges } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import {
$isInPublishFlow,
$isPublishing,
$isReadyToDoValidationRun,
$isSelectingOutputNode,
$outputNodeId,
@@ -184,14 +183,13 @@ SelectOutputNodeButton.displayName = 'SelectOutputNodeButton';
const CancelPublishButton = memo(() => {
const { t } = useTranslation();
const isPublishing = useStore($isPublishing);
const onClick = useCallback(() => {
$isInPublishFlow.set(false);
$isSelectingOutputNode.set(false);
$outputNodeId.set(null);
}, []);
return (
<Button leftIcon={<PiXBold />} onClick={onClick} isDisabled={isPublishing}>
<Button leftIcon={<PiXBold />} onClick={onClick}>
{t('common.cancel')}
</Button>
);
@@ -200,7 +198,6 @@ CancelPublishButton.displayName = 'CancelDeployButton';
const PublishWorkflowButton = memo(() => {
const { t } = useTranslation();
const isPublishing = useStore($isPublishing);
const isReadyToDoValidationRun = useStore($isReadyToDoValidationRun);
const isReadyToEnqueue = useStore($isReadyToEnqueue);
const doesWorkflowHaveUnsavedChanges = useDoesWorkflowHaveUnsavedChanges();
@@ -214,7 +211,6 @@ const PublishWorkflowButton = memo(() => {
const enqueue = useEnqueueWorkflows();
const onClick = useCallback(async () => {
$isPublishing.set(true);
const result = await withResultAsync(() => enqueue(true, true));
if (result.isErr()) {
toast({
@@ -248,30 +244,8 @@ const PublishWorkflowButton = memo(() => {
});
log.debug(parseify(result.value), 'Enqueued batch');
}
$isPublishing.set(false);
}, [enqueue, projectUrl, t]);
const isDisabled = useMemo(() => {
return (
!allowPublishWorkflows ||
!isReadyToEnqueue ||
doesWorkflowHaveUnsavedChanges ||
hasUnpublishableNodes ||
!isReadyToDoValidationRun ||
!(outputNodeId !== null && !isSelectingOutputNode) ||
isPublishing
);
}, [
allowPublishWorkflows,
doesWorkflowHaveUnsavedChanges,
hasUnpublishableNodes,
isReadyToDoValidationRun,
isReadyToEnqueue,
isSelectingOutputNode,
outputNodeId,
isPublishing,
]);
return (
<PublishTooltip
isWorkflowSaved={!doesWorkflowHaveUnsavedChanges}
@@ -281,8 +255,19 @@ const PublishWorkflowButton = memo(() => {
hasPublishableInputs={inputs.publishable.length > 0}
hasUnpublishableInputs={inputs.unpublishable.length > 0}
>
<Button leftIcon={<PiLightningFill />} isDisabled={isDisabled} onClick={onClick}>
{isPublishing ? t('workflows.builder.publishing') : t('workflows.builder.publish')}
<Button
leftIcon={<PiLightningFill />}
isDisabled={
!allowPublishWorkflows ||
!isReadyToEnqueue ||
doesWorkflowHaveUnsavedChanges ||
hasUnpublishableNodes ||
!isReadyToDoValidationRun ||
!(outputNodeId !== null && !isSelectingOutputNode)
}
onClick={onClick}
>
{t('workflows.builder.publish')}
</Button>
</PublishTooltip>
);
@@ -352,10 +337,6 @@ export const StartPublishFlowButton = memo(() => {
$isInPublishFlow.set(true);
}, []);
const isDisabled = useMemo(() => {
return !allowPublishWorkflows || !isReadyToEnqueue || doesWorkflowHaveUnsavedChanges || hasUnpublishableNodes;
}, [allowPublishWorkflows, doesWorkflowHaveUnsavedChanges, hasUnpublishableNodes, isReadyToEnqueue]);
return (
<PublishTooltip
isWorkflowSaved={!doesWorkflowHaveUnsavedChanges}
@@ -365,7 +346,15 @@ export const StartPublishFlowButton = memo(() => {
hasPublishableInputs={inputs.publishable.length > 0}
hasUnpublishableInputs={inputs.unpublishable.length > 0}
>
<Button onClick={onClick} leftIcon={<PiLightningFill />} variant="ghost" size="sm" isDisabled={isDisabled}>
<Button
onClick={onClick}
leftIcon={<PiLightningFill />}
variant="ghost"
size="sm"
isDisabled={
!allowPublishWorkflows || !isReadyToEnqueue || doesWorkflowHaveUnsavedChanges || hasUnpublishableNodes
}
>
{t('workflows.builder.publish')}
</Button>
</PublishTooltip>

View File

@@ -19,6 +19,9 @@ import { useGetBatchStatusQuery } from 'services/api/endpoints/queue';
import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
import { assert } from 'tsafe';
type FieldIdentiferWithLabel = FieldIdentifier & { label: string | null };
type FieldIdentiferWithLabelAndType = FieldIdentiferWithLabel & { type: string };
export const $isPublishing = atom(false);
export const $isInPublishFlow = atom(false);
export const $outputNodeId = atom<string | null>(null);
@@ -54,21 +57,26 @@ export const selectFieldIdentifiersWithInvocationTypes = createSelector(
selectWorkflowFormNodeFieldFieldIdentifiersDeduped,
selectNodesSlice,
(fieldIdentifiers, nodes) => {
const result: { nodeId: string; fieldName: string; type: string }[] = [];
const result: FieldIdentiferWithLabelAndType[] = [];
for (const fieldIdentifier of fieldIdentifiers) {
const node = nodes.nodes.find((node) => node.id === fieldIdentifier.nodeId);
assert(isInvocationNode(node), `Node ${fieldIdentifier.nodeId} not found`);
result.push({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName, type: node.data.type });
result.push({
nodeId: fieldIdentifier.nodeId,
fieldName: fieldIdentifier.fieldName,
type: node.data.type,
label: node.data.inputs[fieldIdentifier.fieldName]?.label ?? null,
});
}
return result;
}
);
export const getPublishInputs = (fieldIdentifiers: (FieldIdentifier & { type: string })[], templates: Templates) => {
export const getPublishInputs = (fieldIdentifiers: FieldIdentiferWithLabelAndType[], templates: Templates) => {
// Certain field types are not allowed to be input fields on a published workflow
const publishable: FieldIdentifier[] = [];
const unpublishable: FieldIdentifier[] = [];
const publishable: FieldIdentiferWithLabel[] = [];
const unpublishable: FieldIdentiferWithLabel[] = [];
for (const fieldIdentifier of fieldIdentifiers) {
const fieldTemplate = templates[fieldIdentifier.type]?.inputs[fieldIdentifier.fieldName];
if (!fieldTemplate) {
@@ -122,11 +130,13 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'metadata_to_controlnets',
'metadata_to_ip_adapters',
'metadata_to_t2i_adapters',
'google_imagen3_generate',
'google_imagen3_edit',
'google_imagen4_generate',
'chatgpt_create_image',
'chatgpt_edit_image',
'google_imagen3_generate_image',
'google_imagen3_edit_image',
'google_imagen4_generate_image',
'chatgpt_4o_generate_image',
'chatgpt_4o_edit_image',
'flux_kontext_generate_image',
'flux_kontext_edit_image',
];
export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => {

View File

@@ -34,6 +34,7 @@ import type {
FieldValue,
FloatFieldValue,
FloatGeneratorFieldValue,
FluxKontextModelFieldValue,
FluxReduxModelFieldValue,
FluxVAEModelFieldValue,
ImageFieldCollectionValue,
@@ -75,6 +76,7 @@ import {
zFloatFieldCollectionValue,
zFloatFieldValue,
zFloatGeneratorFieldValue,
zFluxKontextModelFieldValue,
zFluxReduxModelFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
@@ -527,6 +529,9 @@ export const nodesSlice = createSlice({
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
},
fieldFluxKontextModelValueChanged: (state, action: FieldValueAction<FluxKontextModelFieldValue>) => {
fieldValueReducer(state, action, zFluxKontextModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@@ -697,6 +702,7 @@ export const {
fieldImagen3ModelValueChanged,
fieldImagen4ModelValueChanged,
fieldChatGPT4oModelValueChanged,
fieldFluxKontextModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
fieldStringGeneratorValueChanged,

View File

@@ -78,6 +78,7 @@ const zBaseModel = z.enum([
'imagen3',
'imagen4',
'chatgpt-4o',
'flux-kontext',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum([
@@ -90,6 +91,7 @@ export const zMainModelBase = z.enum([
'imagen3',
'imagen4',
'chatgpt-4o',
'flux-kontext',
]);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;

View File

@@ -260,6 +260,10 @@ const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
name: z.literal('ChatGPT4oModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxKontextModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxKontextModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@@ -313,6 +317,7 @@ const zStatefulFieldType = z.union([
zImagen3ModelFieldType,
zImagen4ModelFieldType,
zChatGPT4oModelFieldType,
zFluxKontextModelFieldType,
zColorFieldType,
zSchedulerFieldType,
zFloatGeneratorFieldType,
@@ -354,6 +359,7 @@ const modelFieldTypeNames = [
zImagen3ModelFieldType.shape.name.value,
zImagen4ModelFieldType.shape.name.value,
zChatGPT4oModelFieldType.shape.name.value,
zFluxKontextModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
'VAEField',
@@ -1231,6 +1237,24 @@ export const isImagen4ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen4ModelFieldInputTemplate>('Imagen4ModelField');
// #endregion
// #region FluxKontextModelField
export const zFluxKontextModelFieldValue = zModelIdentifierField.optional();
const zFluxKontextModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxKontextModelFieldValue,
});
const zFluxKontextModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxKontextModelFieldType,
originalType: zFieldType.optional(),
default: zFluxKontextModelFieldValue,
});
export type FluxKontextModelFieldValue = z.infer<typeof zFluxKontextModelFieldValue>;
export type FluxKontextModelFieldInputInstance = z.infer<typeof zFluxKontextModelFieldInputInstance>;
export type FluxKontextModelFieldInputTemplate = z.infer<typeof zFluxKontextModelFieldInputTemplate>;
export const isFluxKontextModelFieldInputInstance = buildInstanceTypeGuard(zFluxKontextModelFieldInputInstance);
export const isFluxKontextModelFieldInputTemplate =
buildTemplateTypeGuard<FluxKontextModelFieldInputTemplate>('FluxKontextModelField');
// #endregion
// #region ChatGPT4oModelField
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1882,6 +1906,7 @@ export const zStatefulFieldValue = z.union([
zFluxReduxModelFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zFluxKontextModelFieldValue,
zChatGPT4oModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
@@ -1976,6 +2001,7 @@ const zStatefulFieldInputTemplate = z.union([
zImagen3ModelFieldInputTemplate,
zImagen4ModelFieldInputTemplate,
zChatGPT4oModelFieldInputTemplate,
zFluxKontextModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@@ -0,0 +1,92 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isFluxKontextReferenceImageConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { ImageField } from 'features/nodes/types/common';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildFluxKontextGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.fluxKontextIncompatibleGenerationMode'));
}
log.debug({ generationMode }, 'Building Flux Kontext graph');
const model = selectMainModelConfig(state);
const canvas = selectCanvasSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const { bbox } = canvas;
const { positivePrompt } = selectPresetModifiedPrompts(state);
assert(model, 'No model found in state');
assert(model.base === 'flux-kontext', 'Model is not a Flux Kontext model');
const is_intermediate = canvasSettings.sendToCanvas;
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
const validRefImages = canvas.referenceImages.entities
.filter((entity) => entity.isEnabled)
.filter((entity) => isFluxKontextReferenceImageConfig(entity.ipAdapter))
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
let input_image: ImageField | undefined = undefined;
if (validRefImages[0]) {
assert(validRefImages.length === 1, 'Flux Kontext can have at most one reference image');
assert(validRefImages[0].ipAdapter.image, 'Image is required for reference image');
input_image = {
image_name: validRefImages[0].ipAdapter.image.image_name,
};
}
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('flux_kontext_txt2img_graph'));
const fluxKontextImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: input_image ? 'flux_kontext_edit_image' : 'flux_kontext_generate_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
use_cache: false,
is_intermediate,
board,
input_image,
prompt_upsampling: true,
});
g.upsertMetadata({
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
width: bbox.rect.width,
height: bbox.rect.height,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: fluxKontextImage.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Flux Kontext');
};

View File

@@ -36,6 +36,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
Imagen3ModelField: undefined,
Imagen4ModelField: undefined,
ChatGPT4oModelField: undefined,
FluxKontextModelField: undefined,
FloatGeneratorField: undefined,
IntegerGeneratorField: undefined,
StringGeneratorField: undefined,

View File

@@ -16,6 +16,7 @@ import type {
FloatFieldCollectionInputTemplate,
FloatFieldInputTemplate,
FloatGeneratorFieldInputTemplate,
FluxKontextModelFieldInputTemplate,
FluxMainModelFieldInputTemplate,
FluxReduxModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
@@ -613,6 +614,20 @@ const buildImagen4ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen4Mode
};
return template;
};
const buildFluxKontextModelFieldInputTemplate: FieldInputTemplateBuilder<FluxKontextModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxKontextModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildChatGPT4oModelFieldInputTemplate: FieldInputTemplateBuilder<ChatGPT4oModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -835,6 +850,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
Imagen3ModelField: buildImagen3ModelFieldInputTemplate,
Imagen4ModelField: buildImagen4ModelFieldInputTemplate,
ChatGPT4oModelField: buildChatGPT4oModelFieldInputTemplate,
FluxKontextModelField: buildFluxKontextModelFieldInputTemplate,
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
StringGeneratorField: buildStringGeneratorFieldInputTemplate,

View File

@@ -3,12 +3,18 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectIsChatGTP4o, selectIsImagen3, selectIsImagen4 } from 'features/controlLayers/store/paramsSlice';
import {
selectIsChatGTP4o,
selectIsFluxKontext,
selectIsImagen3,
selectIsImagen4,
} from 'features/controlLayers/store/paramsSlice';
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
import {
isAspectRatioID,
zAspectRatioID,
zChatGPT4oAspectRatioID,
zFluxKontextAspectRatioID,
zImagen3AspectRatioID,
} from 'features/controlLayers/store/types';
import type { ChangeEventHandler } from 'react';
@@ -24,6 +30,7 @@ export const BboxAspectRatioSelect = memo(() => {
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isImagen4 = useAppSelector(selectIsImagen4);
const isFluxKontext = useAppSelector(selectIsFluxKontext);
const options = useMemo(() => {
// Imagen3 and ChatGPT4o have different aspect ratio options, and do not support freeform sizes
if (isImagen3 || isImagen4) {
@@ -32,9 +39,12 @@ export const BboxAspectRatioSelect = memo(() => {
if (isChatGPT4o) {
return zChatGPT4oAspectRatioID.options;
}
if (isFluxKontext) {
return zFluxKontextAspectRatioID.options;
}
// All other models
return zAspectRatioID.options;
}, [isImagen3, isChatGPT4o, isImagen4]);
}, [isImagen3, isChatGPT4o, isImagen4, isFluxKontext]);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {

View File

@@ -1,6 +1,7 @@
import type { AspectRatioID } from 'features/controlLayers/store/types';
export const ASPECT_RATIO_MAP: Record<Exclude<AspectRatioID, 'Free'>, { ratio: number; inverseID: AspectRatioID }> = {
'21:9': { ratio: 21 / 9, inverseID: '9:21' },
'16:9': { ratio: 16 / 9, inverseID: '9:16' },
'3:2': { ratio: 3 / 2, inverseID: '2:3' },
'4:3': { ratio: 4 / 3, inverseID: '4:3' },
@@ -8,4 +9,5 @@ export const ASPECT_RATIO_MAP: Record<Exclude<AspectRatioID, 'Free'>, { ratio: n
'3:4': { ratio: 3 / 4, inverseID: '4:3' },
'2:3': { ratio: 2 / 3, inverseID: '3:2' },
'9:16': { ratio: 9 / 16, inverseID: '16:9' },
'9:21': { ratio: 9 / 21, inverseID: '21:9' },
};

View File

@@ -1,10 +1,16 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsChatGTP4o, selectIsImagen3, selectIsImagen4 } from 'features/controlLayers/store/paramsSlice';
import {
selectIsChatGTP4o,
selectIsFluxKontext,
selectIsImagen3,
selectIsImagen4,
} from 'features/controlLayers/store/paramsSlice';
export const useIsApiModel = () => {
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isFluxKontext = useAppSelector(selectIsFluxKontext);
return isImagen3 || isImagen4 || isChatGPT4o;
return isImagen3 || isImagen4 || isChatGPT4o || isFluxKontext;
};

View File

@@ -16,6 +16,7 @@ export const MODEL_TYPE_MAP: Record<BaseModelType, string> = {
imagen3: 'Imagen3',
imagen4: 'Imagen4',
'chatgpt-4o': 'ChatGPT 4o',
'flux-kontext': 'Flux Kontext',
};
/**
@@ -33,6 +34,7 @@ export const MODEL_TYPE_SHORT_MAP: Record<BaseModelType, string> = {
imagen3: 'Imagen3',
imagen4: 'Imagen4',
'chatgpt-4o': 'ChatGPT 4o',
'flux-kontext': 'Flux Kontext',
};
/**
@@ -83,6 +85,10 @@ export const CLIP_SKIP_MAP: Record<BaseModelType, { maxClip: number; markers: nu
maxClip: 0,
markers: [],
},
'flux-kontext': {
maxClip: 0,
markers: [],
},
};
/**
@@ -124,4 +130,4 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
/**
* List of base models that make API requests
*/
export const API_BASE_MODELS = ['imagen3', 'imagen4', 'chatgpt-4o'];
export const API_BASE_MODELS = ['imagen3', 'imagen4', 'chatgpt-4o', 'flux-kontext'];

View File

@@ -21,6 +21,7 @@ export const getOptimalDimension = (base?: BaseModelType | null): number => {
case 'imagen3':
case 'imagen4':
case 'chatgpt-4o':
case 'flux-kontext':
default:
return 1024;
}
@@ -81,6 +82,7 @@ export const getGridSize = (base?: BaseModelType | null): number => {
case 'sdxl':
case 'imagen3':
case 'chatgpt-4o':
case 'flux-kontext':
default:
return 8;
}

View File

@@ -14,7 +14,7 @@ import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildW
import { groupBy } from 'lodash-es';
import { useCallback } from 'react';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Batch, EnqueueBatchArg } from 'services/api/types';
import type { Batch, EnqueueBatchArg, S } from 'services/api/types';
import { assert } from 'tsafe';
const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
@@ -106,12 +106,13 @@ export const useEnqueueWorkflows = () => {
// Derive the input fields from the builder's selected node field elements
const fieldIdentifiers = selectFieldIdentifiersWithInvocationTypes(state);
const inputs = getPublishInputs(fieldIdentifiers, templates);
const api_input_fields = inputs.publishable.map(({ nodeId, fieldName }) => {
const api_input_fields = inputs.publishable.map(({ nodeId, fieldName, label }) => {
return {
kind: 'input',
node_id: nodeId,
field_name: fieldName,
} as const;
user_label: label,
} satisfies S['FieldIdentifier'];
});
// Derive the output fields from the builder's selected output node
@@ -126,7 +127,8 @@ export const useEnqueueWorkflows = () => {
kind: 'output',
node_id: outputNodeId,
field_name: fieldName,
} as const;
user_label: null,
} satisfies S['FieldIdentifier'];
});
assert(nodesState.id, 'Workflow without ID cannot be used for API validation run');

View File

@@ -516,6 +516,17 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
}
});
const enabledGlobalReferenceLayers = canvas.referenceImages.entities.filter(
(referenceImage) => referenceImage.isEnabled
);
// Flux Kontext only supports 1x Reference Image at a time.
const referenceImageCount = enabledGlobalReferenceLayers.length;
if (model?.base === 'flux-kontext' && referenceImageCount > 1) {
reasons.push({ content: i18n.t('parameters.invoke.fluxKontextMultipleReferenceImages') });
}
canvas.referenceImages.entities
.filter((entity) => entity.isEnabled)
.forEach((entity, i) => {

View File

@@ -16,6 +16,7 @@ import {
isControlLayerModelConfig,
isControlLoRAModelConfig,
isControlNetModelConfig,
isFluxKontextModelConfig,
isFluxMainModelModelConfig,
isFluxReduxModelConfig,
isFluxVAEModelConfig,
@@ -85,7 +86,11 @@ export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
export const useSigLipModels = buildModelsHook(isSigLipModelConfig);
export const useFluxReduxModels = buildModelsHook(isFluxReduxModelConfig);
export const useGlobalReferenceImageModels = buildModelsHook(
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config) || isChatGPT4oModelConfig(config)
(config) =>
isIPAdapterModelConfig(config) ||
isFluxReduxModelConfig(config) ||
isChatGPT4oModelConfig(config) ||
isFluxKontextModelConfig(config)
);
export const useRegionalReferenceImageModels = buildModelsHook(
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config)
@@ -94,6 +99,7 @@ export const useLLaVAModels = buildModelsHook(isLLaVAModelConfig);
export const useImagen3Models = buildModelsHook(isImagen3ModelConfig);
export const useImagen4Models = buildModelsHook(isImagen4ModelConfig);
export const useChatGPT4oModels = buildModelsHook(isChatGPT4oModelConfig);
export const useFluxKontextModels = buildModelsHook(isFluxKontextModelConfig);
// const buildModelsSelector =
// <T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T): Selector<RootState, T[]> =>

View File

@@ -2075,7 +2075,7 @@ export type components = {
* @description Base model type.
* @enum {string}
*/
BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "imagen4" | "chatgpt-4o";
BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "imagen4" | "chatgpt-4o" | "flux-kontext";
/** Batch */
Batch: {
/**
@@ -6996,6 +6996,11 @@ export type components = {
* @description The name of the field
*/
field_name: string;
/**
* User Label
* @description The user label of the field, if any
*/
user_label: string | null;
};
/**
* FieldKind
@@ -11991,7 +11996,7 @@ export type components = {
* vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
* lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
* pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
* device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)
* device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
* precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
* sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
* attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
@@ -12266,10 +12271,11 @@ export type components = {
pytorch_cuda_alloc_conf?: string | null;
/**
* Device
* @description Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)
* @description Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
* @default auto
* @enum {string}
*/
device?: string;
device?: "auto" | "cpu" | "cuda" | "cuda:1" | "mps";
/**
* Precision
* @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
@@ -21118,7 +21124,7 @@ export type components = {
* used, and the type will be ignored. They are included here for backwards compatibility.
* @enum {string}
*/
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "Imagen3ModelField" | "Imagen4ModelField" | "ChatGPT4oModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "Imagen3ModelField" | "Imagen4ModelField" | "ChatGPT4oModelField" | "FluxKontextModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
/** UNetField */
UNetField: {
/** @description Info to load unet submodel */

View File

@@ -240,6 +240,10 @@ export const isImagen4ModelConfig = (config: AnyModelConfig): config is ApiModel
return config.type === 'main' && config.base === 'imagen4';
};
export const isFluxKontextModelConfig = (config: AnyModelConfig): config is ApiModelConfig => {
return config.type === 'main' && config.base === 'flux-kontext';
};
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner';
};

View File

@@ -1 +1 @@
__version__ = "5.15.0"
__version__ = "5.14.0"

View File

@@ -1,458 +0,0 @@
state_dict_keys = {
"diffusion_model.double_blocks.0.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.0.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.0.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.0.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.0.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.0.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.0.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.0.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.0.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.0.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.1.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.1.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.1.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.1.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.1.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.1.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.1.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.1.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.1.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.1.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.10.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.10.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.10.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.10.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.10.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.10.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.10.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.10.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.10.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.10.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.11.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.11.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.11.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.11.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.11.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.11.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.11.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.11.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.11.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.11.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.12.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.12.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.12.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.12.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.12.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.12.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.12.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.12.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.12.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.12.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.13.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.13.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.13.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.13.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.13.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.13.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.13.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.13.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.13.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.13.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.14.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.14.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.14.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.14.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.14.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.14.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.14.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.14.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.14.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.14.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.15.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.15.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.15.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.15.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.15.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.15.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.15.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.15.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.15.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.15.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.16.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.16.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.16.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.16.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.16.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.16.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.16.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.16.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.16.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.16.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.17.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.17.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.17.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.17.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.17.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.17.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.17.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.17.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.17.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.17.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.18.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.18.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.18.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.18.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.18.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.18.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.18.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.18.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.18.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.18.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.2.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.2.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.2.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.2.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.2.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.2.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.2.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.2.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.2.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.2.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.3.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.3.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.3.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.3.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.3.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.3.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.3.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.3.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.3.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.3.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.4.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.4.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.4.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.4.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.4.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.4.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.4.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.4.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.4.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.4.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.5.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.5.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.5.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.5.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.5.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.5.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.5.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.5.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.5.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.5.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.6.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.6.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.6.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.6.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.6.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.6.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.6.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.6.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.6.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.6.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.7.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.7.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.7.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.7.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.7.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.7.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.7.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.7.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.7.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.7.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.8.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.8.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.8.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.8.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.8.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.8.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.8.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.8.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.8.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.8.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.9.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.9.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.9.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.9.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.9.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.9.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.9.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.9.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.9.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.9.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.0.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.0.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.0.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.0.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.1.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.1.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.1.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.1.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.10.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.10.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.10.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.10.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.11.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.11.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.11.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.11.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.12.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.12.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.12.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.12.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.13.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.13.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.13.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.13.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.14.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.14.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.14.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.14.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.15.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.15.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.15.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.15.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.16.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.16.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.16.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.16.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.17.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.17.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.17.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.17.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.18.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.18.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.18.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.18.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.19.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.19.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.19.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.19.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.2.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.2.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.2.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.2.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.20.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.20.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.20.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.20.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.21.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.21.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.21.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.21.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.22.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.22.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.22.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.22.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.23.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.23.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.23.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.23.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.24.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.24.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.24.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.24.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.25.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.25.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.25.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.25.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.26.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.26.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.26.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.26.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.27.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.27.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.27.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.27.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.28.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.28.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.28.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.28.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.29.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.29.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.29.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.29.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.3.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.3.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.3.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.3.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.30.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.30.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.30.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.30.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.31.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.31.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.31.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.31.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.32.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.32.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.32.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.32.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.33.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.33.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.33.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.33.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.34.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.34.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.34.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.34.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.35.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.35.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.35.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.35.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.36.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.36.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.36.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.36.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.37.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.37.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.37.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.37.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.4.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.4.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.4.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.4.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.5.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.5.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.5.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.5.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.6.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.6.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.6.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.6.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.7.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.7.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.7.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.7.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.8.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.8.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.8.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.8.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.9.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.9.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.9.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.9.linear2.lora_B.weight": [3072, 16],
}

View File

@@ -1,59 +0,0 @@
import accelerate
import pytest
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
_group_state_by_submodel,
is_state_dict_likely_in_flux_aitoolkit_format,
lora_model_from_flux_aitoolkit_state_dict,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
state_dict_keys as flux_onetrainer_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_aitoolkit_format import (
state_dict_keys as flux_aitoolkit_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_aitoolkit_format():
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
assert is_state_dict_likely_in_flux_aitoolkit_format(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]):
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_aitoolkit_format(state_dict)
def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
converted_state_dict = _group_state_by_submodel(state_dict).transformer
# Extract the prefixes from the converted state dict (without the lora suffixes)
converted_key_prefixes: list[str] = []
for k in converted_state_dict.keys():
k = k.replace(".lora_A.weight", "")
k = k.replace(".lora_B.weight", "")
converted_key_prefixes.append(k)
# Initialize a FLUX model on the meta device.
with accelerate.init_empty_weights():
model = Flux(params["flux-schnell"])
model_keys = set(model.state_dict().keys())
for converted_key_prefix in converted_key_prefixes:
assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), (
f"'{converted_key_prefix}' did not match any model keys."
)
def test_lora_model_from_flux_aitoolkit_state_dict():
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
assert lora_model_from_flux_aitoolkit_state_dict(state_dict)

View File

@@ -10,7 +10,7 @@ import torch
from invokeai.app.services.config import get_config
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
devices = ["cpu", "cuda:0", "cuda:1", "cuda:2", "mps"]
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]