mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 19:07:59 -05:00
Compare commits
17 Commits
main
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67afa7e339 | ||
|
|
92c6a7d658 | ||
|
|
caa9ecafae | ||
|
|
f21464e972 | ||
|
|
eeff9d3df5 | ||
|
|
d9c1c7d63d | ||
|
|
420f6feef9 | ||
|
|
8d09a36c90 | ||
|
|
6f82be4dc4 | ||
|
|
9dfbd6a422 | ||
|
|
a10db807ca | ||
|
|
edc0b63612 | ||
|
|
d44a6b2ca1 | ||
|
|
0d2c1b9d8f | ||
|
|
9952b19c5d | ||
|
|
2335b70dba | ||
|
|
010383faef |
@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
|
||||
@@ -21,6 +21,9 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: Optional[T5EncoderField] = OutputField(
|
||||
default=None, description=FieldDescriptions.t5_encoder, title="T5 Encoder"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -28,7 +31,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
@@ -50,6 +53,12 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="T5 Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
@@ -62,6 +71,8 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
|
||||
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')
|
||||
|
||||
output = FluxLoRALoaderOutput()
|
||||
|
||||
@@ -82,6 +93,14 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
if self.t5_encoder is not None:
|
||||
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
|
||||
output.t5_encoder.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -91,7 +110,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
title="FLUX LoRA Collection Loader",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
@@ -113,6 +132,12 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="T5 Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
output = FluxLoRALoaderOutput()
|
||||
@@ -140,4 +165,9 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(lora)
|
||||
|
||||
if self.t5_encoder is not None:
|
||||
if output.t5_encoder is None:
|
||||
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
|
||||
output.t5_encoder.loras.append(lora)
|
||||
|
||||
return output
|
||||
|
||||
@@ -40,7 +40,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
version="1.0.5",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
@@ -87,7 +87,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
@@ -71,13 +71,34 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
prompt = [self.prompt]
|
||||
|
||||
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
t5_encoder_config = t5_encoder_info.config
|
||||
assert t5_encoder_config is not None
|
||||
|
||||
with (
|
||||
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
|
||||
t5_encoder_info.model_on_device() as (cached_weights, t5_text_encoder),
|
||||
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
# Apply LoRA models to the T5 encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if t5_encoder_config.format == ModelFormat.T5Encoder:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=t5_text_encoder,
|
||||
patches=self._t5_lora_iterator(context),
|
||||
prefix=FLUX_LORA_T5_PREFIX,
|
||||
dtype=t5_text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {t5_encoder_config.format}")
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
@@ -132,3 +153,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _t5_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.t5_encoder.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -68,6 +68,7 @@ class CLIPField(BaseModel):
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
|
||||
@@ -31,6 +31,10 @@ from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_format,
|
||||
lora_model_from_flux_onetrainer_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
|
||||
@@ -84,8 +88,12 @@ class LoRALoader(ModelLoader):
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
|
||||
@@ -46,6 +46,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_format,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
@@ -283,7 +286,7 @@ class ModelProbe(object):
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
|
||||
return ModelType.LoRA
|
||||
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
@@ -632,6 +635,7 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if (
|
||||
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
|
||||
or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint)
|
||||
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
|
||||
or is_state_dict_likely_flux_control(self.checkpoint)
|
||||
):
|
||||
|
||||
@@ -81,11 +81,11 @@ class LayerPatcher:
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers))[0]
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
for layer_key, layer in patch.layers:
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
@@ -14,7 +13,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_layers: Sequence[LoRALayer], concat_axis: int = 0):
|
||||
def __init__(self, lora_layers: Sequence[LoRALayerBase], concat_axis: int = 0):
|
||||
super().__init__(alpha=None, bias=None)
|
||||
|
||||
self.lora_layers = lora_layers
|
||||
|
||||
92
invokeai/backend/patches/layers/dora_layer.py
Normal file
92
invokeai/backend/patches/layers/dora_layer.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
class DoRALayer(LoRALayerBase):
|
||||
"""A DoRA layer. As defined in https://arxiv.org/pdf/2402.09353."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
up: torch.Tensor,
|
||||
down: torch.Tensor,
|
||||
dora_scale: torch.Tensor,
|
||||
alpha: float | None,
|
||||
bias: Optional[torch.Tensor],
|
||||
):
|
||||
super().__init__(alpha, bias)
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.dora_scale = dora_scale
|
||||
|
||||
@classmethod
|
||||
def from_state_dict_values(cls, values: Dict[str, torch.Tensor]):
|
||||
alpha = cls._parse_alpha(values.get("alpha", None))
|
||||
bias = cls._parse_bias(
|
||||
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
|
||||
)
|
||||
|
||||
layer = cls(
|
||||
up=values["lora_up.weight"],
|
||||
down=values["lora_down.weight"],
|
||||
dora_scale=values["dora_scale"],
|
||||
alpha=alpha,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
cls.warn_on_unhandled_keys(
|
||||
values=values,
|
||||
handled_keys={
|
||||
# Default keys.
|
||||
"alpha",
|
||||
"bias_indices",
|
||||
"bias_values",
|
||||
"bias_size",
|
||||
# Layer-specific keys.
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"dora_scale",
|
||||
},
|
||||
)
|
||||
|
||||
return layer
|
||||
|
||||
def _rank(self) -> int:
|
||||
return self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
# Note: Variable names (e.g. delta_v) are based on the paper.
|
||||
delta_v = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
delta_v = delta_v.reshape(orig_weight.shape)
|
||||
|
||||
# TODO(ryand): Should alpha be applied to delta_v here rather than the final diff?
|
||||
# TODO(ryand): I expect this to fail if the original weight is BnB Quantized. This class shouldn't have to worry
|
||||
# about that, but we should add a clear error message further up the stack.
|
||||
|
||||
# At this point, out_weight is the unnormalized direction matrix.
|
||||
out_weight = orig_weight + delta_v
|
||||
|
||||
# TODO(ryand): Simplify this logic.
|
||||
direction_norm = (
|
||||
out_weight.transpose(0, 1)
|
||||
.reshape(out_weight.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(out_weight.shape[1], *[1] * (out_weight.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
out_weight *= self.dora_scale / direction_norm
|
||||
|
||||
return out_weight - orig_weight
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
self.dora_scale = self.dora_scale.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return super().calc_size() + calc_tensors_size([self.up, self.down, self.dora_scale])
|
||||
83
invokeai/backend/patches/layers/partial_layer.py
Normal file
83
invokeai/backend/patches/layers/partial_layer.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
class PartialLayer(BaseLayerPatch):
|
||||
"""A layer patch that only modifies a sub-range of the weights in the original layer.
|
||||
|
||||
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
|
||||
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
|
||||
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_layer: LoRALayerBase, range: tuple[Range, Range]):
|
||||
super().__init__()
|
||||
|
||||
self.lora_layer = lora_layer
|
||||
# self.range[i] gives the range to be modified in the original layer for the i'th dimension.
|
||||
self.range = range
|
||||
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
# HACK(ryand): If the original parameters are in a quantized format that can't be sliced, we replace them with
|
||||
# dummy tensors on the 'meta' device. This allows sub-layers to access the shapes of the sliced parameters. But,
|
||||
# of course, any sub-layers that need to access the actual values of the parameters will fail.
|
||||
for param_name in orig_parameters.keys():
|
||||
param = orig_parameters[param_name]
|
||||
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
|
||||
pass
|
||||
elif type(param) is GGMLTensor:
|
||||
pass
|
||||
else:
|
||||
orig_parameters[param_name] = torch.empty(get_param_shape(param), device="meta")
|
||||
|
||||
# Slice the original parameters to the specified range.
|
||||
sliced_parameters: dict[str, torch.Tensor] = {}
|
||||
for param_name, param_weight in orig_parameters.items():
|
||||
if param_name == "weight":
|
||||
sliced_parameters[param_name] = param_weight[
|
||||
self.range[0].start : self.range[0].end, self.range[1].start : self.range[1].end
|
||||
]
|
||||
elif param_name == "bias":
|
||||
sliced_parameters[param_name] = param_weight[self.range[0].start : self.range[0].end]
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name: {param_name}")
|
||||
|
||||
# Apply the LoRA layer to the sliced parameters.
|
||||
params = self.lora_layer.get_parameters(sliced_parameters, weight)
|
||||
|
||||
# Expand the parameters diffs to match the original parameter shape.
|
||||
out_params: dict[str, torch.Tensor] = {}
|
||||
for param_name, param_weight in params.items():
|
||||
orig_param = orig_parameters[param_name]
|
||||
out_params[param_name] = torch.zeros(
|
||||
get_param_shape(orig_param), dtype=param_weight.dtype, device=param_weight.device
|
||||
)
|
||||
|
||||
if param_name == "weight":
|
||||
out_params[param_name][
|
||||
self.range[0].start : self.range[0].end, self.range[1].start : self.range[1].end
|
||||
] = param_weight
|
||||
elif param_name == "bias":
|
||||
out_params[param_name][self.range[0].start : self.range[0].end] = param_weight
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name: {param_name}")
|
||||
|
||||
return out_params
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.lora_layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return self.lora_layer.calc_size()
|
||||
@@ -3,6 +3,7 @@ from typing import Dict
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.dora_layer import DoRALayer
|
||||
from invokeai.backend.patches.layers.full_layer import FullLayer
|
||||
from invokeai.backend.patches.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.patches.layers.loha_layer import LoHALayer
|
||||
@@ -14,8 +15,9 @@ from invokeai.backend.patches.layers.norm_layer import NormLayer
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
if "lora_up.weight" in state_dict:
|
||||
if "dora_scale" in state_dict:
|
||||
return DoRALayer.from_state_dict_values(state_dict)
|
||||
elif "lora_up.weight" in state_dict:
|
||||
# LoRA a.k.a LoCon
|
||||
return LoRALayer.from_state_dict_values(state_dict)
|
||||
elif "hada_w1_a" in state_dict:
|
||||
|
||||
@@ -56,28 +56,38 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
layers: list[tuple[str, BaseLayerPatch]] = []
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
|
||||
if layer_key == "img_in":
|
||||
# img_in is a special case because it changes the shape of the original weight.
|
||||
layers[prefixed_key] = FluxControlLoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"],
|
||||
layers.append(
|
||||
(
|
||||
prefixed_key,
|
||||
FluxControlLoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"],
|
||||
),
|
||||
)
|
||||
)
|
||||
elif all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
|
||||
layers[prefixed_key] = LoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"],
|
||||
layers.append(
|
||||
(
|
||||
prefixed_key,
|
||||
LoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"],
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
layers.append((prefixed_key, SetParameterLayer("scale", layer_state_dict["scale"])))
|
||||
else:
|
||||
raise ValueError(f"{layer_key} not expected")
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Dict
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.partial_layer import PartialLayer, Range
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
@@ -33,13 +33,21 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
|
||||
def lora_model_from_flux_diffusers_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor], alpha: float | None
|
||||
) -> ModelPatchRaw:
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
|
||||
layers = lora_layers_from_flux_diffusers_grouped_state_dict(grouped_state_dict, alpha)
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
def lora_layers_from_flux_diffusers_grouped_state_dict(
|
||||
grouped_state_dict: Dict[str, Dict[str, torch.Tensor]], alpha: float | None
|
||||
) -> list[tuple[str, BaseLayerPatch]]:
|
||||
"""Converts a grouped state dict with Diffusers FLUX LoRA keys to LoRA layers with BFL keys (i.e. the module key
|
||||
format used by Invoke).
|
||||
|
||||
This function is based on:
|
||||
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
|
||||
"""
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
|
||||
|
||||
# Remove the "transformer." prefix from all keys.
|
||||
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
|
||||
@@ -51,19 +59,28 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
mlp_ratio = 4.0
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
layers: list[tuple[str, BaseLayerPatch]] = []
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
value = {
|
||||
def get_lora_layer_values(src_layer_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
if "lora_A.weight" in src_layer_dict:
|
||||
# The LoRA keys are in PEFT format.
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
value["alpha"] = torch.tensor(alpha)
|
||||
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
assert len(src_layer_dict) == 0
|
||||
return values
|
||||
else:
|
||||
# Assume that the LoRA keys are in Kohya format.
|
||||
return src_layer_dict
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
values = get_lora_layer_values(src_layer_dict)
|
||||
layers.append((dst_key, any_lora_layer_from_state_dict(values)))
|
||||
|
||||
def add_qkv_lora_layer_if_present(
|
||||
src_keys: list[str],
|
||||
@@ -79,29 +96,29 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
sub_layers: list[LoRALayer] = []
|
||||
dim_0_offset = 0
|
||||
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
|
||||
src_layer_dict = grouped_state_dict.pop(src_key, None)
|
||||
if src_layer_dict is not None:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
|
||||
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
assert len(src_layer_dict) == 0
|
||||
values = get_lora_layer_values(src_layer_dict)
|
||||
# assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
|
||||
# assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
|
||||
layers.append(
|
||||
(
|
||||
dst_qkv_key,
|
||||
PartialLayer(
|
||||
any_lora_layer_from_state_dict(values),
|
||||
(
|
||||
Range(dim_0_offset, dim_0_offset + src_weight_shape[0]),
|
||||
Range(0, src_weight_shape[1]),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
if not allow_missing_keys:
|
||||
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
|
||||
values = {
|
||||
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
|
||||
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
|
||||
}
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
|
||||
dim_0_offset += src_weight_shape[0]
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||
@@ -215,9 +232,9 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
# Assert that all keys were processed.
|
||||
assert len(grouped_state_dict) == 0
|
||||
|
||||
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
|
||||
layers_with_prefix = [(f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}", v) for k, v in layers]
|
||||
|
||||
return ModelPatchRaw(layers=layers_with_prefix)
|
||||
return layers_with_prefix
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
|
||||
@@ -7,6 +7,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
|
||||
FLUX_LORA_CLIP_PREFIX,
|
||||
FLUX_LORA_T5_PREFIX,
|
||||
FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
)
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -26,6 +27,14 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
|
||||
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
|
||||
|
||||
# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.alpha
|
||||
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.dora_scale
|
||||
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_down.weight
|
||||
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_up.weight
|
||||
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*"
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
@@ -34,7 +43,9 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(
|
||||
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
)
|
||||
|
||||
@@ -48,27 +59,34 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Split the grouped state dict into transformer and CLIP state dicts.
|
||||
# Split the grouped state dict into transformer, CLIP, and T5 state dicts.
|
||||
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||
if layer_name.startswith("lora_unet"):
|
||||
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||
elif layer_name.startswith("lora_te1"):
|
||||
clip_grouped_sd[layer_name] = layer_state_dict
|
||||
elif layer_name.startswith("lora_te2"):
|
||||
t5_grouped_sd[layer_name] = layer_state_dict
|
||||
else:
|
||||
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
# Convert the state dicts to the InvokeAI format.
|
||||
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
|
||||
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
|
||||
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
for layer_key, layer_state_dict in transformer_grouped_sd.items():
|
||||
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
for layer_key, layer_state_dict in clip_grouped_sd.items():
|
||||
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
layers: list[tuple[str, BaseLayerPatch]] = []
|
||||
for model_prefix, grouped_sd in [
|
||||
(FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd),
|
||||
(FLUX_LORA_CLIP_PREFIX, clip_grouped_sd),
|
||||
(FLUX_LORA_T5_PREFIX, t5_grouped_sd),
|
||||
]:
|
||||
for layer_key, layer_state_dict in grouped_sd.items():
|
||||
layers.append((model_prefix + layer_key, any_lora_layer_from_state_dict(layer_state_dict)))
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return ModelPatchRaw(layers=layers)
|
||||
@@ -123,3 +141,31 @@ def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
|
||||
|
||||
def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by
|
||||
InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
|
||||
"lora_te2_encoder_block_0_layer_0_SelfAttention_k" -> "encoder.block.0.layer.0.SelfAttention.k"
|
||||
"lora_te2_encoder_block_0_layer_1_DenseReluDense_wi_0" -> "encoder.block.0.layer.1.DenseReluDense.wi.0"
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match[str]) -> str:
|
||||
s = f"encoder.block.{match.group(1)}.layer.{match.group(2)}.{match.group(3)}.{match.group(4)}"
|
||||
if match.group(5):
|
||||
s += f".{match.group(5)}"
|
||||
return s
|
||||
|
||||
converted_dict: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(FLUX_KOHYA_T5_KEY_REGEX, replace_func, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
|
||||
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
|
||||
FLUX_LORA_CLIP_PREFIX = "lora_clip-"
|
||||
FLUX_LORA_T5_PREFIX = "lora_t5-"
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_layers_from_flux_diffusers_grouped_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
FLUX_KOHYA_CLIP_KEY_REGEX,
|
||||
FLUX_KOHYA_T5_KEY_REGEX,
|
||||
_convert_flux_clip_kohya_state_dict_to_invoke_format,
|
||||
_convert_flux_t5_kohya_state_dict_to_invoke_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
|
||||
FLUX_LORA_CLIP_PREFIX,
|
||||
FLUX_LORA_T5_PREFIX,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.kohya_key_utils import (
|
||||
INDEX_PLACEHOLDER,
|
||||
ParsingTree,
|
||||
insert_periods_into_kohya_key,
|
||||
)
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
# A regex pattern that matches all of the transformer keys in the OneTrainer FLUX LoRA format.
|
||||
# The OneTrainer format uses a mix of the Kohya and Diffusers formats:
|
||||
# - The base model keys are in Diffusers format.
|
||||
# - Periods are replaced with underscores, to match Kohya.
|
||||
# - The LoRA key suffixes (e.g. .alpha, .lora_down.weight, .lora_up.weight) match Kohya.
|
||||
# Example keys:
|
||||
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.alpha"
|
||||
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.dora_scale"
|
||||
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight"
|
||||
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_up.weight"
|
||||
FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = (
|
||||
r"lora_transformer_(single_transformer_blocks|transformer_blocks)_(\d+)_(\w+)\.(.*)"
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
|
||||
Note that OneTrainer matches the Kohya format for the CLIP and T5 models.
|
||||
"""
|
||||
return all(
|
||||
re.match(FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_onetrainer_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: # type: ignore
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
layer_name, param_name = key.split(".", 1)
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Split the grouped state dict into transformer, CLIP, and T5 state dicts.
|
||||
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||
if layer_name.startswith("lora_transformer"):
|
||||
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||
elif layer_name.startswith("lora_te1"):
|
||||
clip_grouped_sd[layer_name] = layer_state_dict
|
||||
elif layer_name.startswith("lora_te2"):
|
||||
t5_grouped_sd[layer_name] = layer_state_dict
|
||||
else:
|
||||
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
# Convert the state dicts to the InvokeAI format.
|
||||
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
|
||||
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: list[tuple[str, BaseLayerPatch]] = []
|
||||
for model_prefix, grouped_sd in [
|
||||
# (FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd),
|
||||
(FLUX_LORA_CLIP_PREFIX, clip_grouped_sd),
|
||||
(FLUX_LORA_T5_PREFIX, t5_grouped_sd),
|
||||
]:
|
||||
for layer_key, layer_state_dict in grouped_sd.items():
|
||||
layers.append((model_prefix + layer_key, any_lora_layer_from_state_dict(layer_state_dict)))
|
||||
|
||||
# Handle the transformer.
|
||||
transformer_layers = _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(transformer_grouped_sd)
|
||||
layers.extend(transformer_layers)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
# This parsing tree was generated by calling `generate_kohya_parsing_tree_from_keys()` on the keys in
|
||||
# flux_lora_diffusers_format.py.
|
||||
flux_transformer_kohya_parsing_tree: ParsingTree = {
|
||||
"transformer": {
|
||||
"single_transformer_blocks": {
|
||||
INDEX_PLACEHOLDER: {
|
||||
"attn": {"to_k": {}, "to_q": {}, "to_v": {}},
|
||||
"norm": {"linear": {}},
|
||||
"proj_mlp": {},
|
||||
"proj_out": {},
|
||||
}
|
||||
},
|
||||
"transformer_blocks": {
|
||||
INDEX_PLACEHOLDER: {
|
||||
"attn": {
|
||||
"add_k_proj": {},
|
||||
"add_q_proj": {},
|
||||
"add_v_proj": {},
|
||||
"to_add_out": {},
|
||||
"to_k": {},
|
||||
"to_out": {INDEX_PLACEHOLDER: {}},
|
||||
"to_q": {},
|
||||
"to_v": {},
|
||||
},
|
||||
"ff": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}},
|
||||
"ff_context": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}},
|
||||
"norm1": {"linear": {}},
|
||||
"norm1_context": {"linear": {}},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(
|
||||
state_dict: Dict[str, Dict[str, torch.Tensor]],
|
||||
) -> list[tuple[str, BaseLayerPatch]]:
|
||||
"""Converts a FLUX transformer LoRA state dict from the OneTrainer FLUX LoRA format to the LoRA weight format used
|
||||
internally by InvokeAI.
|
||||
"""
|
||||
|
||||
# Step 1: Convert the Kohya-style keys with underscores to classic keys with periods.
|
||||
# Example:
|
||||
# "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight" -> "transformer.single_transformer_blocks.0.attn.to_k.lora_down.weight"
|
||||
lora_prefix = "lora_"
|
||||
lora_prefix_length = len(lora_prefix)
|
||||
kohya_state_dict: dict[str, Dict[str, torch.Tensor]] = {}
|
||||
for key in state_dict.keys():
|
||||
# Remove the "lora_" prefix.
|
||||
assert key.startswith(lora_prefix)
|
||||
new_key = key[lora_prefix_length:]
|
||||
|
||||
# Add periods to the Kohya-style module keys.
|
||||
new_key = insert_periods_into_kohya_key(new_key, flux_transformer_kohya_parsing_tree)
|
||||
|
||||
# Replace the old key with the new key.
|
||||
kohya_state_dict[new_key] = state_dict[key]
|
||||
|
||||
# Step 2: Convert diffusers module names to the BFL module names.
|
||||
return lora_layers_from_flux_diffusers_grouped_state_dict(kohya_state_dict, alpha=None)
|
||||
102
invokeai/backend/patches/lora_conversions/kohya_key_utils.py
Normal file
102
invokeai/backend/patches/lora_conversions/kohya_key_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Iterable
|
||||
|
||||
INDEX_PLACEHOLDER = "index_placeholder"
|
||||
|
||||
|
||||
# Type alias for a 'ParsingTree', which is a recursive dict with string keys.
|
||||
ParsingTree = dict[str, "ParsingTree"]
|
||||
|
||||
|
||||
def insert_periods_into_kohya_key(key: str, parsing_tree: ParsingTree) -> str:
|
||||
"""Insert periods into a Kohya key based on a parsing tree.
|
||||
|
||||
Kohya format keys are produced by replacing periods with underscores in the original key.
|
||||
|
||||
Example:
|
||||
```
|
||||
key = "module_a_module_b_0_attn_to_k"
|
||||
parsing_tree = {
|
||||
"module_a": {
|
||||
"module_b": {
|
||||
INDEX_PLACEHOLDER: {
|
||||
"attn": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
result = insert_periods_into_kohya_key(key, parsing_tree)
|
||||
> "module_a.module_b.0.attn.to_k"
|
||||
```
|
||||
"""
|
||||
# Split key into parts by underscore.
|
||||
parts = key.split("_")
|
||||
|
||||
# Build up result by walking through parsing tree and parts.
|
||||
result_parts: list[str] = []
|
||||
current_part = ""
|
||||
current_tree = parsing_tree
|
||||
|
||||
for part in parts:
|
||||
if len(current_part) > 0:
|
||||
current_part = current_part + "_"
|
||||
current_part += part
|
||||
|
||||
if current_part in current_tree:
|
||||
# Match found.
|
||||
current_tree = current_tree[current_part]
|
||||
result_parts.append(current_part)
|
||||
current_part = ""
|
||||
elif current_part.isnumeric() and INDEX_PLACEHOLDER in current_tree:
|
||||
# Match found with index placeholder.
|
||||
current_tree = current_tree[INDEX_PLACEHOLDER]
|
||||
result_parts.append(current_part)
|
||||
current_part = ""
|
||||
|
||||
if len(current_part) > 0:
|
||||
raise ValueError(f"Key {key} does not match parsing tree {parsing_tree}.")
|
||||
|
||||
return ".".join(result_parts)
|
||||
|
||||
|
||||
def generate_kohya_parsing_tree_from_keys(keys: Iterable[str]) -> ParsingTree:
|
||||
"""Generate a parsing tree from a list of keys.
|
||||
|
||||
Example:
|
||||
```
|
||||
keys = [
|
||||
"module_a.module_b.0.attn.to_k",
|
||||
"module_a.module_b.1.attn.to_k",
|
||||
"module_a.module_c.proj",
|
||||
]
|
||||
|
||||
tree = generate_kohya_parsing_tree_from_keys(keys)
|
||||
> {
|
||||
> "module_a": {
|
||||
> "module_b": {
|
||||
> INDEX_PLACEHOLDER: {
|
||||
> "attn": {
|
||||
> "to_k": {},
|
||||
> "to_q": {},
|
||||
> },
|
||||
> }
|
||||
> },
|
||||
> "module_c": {
|
||||
> "proj": {},
|
||||
> }
|
||||
> }
|
||||
> }
|
||||
```
|
||||
"""
|
||||
tree: ParsingTree = {}
|
||||
for key in keys:
|
||||
subtree: ParsingTree = tree
|
||||
for module_name in key.split("."):
|
||||
key = module_name
|
||||
if module_name.isnumeric():
|
||||
key = INDEX_PLACEHOLDER
|
||||
|
||||
if key not in subtree:
|
||||
subtree[key] = {}
|
||||
|
||||
subtree = subtree[key]
|
||||
return tree
|
||||
@@ -10,9 +10,9 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
layers: list[tuple[str, BaseLayerPatch]] = []
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(values)
|
||||
layers.append((layer_key, any_lora_layer_from_state_dict(values)))
|
||||
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
from typing import Mapping, Optional
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
@@ -8,12 +8,12 @@ from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class ModelPatchRaw(RawModel):
|
||||
def __init__(self, layers: Mapping[str, BaseLayerPatch]):
|
||||
def __init__(self, layers: Sequence[tuple[str, BaseLayerPatch]]):
|
||||
self.layers = layers
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for layer in self.layers.values():
|
||||
for _, layer in self.layers:
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return sum(layer.calc_size() for layer in self.layers.values())
|
||||
return sum(layer.calc_size() for _, layer in self.layers)
|
||||
|
||||
@@ -54,7 +54,9 @@ GGML_TENSOR_OP_TABLE = {
|
||||
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore
|
||||
}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -51,10 +51,8 @@ def test_lora_model_from_flux_control_state_dict(sd_keys: dict[str, list[int]]):
|
||||
k = k.replace("lora_B.bias", "")
|
||||
k = k.replace(".scale", "")
|
||||
expected_lora_layers.add(k)
|
||||
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
|
||||
# the Q weights so that we count these layers once).
|
||||
assert len(model.layers) == len(expected_lora_layers)
|
||||
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
|
||||
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k, _ in model.layers)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_control_state_dict_extra_keys_error():
|
||||
|
||||
@@ -6,6 +6,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
|
||||
state_dict_keys as flux_onetrainer_state_dict_keys,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
@@ -27,12 +30,13 @@ def test_is_state_dict_likely_in_flux_diffusers_format_true(sd_keys: dict[str, l
|
||||
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_diffusers_format_false():
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_onetrainer_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_diffusers_format_false(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is in the Kohya
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
# Construct a state dict that is not in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
|
||||
|
||||
@@ -51,12 +55,8 @@ def test_lora_model_from_flux_diffusers_state_dict(sd_keys: dict[str, list[int]]
|
||||
k = k.replace("lora_A.weight", "")
|
||||
k = k.replace("lora_B.weight", "")
|
||||
expected_lora_layers.add(k)
|
||||
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
|
||||
# the Q weights so that we count these layers once).
|
||||
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
|
||||
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
|
||||
assert len(model.layers) == len(expected_lora_layers)
|
||||
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
|
||||
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k, _ in model.layers)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
|
||||
|
||||
@@ -13,6 +13,9 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
|
||||
FLUX_LORA_CLIP_PREFIX,
|
||||
FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
|
||||
state_dict_keys as flux_onetrainer_state_dict_keys,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
@@ -34,11 +37,12 @@ def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[
|
||||
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_false():
|
||||
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is in the Diffusers
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
|
||||
@@ -106,6 +110,6 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: dict[str, list[int]]):
|
||||
expected_layer_keys.add(k)
|
||||
|
||||
# Assert that the lora_model has the expected layers.
|
||||
lora_model_keys = set(lora_model.layers.keys())
|
||||
lora_model_keys = {k for k, _ in lora_model.layers}
|
||||
lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
|
||||
assert lora_model_keys == expected_layer_keys
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
|
||||
FLUX_LORA_CLIP_PREFIX,
|
||||
FLUX_LORA_T5_PREFIX,
|
||||
FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_format,
|
||||
lora_model_from_flux_onetrainer_state_dict,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
|
||||
state_dict_keys as flux_onetrainer_state_dict_keys,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_format import (
|
||||
state_dict_keys as flux_kohya_state_dict_keys,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
|
||||
state_dict_keys as flux_kohya_te1_state_dict_keys,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_onetrainer_format_true():
|
||||
"""Test that is_state_dict_likely_in_flux_onetrainer_format() can identify a state dict in the OneTrainer
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
# Construct a state dict that is in the OneTrainer FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_onetrainer_state_dict_keys)
|
||||
|
||||
assert is_state_dict_likely_in_flux_onetrainer_format(state_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sd_keys",
|
||||
[
|
||||
flux_kohya_state_dict_keys,
|
||||
flux_kohya_te1_state_dict_keys,
|
||||
flux_diffusers_state_dict_keys,
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_likely_in_flux_onetrainer_format_false(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_in_flux_onetrainer_format() returns False for a state dict that is in the Diffusers
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
assert not is_state_dict_likely_in_flux_onetrainer_format(state_dict)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_onetrainer_state_dict():
|
||||
state_dict = keys_to_mock_state_dict(flux_onetrainer_state_dict_keys)
|
||||
|
||||
lora_model = lora_model_from_flux_onetrainer_state_dict(state_dict)
|
||||
|
||||
# Check that the model has the correct number of LoRA layers.
|
||||
expected_lora_layers: set[str] = set()
|
||||
for k in flux_onetrainer_state_dict_keys:
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
k = k.replace(".alpha", "")
|
||||
k = k.replace(".dora_scale", "")
|
||||
expected_lora_layers.add(k)
|
||||
|
||||
assert len(lora_model.layers) == len(expected_lora_layers)
|
||||
|
||||
# Check that all of the layers have the expected prefix.
|
||||
assert all(
|
||||
k.startswith((FLUX_LORA_TRANSFORMER_PREFIX, FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX))
|
||||
for k, _ in lora_model.layers
|
||||
)
|
||||
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.patches.lora_conversions.kohya_key_utils import (
|
||||
INDEX_PLACEHOLDER,
|
||||
ParsingTree,
|
||||
generate_kohya_parsing_tree_from_keys,
|
||||
insert_periods_into_kohya_key,
|
||||
)
|
||||
|
||||
|
||||
def test_insert_periods_into_kohya_key():
|
||||
"""Test that insert_periods_into_kohya_key() correctly inserts periods into a Kohya key."""
|
||||
key = "module_a_module_b_0_attn_to_k"
|
||||
parsing_tree: ParsingTree = {
|
||||
"module_a": {
|
||||
"module_b": {
|
||||
INDEX_PLACEHOLDER: {
|
||||
"attn": {
|
||||
"to_k": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
result = insert_periods_into_kohya_key(key, parsing_tree)
|
||||
assert result == "module_a.module_b.0.attn.to_k"
|
||||
|
||||
|
||||
def test_insert_periods_into_kohya_key_invalid_key():
|
||||
"""Test that insert_periods_into_kohya_key() raises ValueError for a key that is invalid."""
|
||||
key = "invalid_key_format"
|
||||
parsing_tree: ParsingTree = {
|
||||
"module_a": {
|
||||
"module_b": {
|
||||
INDEX_PLACEHOLDER: {
|
||||
"attn": {
|
||||
"to_k": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
insert_periods_into_kohya_key(key, parsing_tree)
|
||||
|
||||
|
||||
def test_insert_periods_into_kohya_key_too_long():
|
||||
"""Test that insert_periods_into_kohya_key() raises ValueError for a key that has a valid prefix, but is too long."""
|
||||
key = "module_a.module_b.0.attn.to_k.invalid_suffix"
|
||||
parsing_tree: ParsingTree = {
|
||||
"module_a": {
|
||||
"module_b": {
|
||||
INDEX_PLACEHOLDER: {
|
||||
"attn": {
|
||||
"to_k": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
insert_periods_into_kohya_key(key, parsing_tree)
|
||||
|
||||
|
||||
def test_generate_kohya_parsing_tree_from_keys():
|
||||
"""Test that generate_kohya_parsing_tree_from_keys() correctly generates a parsing tree."""
|
||||
keys = [
|
||||
"module_a.module_b.0.attn.to_k",
|
||||
"module_a.module_b.1.attn.to_k",
|
||||
"module_a.module_c.proj",
|
||||
]
|
||||
|
||||
expected_tree: ParsingTree = {
|
||||
"module_a": {
|
||||
"module_b": {
|
||||
INDEX_PLACEHOLDER: {
|
||||
"attn": {
|
||||
"to_k": {},
|
||||
},
|
||||
}
|
||||
},
|
||||
"module_c": {
|
||||
"proj": {},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tree = generate_kohya_parsing_tree_from_keys(keys)
|
||||
assert tree == expected_tree
|
||||
|
||||
|
||||
def test_generate_kohya_parsing_tree_from_empty_keys():
|
||||
"""Test that generate_kohya_parsing_tree_from_keys() handles empty input."""
|
||||
keys: list[str] = []
|
||||
tree = generate_kohya_parsing_tree_from_keys(keys)
|
||||
assert tree == {}
|
||||
@@ -58,14 +58,21 @@ def test_apply_smart_model_patches(
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
lora_layers = [
|
||||
(
|
||||
"linear_layer_1",
|
||||
LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones(
|
||||
(lora_rank, linear_in_features), device="cpu", dtype=torch.float16
|
||||
),
|
||||
"lora_up.weight": torch.ones(
|
||||
(linear_out_features, lora_rank), device="cpu", dtype=torch.float16
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
]
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
@@ -104,8 +111,8 @@ def test_apply_smart_model_patches(
|
||||
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
for lora, _ in lora_models:
|
||||
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
|
||||
assert lora.layers[0][1].up.device.type == "cpu"
|
||||
assert lora.layers[0][1].down.device.type == "cpu"
|
||||
|
||||
output_during_patch = model(input)
|
||||
|
||||
@@ -150,20 +157,34 @@ def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int):
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
lora_layers = [
|
||||
(
|
||||
"linear_layer_1",
|
||||
LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones(
|
||||
(lora_rank, linear_in_features), device="cpu", dtype=torch.float16
|
||||
),
|
||||
"lora_up.weight": torch.ones(
|
||||
(linear_out_features, lora_rank), device="cpu", dtype=torch.float16
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
"linear_layer_2": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
(
|
||||
"linear_layer_2",
|
||||
LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones(
|
||||
(lora_rank, linear_out_features), device="cpu", dtype=torch.float16
|
||||
),
|
||||
"lora_up.weight": torch.ones(
|
||||
(linear_out_features, lora_rank), device="cpu", dtype=torch.float16
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
}
|
||||
]
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
@@ -204,14 +225,21 @@ def test_all_patching_methods_produce_same_output(num_loras: int):
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
lora_layers = [
|
||||
(
|
||||
"linear_layer_1",
|
||||
LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones(
|
||||
(lora_rank, linear_in_features), device="cpu", dtype=torch.float16
|
||||
),
|
||||
"lora_up.weight": torch.ones(
|
||||
(linear_out_features, lora_rank), device="cpu", dtype=torch.float16
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
]
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
@@ -249,14 +277,17 @@ def test_apply_smart_model_patches_change_device():
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
lora_layers = [
|
||||
(
|
||||
"linear_layer_1",
|
||||
LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
]
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
@@ -265,8 +296,8 @@ def test_apply_smart_model_patches_change_device():
|
||||
model=model, patches=[(lora, 0.5)], prefix="", dtype=torch.float16, force_direct_patching=True
|
||||
):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
assert lora_layers[0][1].up.device.type == "cpu"
|
||||
assert lora_layers[0][1].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on the CPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cpu"
|
||||
@@ -292,14 +323,17 @@ def test_apply_smart_model_patches_force_sidecar_and_direct_patching():
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
lora_layers = [
|
||||
(
|
||||
"linear_layer_1",
|
||||
LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
]
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
with pytest.raises(ValueError, match="Cannot force both direct and sidecar patching."):
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
|
||||
Reference in New Issue
Block a user