mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 03:28:25 -05:00
Compare commits
8 Commits
ryan/fix-d
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
766c6082a3 | ||
|
|
ad5d528204 | ||
|
|
536ccf071c | ||
|
|
ca13c3b12f | ||
|
|
d54c1ef9ba | ||
|
|
e9f722aa7d | ||
|
|
139da133bd | ||
|
|
57820929d5 |
@@ -30,6 +30,7 @@ from invokeai.backend.flux.sampling_utils import (
|
|||||||
pack,
|
pack,
|
||||||
unpack,
|
unpack,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
|
||||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||||
from invokeai.backend.model_manager.config import ModelFormat
|
from invokeai.backend.model_manager.config import ModelFormat
|
||||||
@@ -208,7 +209,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
LoRAPatcher.apply_lora_patches(
|
LoRAPatcher.apply_lora_patches(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
patches=self._lora_iterator(context),
|
patches=self._lora_iterator(context),
|
||||||
prefix="",
|
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -219,7 +220,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
LoRAPatcher.apply_lora_sidecar_patches(
|
LoRAPatcher.apply_lora_sidecar_patches(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
patches=self._lora_iterator(context),
|
patches=self._lora_iterator(context),
|
||||||
prefix="",
|
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||||
dtype=inference_dtype,
|
dtype=inference_dtype,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import BaseModelType
|
from invokeai.backend.model_manager.config import BaseModelType
|
||||||
|
|
||||||
@@ -20,6 +20,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
|||||||
transformer: Optional[TransformerField] = OutputField(
|
transformer: Optional[TransformerField] = OutputField(
|
||||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||||
)
|
)
|
||||||
|
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@@ -27,21 +28,28 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
|||||||
title="FLUX LoRA",
|
title="FLUX LoRA",
|
||||||
tags=["lora", "model", "flux"],
|
tags=["lora", "model", "flux"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.0",
|
version="1.1.0",
|
||||||
classification=Classification.Prototype,
|
classification=Classification.Prototype,
|
||||||
)
|
)
|
||||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||||
"""Apply a LoRA model to a FLUX transformer."""
|
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||||
|
|
||||||
lora: ModelIdentifierField = InputField(
|
lora: ModelIdentifierField = InputField(
|
||||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||||
)
|
)
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
transformer: TransformerField = InputField(
|
transformer: TransformerField | None = InputField(
|
||||||
|
default=None,
|
||||||
description=FieldDescriptions.transformer,
|
description=FieldDescriptions.transformer,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="FLUX Transformer",
|
title="FLUX Transformer",
|
||||||
)
|
)
|
||||||
|
clip: CLIPField | None = InputField(
|
||||||
|
default=None,
|
||||||
|
title="CLIP",
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
@@ -49,18 +57,33 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
|||||||
if not context.models.exists(lora_key):
|
if not context.models.exists(lora_key):
|
||||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||||
|
|
||||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
# Check for existing LoRAs with the same key.
|
||||||
|
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||||
|
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||||
|
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
|
||||||
|
|
||||||
transformer = self.transformer.model_copy(deep=True)
|
output = FluxLoRALoaderOutput()
|
||||||
transformer.loras.append(
|
|
||||||
LoRAField(
|
# Attach LoRA layers to the models.
|
||||||
lora=self.lora,
|
if self.transformer is not None:
|
||||||
weight=self.weight,
|
output.transformer = self.transformer.model_copy(deep=True)
|
||||||
|
output.transformer.loras.append(
|
||||||
|
LoRAField(
|
||||||
|
lora=self.lora,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.clip is not None:
|
||||||
|
output.clip = self.clip.model_copy(deep=True)
|
||||||
|
output.clip.loras.append(
|
||||||
|
LoRAField(
|
||||||
|
lora=self.lora,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return FluxLoRALoaderOutput(transformer=transformer)
|
return output
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@@ -68,7 +91,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
|||||||
title="FLUX LoRA Collection Loader",
|
title="FLUX LoRA Collection Loader",
|
||||||
tags=["lora", "model", "flux"],
|
tags=["lora", "model", "flux"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.0",
|
version="1.1.0",
|
||||||
classification=Classification.Prototype,
|
classification=Classification.Prototype,
|
||||||
)
|
)
|
||||||
class FLUXLoRACollectionLoader(BaseInvocation):
|
class FLUXLoRACollectionLoader(BaseInvocation):
|
||||||
@@ -84,6 +107,12 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
|||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="Transformer",
|
title="Transformer",
|
||||||
)
|
)
|
||||||
|
clip: CLIPField | None = InputField(
|
||||||
|
default=None,
|
||||||
|
title="CLIP",
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||||
output = FluxLoRALoaderOutput()
|
output = FluxLoRALoaderOutput()
|
||||||
@@ -106,4 +135,9 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
|||||||
output.transformer = self.transformer.model_copy(deep=True)
|
output.transformer = self.transformer.model_copy(deep=True)
|
||||||
output.transformer.loras.append(lora)
|
output.transformer.loras.append(lora)
|
||||||
|
|
||||||
|
if self.clip is not None:
|
||||||
|
if output.clip is None:
|
||||||
|
output.clip = self.clip.model_copy(deep=True)
|
||||||
|
output.clip.loras.append(lora)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Literal
|
from contextlib import ExitStack
|
||||||
|
from typing import Iterator, Literal, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
@@ -9,6 +10,10 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
|||||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||||
|
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
|
||||||
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||||
|
from invokeai.backend.model_manager.config import ModelFormat
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||||
|
|
||||||
|
|
||||||
@@ -17,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
|||||||
title="FLUX Text Encoding",
|
title="FLUX Text Encoding",
|
||||||
tags=["prompt", "conditioning", "flux"],
|
tags=["prompt", "conditioning", "flux"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.1.0",
|
||||||
classification=Classification.Prototype,
|
classification=Classification.Prototype,
|
||||||
)
|
)
|
||||||
class FluxTextEncoderInvocation(BaseInvocation):
|
class FluxTextEncoderInvocation(BaseInvocation):
|
||||||
@@ -78,15 +83,42 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
prompt = [self.prompt]
|
prompt = [self.prompt]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
clip_text_encoder_info as clip_text_encoder,
|
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||||
clip_tokenizer_info as clip_tokenizer,
|
clip_tokenizer_info as clip_tokenizer,
|
||||||
|
ExitStack() as exit_stack,
|
||||||
):
|
):
|
||||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||||
|
|
||||||
|
clip_text_encoder_config = clip_text_encoder_info.config
|
||||||
|
assert clip_text_encoder_config is not None
|
||||||
|
|
||||||
|
# Apply LoRA models to the CLIP encoder.
|
||||||
|
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||||
|
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||||
|
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||||
|
exit_stack.enter_context(
|
||||||
|
LoRAPatcher.apply_lora_patches(
|
||||||
|
model=clip_text_encoder,
|
||||||
|
patches=self._clip_lora_iterator(context),
|
||||||
|
prefix=FLUX_KOHYA_CLIP_PREFIX,
|
||||||
|
cached_weights=cached_weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||||
|
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||||
|
|
||||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||||
|
|
||||||
pooled_prompt_embeds = clip_encoder(prompt)
|
pooled_prompt_embeds = clip_encoder(prompt)
|
||||||
|
|
||||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||||
return pooled_prompt_embeds
|
return pooled_prompt_embeds
|
||||||
|
|
||||||
|
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
|
for lora in self.clip.loras:
|
||||||
|
lora_info = context.models.load(lora.lora)
|
||||||
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
|
yield (lora_info.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
|||||||
@@ -7,14 +7,25 @@ from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
|||||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
|
||||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
|
||||||
# Example keys:
|
# Example keys:
|
||||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||||
FLUX_KOHYA_KEY_REGEX = (
|
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
|
||||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||||
)
|
)
|
||||||
|
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
|
||||||
|
# Example keys:
|
||||||
|
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
|
||||||
|
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight
|
||||||
|
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
|
||||||
|
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
|
||||||
|
|
||||||
|
|
||||||
|
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
|
||||||
|
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
|
||||||
|
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"
|
||||||
|
|
||||||
|
|
||||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||||
@@ -23,7 +34,10 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
|||||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||||
"""
|
"""
|
||||||
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
|
return all(
|
||||||
|
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||||
|
for k in state_dict.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||||
@@ -35,13 +49,27 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
|||||||
grouped_state_dict[layer_name] = {}
|
grouped_state_dict[layer_name] = {}
|
||||||
grouped_state_dict[layer_name][param_name] = value
|
grouped_state_dict[layer_name][param_name] = value
|
||||||
|
|
||||||
# Convert the state dict to the InvokeAI format.
|
# Split the grouped state dict into transformer and CLIP state dicts.
|
||||||
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
|
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
|
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||||
|
if layer_name.startswith("lora_unet"):
|
||||||
|
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||||
|
elif layer_name.startswith("lora_te1"):
|
||||||
|
clip_grouped_sd[layer_name] = layer_state_dict
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
|
||||||
|
|
||||||
|
# Convert the state dicts to the InvokeAI format.
|
||||||
|
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
|
||||||
|
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
|
||||||
|
|
||||||
# Create LoRA layers.
|
# Create LoRA layers.
|
||||||
layers: dict[str, AnyLoRALayer] = {}
|
layers: dict[str, AnyLoRALayer] = {}
|
||||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
for layer_key, layer_state_dict in transformer_grouped_sd.items():
|
||||||
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||||
|
for layer_key, layer_state_dict in clip_grouped_sd.items():
|
||||||
|
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||||
|
|
||||||
# Create and return the LoRAModelRaw.
|
# Create and return the LoRAModelRaw.
|
||||||
return LoRAModelRaw(layers=layers)
|
return LoRAModelRaw(layers=layers)
|
||||||
@@ -50,16 +78,34 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
def _convert_flux_clip_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
"""Converts a CLIP LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by
|
||||||
|
InvokeAI.
|
||||||
|
|
||||||
|
Example key conversions:
|
||||||
|
|
||||||
|
"lora_te1_text_model_encoder_layers_0_mlp_fc1" -> "text_model.encoder.layers.0.mlp.fc1",
|
||||||
|
"lora_te1_text_model_encoder_layers_0_self_attn_k_proj" -> "text_model.encoder.layers.0.self_attn.k_proj"
|
||||||
|
"""
|
||||||
|
converted_sd: dict[str, T] = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
match = re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||||
|
if match:
|
||||||
|
new_key = f"text_model.encoder.layers.{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||||
|
converted_sd[new_key] = v
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||||
|
|
||||||
|
return converted_sd
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||||
|
"""Converts a FLUX tranformer LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally
|
||||||
|
by InvokeAI.
|
||||||
|
|
||||||
Example key conversions:
|
Example key conversions:
|
||||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
|
||||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
|
||||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
|
||||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def replace_func(match: re.Match[str]) -> str:
|
def replace_func(match: re.Match[str]) -> str:
|
||||||
@@ -70,9 +116,9 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) ->
|
|||||||
|
|
||||||
converted_dict: dict[str, T] = {}
|
converted_dict: dict[str, T] = {}
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
match = re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
|
||||||
if match:
|
if match:
|
||||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
new_key = re.sub(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, replace_func, k)
|
||||||
converted_dict[new_key] = v
|
converted_dict[new_key] = v
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ export const addFLUXLoRAs = (
|
|||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'flux_denoise'>,
|
denoise: Invocation<'flux_denoise'>,
|
||||||
modelLoader: Invocation<'flux_model_loader'>
|
modelLoader: Invocation<'flux_model_loader'>,
|
||||||
|
fluxTextEncoder: Invocation<'flux_text_encoder'>
|
||||||
): void => {
|
): void => {
|
||||||
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux');
|
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux');
|
||||||
const loraCount = enabledLoRAs.length;
|
const loraCount = enabledLoRAs.length;
|
||||||
@@ -20,7 +21,7 @@ export const addFLUXLoRAs = (
|
|||||||
const loraMetadata: S['LoRAMetadataField'][] = [];
|
const loraMetadata: S['LoRAMetadataField'][] = [];
|
||||||
|
|
||||||
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
|
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
|
||||||
// each LoRA to the UNet and CLIP.
|
// each LoRA to the transformer and text encoders.
|
||||||
const loraCollector = g.addNode({
|
const loraCollector = g.addNode({
|
||||||
id: getPrefixedId('lora_collector'),
|
id: getPrefixedId('lora_collector'),
|
||||||
type: 'collect',
|
type: 'collect',
|
||||||
@@ -33,10 +34,12 @@ export const addFLUXLoRAs = (
|
|||||||
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
|
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
|
||||||
// Use model loader as transformer input
|
// Use model loader as transformer input
|
||||||
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
|
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
|
||||||
// Reroute transformer connections through the LoRA collection loader
|
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
|
||||||
|
// Reroute model connections through the LoRA collection loader
|
||||||
g.deleteEdgesTo(denoise, ['transformer']);
|
g.deleteEdgesTo(denoise, ['transformer']);
|
||||||
|
g.deleteEdgesTo(fluxTextEncoder, ['clip']);
|
||||||
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
|
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
|
||||||
|
g.addEdge(loraCollectionLoader, 'clip', fluxTextEncoder, 'clip');
|
||||||
|
|
||||||
for (const lora of enabledLoRAs) {
|
for (const lora of enabledLoRAs) {
|
||||||
const { weight } = lora;
|
const { weight } = lora;
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSe
|
|||||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
|
||||||
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
|
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
|
||||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||||
@@ -18,8 +19,6 @@ import type { Invocation } from 'services/api/types';
|
|||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { addFLUXLoRAs } from './addFLUXLoRAs';
|
|
||||||
|
|
||||||
const log = logger('system');
|
const log = logger('system');
|
||||||
|
|
||||||
export const buildFLUXGraph = async (
|
export const buildFLUXGraph = async (
|
||||||
@@ -96,12 +95,12 @@ export const buildFLUXGraph = async (
|
|||||||
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
|
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
|
||||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||||
|
|
||||||
addFLUXLoRAs(state, g, noise, modelLoader);
|
|
||||||
|
|
||||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||||
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
||||||
|
|
||||||
|
addFLUXLoRAs(state, g, noise, modelLoader, posCond);
|
||||||
|
|
||||||
g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning');
|
g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning');
|
||||||
|
|
||||||
g.addEdge(noise, 'latents', l2i, 'latents');
|
g.addEdge(noise, 'latents', l2i, 'latents');
|
||||||
|
|||||||
@@ -5707,6 +5707,12 @@ export type components = {
|
|||||||
* @default null
|
* @default null
|
||||||
*/
|
*/
|
||||||
transformer?: components["schemas"]["TransformerField"] | null;
|
transformer?: components["schemas"]["TransformerField"] | null;
|
||||||
|
/**
|
||||||
|
* CLIP
|
||||||
|
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||||
|
* @default null
|
||||||
|
*/
|
||||||
|
clip?: components["schemas"]["CLIPField"] | null;
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default flux_lora_collection_loader
|
* @default flux_lora_collection_loader
|
||||||
@@ -6391,7 +6397,7 @@ export type components = {
|
|||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* FLUX LoRA
|
* FLUX LoRA
|
||||||
* @description Apply a LoRA model to a FLUX transformer.
|
* @description Apply a LoRA model to a FLUX transformer and/or text encoder.
|
||||||
*/
|
*/
|
||||||
FluxLoRALoaderInvocation: {
|
FluxLoRALoaderInvocation: {
|
||||||
/**
|
/**
|
||||||
@@ -6428,7 +6434,13 @@ export type components = {
|
|||||||
* @description Transformer
|
* @description Transformer
|
||||||
* @default null
|
* @default null
|
||||||
*/
|
*/
|
||||||
transformer?: components["schemas"]["TransformerField"];
|
transformer?: components["schemas"]["TransformerField"] | null;
|
||||||
|
/**
|
||||||
|
* CLIP
|
||||||
|
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||||
|
* @default null
|
||||||
|
*/
|
||||||
|
clip?: components["schemas"]["CLIPField"] | null;
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default flux_lora_loader
|
* @default flux_lora_loader
|
||||||
@@ -6448,6 +6460,12 @@ export type components = {
|
|||||||
* @default null
|
* @default null
|
||||||
*/
|
*/
|
||||||
transformer: components["schemas"]["TransformerField"] | null;
|
transformer: components["schemas"]["TransformerField"] | null;
|
||||||
|
/**
|
||||||
|
* CLIP
|
||||||
|
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||||
|
* @default null
|
||||||
|
*/
|
||||||
|
clip: components["schemas"]["CLIPField"] | null;
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default flux_lora_loader_output
|
* @default flux_lora_loader_output
|
||||||
|
|||||||
@@ -0,0 +1,386 @@
|
|||||||
|
# A sample state dict in the Diffusers FLUX LoRA format with non-standard PEFT weight keys.
|
||||||
|
# I.e. key suffixes of `.lora.down.weight` and `.lora.up.weight` instead of `.lora_A.weight` and `.lora_B.weight`.
|
||||||
|
# These keys are based on the LoRA model here:
|
||||||
|
# https://civitai.com/models/684810/flux1-dev-cctv-mania
|
||||||
|
state_dict_keys = [
|
||||||
|
"transformer.single_transformer_blocks.0.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.0.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.0.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.0.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.0.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.0.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.1.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.1.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.1.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.1.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.1.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.1.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.10.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.10.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.10.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.10.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.10.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.10.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.11.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.11.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.11.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.11.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.11.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.11.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.12.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.12.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.12.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.12.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.12.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.12.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.13.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.13.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.13.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.13.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.13.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.13.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.14.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.14.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.14.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.14.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.14.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.14.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.15.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.15.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.15.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.15.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.15.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.15.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.16.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.16.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.16.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.16.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.16.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.16.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.17.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.17.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.17.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.17.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.17.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.17.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.18.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.18.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.18.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.18.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.18.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.18.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.19.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.19.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.19.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.19.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.19.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.19.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.2.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.2.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.2.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.2.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.2.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.2.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.20.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.20.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.20.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.20.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.20.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.20.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.21.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.21.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.21.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.21.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.21.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.21.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.22.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.22.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.22.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.22.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.22.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.22.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.23.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.23.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.23.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.23.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.23.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.23.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.24.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.24.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.24.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.24.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.24.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.24.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.25.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.25.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.25.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.25.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.25.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.25.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.26.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.26.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.26.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.26.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.26.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.26.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.27.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.27.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.27.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.27.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.27.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.27.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.28.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.28.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.28.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.28.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.28.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.28.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.29.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.29.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.29.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.29.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.29.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.29.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.3.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.3.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.3.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.3.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.3.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.3.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.30.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.30.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.30.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.30.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.30.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.30.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.31.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.31.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.31.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.31.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.31.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.31.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.32.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.32.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.32.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.32.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.32.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.32.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.33.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.33.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.33.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.33.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.33.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.33.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.34.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.34.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.34.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.34.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.34.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.34.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.35.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.35.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.35.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.35.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.35.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.35.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.36.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.36.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.36.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.36.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.36.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.36.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.37.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.37.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.37.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.37.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.37.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.37.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.4.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.4.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.4.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.4.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.4.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.4.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.5.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.5.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.5.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.5.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.5.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.5.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.6.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.6.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.6.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.6.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.6.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.6.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.7.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.7.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.7.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.7.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.7.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.7.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.8.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.8.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.8.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.8.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.8.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.8.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.9.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.9.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.9.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.9.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.single_transformer_blocks.9.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.single_transformer_blocks.9.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.1.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.1.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.1.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.1.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.1.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.1.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.1.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.1.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.10.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.10.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.10.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.10.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.10.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.10.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.10.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.10.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.11.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.11.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.11.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.11.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.11.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.11.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.11.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.11.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.12.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.12.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.12.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.12.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.12.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.12.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.12.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.12.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.13.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.13.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.13.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.13.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.13.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.13.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.13.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.13.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.14.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.14.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.14.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.14.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.14.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.14.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.14.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.14.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.15.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.15.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.15.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.15.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.15.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.15.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.15.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.15.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.16.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.16.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.16.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.16.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.16.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.16.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.16.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.16.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.17.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.17.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.17.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.17.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.17.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.17.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.17.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.17.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.18.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.18.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.18.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.18.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.18.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.18.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.18.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.18.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.2.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.2.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.2.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.2.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.2.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.2.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.2.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.2.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.3.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.3.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.3.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.3.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.3.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.3.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.3.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.3.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.4.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.4.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.4.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.4.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.4.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.4.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.4.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.4.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.5.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.5.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.5.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.5.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.5.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.5.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.5.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.5.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.6.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.6.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.6.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.6.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.6.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.6.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.6.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.6.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.7.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.7.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.7.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.7.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.7.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.7.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.7.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.7.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.8.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.8.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.8.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.8.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.8.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.8.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.8.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.8.attn.to_v.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.9.attn.to_k.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.9.attn.to_k.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.9.attn.to_out.0.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.9.attn.to_out.0.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.9.attn.to_q.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.9.attn.to_q.lora.up.weight",
|
||||||
|
"transformer.transformer_blocks.9.attn.to_v.lora.down.weight",
|
||||||
|
"transformer.transformer_blocks.9.attn.to_v.lora.up.weight",
|
||||||
|
]
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,9 @@ import torch
|
|||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.util import params
|
from invokeai.backend.flux.util import params
|
||||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||||
convert_flux_kohya_state_dict_to_invoke_format,
|
FLUX_KOHYA_CLIP_PREFIX,
|
||||||
|
FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||||
|
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
|
||||||
is_state_dict_likely_in_flux_kohya_format,
|
is_state_dict_likely_in_flux_kohya_format,
|
||||||
lora_model_from_flux_kohya_state_dict,
|
lora_model_from_flux_kohya_state_dict,
|
||||||
)
|
)
|
||||||
@@ -15,13 +17,17 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format
|
|||||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
|
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
|
||||||
state_dict_keys as flux_kohya_state_dict_keys,
|
state_dict_keys as flux_kohya_state_dict_keys,
|
||||||
)
|
)
|
||||||
|
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
|
||||||
|
state_dict_keys as flux_kohya_te1_state_dict_keys,
|
||||||
|
)
|
||||||
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||||
|
|
||||||
|
|
||||||
def test_is_state_dict_likely_in_flux_kohya_format_true():
|
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||||
|
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
|
||||||
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
||||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||||
|
|
||||||
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
|
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||||
|
|
||||||
@@ -34,11 +40,11 @@ def test_is_state_dict_likely_in_flux_kohya_format_false():
|
|||||||
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
|
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||||
|
|
||||||
|
|
||||||
def test_convert_flux_kohya_state_dict_to_invoke_format():
|
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
|
||||||
# Construct state_dict from state_dict_keys.
|
# Construct state_dict from state_dict_keys.
|
||||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||||
|
|
||||||
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
converted_state_dict = _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
|
||||||
|
|
||||||
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
|
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
|
||||||
# .alpha suffixes).
|
# .alpha suffixes).
|
||||||
@@ -65,29 +71,33 @@ def test_convert_flux_kohya_state_dict_to_invoke_format():
|
|||||||
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
|
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
|
||||||
|
|
||||||
|
|
||||||
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
|
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
|
||||||
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
|
"""Test that an error is raised by _convert_flux_transformer_kohya_state_dict_to_invoke_format() if the input
|
||||||
unexpected keys.
|
state_dict contains unexpected keys.
|
||||||
"""
|
"""
|
||||||
state_dict = {
|
state_dict = {
|
||||||
"unexpected_key.lora_up.weight": torch.empty(1),
|
"unexpected_key.lora_up.weight": torch.empty(1),
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
_convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
|
||||||
|
|
||||||
|
|
||||||
def test_lora_model_from_flux_kohya_state_dict():
|
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||||
|
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
|
||||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||||
|
|
||||||
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
|
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
|
||||||
|
|
||||||
# Prepare expected layer keys.
|
# Prepare expected layer keys.
|
||||||
expected_layer_keys: set[str] = set()
|
expected_layer_keys: set[str] = set()
|
||||||
for k in flux_kohya_state_dict_keys:
|
for k in sd_keys:
|
||||||
k = k.replace("lora_unet_", "")
|
# Replace prefixes.
|
||||||
|
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
|
||||||
|
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
|
||||||
|
# Remove suffixes.
|
||||||
k = k.replace(".lora_up.weight", "")
|
k = k.replace(".lora_up.weight", "")
|
||||||
k = k.replace(".lora_down.weight", "")
|
k = k.replace(".lora_down.weight", "")
|
||||||
k = k.replace(".alpha", "")
|
k = k.replace(".alpha", "")
|
||||||
|
|||||||
Reference in New Issue
Block a user