Compare commits

..

15 Commits

Author SHA1 Message Date
Ryan Dick
3ed6e65a6e Enable LoRAPatcher.apply_smart_lora_patches(...) throughout the stack. 2024-12-12 22:41:50 +00:00
Ryan Dick
52c9646f84 (minor) Rename num_layers -> num_loras in unit tests. 2024-12-12 22:41:50 +00:00
Ryan Dick
7662f0522b Add test_apply_smart_lora_patches_to_partially_loaded_model(...). 2024-12-12 22:41:50 +00:00
Ryan Dick
e50fe69839 Add LoRAPatcher.smart_apply_lora_patches() 2024-12-12 22:41:50 +00:00
Ryan Dick
5a9f884620 Refactor LoRAPatcher slightly in preparation for a 'smart' patcher. 2024-12-12 22:41:46 +00:00
Ryan Dick
edc72d1739 Fix LoRAPatcher.apply_lora_wrapper_patches(...) 2024-12-12 22:33:07 +00:00
Ryan Dick
23f521dc7c Finish consolidating LoRA sidecar wrapper implementations. 2024-12-12 22:33:07 +00:00
Ryan Dick
3d6b93efdd Begin to consolidate the LoRA sidecar and LoRA layer wrapper implementations. 2024-12-12 22:33:07 +00:00
Ryan Dick
3f28d3afad Fix bias handling in LoRAModuleWrapper and add unit test that checks that all LoRA patching methods produce the same outputs. 2024-12-12 22:33:07 +00:00
Ryan Dick
9353bfbdd6 Add LoRA wrapper patching to LoRAPatcher. 2024-12-12 22:33:07 +00:00
Ryan Dick
93f2bc6118 Add LoRA wrapper layer. 2024-12-12 22:33:07 +00:00
Ryan Dick
9019026d6d Fixes to get FLUX Control LoRA working. 2024-12-12 00:19:39 +00:00
Brandon Rising
c195b326ec Lots of updates centered around using the lora patcher rather than changing the modules in the transformer model 2024-12-11 14:14:50 -05:00
Brandon Rising
2f460d2a45 Support bnb quantized nf4 flux models, Use controlnet vae, only support 1 structural lora per transformer. various other refractors and bugfixes 2024-12-10 03:26:29 -05:00
Brandon Rising
4473cba512 Initial setup for flux tools control loras 2024-12-09 16:01:29 -05:00
133 changed files with 2090 additions and 2352 deletions

View File

@@ -59,32 +59,11 @@ logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
# the correct port when the server starts in the lifespan handler.
port = app_config.port
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
proto = "https" if app_config.ssl_certfile else "http"
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
# Logging this way ignores the logger's log level and _always_ logs the message
record = logger.makeRecord(
name=logger.name,
level=logging.INFO,
fn="",
lno=0,
msg=msg,
args=(),
exc_info=None,
)
logger.handle(record)
yield
# Shut down threads
ApiDependencies.shutdown()
@@ -227,7 +206,6 @@ def invoke_api() -> None:
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
global port
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
@@ -239,17 +217,18 @@ def invoke_api() -> None:
host=app_config.host,
port=port,
loop="asyncio",
log_level=app_config.log_level_network,
log_level=app_config.log_level,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
uvicorn_logger.handlers.clear()
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)
for logname in ["uvicorn.access", "uvicorn"]:
log = InvokeAILogger.get_logger(logname)
log.handlers.clear()
for ch in logger.handlers:
log.addHandler(ch)
loop.run_until_complete(server.serve())

View File

@@ -19,9 +19,9 @@ from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
@@ -66,10 +66,10 @@ class CompelInvocation(BaseInvocation):
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
@@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LayerPatcher.apply_model_patches(
LoRAPatcher.apply_smart_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -162,11 +163,11 @@ class SDXLPromptInvocationBase:
c_pooled = None
return c, c_pooled
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, ModelPatchRaw)
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
del lora_info
return
@@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LayerPatcher.apply_model_patches(
LoRAPatcher.apply_smart_lora_patches(
text_encoder,
patches=_lora_loader(),
prefix=lora_prefix,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.

View File

@@ -6,6 +6,7 @@ from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import VAEField
@@ -28,7 +29,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32, ui_order=4)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=4,
)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":

View File

@@ -7,6 +7,7 @@ from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@@ -75,7 +76,11 @@ class CreateGradientMaskInvocation(BaseInvocation):
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32, ui_order=9)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:

View File

@@ -37,10 +37,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
@@ -987,10 +987,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
@@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
LayerPatcher.apply_model_patches(
LoRAPatcher.apply_smart_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
dtype=unet.dtype,
cached_weights=cached_weights,
),
):

View File

@@ -56,7 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
ControlLoRAModel = "ControlLoRAModelField"
StructuralLoRAModel = "StructuralLoRAModelField"
# endregion
# region Misc Field Types
@@ -144,7 +144,7 @@ class FieldDescriptions:
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
control_lora_model = "Control LoRA model to load"
structural_lora_model = "Structural LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"

View File

@@ -1,49 +0,0 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_control_lora_loader_output")
class FluxControlLoRALoaderOutput(BaseInvocationOutput):
"""Flux Control LoRA Loader Output"""
control_lora: ControlLoRAField = OutputField(
title="Flux Control LoRA", description="Control LoRAs to apply on model loading", default=None
)
@invocation(
"flux_control_lora_loader",
title="Flux Control LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxControlLoRALoaderInvocation(BaseInvocation):
"""LoRA model and Image to use with FLUX transformer generation."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
)
image: ImageField = InputField(description="The image to encode.")
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
if not context.models.exists(self.lora.key):
raise ValueError(f"Unknown lora: {self.lora.key}!")
return FluxControlLoRALoaderOutput(
control_lora=ControlLoRAField(
lora=self.lora,
img=self.image,
weight=self.weight,
)
)

View File

@@ -1,15 +1,15 @@
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple, Union
import einops
import numpy as np
import numpy.typing as npt
import torch
import torchvision.transforms as tv_transforms
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
@@ -23,9 +23,8 @@ from invokeai.app.invocations.fields import (
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -46,11 +45,13 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -92,9 +93,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
control_lora: Optional[ControlLoRAField] = InputField(
description=FieldDescriptions.control_lora_model, input=Input.Connection, title="Control LoRA", default=None
)
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
@@ -200,7 +198,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
timesteps = get_schedule(
@@ -240,12 +238,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if len(timesteps) <= 1:
return x
if is_schnell and self.control_lora:
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
img_cond = self._prep_structural_control_img_cond(context)
inpaint_mask = self._prep_inpaint_mask(context, x)
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
@@ -253,7 +245,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
img_cond = pack(img_cond) if img_cond is not None else None
noise = pack(noise)
x = pack(x)
@@ -297,6 +288,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
device=x.device,
)
img_cond = None
if struct_lora := self.transformer.structural_lora:
# What should we do when we have multiple of these?
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a strutural lora")
ae_info = context.models.load(self.controlnet_vae.vae)
img = context.images.get_pil(struct_lora.img.image_name)
with ae_info as ae:
assert isinstance(ae, AutoEncoder)
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
@@ -309,10 +310,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_model_patches(
LoRAPatcher.apply_smart_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
)
)
@@ -324,7 +326,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LayerPatcher.apply_model_sidecar_patches(
LoRAPatcher.apply_lora_wrapper_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
@@ -358,7 +360,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
controlnet_extensions=controlnet_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
img_cond=img_cond
)
x = unpack(x.float(), self.height, self.width)
@@ -589,29 +591,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return controlnet_extensions
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
if self.control_lora is None:
return None
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a FLUX Control LoRA.")
# Load the conditioning image and resize it to the target image size.
cond_img = context.images.get_pil(self.control_lora.img.image_name)
cond_img = cond_img.convert("RGB")
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
cond_img = np.array(cond_img)
# Normalize the conditioning image to the range [-1, 1].
# This normalization is based on the original implementations here:
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L34
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L60
img_cond = torch.from_numpy(cond_img).float() / 127.5 - 1.0
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
vae_info = context.models.load(self.controlnet_vae.vae)
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
if self.ip_adapter is None:
return []
@@ -718,15 +697,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
if self.control_lora:
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
# applied last.
loras.append(self.control_lora)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
if self.transformer.structural_lora:
loras.append(self.transformer.structural_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -81,8 +81,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],

View File

@@ -0,0 +1,70 @@
from typing import Optional, Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_structural_lora_loader_output")
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
"""Flux Structural LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_structural_lora_loader",
title="Flux Structural LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
image: ImageField = InputField(
description="The image to encode.",
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
# Check for existing LoRAs with the same key.
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
output = FluxStructuralLoRALoaderOutput()
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.structural_lora = StructuralLoRAField(
lora=self.lora,
img=self.image,
weight=self.weight,
)
return output

View File

@@ -17,11 +17,12 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -111,10 +112,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_model_patches(
LoRAPatcher.apply_smart_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
)
)
@@ -130,9 +132,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -13,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@@ -49,7 +49,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(

View File

@@ -12,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -51,7 +51,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Optional
from typing import List, Optional, Literal
from pydantic import BaseModel, Field
@@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
@@ -74,15 +74,13 @@ class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class ControlLoRAField(LoRAField):
class StructuralLoRAField(LoRAField):
img: ImageField = Field(description="Image to use in structural conditioning")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):

View File

@@ -16,11 +16,12 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@@ -150,10 +151,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_model_patches(
LoRAPatcher.apply_smart_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
)
)
@@ -193,9 +195,9 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[ModelPatchRaw, float]]:
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@@ -194,10 +194,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
with (
ExitStack() as exit_stack,
unet_info as unet,
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
LoRAPatcher.apply_smart_lora_patches(
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@@ -97,7 +97,6 @@ class InvokeAIAppConfig(BaseSettings):
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
log_sql: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
log_level_network: Log level for network-related messages. 'info' and 'debug' are very verbose.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
use_memory_db: Use in-memory database. Useful for development.
dev_reload: Automatically reload when Python sources are changed. Does not reload node definitions.
profile_graphs: Enable graph profiling using `cProfile`.
@@ -164,7 +163,6 @@ class InvokeAIAppConfig(BaseSettings):
log_format: LOG_FORMAT = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.')
log_level: LOG_LEVEL = Field(default="info", description="Emit logging messages at this level or higher.")
log_sql: bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.")
log_level_network: LOG_LEVEL = Field(default='warning', description="Log level for network-related messages. 'info' and 'debug' are very verbose.")
# Development
use_memory_db: bool = Field(default=False, description="Use in-memory database. Useful for development.")

View File

@@ -438,10 +438,9 @@ class ModelInstallService(ModelInstallServiceBase):
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
source_stripped = source.strip('"')
if Path(source_stripped).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source_stripped))
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),

View File

@@ -439,9 +439,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
self._invoker.services.logger.info(
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# Run the graph

View File

@@ -31,7 +31,7 @@ def denoise(
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
# extra img tokens
img_cond: torch.Tensor | None,
img_cond: torch.Tensor | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1

View File

@@ -0,0 +1,27 @@
import torch
import numpy as np
from PIL import Image
from einops import rearrange
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def prepare_control(
height: int,
width: int,
seed: int,
ae: AutoEncoder,
cond_image: Image.Image,
) -> torch.Tensor:
# load and encode the conditioning image
img_cond = cond_image.convert("RGB")
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float()
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
ae_dtype = next(iter(ae.parameters())).dtype
ae_device = next(iter(ae.parameters())).device
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
generator = torch.Generator(device=ae_device).manual_seed(seed)
img_cond = ae.encode(img_cond, sample=True, generator=generator)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return img_cond

View File

@@ -1,10 +1,10 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor, nn
from typing import Optional
from invokeai.backend.flux.custom_block_processor import (
CustomDoubleStreamBlockProcessor,

View File

@@ -0,0 +1,50 @@
import os
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
class DepthImageEncoder:
depth_model_name = "LiheYoung/depth-anything-large-hf"
def __init__(self, device):
self.device = device
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
def __call__(self, img: torch.Tensor) -> torch.Tensor:
hw = img.shape[-2:]
img = torch.clamp(img, -1.0, 1.0)
img_byte = ((img + 1.0) * 127.5).byte()
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
depth = self.depth_model(img.to(self.device)).predicted_depth
depth = repeat(depth, "b h w -> b 3 h w")
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
depth = depth / 127.5 - 1.0
return depth
class CannyImageEncoder:
def __init__(
self,
device,
min_t: int = 50,
max_t: int = 200,
):
self.device = device
self.min_t = min_t
self.max_t = max_t
def __call__(self, img: torch.Tensor) -> torch.Tensor:
assert img.shape[0] == 1, "Only batch size 1 is supported"
img = rearrange(img[0], "c h w -> h w c")
img = torch.clamp(img, -1.0, 1.0)
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
# Apply Canny edge detection
canny = cv2.Canny(img_np, self.min_t, self.max_t)
# Convert back to torch tensor and reshape
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
canny = rearrange(canny, "h w -> 1 1 h w")
canny = repeat(canny, "b 1 ... -> b 3 ...")
return canny.to(self.device)

View File

@@ -0,0 +1,65 @@
import re
import torch
from typing import Any, Dict
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
# Example keys:
# guidance_in.in_layer.lora_B.bias
# single_blocks.0.linear1.lora_A.weight
# double_blocks.0.img_attn.norm.key_norm.scale
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(
re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
key_props = key.split(".")
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
# Leaving this in since it doesn't hurt anything and may be better
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
layer_name = ".".join(key_props[:layer_prop_size])
param_name = ".".join(key_props[layer_prop_size:])
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
# Convert to a full layer diff
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
layers[prefixed_key] = LoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"]
)
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
else:
raise AssertionError(f"{layer_key} not expected")
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)

View File

@@ -2,11 +2,11 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
@@ -30,9 +30,7 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(
state_dict: Dict[str, torch.Tensor], alpha: float | None
) -> ModelPatchRaw:
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
@@ -51,7 +49,7 @@ def lora_model_from_flux_diffusers_state_dict(
mlp_ratio = 4.0
mlp_hidden_dim = int(hidden_size * mlp_ratio)
layers: dict[str, BaseLayerPatch] = {}
layers: dict[str, AnyLoRALayer] = {}
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
@@ -217,7 +215,7 @@ def lora_model_from_flux_diffusers_state_dict(
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
return ModelPatchRaw(layers=layers_with_prefix)
return LoRAModelRaw(layers=layers_with_prefix)
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:

View File

@@ -3,13 +3,10 @@ from typing import Any, Dict, TypeVar
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
# Example keys:
@@ -39,7 +36,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
)
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
@@ -64,14 +61,14 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
# Create LoRA layers.
layers: dict[str, BaseLayerPatch] = {}
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return ModelPatchRaw(layers=layers)
return LoRAModelRaw(layers=layers)
T = TypeVar("T")

View File

@@ -2,19 +2,19 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, BaseLayerPatch] = {}
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(values)
return ModelPatchRaw(layers=layers)
return LoRAModelRaw(layers=layers)
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:

View File

@@ -0,0 +1,12 @@
from typing import Union
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]

View File

@@ -2,8 +2,8 @@ from typing import Optional, Sequence
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class ConcatenatedLoRALayer(LoRALayerBase):
@@ -20,7 +20,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
self.lora_layers = lora_layers
self.concat_axis = concat_axis
def _rank(self) -> int | None:
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:

View File

@@ -2,7 +2,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@@ -20,7 +20,7 @@ class FullLayer(LoRALayerBase):
cls.warn_on_unhandled_keys(values=values, handled_keys={"diff", "diff_b"})
return layer
def _rank(self) -> int | None:
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:

View File

@@ -2,7 +2,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class IA3Layer(LoRALayerBase):
@@ -16,7 +16,7 @@ class IA3Layer(LoRALayerBase):
self.weight = weight
self.on_input = on_input
def _rank(self) -> int | None:
def rank(self) -> int | None:
return None
@classmethod

View File

@@ -2,7 +2,7 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
@@ -32,7 +32,7 @@ class LoHALayer(LoRALayerBase):
self.t2 = t2
assert (self.t1 is None) == (self.t2 is None)
def _rank(self) -> int | None:
def rank(self) -> int | None:
return self.w1_b.shape[0]
@classmethod

View File

@@ -2,7 +2,7 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
@@ -39,7 +39,7 @@ class LoKRLayer(LoRALayerBase):
assert (self.w2 is None) != (self.w2_a is None)
assert (self.w2_a is None) == (self.w2_b is None)
def _rank(self) -> int | None:
def rank(self) -> int | None:
if self.w1_b is not None:
return self.w1_b.shape[0]
elif self.w2_b is not None:

View File

@@ -2,7 +2,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
@@ -55,7 +55,7 @@ class LoRALayer(LoRALayerBase):
return layer
def _rank(self) -> int:
def rank(self) -> int:
return self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:

View File

@@ -1,13 +1,12 @@
from typing import Optional
from typing import Dict, Optional, Set
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoRALayerBase(BaseLayerPatch):
class LoRALayerBase:
"""Base class for all LoRA-like patching layers."""
# Note: It is tempting to make this a torch.nn.Module sub-class and make all tensors 'torch.nn.Parameter's. Then we
@@ -24,7 +23,6 @@ class LoRALayerBase(BaseLayerPatch):
def _parse_bias(
cls, bias_indices: torch.Tensor | None, bias_values: torch.Tensor | None, bias_size: torch.Tensor | None
) -> torch.Tensor | None:
"""Helper function to parse a bias tensor from a state dict in LyCORIS format."""
assert (bias_indices is None) == (bias_values is None) == (bias_size is None)
bias = None
@@ -39,14 +37,11 @@ class LoRALayerBase(BaseLayerPatch):
) -> float | None:
return alpha.item() if alpha is not None else None
def _rank(self) -> int | None:
"""Return the rank of the LoRA-like layer. Or None if the layer does not have a rank. This value is used to
calculate the scale.
"""
def rank(self) -> int | None:
raise NotImplementedError()
def scale(self) -> float:
rank = self._rank()
rank = self.rank()
if self._alpha is None or rank is None:
return 1.0
return self._alpha / rank
@@ -57,23 +52,15 @@ class LoRALayerBase(BaseLayerPatch):
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
scale = self.scale()
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias * (weight * scale)
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_module.get_parameter(param_name)
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)
params["bias"] = bias
return params
@classmethod
def warn_on_unhandled_keys(cls, values: dict[str, torch.Tensor], handled_keys: set[str]):
def warn_on_unhandled_keys(cls, values: Dict[str, torch.Tensor], handled_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
unknown_keys = set(values.keys()) - handled_keys
if unknown_keys:

View File

@@ -2,7 +2,7 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@@ -20,7 +20,7 @@ class NormLayer(LoRALayerBase):
cls.warn_on_unhandled_keys(values, {"w_norm", "b_norm"})
return layer
def _rank(self) -> int | None:
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:

View File

@@ -0,0 +1,34 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class ReshapeWeightLayer(LoRALayerBase):
# TODO: Just everything in this class
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight) if weight is not None else None
self.bias = torch.nn.Parameter(bias) if bias is not None else None
self.manual_scale = scale
def scale(self):
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return orig_weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
if self.weight is not None:
self.weight = self.weight.to(device=device, dtype=dtype)
if self.manual_scale is not None:
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.manual_scale)

View File

@@ -0,0 +1,29 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class SetParameterLayer(LoRALayerBase):
def __init__(self, param_name: str, weight: torch.Tensor):
super().__init__(None, None)
self.weight = weight
self.param_name = param_name
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight - orig_weight
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.weight)

View File

@@ -2,16 +2,17 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.full_layer import FullLayer
from invokeai.backend.patches.layers.ia3_layer import IA3Layer
from invokeai.backend.patches.layers.loha_layer import LoHALayer
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules

View File

@@ -0,0 +1,133 @@
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRASidecarWrapper(torch.nn.Module):
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
super().__init__()
self._orig_module = orig_module
self._lora_layers = lora_layers
self._lora_weights = lora_weights
@property
def orig_module(self) -> torch.nn.Module:
return self._orig_module
def add_lora_layer(self, lora_layer: AnyLoRALayer, lora_weight: float):
self._lora_layers.append(lora_layer)
self._lora_weights.append(lora_weight)
@torch.no_grad()
def _get_lora_patched_parameters(
self, orig_params: dict[str, torch.Tensor], lora_layers: list[AnyLoRALayer], lora_weights: list[float]
) -> dict[str, torch.Tensor]:
params: dict[str, torch.Tensor] = {}
for lora_layer, lora_weight in zip(lora_layers, lora_weights, strict=True):
layer_params = lora_layer.get_parameters(self._orig_module)
for param_name, param_weight in layer_params.items():
if orig_params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(orig_params[param_name].shape)
if param_name not in params:
params[param_name] = param_weight * (lora_layer.scale() * lora_weight)
else:
params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
return params
class LoRALinearWrapper(LoRASidecarWrapper):
def _lora_linear_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def _concatenated_lora_forward(
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Split the LoRA layers into those that have optimized implementations and those that don't.
optimized_layer_types = (LoRALayer, ConcatenatedLoRALayer)
optimized_layers = [
(layer, weight)
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
if isinstance(layer, optimized_layer_types)
]
non_optimized_layers = [
(layer, weight)
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
if not isinstance(layer, optimized_layer_types)
]
# First, calculate the residual for LoRA layers for which there is an optimized implementation.
residual = None
for lora_layer, lora_weight in optimized_layers:
if isinstance(lora_layer, LoRALayer):
added_residual = self._lora_linear_forward(input, lora_layer, lora_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
added_residual = self._concatenated_lora_forward(input, lora_layer, lora_weight)
else:
raise ValueError(f"Unsupported LoRA layer type: {type(lora_layer)}")
if residual is None:
residual = added_residual
else:
residual += added_residual
# Next, calculate the residuals for the LoRA layers for which there is no optimized implementation.
if non_optimized_layers:
unoptimized_layers, unoptimized_weights = zip(*non_optimized_layers, strict=True)
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=unoptimized_layers,
lora_weights=unoptimized_weights,
)
added_residual = torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
if residual is None:
residual = added_residual
else:
residual += added_residual
return self.orig_module(input) + residual
class LoRAConv1dWrapper(LoRASidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=self._lora_layers,
lora_weights=self._lora_weights,
)
return self.orig_module(input) + torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
class LoRAConv2dWrapper(LoRASidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=self._lora_layers,
lora_weights=self._lora_weights,
)
return self.orig_module(input) + torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))

View File

@@ -3,17 +3,20 @@ from typing import Mapping, Optional
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.raw_model import RawModel
class ModelPatchRaw(RawModel):
def __init__(self, layers: Mapping[str, BaseLayerPatch]):
class LoRAModelRaw(RawModel): # (torch.nn.Module):
def __init__(self, layers: Mapping[str, AnyLoRALayer]):
self.layers = layers
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for layer in self.layers.values():
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return sum(layer.calc_size() for layer in self.layers.values())
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size

View File

@@ -3,23 +3,133 @@ from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.lora_layer_wrappers import (
LoRAConv1dWrapper,
LoRAConv2dWrapper,
LoRALinearWrapper,
LoRASidecarWrapper,
)
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LayerPatcher:
class LoRAPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_model_patches(
def apply_smart_lora_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[ModelPatchRaw, float]],
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply 'smart' LoRA patching that chooses whether to use direct patching or a sidecar wrapper for each module."""
# original_weights are stored for unpatching layers that are directly patched.
original_weights = OriginalWeightsStorage(cached_weights)
# original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper.
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoRAPatcher._apply_smart_lora_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
original_modules=original_modules,
dtype=dtype,
)
yield
finally:
# Restore directly patched layers.
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
# Restore LoRASidecarWrapper modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
@torch.no_grad()
def _apply_smart_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
patching or a sidecar wrapper for each module.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Decide whether to use direct patching or a sidecar wrapper.
# Direct patching is preferred, because it results in better runtime speed.
# Reasons to use sidecar patching:
# - The module is already wrapped in a LoRASidecarWrapper.
# - The module is quantized.
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
# CPU, since this would double the RAM usage)
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
# and that the caller will use the 'apply_lora_wrapper_patches' method if the layer is quantized.
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
# forcing full patching even on the CPU?
if isinstance(module, LoRASidecarWrapper) or LoRAPatcher._is_any_part_of_layer_on_cpu(module):
LoRAPatcher._apply_lora_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
else:
LoRAPatcher._apply_lora_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
@staticmethod
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
return any(p.device.type == "cpu" for p in layer.parameters())
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
@@ -37,7 +147,7 @@ class LayerPatcher:
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
LayerPatcher.apply_model_patch(
LoRAPatcher._apply_lora_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -54,10 +164,10 @@ class LayerPatcher:
@staticmethod
@torch.no_grad()
def apply_model_patch(
def _apply_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: ModelPatchRaw,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
@@ -85,11 +195,11 @@ class LayerPatcher:
if not layer_key.startswith(prefix):
continue
module_key, module = LayerPatcher._get_submodule(
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
LayerPatcher._apply_model_layer_patch(
LoRAPatcher._apply_lora_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
@@ -99,10 +209,10 @@ class LayerPatcher:
@staticmethod
@torch.no_grad()
def _apply_model_layer_patch(
def _apply_lora_layer_patch(
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: BaseLayerPatch,
patch: AnyLoRALayer,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
@@ -112,6 +222,8 @@ class LayerPatcher:
device = first_param.device
dtype = first_param.dtype
layer_scale = patch.scale()
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
@@ -120,41 +232,51 @@ class LayerPatcher:
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items():
for param_name, lora_param_weight in patch.get_parameters(module_to_patch).items():
param_key = module_to_patch_key + "." + param_name
module_param = module_to_patch.get_parameter(param_name)
# Save original weight
original_weights.save(param_key, module_param)
# HACK(ryand): This condition is only necessary to handle layers in FLUX control LoRAs that change the
# shape of the original layer.
if module_param.nelement() != param_weight.nelement():
assert isinstance(patch, FluxControlLoRALayer)
expanded_weight = pad_with_zeros(module_param, param_weight.shape)
setattr(
module_to_patch,
param_name,
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
if module_param.shape != lora_param_weight.shape:
if module_param.nelement() == lora_param_weight.nelement():
lora_param_weight = lora_param_weight.reshape(module_param.shape)
else:
# This condition was added to handle layers in FLUX control LoRAs.
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
# to worry about this?
expanded_weight = torch.zeros_like(
lora_param_weight, dtype=module_param.dtype, device=module_param.device
)
slices = tuple(slice(0, dim) for dim in module_param.shape)
expanded_weight[slices] = module_param
setattr(
module,
param_name,
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
module_param += param_weight.to(dtype=dtype)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
patch.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_model_sidecar_patches(
def apply_lora_wrapper_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[ModelPatchRaw, float]],
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
):
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
quantization format.
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
runtime overhead compared to normal LoRA patching, but they enable:
- LoRA layers to be applied to quantized models
- LoRA layers to be applied to CPU layers without needing to store a full copy of the original weights (i.e.
avoid doubling the memory requirements).
Args:
model (torch.nn.Module): The model to patch.
@@ -162,14 +284,11 @@ class LayerPatcher:
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
different from their compute dtype.
"""
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LayerPatcher._apply_model_sidecar_patch(
LoRAPatcher._apply_lora_wrapper_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -182,20 +301,20 @@ class LayerPatcher:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
LayerPatcher._set_submodule(parent_module, module_name, orig_module)
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_model_sidecar_patch(
def _apply_lora_wrapper_patch(
model: torch.nn.Module,
patch: ModelPatchRaw,
patch: LoRAModelRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA sidecar patch to a model."""
"""Apply a single LoRA wrapper patch to a model."""
if patch_weight == 0:
return
@@ -212,11 +331,11 @@ class LayerPatcher:
if not layer_key.startswith(prefix):
continue
module_key, module = LayerPatcher._get_submodule(
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
LayerPatcher._apply_model_layer_wrapper_patch(
LoRAPatcher._apply_lora_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
@@ -228,34 +347,35 @@ class LayerPatcher:
@staticmethod
@torch.no_grad()
def _apply_model_layer_wrapper_patch(
def _apply_lora_layer_wrapper_patch(
model: torch.nn.Module,
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: BaseLayerPatch,
patch: AnyLoRALayer,
patch_weight: float,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA wrapper patch to a model."""
# Replace the original module with a BaseSidecarWrapper if it has not already been done.
if not isinstance(module_to_patch, BaseSidecarWrapper):
wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch)
# Replace the original module with a LoRASidecarWrapper if it has not already been done.
if not isinstance(module_to_patch, LoRASidecarWrapper):
lora_wrapper_layer = LoRAPatcher._initialize_lora_wrapper_layer(module_to_patch)
original_modules[module_to_patch_key] = module_to_patch
module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key)
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_to_patch_key)
module_parent = model.get_submodule(module_parent_key)
LayerPatcher._set_submodule(module_parent, module_name, wrapped_module)
LoRAPatcher._set_submodule(module_parent, module_name, lora_wrapper_layer)
orig_module = module_to_patch
else:
assert module_to_patch_key in original_modules
wrapped_module = module_to_patch
lora_wrapper_layer = module_to_patch
orig_module = module_to_patch.orig_module
# Move the LoRA layer to the same device/dtype as the orig module.
first_param = next(module_to_patch.parameters())
device = first_param.device
patch.to(device=device, dtype=dtype)
patch.to(device=orig_module.weight.device, dtype=dtype)
# Add the patch to the sidecar wrapper.
wrapped_module.add_patch(patch, patch_weight)
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
lora_wrapper_layer.add_lora_layer(patch, patch_weight)
@staticmethod
def _split_parent_key(module_key: str) -> tuple[str, str]:
@@ -275,6 +395,17 @@ class LayerPatcher:
else:
raise ValueError(f"Invalid module key: {module_key}")
@staticmethod
def _initialize_lora_wrapper_layer(orig_layer: torch.nn.Module):
if isinstance(orig_layer, torch.nn.Linear):
return LoRALinearWrapper(orig_layer, [], [])
elif isinstance(orig_layer, torch.nn.Conv1d):
return LoRAConv1dWrapper(orig_layer, [], [])
elif isinstance(orig_layer, torch.nn.Conv2d):
return LoRAConv2dWrapper(orig_layer, [], [])
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
@staticmethod
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
try:

View File

@@ -67,7 +67,7 @@ class ModelType(str, Enum):
Main = "main"
VAE = "vae"
LoRA = "lora"
ControlLoRa = "control_lora"
StructuralLoRa = "structural_lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
@@ -274,34 +274,16 @@ class LoRALyCORISConfig(LoRAConfigBase):
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class ControlAdapterConfigBase(BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class StructuralLoRALyCORISConfig(ModelConfigBase):
"""Model config for Structural LoRA/Lycoris models."""
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
type: Literal[ModelType.StructuralLoRa] = ModelType.StructuralLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.Diffusers.value}")
return Tag(f"{ModelType.StructuralLoRa.value}.{ModelFormat.LyCORIS.value}")
class LoRADiffusersConfig(LoRAConfigBase):
@@ -335,6 +317,12 @@ class VAEDiffusersConfig(ModelConfigBase):
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlAdapterConfigBase(BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
"""Model config for ControlNet models (diffusers version)."""
@@ -560,8 +548,7 @@ AnyModelConfig = Annotated[
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
Annotated[StructuralLoRALyCORISConfig, StructuralLoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],

View File

@@ -9,6 +9,15 @@ import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -20,25 +29,11 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.StructuralLoRa, format=ModelFormat.LyCORIS)
class LoRALoader(ModelLoader):
"""Class to load LoRA models."""

View File

@@ -15,9 +15,9 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@@ -43,7 +43,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
(
TextualInversionModelRaw,
IPAdapter,
ModelPatchRaw,
LoRAModelRaw,
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingPipeline,

View File

@@ -15,6 +15,11 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@@ -39,13 +44,6 @@ from invokeai.backend.model_manager.util.model_util import (
lora_token_vector_length,
read_checkpoint_meta,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@@ -202,8 +200,8 @@ class ModelProbe(object):
fields["default_settings"] = fields.get("default_settings")
if not fields["default_settings"]:
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}:
fields["default_settings"] = get_default_settings_control_adapters(fields["name"])
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
elif fields["type"] is ModelType.Main:
fields["default_settings"] = get_default_settings_main(fields["base"])
@@ -261,8 +259,17 @@ class ModelProbe(object):
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
if isinstance(ckpt, dict) and is_state_dict_likely_flux_control(ckpt):
return ModelType.ControlLoRa
if isinstance(ckpt, dict) and "img_in.lora_A.weight" in ckpt and "img_in.lora_B.weight" in ckpt:
tensor_a, tensor_b = ckpt["img_in.lora_A.weight"], ckpt["img_in.lora_B.weight"]
if (
tensor_a is not None
and isinstance(tensor_a, torch.Tensor)
and tensor_a.shape[1] == 128
and tensor_b is not None
and isinstance(tensor_b, torch.Tensor)
and tensor_b.shape[0] == 3072
):
return ModelType.StructuralLoRa
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(
@@ -503,7 +510,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
}
def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
model_name_lower = model_name.lower()
if k in model_name_lower:
@@ -1042,7 +1049,6 @@ class T2IAdapterFolderProbe(FolderProbeBase):
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
@@ -1055,7 +1061,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.StructuralLoRa, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@@ -488,22 +488,6 @@ union_cnet_flux = StarterModel(
type=ModelType.ControlNet,
)
# endregion
# region Control LoRA
flux_canny_control_lora = StarterModel(
name="Hard Edge Detection (Canny)",
base=BaseModelType.Flux,
source="black-forest-labs/FLUX.1-Canny-dev-lora::flux1-canny-dev-lora.safetensors",
description="Uses detected edges in the image to control composition.",
type=ModelType.ControlLoRa,
)
flux_depth_control_lora = StarterModel(
name="Depth Map",
base=BaseModelType.Flux,
source="black-forest-labs/FLUX.1-Depth-dev-lora::flux1-depth-dev-lora.safetensors",
description="Uses depth information in the image to control the depth in the generation.",
type=ModelType.ControlLoRa,
)
# endregion
# region T2I Adapter
t2i_canny_sd1 = StarterModel(
name="Hard Edge Detection (canny)",
@@ -646,8 +630,6 @@ STARTER_MODELS: list[StarterModel] = [
tile_sdxl,
union_cnet_sdxl,
union_cnet_flux,
flux_canny_control_lora,
flux_depth_control_lora,
t2i_canny_sd1,
t2i_sketch_sd1,
t2i_depth_sd1,
@@ -706,8 +688,6 @@ flux_bundle: list[StarterModel] = [
clip_l_encoder,
union_cnet_flux,
ip_adapter_flux,
flux_canny_control_lora,
flux_depth_control_lora,
]
STARTER_BUNDLES: dict[str, list[StarterModel]] = {

View File

@@ -5,14 +5,17 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
@@ -173,3 +176,180 @@ class ModelPatcher:
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
if did_apply_freeu:
unet.disable_freeu()
class ONNXModelPatcher:
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod
@contextmanager
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> None:
from invokeai.backend.models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = {}
try:
blended_loras: Dict[str, torch.Tensor] = {}
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
layer.to(dtype=torch.float32)
layer_key = layer_key.replace(prefix, "")
# TODO: rewrite to pass original tensor weight(required by ia3)
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
if layer_key in blended_loras:
blended_loras[layer_key] += layer_weight
else:
blended_loras[layer_key] = layer_weight
node_names = {}
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
for layer_key, lora_weight in blended_loras.items():
conv_key = layer_key + "_Conv"
gemm_key = layer_key + "_Gemm"
matmul_key = layer_key + "_MatMul"
if conv_key in node_names or gemm_key in node_names:
if conv_key in node_names:
conv_node = model.nodes[node_names[conv_key]]
else:
conv_node = model.nodes[node_names[gemm_key]]
weight_name = [n for n in conv_node.input if ".weight" in n][0]
orig_weight = model.tensors[weight_name]
if orig_weight.shape[-2:] == (1, 1):
if lora_weight.shape[-2:] == (1, 1):
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
else:
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
new_weight = np.expand_dims(new_weight, (2, 3))
else:
if orig_weight.shape != lora_weight.shape:
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
else:
new_weight = orig_weight + lora_weight
orig_weights[weight_name] = orig_weight
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
elif matmul_key in node_names:
weight_node = model.nodes[node_names[matmul_key]]
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
orig_weight = model.tensors[matmul_name]
new_weight = orig_weight + lora_weight.transpose()
orig_weights[matmul_name] = orig_weight
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
else:
# warn? err?
pass
yield
finally:
# restore original weights
for name, orig_weight in orig_weights.items():
model.tensors[name] = orig_weight
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
from invokeai.backend.models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_embeddings = None
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name: str, index: int) -> str:
trigger = ti_name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
if ti.embedding_2 is not None:
ti_embedding = (
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
)
else:
ti_embedding = ti.embedding
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
embeddings = np.concatenate(
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
axis=0,
)
for ti_name, _ in ti_list:
ti_tokens = []
for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i].detach().numpy()
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if embeddings[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {embeddings[token_id].shape[0]}."
)
embeddings[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
orig_embeddings.dtype
)
yield ti_tokenizer, ti_manager
finally:
# restore
if orig_embeddings is not None:
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings

View File

@@ -1,22 +0,0 @@
from abc import ABC, abstractmethod
import torch
class BaseLayerPatch(ABC):
@abstractmethod
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
"""Get the parameter residual updates that should be applied to the original parameters. Parameters omitted
from the returned dict are not updated.
"""
...
@abstractmethod
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
"""Move all internal tensors to the specified device and dtype."""
...
@abstractmethod
def calc_size(self) -> int:
"""Calculate the total size of all internal tensors in bytes."""
...

View File

@@ -1,19 +0,0 @@
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
class FluxControlLoRALayer(LoRALayer):
"""A special case of LoRALayer for use with FLUX Control LoRAs that pads the target parameter with zeros if the
shapes don't match.
"""
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
"""This overrides the base class behavior to skip the reshaping step."""
scale = self.scale()
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias * (weight * scale)
return params

View File

@@ -1,27 +0,0 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class SetParameterLayer(BaseLayerPatch):
"""A layer that sets a single parameter to a new target value.
(The diff between the target value and current value is calculated internally.)
"""
def __init__(self, param_name: str, weight: torch.Tensor):
super().__init__()
self.weight = weight
self.param_name = param_name
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
# Control LoRA implementation.
diff = self.weight - orig_module.get_parameter(self.param_name)
return {self.param_name: diff}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.weight = self.weight.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return calc_tensor_size(self.weight)

View File

@@ -1,84 +0,0 @@
import re
from typing import Any, Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
# Example keys:
# guidance_in.in_layer.lora_B.bias
# single_blocks.0.linear1.lora_A.weight
# double_blocks.0.img_attn.norm.key_norm.scale
FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)"
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys())
# Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs.
lora_a_weight = state_dict.get("img_in.lora_A.weight", None)
lora_b_bias = state_dict.get("img_in.lora_B.bias", None)
lora_b_weight = state_dict.get("img_in.lora_B.weight", None)
return (
all_keys_match
and lora_a_weight is not None
and lora_b_bias is not None
and lora_b_weight is not None
and lora_a_weight.shape[1] == 128
and lora_b_weight.shape[0] == 3072
and lora_b_bias.shape[0] == 3072
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
key_props = key.split(".")
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
layer_name = ".".join(key_props[:layer_prop_size])
param_name = ".".join(key_props[layer_prop_size:])
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Create LoRA layers.
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
if layer_key == "img_in":
# img_in is a special case because it changes the shape of the original weight.
layers[prefixed_key] = FluxControlLoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"],
)
elif all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
layers[prefixed_key] = LoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"],
)
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
else:
raise ValueError(f"{layer_key} not expected")
return ModelPatchRaw(layers=layers)

View File

@@ -1,9 +0,0 @@
import torch
def pad_with_zeros(orig_weight: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
"""Pad a weight tensor with zeros to match the target shape."""
expanded_weight = torch.zeros(target_shape, dtype=orig_weight.dtype, device=orig_weight.device)
slices = tuple(slice(0, dim) for dim in orig_weight.shape)
expanded_weight[slices] = orig_weight
return expanded_weight

View File

@@ -1,54 +0,0 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
class BaseSidecarWrapper(torch.nn.Module):
"""A base class for sidecar wrappers.
A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a
list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying
the original module.
Sidecar wrappers are typically used over regular patches when:
- The original module is quantized and so the weights can't be patched in the usual way.
- The original module is on the CPU and modifying the weights would require backing up the original weights and
doubling the CPU memory usage.
"""
def __init__(
self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None
):
super().__init__()
self._orig_module = orig_module
self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights
@property
def orig_module(self) -> torch.nn.Module:
return self._orig_module
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
"""Add a patch to the sidecar wrapper."""
self._patches_and_weights.append((patch, patch_weight))
def _aggregate_patch_parameters(
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> dict[str, torch.Tensor]:
"""Helper function that aggregates the parameters from all patches into a single dict."""
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
# module, this might fail or return incorrect results.
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
for param_name, param_weight in layer_params.items():
if param_name not in params:
params[param_name] = param_weight
else:
params[param_name] += param_weight
return params
def forward(self, *args, **kwargs): # type: ignore
raise NotImplementedError()

View File

@@ -1,11 +0,0 @@
import torch
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class Conv1dSidecarWrapper(BaseSidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
return self.orig_module(input) + torch.nn.functional.conv1d(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)

View File

@@ -1,11 +0,0 @@
import torch
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class Conv2dSidecarWrapper(BaseSidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
return self.orig_module(input) + torch.nn.functional.conv1d(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)

View File

@@ -1,24 +0,0 @@
import torch
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class FluxRMSNormSidecarWrapper(BaseSidecarWrapper):
"""A sidecar wrapper for a FLUX RMSNorm layer.
This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite
the RMSNorm scale parameters.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Given the narrow focus of this wrapper, we only support a very particular patch configuration:
assert len(self._patches_and_weights) == 1
patch, _patch_weight = self._patches_and_weights[0]
assert isinstance(patch, SetParameterLayer)
assert patch.param_name == "scale"
# Apply the patch.
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
# be handled.
return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6)

View File

@@ -1,66 +0,0 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
class LinearSidecarWrapper(BaseSidecarWrapper):
def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def _concatenated_lora_forward(
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def forward(self, input: torch.Tensor) -> torch.Tensor:
# First, apply the original linear layer.
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
# change the linear layer's in_features.
orig_input = input
input = orig_input[..., : self.orig_module.in_features]
output = self.orig_module(input)
# Then, apply layers for which we have optimized implementations.
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
for patch, patch_weight in self._patches_and_weights:
if isinstance(patch, FluxControlLoRALayer):
# Note that we use the original input here, not the sliced input.
output += self._lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += self._lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += self._concatenated_lora_forward(input, patch, patch_weight)
else:
unprocessed_patches_and_weights.append((patch, patch_weight))
# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
return output

View File

@@ -1,20 +0,0 @@
import torch
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
if isinstance(orig_module, torch.nn.Linear):
return LinearSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv1d):
return Conv1dSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv2d):
return Conv2dSidecarWrapper(orig_module)
elif isinstance(orig_module, RMSNorm):
return FluxRMSNormSidecarWrapper(orig_module)
else:
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")

View File

@@ -52,7 +52,6 @@ GGML_TENSOR_OP_TABLE = {
torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
}
if torch.backends.mps.is_available():

View File

@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
from diffusers import UNet2DConditionModel
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
@@ -30,8 +30,8 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, ModelPatchRaw)
LayerPatcher.apply_model_patch(
assert isinstance(lora_model, LoRAModelRaw)
LoRAPatcher.apply_lora_patch(
model=unet,
prefix="lora_unet_",
patch=lora_model,

View File

@@ -98,8 +98,7 @@
"close": "Schließen",
"clipboard": "Zwischenablage",
"generating": "Generieren",
"loadingModel": "Lade Modell",
"warnings": "Warnungen"
"loadingModel": "Lade Modell"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -271,7 +270,7 @@
"title": "Ansichts-Tool"
},
"quickSwitch": {
"title": "Ebenen Schnell-Umschalten",
"title": "Ebenen schnell umschalten",
"desc": "Wechseln Sie zwischen den beiden zuletzt gewählten Ebenen. Wenn eine Ebene mit einem Lesezeichen versehen ist, wird zwischen ihr und der letzten nicht markierten Ebene gewechselt."
},
"applyFilter": {
@@ -602,9 +601,7 @@
"hfTokenLabel": "HuggingFace-Token (für einige Modelle erforderlich)",
"hfTokenHelperText": "Für die Nutzung einiger Modelle ist ein HF-Token erforderlich. Klicken Sie hier, um Ihr Token zu erstellen oder zu erhalten.",
"hfForbidden": "Sie haben keinen Zugriff auf dieses HF-Modell",
"hfTokenInvalid": "Ungültiges oder fehlendes HF-Token",
"restoreDefaultSettings": "Klicken, um die Standardeinstellungen des Modells zu verwenden.",
"usingDefaultSettings": "Die Standardeinstellungen des Modells werden verwendet"
"hfTokenInvalid": "Ungültiges oder fehlendes HF-Token"
},
"parameters": {
"images": "Bilder",
@@ -648,19 +645,8 @@
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), Skalierte Bbox-Breite ist {{width}}",
"fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), Skalierte Bbox-Höhe ist {{height}}",
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), Bbox-Breite ist {{width}}",
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), Bbox-Höhe ist {{height}}",
"noNodesInGraph": "Keine Knoten im Graphen",
"canvasIsTransforming": "Leinwand ist beschäftigt (wird transformiert)",
"canvasIsRasterizing": "Leinwand ist beschäftigt (wird gerastert)",
"canvasIsCompositing": "Leinwand ist beschäftigt (wird zusammengesetzt)",
"canvasIsFiltering": "Leinwand ist beschäftigt (wird gefiltert)",
"canvasIsSelectingObject": "Leinwand ist beschäftigt (wird Objekt ausgewählt)",
"noPrompts": "Keine Eingabeaufforderungen generiert"
},
"seed": "Seed",
"patchmatchDownScaleSize": "Herunterskalieren",
"seamlessXAxis": "Nahtlose X Achse",
"seamlessYAxis": "Nahtlose Y Achse"
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), Bbox-Höhe ist {{height}}"
}
},
"settings": {
"displayInProgress": "Zwischenbilder anzeigen",
@@ -1115,18 +1101,6 @@
},
"paramHrf": {
"heading": "High Resolution Fix aktivieren"
},
"seamlessTilingYAxis": {
"heading": "Nahtlose Kachelung Y Achse",
"paragraphs": [
"Nahtloses Kacheln eines Bildes entlang der vertikalen Achse."
]
},
"seamlessTilingXAxis": {
"paragraphs": [
"Nahtloses Kacheln eines Bildes entlang der horizontalen Achse."
],
"heading": "Nahtlose Kachelung X Achse"
}
},
"invocationCache": {
@@ -1338,15 +1312,7 @@
"loadWorkflow": "Arbeitsablauf $t(common.load)",
"updated": "Aktualisiert",
"created": "Erstellt",
"descending": "Absteigend",
"edit": "Bearbeiten",
"loadFromGraph": "Arbeitsablauf aus dem Graph laden",
"delete": "Löschen",
"copyShareLinkForWorkflow": "Teilen-Link für Arbeitsablauf kopieren",
"autoLayout": "Auto Layout",
"copyShareLink": "Teilen-Link kopieren",
"download": "Herunterladen",
"convertGraph": "Graph konvertieren"
"descending": "Absteigend"
},
"sdxl": {
"concatPromptStyle": "Verknüpfen von Prompt & Stil",
@@ -1596,8 +1562,7 @@
"filters": "Filter",
"filterType": "Filtertyp",
"filter": "Filter"
},
"bookmark": "Lesezeichen für Schnell-Umschalten"
}
},
"upsell": {
"shareAccess": "Zugang teilen",

View File

@@ -809,7 +809,7 @@
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"controlLora": "Control LoRA",
"structuralLora": "Structural LoRA",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",
@@ -1033,7 +1033,6 @@
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), bbox height is {{height}}",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), scaled bbox width is {{width}}",
"fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), scaled bbox height is {{height}}",
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
"canvasIsFiltering": "Canvas is busy (filtering)",
"canvasIsTransforming": "Canvas is busy (transforming)",
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
@@ -2112,7 +2111,6 @@
},
"logNamespaces": {
"logNamespaces": "Log Namespaces",
"dnd": "Drag and Drop",
"gallery": "Gallery",
"models": "Models",
"config": "Config",
@@ -2136,7 +2134,8 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"<StrongComponent>Flux Control Layers</StrongComponent>: New control models for edge detection and depth mapping are now supported for Flux dev models."
"<StrongComponent>FLUX Regional Guidance (beta)</StrongComponent>: Our beta release of FLUX Regional Guidance is live for regional prompt control.",
"<StrongComponent>Various UX Improvements</StrongComponent>: A number of small UX and Quality of Life improvements throughout the app."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@@ -507,7 +507,8 @@
"watchUiUpdatesOverview": "Descripción general de las actualizaciones de la interfaz de usuario de Watch",
"whatsNewInInvoke": "Novedades en Invoke",
"items": [
"<StrongComponent>SD 3.5</StrongComponent>: compatibilidad con SD 3.5 Medium y Large."
"<StrongComponent>SD 3.5</StrongComponent>: compatibilidad con SD 3.5 Medium y Large.",
"<StrongComponent>Lienzo</StrongComponent>: Se ha simplificado el procesamiento de la capa de control y se ha mejorado la configuración predeterminada del control."
]
},
"invocationCache": {

View File

@@ -96,9 +96,7 @@
"negativePrompt": "Prompt Négatif",
"ok": "Ok",
"close": "Fermer",
"clipboard": "Presse-papier",
"loadingModel": "Chargement du modèle",
"generating": "En Génération"
"clipboard": "Presse-papier"
},
"gallery": {
"galleryImageSize": "Taille de l'image",
@@ -289,20 +287,7 @@
"noDefaultSettings": "Aucun paramètre par défaut configuré pour ce modèle. Visitez le Gestionnaire de Modèles pour ajouter des paramètres par défaut.",
"usingDefaultSettings": "Utilisation des paramètres par défaut du modèle",
"defaultSettingsOutOfSync": "Certain paramètres ne correspondent pas aux valeurs par défaut du modèle :",
"restoreDefaultSettings": "Cliquez pour utiliser les paramètres par défaut du modèle.",
"hfForbiddenErrorMessage": "Nous vous recommandons de visiter la page du modèle sur HuggingFace.com. Le propriétaire peut exiger l'acceptation des conditions pour pouvoir télécharger.",
"hfTokenRequired": "Vous essayez de télécharger un modèle qui nécessite un token HuggingFace valide.",
"clipLEmbed": "CLIP-L Embed",
"hfTokenSaved": "Token HF enregistré",
"hfTokenUnableToVerifyErrorMessage": "Impossible de vérifier le token HuggingFace. Cela est probablement dû à une erreur réseau. Veuillez réessayer plus tard.",
"clipGEmbed": "CLIP-G Embed",
"hfTokenUnableToVerify": "Impossible de vérifier le token HF",
"hfTokenInvalidErrorMessage": "Token HuggingFace invalide ou manquant.",
"hfTokenLabel": "Token HuggingFace (Requis pour certains modèles)",
"hfTokenHelperText": "Un token HF est requis pour utiliser certains modèles. Cliquez ici pour créer ou obtenir votre token.",
"hfTokenInvalid": "Token HF invalide ou manquant",
"hfForbidden": "Vous n'avez pas accès à ce modèle HF.",
"hfTokenInvalidErrorMessage2": "Mettre à jour dans le "
"restoreDefaultSettings": "Cliquez pour utiliser les paramètres par défaut du modèle."
},
"parameters": {
"images": "Images",
@@ -351,11 +336,7 @@
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), la largeur de la bounding box est {{width}}",
"noT5EncoderModelSelected": "Aucun modèle T5 Encoder sélectionné pour la génération FLUX",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), la largeur de la bounding box mise à l'échelle est {{width}}",
"canvasIsCompositing": "La toile est en train de composer",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} collection vide",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}} : trop peu d'éléments, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}} : trop d'éléments, maximum {{maxItems}}",
"canvasIsSelectingObject": "La toile est occupée (sélection d'objet)"
"canvasIsCompositing": "La toile est en train de composer"
},
"negativePromptPlaceholder": "Prompt Négatif",
"positivePromptPlaceholder": "Prompt Positif",
@@ -396,9 +377,7 @@
"sendToUpscale": "Envoyer à Agrandir",
"guidance": "Guidage",
"postProcessing": "Post-traitement (Maj + U)",
"processImage": "Traiter l'image",
"disabledNoRasterContent": "Désactivé (Aucun contenu raster)",
"recallMetadata": "Rappeler les métadonnées"
"processImage": "Traiter l'image"
},
"settings": {
"models": "Modèles",
@@ -436,8 +415,7 @@
"confirmOnNewSession": "Confirmer lors d'une nouvelle session",
"modelDescriptionsDisabledDesc": "Les descriptions des modèles dans les menus déroulants ont été désactivées. Activez-les dans les paramètres.",
"enableModelDescriptions": "Activer les descriptions de modèle dans les menus déroulants",
"modelDescriptionsDisabled": "Descriptions de modèle dans les menus déroulants désactivés",
"showDetailedInvocationProgress": "Afficher les détails de progression"
"modelDescriptionsDisabled": "Descriptions de modèle dans les menus déroulants désactivés"
},
"toast": {
"uploadFailed": "Importation échouée",
@@ -656,8 +634,7 @@
"iterations_one": "Itération",
"iterations_many": "Itérations",
"iterations_other": "Itérations",
"back": "fin",
"batchSize": "Taille de lot"
"back": "fin"
},
"prompt": {
"noMatchingTriggers": "Pas de déclancheurs correspondants",
@@ -1175,8 +1152,7 @@
"heading": "Force de débruitage",
"paragraphs": [
"Intensité du bruit ajouté à l'image d'entrée.",
"0 produira une image identique, tandis que 1 produira une image complètement différente.",
"Lorsque aucune couche raster avec du contenu visible n'est présente, ce paramètre est ignoré."
"0 produira une image identique, tandis que 1 produira une image complètement différente."
]
},
"lora": {
@@ -1471,9 +1447,7 @@
"parsingFailed": "L'analyse a échoué",
"recallParameter": "Rappeler {{label}}",
"canvasV2Metadata": "Toile",
"guidance": "Guide",
"seamlessXAxis": "Axe X sans bords",
"seamlessYAxis": "Axe Y sans bords"
"guidance": "Guide"
},
"sdxl": {
"freePromptStyle": "Écriture de Prompt manuelle",
@@ -1694,13 +1668,7 @@
"delete": "Supprimer"
},
"whatsNew": {
"whatsNewInInvoke": "Quoi de neuf dans Invoke",
"watchRecentReleaseVideos": "Regarder les vidéos des dernières versions",
"items": [
"<StrongComponent>FLUX Guidage Régional (bêta)</StrongComponent> : Notre version bêta de FLUX Guidage Régional est en ligne pour le contrôle des prompt régionaux."
],
"readReleaseNotes": "Notes de version",
"watchUiUpdatesOverview": "Aperçu des mises à jour de l'interface utilisateur"
"whatsNewInInvoke": "Quoi de neuf dans Invoke"
},
"ui": {
"tabs": {
@@ -1808,10 +1776,7 @@
},
"process": "Traiter",
"apply": "Appliquer",
"cancel": "Annuler",
"advanced": "Avancé",
"processingLayerWith": "Calque de traitement avec le filtre {{type}}.",
"forMoreControl": "Pour plus de contrôle, cliquez sur Avancé ci-dessous."
"cancel": "Annuler"
},
"canvasContextMenu": {
"saveToGalleryGroup": "Enregistrer dans la galerie",
@@ -2064,17 +2029,7 @@
"convertInpaintMaskTo": "Convertir $t(controlLayers.inpaintMask) vers",
"copyControlLayerTo": "Copier $t(controlLayers.controlLayer) vers",
"newInpaintMask": "Nouveau $t(controlLayers.inpaintMask)",
"newRasterLayer": "Nouveau $t(controlLayers.rasterLayer)",
"mergingLayers": "Fusionner les couches",
"resetCanvasLayers": "Réinitialiser les couches de la toile",
"resetGenerationSettings": "Réinitialiser les paramètres de génération",
"mergeDown": "Fusionner",
"controlLayerEmptyState": "<UploadButton>Télécharger une image</UploadButton>, faites glisser une image depuis la <GalleryButton>galerie</GalleryButton> sur ce calque, ou dessinez sur la toile pour commencer.",
"asRasterLayer": "En tant que $t(controlLayers.rasterLayer)",
"asRasterLayerResize": "En tant que $t(controlLayers.rasterLayer) (Redimensionner)",
"asControlLayer": "En tant que $t(controlLayers.controlLayer)",
"asControlLayerResize": "En $t(controlLayers.controlLayer) (Redimensionner)",
"newSession": "Nouvelle session"
"newRasterLayer": "Nouveau $t(controlLayers.rasterLayer)"
},
"upscaling": {
"exceedsMaxSizeDetails": "La limite maximale d'agrandissement est de {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Veuillez essayer une image plus petite ou réduire votre sélection d'échelle.",
@@ -2092,9 +2047,7 @@
"postProcessingModel": "Modèle de post-traitement",
"missingUpscaleModel": "Modèle d'agrandissement manquant",
"missingUpscaleInitialImage": "Image initiale manquante pour l'agrandissement",
"missingTileControlNetModel": "Aucun modèle ControlNet valide installé",
"incompatibleBaseModelDesc": "L'upscaling est pris en charge uniquement pour les modèles d'architecture SD1.5 et SDXL. Changez le modèle principal pour activer l'upscaling.",
"incompatibleBaseModel": "Modèle principal non pris en charge pour l'upscaling"
"missingTileControlNetModel": "Aucun modèle ControlNet valide installé"
},
"stylePresets": {
"deleteTemplate": "Supprimer le template",
@@ -2180,62 +2133,5 @@
"inviteTeammates": "Inviter des collègues",
"professionalUpsell": "Disponible dans l'édition professionnelle d'Invoke. Cliquez ici ou visitez invoke.com/pricing pour plus de détails.",
"professional": "Professionnel"
},
"supportVideos": {
"watch": "Regarder",
"videos": {
"upscaling": {
"description": "Comment améliorer la résolution des images avec les outils d'Invoke pour les agrandir.",
"title": "Upscaling"
},
"howDoIGenerateAndSaveToTheGallery": {
"description": "Étapes pour générer et enregistrer des images dans la galerie.",
"title": "Comment générer et enregistrer dans la galerie?"
},
"usingControlLayersAndReferenceGuides": {
"title": "Utilisation des couche de contrôle et des guides de référence",
"description": "Apprenez à guider la création de vos images avec des couche de contrôle et des images de référence."
},
"exploringAIModelsAndConceptAdapters": {
"description": "Plongez dans les modèles d'IA et découvrez comment utiliser les adaptateurs de concepts pour un contrôle créatif.",
"title": "Exploration des modèles d'IA et des adaptateurs de concepts"
},
"howDoIUseControlNetsAndControlLayers": {
"title": "Comment utiliser les réseaux de contrôle et les couches de contrôle?",
"description": "Apprenez à appliquer des couches de contrôle et des ControlNets à vos images."
},
"creatingAndComposingOnInvokesControlCanvas": {
"description": "Apprenez à composer des images en utilisant le canvas de contrôle d'Invoke.",
"title": "Créer et composer sur le canvas de contrôle d'Invoke"
},
"howDoIEditOnTheCanvas": {
"title": "Comment puis-je modifier sur la toile?",
"description": "Guide pour éditer des images directement sur la toile."
},
"howDoIDoImageToImageTransformation": {
"title": "Comment effectuer une transformation d'image à image?",
"description": "Tutoriel sur la réalisation de transformations d'image à image dans Invoke."
},
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
"title": "Comment utiliser les IP Adapters globaux et les images de référence?",
"description": "Introduction à l'ajout d'images de référence et IP Adapters globaux."
},
"howDoIUseInpaintMasks": {
"title": "Comment utiliser les masques d'inpainting?"
},
"creatingYourFirstImage": {
"title": "Créer votre première image",
"description": "Introduction à la création d'une image à partir de zéro en utilisant les outils d'Invoke."
},
"understandingImageToImageAndDenoising": {
"title": "Comprendre l'Image-à-Image et le Débruitage",
"description": "Aperçu des transformations d'image à image et du débruitage dans Invoke."
}
},
"gettingStarted": "Commencer",
"studioSessionsDesc1": "Consultez le <StudioSessionsPlaylistLink /> pour des approfondissements sur Invoke.",
"studioSessionsDesc2": "Rejoignez notre <DiscordLink /> pour participer aux sessions en direct et poser vos questions. Les sessions sont ajoutée dans la playlist la semaine suivante.",
"supportVideos": "Vidéos d'assistance",
"controlCanvas": "Contrôler la toile"
}
}

View File

@@ -611,8 +611,7 @@
"hfForbidden": "Non hai accesso a questo modello HF",
"hfTokenLabel": "Gettone HuggingFace (richiesto per alcuni modelli)",
"hfForbiddenErrorMessage": "Consigliamo di visitare la pagina del repository su HuggingFace.com. Il proprietario potrebbe richiedere l'accettazione dei termini per poter effettuare il download.",
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
"controlLora": "Controllo LoRA"
"hfTokenInvalidErrorMessage2": "Aggiornalo in "
},
"parameters": {
"images": "Immagini",
@@ -680,8 +679,7 @@
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi elementi, massimo {{maxItems}}",
"canvasIsSelectingObject": "La tela è occupata (selezione dell'oggetto)",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi pochi elementi, minimo {{minItems}}",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} raccolta vuota",
"fluxModelMultipleControlLoRAs": "È possibile utilizzare solo 1 Controllo LoRA alla volta"
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} raccolta vuota"
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -2173,7 +2171,8 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"<StrongComponent>Livelli di controllo Flux</StrongComponent>: nuovi modelli di controllo per il rilevamento dei bordi e la mappatura della profondità sono ora supportati per i modelli di Flux dev."
"<StrongComponent>FLUX Regional Guidance (beta)</StrongComponent>: la nostra versione beta di FLUX Regional Guidance è attiva per il controllo dei prompt regionali.",
"<StrongComponent>Vari miglioramenti dell'esperienza utente</StrongComponent>: numerosi piccoli miglioramenti dell'esperienza utente e della qualità della vita in tutta l'app."
]
},
"system": {
@@ -2197,8 +2196,7 @@
"events": "Eventi",
"system": "Sistema",
"metadata": "Metadati",
"logNamespaces": "Elementi del registro",
"dnd": "Trascina e rilascia"
"logNamespaces": "Elementi del registro"
},
"enableLogging": "Abilita la registrazione"
},

View File

@@ -46,7 +46,7 @@
"menuItemAutoAdd": "Tự động thêm cho Bảng này",
"move": "Di Chuyển",
"topMessage": "Bảng này chứa ảnh được dùng với những tính năng sau:",
"uncategorized": "Chưa Sắp Xếp",
"uncategorized": "Chưa Phân Loại",
"archived": "Được Lưu Trữ",
"loading": "Đang Tải...",
"selectBoard": "Chọn Bảng",
@@ -91,7 +91,7 @@
"sideBySide": "Cạnh Nhau",
"alwaysShowImageSizeBadge": "Luôn Hiển Thị Kích Thước Ảnh",
"autoAssignBoardOnClick": "Tự Động Gán Vào Bảng Khi Nhấp Chuột",
"jump": "Nhảy Đến",
"jump": "Nhảy Vào",
"go": "Đi",
"autoSwitchNewImages": "Tự Động Đổi Sang Hình Ảnh Mới",
"featuresWillReset": "Nếu bạn xoá hình ảnh này, những tính năng đó sẽ lập tức được khởi động lại.",
@@ -133,7 +133,7 @@
"alpha": "Alpha",
"edit": "Sửa",
"nodes": "Workflow",
"format": "Định Dạng",
"format": "định dạng",
"delete": "Xoá",
"details": "Chi Tiết",
"imageFailedToLoad": "Không Thể Tải Hình Ảnh",
@@ -164,7 +164,7 @@
"discordLabel": "Discord",
"back": "Trở Về",
"advanced": "Nâng Cao",
"batch": "Quản Lý ",
"batch": "Quản Lý Hàng Loạt",
"modelManager": "Quản Lý Model",
"dontShowMeThese": "Không hiển thị thứ này",
"ok": "OK",
@@ -196,7 +196,7 @@
"areYouSure": "Bạn chắc chứ?",
"ai": "ai",
"aboutDesc": "Sử dụng Invoke cho công việc? Xem thử:",
"aboutHeading": "Quyền Năng Sáng Tạo Của Riêng",
"aboutHeading": "Sở Hữu Khả Năng Sáng Tạo Cho Riêng Mình",
"enabled": "Đã Bật",
"close": "Đóng",
"data": "Dữ Liệu",
@@ -232,24 +232,24 @@
"enqueueing": "Xếp Vào Hàng Hàng Loạt",
"prompts_other": "Lệnh",
"iterations_other": "Lặp Lại",
"total": "Tổng",
"total": "Tổng Cộng",
"pruneFailed": "Có Vấn Đề Khi Cắt Bớt Mục Khỏi Hàng",
"clearSucceeded": "Hàng Đã Được Dọn Sạch",
"cancel": "Huỷ Bỏ",
"clearQueueAlertDialog2": "Bạn chắc chắn muốn dọn sạch hàng không?",
"queueEmpty": "Hàng Trống",
"queueBack": "Thêm Vào Hàng",
"batchFieldValues": "Giá Trị Vùng Theo Lô",
"batchFieldValues": "Giá Trị Vùng Hàng Loạt",
"openQueue": "Mở Queue",
"pause": "Dừng Lại",
"pauseFailed": "Có Vấn Đề Khi Dừng Lại Bộ Xử Lý",
"batchQueued": " Đã Vào Hàng",
"batchFailedToQueue": "Lỗi Khi Xếp Vào Hàng",
"batchQueued": "Hàng Loạt Đã Vào hàng",
"batchFailedToQueue": "Lỗi Khi Xếp Hàng Loạt Vào Hàng",
"next": "Tiếp Theo",
"in_progress": "Đang Chạy",
"in_progress": "Đang Tiến Hành",
"failed": "Thất Bại",
"canceled": "Bị Huỷ",
"cancelBatchFailed": "Có Vấn Đề Khi Huỷ Bỏ ",
"cancelBatchFailed": "Có Vấn Đề Khi Huỷ Bỏ Hàng Loạt",
"workflows": "Workflow (Luồng làm việc)",
"canvas": "Canvas (Vùng ảnh)",
"upscaling": "Upscale (Nâng Cấp Chất Lượng Hình Ảnh)",
@@ -272,19 +272,19 @@
"resumeTooltip": "Tiếp Tục Bộ Xử Lý",
"clearFailed": "Có Vấn Đề Khi Dọn Dẹp Hàng",
"generations_other": "Máy Tạo Sinh",
"cancelBatch": "Huỷ Bỏ ",
"cancelBatch": "Huỷ Bỏ Hàng Loạt",
"status": "Trạng Thái",
"pending": "Đang Chờ",
"gallery": "Thư Viện",
"front": "trước",
"batch": "",
"batch": "Hàng Loạt",
"origin": "Nguồn Gốc",
"destination": "Điểm Đến",
"other": "Khác",
"graphFailedToQueue": "Lỗi Khi Xếp Đồ Thị Vào Hàng",
"notReady": "Không Thể Xếp Hàng",
"cancelItem": "Huỷ Bỏ Mục",
"cancelBatchSucceeded": " Đã Huỷ Bỏ",
"cancelBatchSucceeded": "Mục Hàng Loạt Đã Huỷ Bỏ",
"current": "Hiện Tại",
"time": "Thời Gian",
"completed": "Hoàn Tất",
@@ -294,7 +294,7 @@
"completedIn": "Hoàn tất trong",
"graphQueued": "Đồ Thị Đã Vào Hàng",
"batchQueuedDesc_other": "Thêm {{count}} phiên vào {{direction}} của hàng",
"batchSize": "Kích Thước "
"batchSize": "Kích Thước Vùng Hàng Loạt"
},
"hotkeys": {
"canvas": {
@@ -457,8 +457,8 @@
"title": "Gợi Lại Tất Cả Metadata"
},
"recallSeed": {
"title": "Gợi Lại Hạt Giống",
"desc": "Gợi lại hạt giống của ảnh hiện tại."
"title": "Gợi Lại Tham Số Hạt Giống",
"desc": "Gợi lại tham số hạt giống của ảnh hiện tại."
},
"useSize": {
"title": "Dùng Kích Thước",
@@ -490,7 +490,7 @@
"title": "Đổi Ảnh So Sánh"
},
"remix": {
"desc": "Gợi lại tất cả metadata cho hạt giống của ảnh hiện tại.",
"desc": "Gợi lại tất cả metadata cho tham số hạt giống của ảnh hiện tại.",
"title": "Phối Lại"
},
"runPostprocessing": {
@@ -633,7 +633,7 @@
"modelManager": "Quản Lý Model",
"name": "Tên",
"noModelSelected": "Không Có Model Được Chọn",
"installQueue": "Danh Sách Tải Xuống",
"installQueue": "Tải Xuống Danh Sách Đợi",
"modelDeleteFailed": "Xoá model thất bại",
"inplaceInstallDesc": "Tải xuống model mà không sao chép toàn bộ tài nguyên. Khi sử dụng model, nó được sẽ tải từ vị trí được đặt. Nếu bị tắt, toàn bộ tài nguyên của model sẽ được sao chép vào thư mục quản lý model của Invoke trong quá trình tải xuống.",
"modelType": "Loại Model",
@@ -705,7 +705,7 @@
"spandrelImageToImage": "Hình Ảnh Sang Hình Ảnh (Spandrel)",
"starterBundles": "Quà Tân Thủ",
"vae": "VAE",
"urlOrLocalPath": "URL / Đường Dẫn",
"urlOrLocalPath": "URL Hoặc Đường Dẫn Trên Máy Chủ",
"triggerPhrases": "Từ Ngữ Kích Hoạt",
"variant": "Biến Thể",
"urlOrLocalPathHelper": "Url cần chỉ vào một tệp duy nhất. Còn đường dẫn trên máy chủ có thể chỉ vào một tệp hoặc một thư mục cho chỉ một model diffusers.",
@@ -723,7 +723,7 @@
"starterModels": "Model Khởi Đầu",
"typePhraseHere": "Thêm từ ngữ ở đây",
"upcastAttention": "Upcast Attention",
"vaePrecision": "Độ Chuẩn VAE",
"vaePrecision": "VAE Precision",
"installingBundle": "Đang Tải Nguyên Bộ",
"installingModel": "Đang Tải Model",
"installingXModels_other": "Đang tải {{count}} model",
@@ -751,18 +751,18 @@
"parameterSet": "Dữ liệu tham số {{parameter}}",
"positivePrompt": "Lệnh Tích Cực",
"recallParameter": "Gợi Nhớ {{label}}",
"seed": "Hạt Giống",
"seed": "Tham Số Hạt Giống",
"negativePrompt": "Lệnh Tiêu Cực",
"noImageDetails": "Không tìm thấy chí tiết ảnh",
"strength": "Mức độ mạnh từ ảnh sang ảnh",
"Threshold": "Ngưỡng Nhiễu",
"width": "Chiều Rộng",
"steps": "Số Bước",
"steps": "Tham Số Bước",
"vae": "VAE",
"workflow": "Workflow",
"seamlessXAxis": "Trục X Liền Mạch",
"seamlessYAxis": "Trục Y Liền Mạch",
"cfgScale": "Thang CFG",
"cfgScale": "Thước Đo CFG",
"allPrompts": "Tất Cả Lệnh",
"generationMode": "Chế Độ Tạo Sinh",
"height": "Chiều Dài",
@@ -776,7 +776,7 @@
},
"accordions": {
"generation": {
"title": "Generation (Tạo Sinh)"
"title": "Generation (Máy Tạo Sinh)"
},
"image": {
"title": "Hình Ảnh"
@@ -786,7 +786,7 @@
"options": "Lựa Chọn $t(accordions.advanced.title)"
},
"compositing": {
"coherenceTab": "Coherence Pass (Liên Kết)",
"coherenceTab": "Coherence Pass (Lớp Kết Hợp)",
"title": "Kết Hợp",
"infillTab": "Infill (Lấp Đầy)"
},
@@ -795,21 +795,21 @@
}
},
"invocationCache": {
"disableSucceeded": "Bộ Nhớ Đệm Đã Tắt",
"disableFailed": "Có Vấn Đề Khi Tắt Bộ Nhớ Đệm",
"hits": "Số Lần Trúng",
"maxCacheSize": "Tối Đa",
"cacheSize": "Tổng Cache",
"enableFailed": "Có Vấn Đề Khi Bật Bộ Nhớ Đệm",
"disableSucceeded": "Bộ Nhớ Đệm Kích Hoạt Đã Tắt",
"disableFailed": "Có Vấn Đề Khi Tắt Bộ Nhớ Đệm Kích Hoạt",
"hits": "Truy Cập Bộ Nhớ Đệm",
"maxCacheSize": "Kích Thước Tối Đa Bộ Nhớ Đệm",
"cacheSize": "Kích Thước Bộ Nhớ Đệm",
"enableFailed": "Có Vấn Đề Khi Bật Bộ Nhớ Đệm Kích Hoạt",
"disable": "Tắt",
"invocationCache": "Bộ Nhớ Đệm",
"clearSucceeded": "Bộ Nhớ Đệm Đã Được Dọn",
"enableSucceeded": "Bộ Nhớ Đệm Đã Bật",
"invocationCache": "Bộ Nhớ Đệm Kích Hoạt",
"clearSucceeded": "Bộ Nhớ Đệm Kích Hoạt Đã Được Dọn",
"enableSucceeded": "Bộ Nhớ Đệm Kích Hoạt Đã Bật",
"useCache": "Dùng Bộ Nhớ Đệm",
"enable": "Bật",
"misses": "Số Lần Trật",
"misses": "Không Truy Cập Bộ Nhớ Đệm",
"clear": "Dọn Dẹp",
"clearFailed": "Có Vấn Đề Khi Dọn Dẹp Bộ Nhớ Đệm"
"clearFailed": "Có Vấn Đề Khi Dọn Dẹp Bộ Nhớ Đệm Kích Hoạt"
},
"hrf": {
"metadata": {
@@ -964,9 +964,9 @@
},
"popovers": {
"paramCFGRescaleMultiplier": {
"heading": "Hệ Số Nhân Thang CFG",
"heading": "CFG Rescale Multiplier",
"paragraphs": [
"Hệ số nhân điều chỉnh để hướng dẫn cho CFG, dùng cho model được huấn luyện bằng zero-terminal SNR (ztsnr).",
"Hệ số nhân điều chỉnh cho hướng dẫn CFG, dùng cho model được huấn luyện bằng zero-terminal SNR (ztsnr).",
"Giá trị khuyến cáo là 0.7 cho những model này."
]
},
@@ -978,10 +978,10 @@
]
},
"paramCFGScale": {
"heading": "Thang CFG",
"heading": "Thước Đo CFG",
"paragraphs": [
"Điều khiển mức độ lệnh tác động lên quá trình tạo sinh.",
"Giá trị của Thang CFG quá cao có thể tạo độ bão hoà quá mức và khiến ảnh tạo sinh bị méo mó. "
"Giá trị của Thước đo CFG quá cao có thể tạo độ bão hoà quá mức và khiến ảnh tạo sinh bị méo mó. "
]
},
"paramScheduler": {
@@ -992,13 +992,13 @@
]
},
"compositingCoherencePass": {
"heading": "Coherence Pass (Liên Kết)",
"heading": "Coherence Pass (Lớp Kết Hợp)",
"paragraphs": [
"Bước thứ hai trong quá trình khử nhiễu để hợp nhất với ảnh inpaint/outpaint."
]
},
"refinerNegativeAestheticScore": {
"heading": "Điểm Khác Tiêu Chuẩn",
"heading": "Điểm Tiêu Cực Cho Tiêu Chuẩn",
"paragraphs": [
"Trọng lượng để tạo sinh ảnh giống với ảnh có điểm tiêu chuẩn thấp, dựa vào dữ liệu huấn luyện."
]
@@ -1006,26 +1006,26 @@
"refinerCfgScale": {
"paragraphs": [
"Điều khiển mức độ lệnh tác động lên quá trình tạo sinh.",
"Giống với thang CFG để tạo sinh."
"Giống với thước đo CFG để tạo sinh."
],
"heading": "Thang CFG"
"heading": "Thước Đo CFG"
},
"refinerSteps": {
"heading": "Số Bước",
"heading": "Tham Số Bước",
"paragraphs": [
"Số bước diễn ra trong khi tinh chế các phần nhỏ của quá trình tạo sinh.",
"Giống với số bước để tạo sinh."
"Giống với tham số bước để tạo sinh."
]
},
"paramSteps": {
"heading": "Số Bước",
"heading": "Tham Số Bước",
"paragraphs": [
"Số bước dùng để biểu diễn trong mỗi lần tạo sinh.",
"Số bước càng cao thường sẽ tạo ra ảnh tốt hơn nhưng ngốn nhiều thời gian hơn."
]
},
"paramWidth": {
"heading": "Rộng",
"heading": "Chiều Rộng",
"paragraphs": [
"Chiều rộng của ảnh tạo sinh. Phải là bội số của 8."
]
@@ -1052,14 +1052,14 @@
"paragraphs": [
"Trọng lượng để tạo sinh ảnh giống với ảnh có điểm tiêu chuẩn cao, dựa vào dữ liệu huấn luyện."
],
"heading": "Điểm Giống Tiêu Chuẩn"
"heading": "Điểm Tích Cực Cho Tiêu Chuẩn"
},
"paramVAEPrecision": {
"paragraphs": [
"Độ chính xác dùng trong khi mã hoá và giải mã VAE.",
"Chính xác một nửa/Fp16 sẽ hiệu quả hơn, đổi lại cho những thay đổi nhỏ với ảnh."
],
"heading": "Độ Chuẩn VAE"
"heading": "VAE Precision"
},
"fluxDevLicense": {
"heading": "Giấy Phép Phi Thương Mại",
@@ -1078,14 +1078,14 @@
"paragraphs": [
"Chiều dài của ảnh tạo sinh. Phải là bội số của 8."
],
"heading": "Dài"
"heading": "Chiều Dài"
},
"paramRatio": {
"paragraphs": [
"Tỉ lệ khung hình của kích thước của ảnh được tạo ra.",
"Kích thước ảnh (theo số lượng pixel) tương đương với 512x512 được khuyến nghị cho model SD1.5 và kích thước tương đương với 1024x1024 được khuyến nghị cho model SDXL."
],
"heading": "Tỉ Lệ"
"heading": "Tỉ Lệ Khung Hình"
},
"seamlessTilingYAxis": {
"paragraphs": [
@@ -1140,19 +1140,19 @@
},
"compositingCoherenceMinDenoise": {
"paragraphs": [
"Độ khử nhiễu nhỏ nhất cho chế độ liên kết",
"Sức mạnh khử nhiễu nhỏ nhất cho vùng liên kết khi inpaint/outpaint"
"Sức mạnh khử nhiễu nhỏ nhất cho chế độ kết hợp",
"Sức mạnh khử nhiễu nhỏ nhất cho vùng kết hợp khi inpaint/outpaint"
],
"heading": "Min Khử Nhiễu"
"heading": "Độ Khử Nhiễu Tối Thiểu"
},
"compositingCoherenceEdgeSize": {
"paragraphs": [
"Kích c cạnh dùng cho coherence pass."
"Kích thước cạnh cho lớp kết hợp."
],
"heading": "Kích Cỡ Cạnh"
"heading": "Kích Thước Cạnh"
},
"compositingMaskBlur": {
"heading": "Độ Mờ Vùng",
"heading": "Mask Blur",
"paragraphs": [
"Độ mờ của phần được phủ."
]
@@ -1180,7 +1180,7 @@
"noiseUseCPU": {
"paragraphs": [
"Điều chỉnh độ nhiễu được tạo ra trên CPU hay GPU.",
"Với Độ nhiễu CPU được bật, một hạt giống cụ thể sẽ tạo ra hình ảnh giống nhau trên mọi máy.",
"Với Độ nhiễu CPU được bật, một tham số hạt giống cụ thể sẽ tạo ra hình ảnh giống nhau trên mọi máy.",
"Không có tác động nào đến hiệu suất khi bật Độ nhiễu CPU."
],
"heading": "Dùng Độ Nhiễu CPU"
@@ -1210,7 +1210,7 @@
"• Bước Bắt Đầu (%): Chỉ định lúc bắt đầu áp dụng chỉ dẫn từ layer này trong quá trình tạo sinh.",
"• Bước Kết Thúc (%): Chỉ định lúc dừng áp dụng chỉ dẫn của layer này và trở về chỉ dẫn chung từ model và các thiết lập khác."
],
"heading": "Phần Trăm Số Bước Khi Bắt Đầu/Kết Thúc"
"heading": "Phần Trăm Tham Số Bước Khi Bắt Đầu/Kết Thúc"
},
"scale": {
"heading": "Tỉ Lệ",
@@ -1232,12 +1232,12 @@
},
"dynamicPromptsSeedBehaviour": {
"paragraphs": [
"Điều khiển cách hạt giống được dùng khi tạo sinh từ lệnh.",
"Cứ mỗi lần lặp, một hạt giống mới sẽ được dùng. Dùng nó để khám phá những biến thể từ lệnh trên mỗi hạt giống.",
"Ví dụ, nếu bạn có 5 lệnh, mỗi ảnh sẽ dùng cùng hạt giống.",
"Một hạt giống mới sẽ được dùng cho từng ảnh. Nó tạo ra nhiều biến thể."
"Điều khiển cách tham số hạt giống được dùng khi tạo sinh từ lệnh.",
"Cứ mỗi lần lặp, một tham số hạt giống mới sẽ được dùng. Dùng nó để khám phá những biến thể từ lệnh trên mỗi tham số hạt giống.",
"Ví dụ, nếu bạn có 5 lệnh, mỗi ảnh sẽ dùng cùng tham số hạt giống.",
"Một tham số hạt giống mới sẽ được dùng cho từng ảnh. Nó tạo ra nhiều biến thể."
],
"heading": "Hành Vi Của Hạt Giống"
"heading": "Hành Động Cho Tham Số Hạt Giống"
},
"paramGuidance": {
"heading": "Hướng Dẫn",
@@ -1266,10 +1266,10 @@
},
"paramAspect": {
"paragraphs": [
"Tỉ lệ khung hành của ảnh tạo sinh. Điều chỉnh tỉ lệ s cập nhật chiều rộng và chiều dài tương ứng.",
"\"Tối ưu hoá\" sẽ đặt chiều rộng và chiều dài vào kích thước tối ưu cho model được chọn."
"Tỉ lệ khung hành của ảnh tạo sinh. Điều chỉnh tỉ lệ se cập nhật Chiều Rộng và Chiều Dài tương ứng.",
"\"Tối ưu hoá\" sẽ đặt Chiều Rộng và Chiều Dài vào kích thước tối ưu cho model được chọn."
],
"heading": "Tỉ Lệ"
"heading": "Khung Hình"
},
"paramNegativeConditioning": {
"heading": "Lệnh Tiêu Cực",
@@ -1289,7 +1289,7 @@
"Nơi trong quá trình xử lý tạo sinh mà refiner bắt đầu được dùng.",
"0 nghĩa là bộ refiner sẽ được dùng trong toàn bộ quá trình tạo sinh , 0.8 nghĩa là refiner sẽ được dùng trong 20% cuối cùng quá trình tạo sinh."
],
"heading": "Bắt Đầu Refiner"
"heading": "Nơi Bắt Đầu Refiner"
},
"paramUpscaleMethod": {
"paragraphs": [
@@ -1310,9 +1310,9 @@
"heading": "Độ Cấu Trúc"
},
"infillMethod": {
"heading": "Cách Infill",
"heading": "Cách Thức Infill",
"paragraphs": [
"Cách thức infill trong quá trình inpaint/outpaint."
"Cách thức làm infill trong quá trình inpaint/outpaint."
]
},
"paramDenoisingStrength": {
@@ -1341,7 +1341,7 @@
"Điều khiển độ nhiễu ban đầu được dùng để tạo sinh.",
"Tắt lựa chọn \"Ngẫu Nhiên\" để tạo ra kết quá y hệt nhau với cùng một thiết lập tạo sinh."
],
"heading": "Hạt Giống"
"heading": "Tham Số Hạt Giống"
},
"clipSkip": {
"heading": "CLIP Skip",
@@ -1366,7 +1366,7 @@
"compositingCoherenceMode": {
"heading": "Chế Độ",
"paragraphs": [
"Cách thức được dùng để liên kết ảnh với vùng bao phủ vừa được tạo sinh."
"Cách thức được dùng để kết hợp ảnh với vùng bao phủ vừa được tạo sinh."
]
},
"paramModel": {
@@ -1391,7 +1391,7 @@
},
"models": {
"addLora": "Thêm LoRA",
"concepts": "LoRA",
"concepts": "Khái Niệm",
"loading": "đang tải",
"lora": "LoRA",
"noMatchingLoRAs": "Không có LoRA phù hợp",
@@ -1406,7 +1406,7 @@
"postProcessing": "Xử Lý Hậu Kỳ (Shift + U)",
"symmetry": "Tính Đối Xứng",
"type": "Loại",
"seed": "Hạt Giống",
"seed": "Tham Số Hạt Giống",
"processImage": "Xử Lý Hình Ảnh",
"useSize": "Dùng Kích Thước",
"invoke": {
@@ -1435,19 +1435,19 @@
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: quá nhiều mục, tối đa {{maxItems}}",
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)"
},
"cfgScale": "Thang CFG",
"useSeed": "Dùng Hạt Giống",
"cfgScale": "Thước Đo CFG",
"useSeed": "Dùng Tham Số Hạt Giống",
"imageActions": "Hành Động Với Hình Ảnh",
"steps": "Số Bước",
"aspect": "Tỉ Lệ",
"steps": "Tham Số Bước",
"aspect": "Khung Hình",
"coherenceMode": "Chế Độ",
"coherenceEdgeSize": "Kích Cỡ Cạnh",
"coherenceMinDenoise": "Min Khử Nhiễu",
"coherenceEdgeSize": "Kích Thước Cạnh",
"coherenceMinDenoise": "Tham Số Khử Nhiễu Nhỏ Nhất",
"denoisingStrength": "Sức Mạnh Khử Nhiễu",
"infillMethod": "Cách Infill",
"infillMethod": "Cách Thức Infill",
"setToOptimalSize": "Tối ưu hoá kích cỡ cho model",
"maskBlur": "Độ Mờ Vùng",
"width": "Rộng",
"maskBlur": "Mask Blur",
"width": "Chiều Rộng",
"scale": "Tỉ Lệ",
"recallMetadata": "Gợi Lại Metadata",
"clipSkip": "CLIP Skip",
@@ -1455,7 +1455,7 @@
"boxBlur": "Box Blur",
"gaussianBlur": "Gaussian Blur",
"staged": "Staged (Tăng khử nhiễu có hệ thống)",
"scaledHeight": "Tỉ Lệ Dài",
"scaledHeight": "Tỉ Lệ Chiều Dài",
"cancel": {
"cancel": "Huỷ"
},
@@ -1463,12 +1463,12 @@
"optimizedImageToImage": "Tối Ưu Hoá Hình Ảnh Sang Hình Ảnh",
"sendToCanvas": "Gửi Vào Canvas",
"sendToUpscale": "Gửi Vào Upscale",
"scaledWidth": "Tỉ Lệ Rộng",
"scaledWidth": "Tỉ Lệ Chiều Rộng",
"scheduler": "Scheduler",
"seamlessXAxis": "Trục X Liền Mạch",
"seamlessYAxis": "Trục Y Liền Mạch",
"guidance": "Hướng Dẫn",
"height": "Dài",
"height": "Chiều Cao",
"noiseThreshold": "Ngưỡng Nhiễu",
"negativePromptPlaceholder": "Lệnh Tiêu Cực",
"iterations": "Lặp Lại",
@@ -1481,13 +1481,13 @@
"useCpuNoise": "Dùng Độ Nhiễu CPU",
"remixImage": "Phối Lại Hình Ảnh",
"showOptionsPanel": "Hiển Thị Bảng Bên Cạnh (O hoặc T)",
"shuffle": "Xáo Trộn",
"shuffle": "Xáo Trộn Tham Số Hạt Giống",
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (lớn quá)",
"cfgRescaleMultiplier": "Hệ Số Nhân Thang CFG",
"cfgRescaleMultiplier": "CFG Rescale Multiplier",
"setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (nhỏ quá)",
"images": "Ảnh Ban Đầu",
"controlNetControlMode": "Chế Độ Điều Khiển",
"lockAspectRatio": "Khoá Tỉ Lệ",
"lockAspectRatio": "Khoá Tỉ Lệ Khung Hình",
"swapDimensions": "Hoán Đổi Kích Thước",
"copyImage": "Sao Chép Hình Ảnh",
"downloadImage": "Tải Xuống Hình Ảnh",
@@ -1500,11 +1500,11 @@
},
"dynamicPrompts": {
"seedBehaviour": {
"perIterationDesc": "Sử dụng hạt giống khác nhau cho mỗi lần lặp lại",
"perPromptDesc": "Sử dụng hạt giống khác nhau cho mỗi hình ảnh",
"label": "Hành Động Cho Hạt Giống",
"perPromptLabel": "Một Hạt Giống Mỗi Ảnh",
"perIterationLabel": "Hạt Giống Mỗi Lần Lặp Lại"
"perIterationDesc": "Sử dụng tham số hạt giống khác nhau cho mỗi lần lặp lại",
"perPromptDesc": "Sử dụng tham số hạt giống khác nhau cho mỗi hình ảnh",
"label": "Hành Động Cho Tham Số Hạt Giống",
"perPromptLabel": "Tham Số Hạt Giống Mỗi Hình Ảnh",
"perIterationLabel": "Tham Số Hạt Giống Mỗi Lần Lặp Lại"
},
"loading": "Tạo Sinh Dùng Dynamic Prompt...",
"showDynamicPrompts": "HIện Dynamic Prompt",
@@ -1515,9 +1515,9 @@
"settings": {
"beta": "Beta",
"general": "Cài Đặt Chung",
"confirmOnDelete": "Xác Nhận Khi Xoá",
"confirmOnDelete": "Chắp Nhận Xoá",
"developer": "Nhà Phát Triển",
"confirmOnNewSession": "Xác Nhận Khi Mở Phiên Mới",
"confirmOnNewSession": "Chắp Nhận Mở Phiên Mới",
"antialiasProgressImages": "Xử Lý Khử Răng Cưa Hình Ảnh",
"models": "Models",
"informationalPopoversDisabledDesc": "Hộp thoại hỗ trợ thông tin đã tắt. Bật lại trong Cài đặt.",
@@ -1549,20 +1549,20 @@
},
"sdxl": {
"loading": "Đang Tải...",
"posAestheticScore": "Điểm Giống Tiêu Chuẩn",
"steps": "Số Bước",
"refinerSteps": "Số Bước Refiner",
"posAestheticScore": "Điểm Tích Cực Cho Tiêu Chuẩn",
"steps": "Tham Số Bước",
"refinerSteps": "Tham Số Bước Refiner",
"refinermodel": "Model Refiner",
"refinerStart": "Bắt Đầu Refiner",
"refinerStart": "Nơi Bắt Đầu Refiner",
"denoisingStrength": "Sức Mạnh Khử Nhiễu",
"posStylePrompt": "Điểm Tích Cực Cho Lệnh Phong Cách",
"scheduler": "Scheduler",
"refiner": "Refiner",
"cfgScale": "Thang CFG",
"cfgScale": "Thước Đo CFG",
"concatPromptStyle": "Liên Kết Lệnh & Phong Cách",
"freePromptStyle": "Viết Lệnh Thủ Công Cho Phong Cách",
"negStylePrompt": "Điểm Tiêu Cực Cho Lệnh Phong Cách",
"negAestheticScore": "Điểm Khác Tiêu Chuẩn",
"negAestheticScore": "Điểm Tiêu Cực Cho Tiêu Chuẩn",
"noModelsAvailable": "Không có sẵn model"
},
"controlLayers": {
@@ -1992,8 +1992,7 @@
"generation": "Generation",
"system": "Hệ Thống",
"canvas": "Canvas",
"logNamespaces": "Vùng Ghi Log",
"dnd": "Kéo Thả"
"logNamespaces": "Nơi Được Log"
},
"logLevel": {
"logLevel": "Cấp Độ Log",
@@ -2157,7 +2156,8 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"<StrongComponent>Hướng Dẫn Khu Vực FLUX (beta)</StrongComponent>: Bản beta của Hướng Dẫn Khu Vực FLUX của chúng ta đã có mắt tại bảng điều khiển lệnh khu vực."
"<StrongComponent>Hướng Dẫn Khu Vực FLUX (beta)</StrongComponent>: Bản beta của Hướng Dẫn Khu Vực FLUX của chúng ta đã có mắt tại bảng điều khiển lệnh khu vực.",
"<StrongComponent>Nhiều Cải Tiến Ở UX</StrongComponent>: Một số nâng cấp nhỏ ở trải nghiệm và chất lượng người dùng trên toàn bộ ứng dụng."
]
},
"upsell": {

View File

@@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import type { Result } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
@@ -9,9 +10,11 @@ import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGr
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { toast } from 'features/toast/toast';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert, AssertionError } from 'tsafe';
import type { JsonObject } from 'type-fest';
@@ -22,32 +25,42 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
effect: async (action, { getState, dispatch }) => {
log.debug('Enqueue requested');
const state = getState();
const model = state.params.model;
const { prepend } = action.payload;
const manager = $canvasManager.get();
assert(manager, 'No canvas manager');
assert(manager, 'No model found in state');
let buildGraphResult: Result<
{
g: Graph;
noise: Invocation<'noise' | 'flux_denoise' | 'sd3_denoise'>;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'sd3_text_encoder'>;
},
Error
>;
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, manager);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, manager);
case `sd-3`:
return await buildSD3Graph(state, manager);
case `flux`:
return await buildFLUXGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
});
switch (base) {
case 'sdxl':
buildGraphResult = await withResultAsync(() => buildSDXLGraph(state, manager));
break;
case 'sd-1':
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
case `sd-3`:
buildGraphResult = await withResultAsync(() => buildSD3Graph(state, manager));
break;
case `flux`:
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
break;
default:
assert(false, `No graph builders for base ${base}`);
}
if (buildGraphResult.isErr()) {
let description: string | null = null;

View File

@@ -30,7 +30,7 @@ import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlLayerModelConfig,
isControlNetOrT2IAdapterModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
@@ -190,7 +190,7 @@ const handleLoRAModels: ModelHandler = (models, state, dispatch, log) => {
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const caModels = models.filter(isControlLayerModelConfig);
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
const selectedControlAdapterModel = entity.controlAdapter.model;
// `null` is a valid control adapter model - no need to do anything.

View File

@@ -17,6 +17,8 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
startAppListening({
actionCreator: socketConnected,
effect: async (action, { dispatch, getState, cancelActiveListeners, delay }) => {
log.debug('Connected');
/**
* The rest of this listener has recovery logic for when the socket disconnects and reconnects.
*

View File

@@ -57,7 +57,7 @@ export class Err<E> {
* @template T The type of the value in the `Ok` case.
* @template E The type of the error in the `Err` case.
*/
type Result<T, E = Error> = Ok<T> | Err<E>;
export type Result<T, E = Error> = Ok<T> | Err<E>;
/**
* Creates a successful result.

View File

@@ -26,12 +26,7 @@ import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actio
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
ImageDTO,
T2IAdapterModelConfig,
} from 'services/api/types';
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
const buildSelectControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) =>
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
@@ -71,7 +66,7 @@ export const ControlLayerControlAdapter = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => {
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
// filter config.
@@ -102,12 +97,12 @@ export const ControlLayerControlAdapter = memo(() => {
const filterConfig = defaultFilterForNewModel.buildDefaults();
if (isFiltering) {
adapter.filterer.$filterConfig.set(filterConfig);
// The user may have disabled auto-processing, so we should process the filter manually. This is essentially a
// no-op if auto-processing is already enabled, because the process method is debounced.
adapter.filterer.process();
} else {
adapter.filterer.start(filterConfig);
}
// The user may have disabled auto-processing, so we should process the filter manually. This is essentially a
// no-op if auto-processing is already enabled, because the process method is debounced.
adapter.filterer.process();
},
[adapter.filterer, dispatch, entityIdentifier]
);
@@ -163,9 +158,7 @@ export const ControlLayerControlAdapter = memo(() => {
<input {...uploadApi.getUploadInputProps()} />
</Flex>
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
{controlAdapter.type !== 'control_lora' && (
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
)}
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
{controlAdapter.type === 'controlnet' && !isFLUX && (
<ControlLayerControlAdapterControlMode
controlMode={controlAdapter.controlMode}

View File

@@ -4,27 +4,22 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlLayerModels } from 'services/api/hooks/modelsByType';
import type {
AnyModelConfig,
ControlLoRAModelConfig,
ControlNetModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type Props = {
modelKey: string | null;
onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => void;
onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
};
export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlLayerModels();
const [modelConfigs, { isLoading }] = useControlNetAndT2IAdapterModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChange = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
if (!modelConfig) {
return;
}

View File

@@ -18,7 +18,6 @@ import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/se
import type {
CanvasEntityIdentifier,
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
@@ -27,13 +26,8 @@ import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
IPAdapterModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
import { isControlLayerModelConfig, isIPAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
/**
* Selects the default control adapter configuration based on the model configurations and the base.
@@ -45,13 +39,13 @@ import { isControlLayerModelConfig, isIPAdapterModelConfig } from 'services/api/
export const selectDefaultControlAdapter = createSelector(
selectModelConfigsQuery,
selectBase,
(query, base): ControlNetConfig | T2IAdapterConfig | ControlLoRAConfig => {
(query, base): ControlNetConfig | T2IAdapterConfig => {
const { data } = query;
let model: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null = null;
let model: ControlNetModelConfig | T2IAdapterModelConfig | null = null;
if (data) {
const modelConfigs = modelConfigsAdapterSelectors
.selectAll(data)
.filter(isControlLayerModelConfig)
.filter(isControlNetOrT2IAdapterModelConfig)
.sort((a) => (a.type === 'controlnet' ? -1 : 1)); // Prefer ControlNet models
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
model = compatibleModels[0] ?? modelConfigs[0] ?? null;

View File

@@ -1,4 +1,4 @@
import { withResult, withResultAsync } from 'common/util/result';
import { withResultAsync } from 'common/util/result';
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
import type { CanvasEntityAdapter, CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -277,19 +277,13 @@ export class CanvasCompositorModule extends CanvasModuleBase {
this.log.warn({ rect, imageName: cachedImageName }, 'Cached image name not found, recompositing');
}
const getCompositeCanvasResult = withResult(() => this.getCompositeCanvas(adapters, rect, compositingOptions));
if (getCompositeCanvasResult.isErr()) {
this.log.error({ error: serializeError(getCompositeCanvasResult.error) }, 'Failed to get composite canvas');
throw getCompositeCanvasResult.error;
}
const canvas = this.getCompositeCanvas(adapters, rect, compositingOptions);
this.$isProcessing.set(true);
const blobResult = await withResultAsync(() => canvasToBlob(getCompositeCanvasResult.value));
const blobResult = await withResultAsync(() => canvasToBlob(canvas));
this.$isProcessing.set(false);
if (blobResult.isErr()) {
this.log.error({ error: serializeError(blobResult.error) }, 'Failed to convert composite canvas to blob');
throw blobResult.error;
}
const blob = blobResult.value;
@@ -495,45 +489,37 @@ export class CanvasCompositorModule extends CanvasModuleBase {
this.log.debug({ rect }, 'Calculating generation mode');
this.$isProcessing.set(true);
const generationModeResult = await withResultAsync(async () => {
const compositeRasterLayerTransparency = await this.getTransparency(
rasterLayerAdapters,
rect,
compositeRasterLayerHash
);
const compositeInpaintMaskTransparency = await this.getTransparency(
inpaintMaskAdapters,
rect,
compositeInpaintMaskHash
);
let generationMode: GenerationMode;
if (compositeRasterLayerTransparency === 'FULLY_TRANSPARENT') {
// When the initial image is fully transparent, we are always doing txt2img
generationMode = 'txt2img';
} else if (compositeRasterLayerTransparency === 'PARTIALLY_TRANSPARENT') {
// When the initial image is partially transparent, we are always outpainting
generationMode = 'outpaint';
} else if (compositeInpaintMaskTransparency === 'FULLY_TRANSPARENT') {
// compositeLayerTransparency === 'OPAQUE'
// When the inpaint mask is fully transparent, we are doing img2img
generationMode = 'img2img';
} else {
// Else at least some of the inpaint mask is opaque, so we are inpainting
generationMode = 'inpaint';
}
this.manager.cache.generationModeCache.set(hash, generationMode);
return generationMode;
});
const compositeRasterLayerTransparency = await this.getTransparency(
rasterLayerAdapters,
rect,
compositeRasterLayerHash
);
const compositeInpaintMaskTransparency = await this.getTransparency(
inpaintMaskAdapters,
rect,
compositeInpaintMaskHash
);
this.$isProcessing.set(false);
if (generationModeResult.isErr()) {
this.log.error({ error: serializeError(generationModeResult.error) }, 'Failed to calculate generation mode');
throw generationModeResult.error;
let generationMode: GenerationMode;
if (compositeRasterLayerTransparency === 'FULLY_TRANSPARENT') {
// When the initial image is fully transparent, we are always doing txt2img
generationMode = 'txt2img';
} else if (compositeRasterLayerTransparency === 'PARTIALLY_TRANSPARENT') {
// When the initial image is partially transparent, we are always outpainting
generationMode = 'outpaint';
} else if (compositeInpaintMaskTransparency === 'FULLY_TRANSPARENT') {
// compositeLayerTransparency === 'OPAQUE'
// When the inpaint mask is fully transparent, we are doing img2img
generationMode = 'img2img';
} else {
// Else at least some of the inpaint mask is opaque, so we are inpainting
generationMode = 'inpaint';
}
return generationModeResult.value;
this.manager.cache.generationModeCache.set(hash, generationMode);
return generationMode;
};
repr = () => {

View File

@@ -2,7 +2,6 @@ import type { Selector } from '@reduxjs/toolkit';
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
@@ -30,7 +29,6 @@ import {
isRasterLayerEntityIdentifier,
type Rect,
} from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import Konva from 'konva';
import { atom } from 'nanostores';
import rafThrottle from 'raf-throttle';
@@ -576,16 +574,9 @@ export abstract class CanvasEntityAdapterBase<
return stableHash(arg);
};
cropToBbox = async (): Promise<ImageDTO> => {
cropToBbox = (): Promise<ImageDTO> => {
const { rect } = this.manager.stateApi.getBbox();
const rasterizeResult = await withResultAsync(() =>
this.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } })
);
if (rasterizeResult.isErr()) {
toast({ status: 'error', title: 'Failed to crop to bbox' });
throw rasterizeResult.error;
}
return rasterizeResult.value;
return this.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } });
};
destroy = (): void => {

View File

@@ -11,14 +11,13 @@ import type { FilterConfig } from 'features/controlLayers/store/filters';
import { getFilterForModel, IMAGE_FILTERS } from 'features/controlLayers/store/filters';
import type { CanvasImageState, CanvasRenderableEntityType } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { toast } from 'features/toast/toast';
import Konva from 'konva';
import { debounce } from 'lodash-es';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
import { isControlLayerModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
import stableHash from 'stable-hash';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
@@ -204,7 +203,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
// If the parent is a control layer adapter, we should check if the model has a default filter and set it if so
const selectModelConfig = buildSelectModelConfig(
this.parent.state.controlAdapter.model.key,
isControlLayerModelConfig
isControlNetOrT2IAdapterModelConfig
);
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
// This always returns a filter
@@ -247,7 +246,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } })
);
if (rasterizeResult.isErr()) {
toast({ status: 'error', title: 'Failed to process filter' });
this.log.error({ error: serializeError(rasterizeResult.error) }, 'Error rasterizing entity');
this.$isProcessing.set(false);
return;

View File

@@ -27,7 +27,6 @@ import Konva from 'konva';
import type { GroupConfig } from 'konva/lib/Group';
import { throttle } from 'lodash-es';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
@@ -486,36 +485,30 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
this.log.trace({ rasterizeArgs }, 'Rasterizing entity');
this.manager.stateApi.$rasterizingAdapter.set(this.parent);
try {
const blob = await this.getBlob(rasterizeArgs);
if (this.manager._isDebugging) {
previewBlob(blob, 'Rasterized entity');
}
imageDTO = await uploadImage({
file: new File([blob], `${this.id}_rasterized.png`, { type: 'image/png' }),
image_category: 'other',
is_intermediate: true,
silent: true,
});
const imageObject = imageDTOToImageObject(imageDTO);
if (replaceObjects) {
await this.parent.bufferRenderer.setBuffer(imageObject);
this.parent.bufferRenderer.commitBuffer({ pushToState: false });
}
this.manager.stateApi.rasterizeEntity({
entityIdentifier: this.parent.entityIdentifier,
imageObject,
position: { x: Math.round(rect.x), y: Math.round(rect.y) },
replaceObjects,
});
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
return imageDTO;
} catch (error) {
this.log.error({ rasterizeArgs, error: serializeError(error) }, 'Failed to rasterize entity');
throw error;
} finally {
this.manager.stateApi.$rasterizingAdapter.set(null);
const blob = await this.getBlob(rasterizeArgs);
if (this.manager._isDebugging) {
previewBlob(blob, 'Rasterized entity');
}
imageDTO = await uploadImage({
file: new File([blob], `${this.id}_rasterized.png`, { type: 'image/png' }),
image_category: 'other',
is_intermediate: true,
silent: true,
});
const imageObject = imageDTOToImageObject(imageDTO);
if (replaceObjects) {
await this.parent.bufferRenderer.setBuffer(imageObject);
this.parent.bufferRenderer.commitBuffer({ pushToState: false });
}
this.manager.stateApi.rasterizeEntity({
entityIdentifier: this.parent.entityIdentifier,
imageObject,
position: { x: Math.round(rect.x), y: Math.round(rect.y) },
replaceObjects,
});
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
this.manager.stateApi.$rasterizingAdapter.set(null);
return imageDTO;
};
cloneObjectGroup = (arg: { attrs?: GroupConfig } = {}): Konva.Group => {

View File

@@ -13,7 +13,6 @@ import {
} from 'features/controlLayers/konva/util';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import Konva from 'konva';
import type { GroupConfig } from 'konva/lib/Group';
import { clamp, debounce, get } from 'lodash-es';
@@ -780,7 +779,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
})
);
if (rasterizeResult.isErr()) {
toast({ status: 'error', title: 'Failed to apply transform' });
this.log.error({ error: serializeError(rasterizeResult.error) }, 'Failed to rasterize entity');
}
this.requestRectCalculation();

View File

@@ -26,7 +26,6 @@ import type {
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { toast } from 'features/toast/toast';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import { debounce } from 'lodash-es';
@@ -572,7 +571,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
);
if (rasterizeResult.isErr()) {
toast({ status: 'error', title: 'Failed to select object' });
this.log.error({ error: serializeError(rasterizeResult.error) }, 'Error rasterizing entity');
this.$isProcessing.set(false);
return;

View File

@@ -18,7 +18,6 @@ import type {
CanvasEntityType,
CanvasInpaintMaskState,
CanvasMetadata,
ControlLoRAConfig,
EntityMovedByPayload,
FillStyle,
RegionalGuidanceReferenceImageState,
@@ -35,13 +34,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
import type { IRect } from 'konva/lib/types';
import { merge } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
ImageDTO,
IPAdapterModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type {
@@ -74,10 +67,7 @@ import {
getReferenceImageState,
getRegionalGuidanceState,
imageDTOToImageWithDims,
initialControlLoRA,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from './util';
const getInitialState = (): CanvasState => {
@@ -446,7 +436,7 @@ export const canvasSlice = createSlice({
action: PayloadAction<
EntityIdentifierPayload<
{
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
},
'control_layer'
>
@@ -463,69 +453,20 @@ export const canvasSlice = createSlice({
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
// When converting between control layer types, we may need to add or remove properties. For example, ControlNet
// has a control mode, while T2I Adapter does not - otherwise they are the same.
switch (layer.controlAdapter.model.type) {
// Converting to T2I adapter from...
case 't2i_adapter': {
if (layer.controlAdapter.type === 'controlnet') {
// T2I Adapters have all the ControlNet properties, minus control mode - strip it
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...initialT2IAdapter, ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
} else if (layer.controlAdapter.type === 'control_lora') {
// Control LoRAs have only model and weight
const t2iAdapterConfig: T2IAdapterConfig = {
...initialT2IAdapter,
...layer.controlAdapter,
type: 't2i_adapter',
};
layer.controlAdapter = t2iAdapterConfig;
}
break;
}
// Converting to ControlNet from...
case 'controlnet': {
if (layer.controlAdapter.type === 't2i_adapter') {
// ControlNets have all the T2I Adapter properties, plus control mode
const controlNetConfig: ControlNetConfig = {
...initialControlNet,
...layer.controlAdapter,
type: 'controlnet',
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'control_lora') {
// ControlNets have all the Control LoRA properties, plus control mode and begin/end step pct
const controlNetConfig: ControlNetConfig = {
...initialControlNet,
...layer.controlAdapter,
type: 'controlnet',
};
layer.controlAdapter = controlNetConfig;
}
break;
}
// Converting to ControlLoRA from...
case 'control_lora': {
if (layer.controlAdapter.type === 'controlnet') {
// We only need the model and weight for Control LoRA
const { model, weight } = layer.controlAdapter;
const controlNetConfig: ControlLoRAConfig = { ...initialControlLoRA, model, weight };
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 't2i_adapter') {
// We only need the model and weight for Control LoRA
const { model, weight } = layer.controlAdapter;
const t2iAdapterConfig: ControlLoRAConfig = { ...initialControlLoRA, model, weight };
layer.controlAdapter = t2iAdapterConfig;
}
break;
}
default:
break;
// We may need to convert the CA to match the model
if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') {
// Converting from T2I Adapter to ControlNet - add `controlMode`
const controlNetConfig: ControlNetConfig = {
...layer.controlAdapter,
type: 'controlnet',
controlMode: 'balanced',
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') {
// Converting from ControlNet to T2I Adapter - remove `controlMode`
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
}
},
controlLayerControlModeChanged: (
@@ -556,7 +497,7 @@ export const canvasSlice = createSlice({
) => {
const { entityIdentifier, beginEndStepPct } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer || !layer.controlAdapter || layer.controlAdapter.type === 'control_lora') {
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.beginEndStepPct = beginEndStepPct;

View File

@@ -2,7 +2,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { z } from 'zod';
@@ -454,9 +454,7 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
* Gets the default filter for a control model. If the model has a default, it will be used, otherwise the default
* filter for the model type will be used.
*/
export const getFilterForModel = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null
) => {
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
if (!modelConfig) {
// No model
return null;

View File

@@ -11,7 +11,6 @@ import type {
ParameterCLIPEmbedModel,
ParameterCLIPGEmbedModel,
ParameterCLIPLEmbedModel,
ParameterControlLoRAModel,
ParameterGuidance,
ParameterMaskBlurMethod,
ParameterModel,
@@ -25,6 +24,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterStructuralLoRAModel,
ParameterT5EncoderModel,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
@@ -76,7 +76,7 @@ export type ParamsState = {
clipEmbedModel: ParameterCLIPEmbedModel | null;
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
controlLora: ParameterControlLoRAModel | null;
structuralLora: ParameterStructuralLoRAModel | null;
};
const initialState: ParamsState = {
@@ -123,7 +123,7 @@ const initialState: ParamsState = {
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
controlLora: null,
structuralLora: null,
};
export const paramsSlice = createSlice({
@@ -198,8 +198,8 @@ export const paramsSlice = createSlice({
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload;
},
controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => {
state.controlLora = action.payload;
structuralLoRAModelSelected: (state, action: PayloadAction<ParameterStructuralLoRAModel | null>) => {
state.structuralLora = action.payload;
},
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload;

View File

@@ -296,13 +296,6 @@ const zT2IAdapterConfig = z.object({
});
export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
const zControlLoRAConfig = z.object({
type: z.literal('control_lora'),
weight: z.number().gte(-1).lte(2),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
export const zCanvasRasterLayerState = zCanvasEntityBase.extend({
type: z.literal('raster_layer'),
position: zCoordinate,
@@ -314,7 +307,7 @@ export type CanvasRasterLayerState = z.infer<typeof zCanvasRasterLayerState>;
const zCanvasControlLayerState = zCanvasRasterLayerState.extend({
type: z.literal('control_layer'),
withTransparencyEffect: z.boolean(),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig, zControlLoRAConfig]),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]),
});
export type CanvasControlLayerState = z.infer<typeof zCanvasControlLayerState>;

View File

@@ -7,7 +7,6 @@ import type {
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
ImageWithDims,
IPAdapterConfig,
@@ -83,11 +82,6 @@ export const initialControlNet: ControlNetConfig = {
beginEndStepPct: [0, 0.75],
controlMode: 'balanced',
};
export const initialControlLoRA: ControlLoRAConfig = {
type: 'control_lora',
model: null,
weight: 0.75,
};
export const getReferenceImageState = (
id: string,

View File

@@ -77,7 +77,7 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
});
}, [imageDTO, imageViewer, store, t]);
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
const onClickNewRegionalReferenceImageFromImage = useCallback(() => {
const { dispatch, getState } = store;
createNewCanvasEntityFromImage({ imageDTO, type: 'reference_image', dispatch, getState });
dispatch(sentImageToCanvas());
@@ -90,7 +90,7 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
});
}, [imageDTO, imageViewer, store, t]);
const onClickNewRegionalReferenceImageFromImage = useCallback(() => {
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
const { dispatch, getState } = store;
createNewCanvasEntityFromImage({ imageDTO, type: 'regional_guidance_with_reference_image', dispatch, getState });
dispatch(sentImageToCanvas());

View File

@@ -27,7 +27,6 @@ import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField
import type {
ParameterCFGRescaleMultiplier,
ParameterCFGScale,
ParameterControlLoRAModel,
ParameterGuidance,
ParameterHeight,
ParameterHRFEnabled,
@@ -47,6 +46,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterStructuralLoRAModel,
ParameterVAEModel,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
@@ -76,12 +76,12 @@ import {
import { get, isArray, isString } from 'lodash-es';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import {
isControlLoRAModelConfig,
isControlNetModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isStructuralLoRAModelConfig,
isT2IAdapterModelConfig,
isVAEModelConfig,
} from 'services/api/types';
@@ -228,10 +228,10 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
return modelIdentifier;
};
const parseControlLoRAModel: MetadataParseFunc<ParameterControlLoRAModel> = async (metadata) => {
const slora = await getProperty(metadata, 'control_lora', undefined);
const key = await getModelKey(slora, 'control_lora');
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isControlLoRAModelConfig);
const parseStructuralLoRAModel: MetadataParseFunc<ParameterStructuralLoRAModel> = async (metadata) => {
const slora = await getProperty(metadata, 'structural_lora', undefined);
const key = await getModelKey(slora, 'structural_lora');
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isStructuralLoRAModelConfig);
const modelIdentifier = zModelIdentifierField.parse(sloraModelConfig);
return modelIdentifier;
};
@@ -681,7 +681,7 @@ export const parsers = {
mainModel: parseMainModel,
refinerModel: parseRefinerModel,
vaeModel: parseVAEModel,
controlLora: parseControlLoRAModel,
structuralLora: parseStructuralLoRAModel,
lora: parseLoRA,
loras: parseAllLoRAs,
controlNet: parseControlNet,

View File

@@ -1,9 +1,9 @@
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export const useControlAdapterModelDefaultSettings = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig
export const useControlNetOrT2IAdapterDefaultSettings = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
) => {
const defaultSettingsDefaults = useMemo(() => {
return {

View File

@@ -11,7 +11,6 @@ import { useTranslation } from 'react-i18next';
import {
useCLIPEmbedModels,
useCLIPVisionModels,
useControlLoRAModel,
useControlNetModels,
useEmbeddingModels,
useIPAdapterModels,
@@ -19,6 +18,7 @@ import {
useMainModels,
useRefinerModels,
useSpandrelImageToImageModels,
useStructuralLoRAModel,
useT2IAdapterModels,
useT5EncoderModels,
useVAEModels,
@@ -93,10 +93,10 @@ const ModelList = () => {
[t5EncoderModels, searchTerm, filteredModelType]
);
const [controlLoRAModels, { isLoading: isLoadingControlLoRAModels }] = useControlLoRAModel();
const filteredControlLoRAModels = useMemo(
() => modelsFilter(controlLoRAModels, searchTerm, filteredModelType),
[controlLoRAModels, searchTerm, filteredModelType]
const [structuralLoRAModels, { isLoading: isLoadingStructuralLoRAModels }] = useStructuralLoRAModel();
const filteredStructuralLoRAModels = useMemo(
() => modelsFilter(structuralLoRAModels, searchTerm, filteredModelType),
[structuralLoRAModels, searchTerm, filteredModelType]
);
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
@@ -126,7 +126,7 @@ const ModelList = () => {
filteredSpandrelImageToImageModels.length +
t5EncoderModels.length +
clipEmbedModels.length +
controlLoRAModels.length
structuralLoRAModels.length
);
}, [
filteredControlNetModels.length,
@@ -141,7 +141,7 @@ const ModelList = () => {
filteredSpandrelImageToImageModels.length,
t5EncoderModels.length,
clipEmbedModels.length,
controlLoRAModels.length,
structuralLoRAModels.length,
]);
return (
@@ -204,13 +204,13 @@ const ModelList = () => {
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
)}
{/* Control Lora List */}
{isLoadingControlLoRAModels && <FetchingModelsLoader loadingMessage="Loading Control Loras..." />}
{!isLoadingControlLoRAModels && filteredControlLoRAModels.length > 0 && (
{/* Structural Lora List */}
{isLoadingStructuralLoRAModels && <FetchingModelsLoader loadingMessage="Loading Structural Loras..." />}
{!isLoadingStructuralLoRAModels && filteredStructuralLoRAModels.length > 0 && (
<ModelListWrapper
title={t('modelManager.controlLora')}
modelList={filteredControlLoRAModels}
key="control-lora"
title={t('modelManager.structuralLora')}
modelList={filteredStructuralLoRAModels}
key="structural-lora"
/>
)}
{/* Clip Embed List */}

View File

@@ -24,7 +24,7 @@ export const ModelTypeFilter = memo(() => {
ip_adapter: t('common.ipAdapter'),
clip_vision: 'CLIP Vision',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
control_lora: t('modelManager.controlLora'),
structural_lora: t('modelManager.structuralLora'),
}),
[t]
);

View File

@@ -1,6 +1,6 @@
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
import { useControlAdapterModelDefaultSettings } from 'features/modelManagerV2/hooks/useControlAdapterModelDefaultSettings';
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/DefaultPreprocessor';
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
@@ -9,28 +9,28 @@ import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { PiCheckBold } from 'react-icons/pi';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export type ControlAdapterModelDefaultSettingsFormData = {
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
preprocessor: FormField<string>;
};
type Props = {
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig;
};
export const ControlAdapterModelDefaultSettings = memo(({ modelConfig }: Props) => {
export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();
const defaultSettingsDefaults = useControlAdapterModelDefaultSettings(modelConfig);
const defaultSettingsDefaults = useControlNetOrT2IAdapterDefaultSettings(modelConfig);
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
const { handleSubmit, control, formState, reset } = useForm<ControlAdapterModelDefaultSettingsFormData>({
const { handleSubmit, control, formState, reset } = useForm<ControlNetOrT2IAdapterDefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
});
const onSubmit = useCallback<SubmitHandler<ControlAdapterModelDefaultSettingsFormData>>(
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
(data) => {
const body = {
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
@@ -85,4 +85,4 @@ export const ControlAdapterModelDefaultSettings = memo(({ modelConfig }: Props)
);
});
ControlAdapterModelDefaultSettings.displayName = 'ControlAdapterModelDefaultSettings';
ControlNetOrT2IAdapterDefaultSettings.displayName = 'ControlNetOrT2IAdapterDefaultSettings';

View File

@@ -1,7 +1,7 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { ControlAdapterModelDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
import type { ControlNetOrT2IAdapterDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { memo, useCallback, useMemo } from 'react';
@@ -26,9 +26,9 @@ const OPTIONS = [
{ label: 'None', value: 'none' },
] as const;
type DefaultSchedulerType = ControlAdapterModelDefaultSettingsFormData['preprocessor'];
type DefaultSchedulerType = ControlNetOrT2IAdapterDefaultSettingsFormData['preprocessor'];
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlAdapterModelDefaultSettingsFormData>) => {
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) => {
const { t } = useTranslation();
const { field } = useController(props);

View File

@@ -1,5 +1,5 @@
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
import { ControlAdapterModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
@@ -21,11 +21,7 @@ export const ModelView = memo(({ modelConfig }: Props) => {
if (modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner') {
return true;
}
if (
modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora'
) {
if (modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') {
return true;
}
if (modelConfig.type === 'main' || modelConfig.type === 'lora') {
@@ -73,9 +69,9 @@ export const ModelView = memo(({ modelConfig }: Props) => {
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
<MainModelDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora') && <ControlAdapterModelDefaultSettings modelConfig={modelConfig} />}
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && (
<TriggerPhrases modelConfig={modelConfig} />
)}

View File

@@ -15,8 +15,6 @@ import {
isCLIPLEmbedModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlLoRAModelFieldInputInstance,
isControlLoRAModelFieldInputTemplate,
isControlNetModelFieldInputInstance,
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
@@ -53,6 +51,8 @@ import {
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isStructuralLoRAModelFieldInputInstance,
isStructuralLoRAModelFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
@@ -68,7 +68,6 @@ import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInput
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
@@ -84,6 +83,7 @@ import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComp
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import StructuralLoRAModelFieldInputComponent from './inputs/StructuralLoraModelFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
@@ -159,8 +159,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isControlLoRAModelFieldInputInstance(fieldInstance) && isControlLoRAModelFieldInputTemplate(fieldTemplate)) {
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
if (
isStructuralLoRAModelFieldInputInstance(fieldInstance) &&
isStructuralLoRAModelFieldInputTemplate(fieldTemplate)
) {
return (
<StructuralLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
);
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {

Some files were not shown because too many files have changed in this diff Show More