mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 20:58:04 -05:00
Support FLUX LoRA models in kohya format with lora_te1 layers (i.e. CLIP LoRA layers) (#6967)
## Summary This PR add support for FLUX LoRA models in kohya format with `lora_te1` layers (i.e. CLIP LoRA layers). Previously, only transformer LoRA layers were supported. Example LoRA model in this format: https://huggingface.co/cocktailpeanut/optimus ### Example Prompt: `optimus is playing tennis in a tennis court` Seed: 0 Without LoRA:  With LoRA:  ## QA Instructions I tested the following: - [x] The optimus LoRA (with CLIP layers) can be applied. - [x] FLUX LoRAs without CLIP layers still work - [x] Loading the optimus LoRA, but applying it to the transformer _only_ produces a different result. I.e. verified that patching the CLIP layers is doing _something_. Ironically, the results seem better without applying the CLIP layers. The CLIP layers seem to pull in more background concepts. Regardless, it works. - [x] The optimus LoRA can be applied via the Linear UI, and the output matches results from manually constructing the workflow graph. - [x] FLUX LoRAs without CLIP layers still work via the Linear UI. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
This commit is contained in:
@@ -30,6 +30,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
@@ -208,7 +209,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
@@ -219,7 +220,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
LoRAPatcher.apply_lora_sidecar_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
|
||||
@@ -20,6 +20,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -27,21 +28,28 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer."""
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField = InputField(
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
clip: CLIPField | None = InputField(
|
||||
default=None,
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
@@ -49,18 +57,33 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
|
||||
|
||||
transformer = self.transformer.model_copy(deep=True)
|
||||
transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
output = FluxLoRALoaderOutput()
|
||||
|
||||
# Attach LoRA layers to the models.
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
if self.clip is not None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return FluxLoRALoaderOutput(transformer=transformer)
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -68,7 +91,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
title="FLUX LoRA Collection Loader",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
@@ -84,6 +107,12 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
clip: CLIPField | None = InputField(
|
||||
default=None,
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
output = FluxLoRALoaderOutput()
|
||||
@@ -106,4 +135,9 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(lora)
|
||||
|
||||
if self.clip is not None:
|
||||
if output.clip is None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(lora)
|
||||
|
||||
return output
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Literal
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
@@ -9,6 +10,10 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@@ -17,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@@ -78,15 +83,42 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
|
||||
# Apply LoRA models to the CLIP encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context),
|
||||
prefix=FLUX_KOHYA_CLIP_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -7,14 +7,25 @@ from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||
FLUX_KOHYA_KEY_REGEX = (
|
||||
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
|
||||
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
|
||||
|
||||
|
||||
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
|
||||
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
|
||||
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
@@ -23,7 +34,10 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
|
||||
return all(
|
||||
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
@@ -35,13 +49,27 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Convert the state dict to the InvokeAI format.
|
||||
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
||||
# Split the grouped state dict into transformer and CLIP state dicts.
|
||||
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||
if layer_name.startswith("lora_unet"):
|
||||
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||
elif layer_name.startswith("lora_te1"):
|
||||
clip_grouped_sd[layer_name] = layer_state_dict
|
||||
else:
|
||||
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
# Convert the state dicts to the InvokeAI format.
|
||||
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
|
||||
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
for layer_key, layer_state_dict in transformer_grouped_sd.items():
|
||||
layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
for layer_key, layer_state_dict in clip_grouped_sd.items():
|
||||
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
@@ -50,16 +78,34 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
def _convert_flux_clip_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a CLIP LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by
|
||||
InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
|
||||
"lora_te1_text_model_encoder_layers_0_mlp_fc1" -> "text_model.encoder.layers.0.mlp.fc1",
|
||||
"lora_te1_text_model_encoder_layers_0_self_attn_k_proj" -> "text_model.encoder.layers.0.self_attn.k_proj"
|
||||
"""
|
||||
converted_sd: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = f"text_model.encoder.layers.{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||
converted_sd[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_sd
|
||||
|
||||
|
||||
def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a FLUX tranformer LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally
|
||||
by InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match[str]) -> str:
|
||||
@@ -70,9 +116,9 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) ->
|
||||
|
||||
converted_dict: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||
match = re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
||||
new_key = re.sub(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, replace_func, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
@@ -8,7 +8,8 @@ export const addFLUXLoRAs = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'flux_denoise'>,
|
||||
modelLoader: Invocation<'flux_model_loader'>
|
||||
modelLoader: Invocation<'flux_model_loader'>,
|
||||
fluxTextEncoder: Invocation<'flux_text_encoder'>
|
||||
): void => {
|
||||
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux');
|
||||
const loraCount = enabledLoRAs.length;
|
||||
@@ -20,7 +21,7 @@ export const addFLUXLoRAs = (
|
||||
const loraMetadata: S['LoRAMetadataField'][] = [];
|
||||
|
||||
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
|
||||
// each LoRA to the UNet and CLIP.
|
||||
// each LoRA to the transformer and text encoders.
|
||||
const loraCollector = g.addNode({
|
||||
id: getPrefixedId('lora_collector'),
|
||||
type: 'collect',
|
||||
@@ -33,10 +34,12 @@ export const addFLUXLoRAs = (
|
||||
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
|
||||
// Use model loader as transformer input
|
||||
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
|
||||
// Reroute transformer connections through the LoRA collection loader
|
||||
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
|
||||
// Reroute model connections through the LoRA collection loader
|
||||
g.deleteEdgesTo(denoise, ['transformer']);
|
||||
|
||||
g.deleteEdgesTo(fluxTextEncoder, ['clip']);
|
||||
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
|
||||
g.addEdge(loraCollectionLoader, 'clip', fluxTextEncoder, 'clip');
|
||||
|
||||
for (const lora of enabledLoRAs) {
|
||||
const { weight } = lora;
|
||||
|
||||
@@ -6,6 +6,7 @@ import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSe
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
|
||||
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||
@@ -18,8 +19,6 @@ import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { addFLUXLoRAs } from './addFLUXLoRAs';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFLUXGraph = async (
|
||||
@@ -96,12 +95,12 @@ export const buildFLUXGraph = async (
|
||||
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
|
||||
addFLUXLoRAs(state, g, noise, modelLoader);
|
||||
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
||||
|
||||
addFLUXLoRAs(state, g, noise, modelLoader, posCond);
|
||||
|
||||
g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning');
|
||||
|
||||
g.addEdge(noise, 'latents', l2i, 'latents');
|
||||
|
||||
@@ -5707,6 +5707,12 @@ export type components = {
|
||||
* @default null
|
||||
*/
|
||||
transformer?: components["schemas"]["TransformerField"] | null;
|
||||
/**
|
||||
* CLIP
|
||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||
* @default null
|
||||
*/
|
||||
clip?: components["schemas"]["CLIPField"] | null;
|
||||
/**
|
||||
* type
|
||||
* @default flux_lora_collection_loader
|
||||
@@ -6391,7 +6397,7 @@ export type components = {
|
||||
};
|
||||
/**
|
||||
* FLUX LoRA
|
||||
* @description Apply a LoRA model to a FLUX transformer.
|
||||
* @description Apply a LoRA model to a FLUX transformer and/or text encoder.
|
||||
*/
|
||||
FluxLoRALoaderInvocation: {
|
||||
/**
|
||||
@@ -6428,7 +6434,13 @@ export type components = {
|
||||
* @description Transformer
|
||||
* @default null
|
||||
*/
|
||||
transformer?: components["schemas"]["TransformerField"];
|
||||
transformer?: components["schemas"]["TransformerField"] | null;
|
||||
/**
|
||||
* CLIP
|
||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||
* @default null
|
||||
*/
|
||||
clip?: components["schemas"]["CLIPField"] | null;
|
||||
/**
|
||||
* type
|
||||
* @default flux_lora_loader
|
||||
@@ -6448,6 +6460,12 @@ export type components = {
|
||||
* @default null
|
||||
*/
|
||||
transformer: components["schemas"]["TransformerField"] | null;
|
||||
/**
|
||||
* CLIP
|
||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||
* @default null
|
||||
*/
|
||||
clip: components["schemas"]["CLIPField"] | null;
|
||||
/**
|
||||
* type
|
||||
* @default flux_lora_loader_output
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,9 @@ import torch
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
convert_flux_kohya_state_dict_to_invoke_format,
|
||||
FLUX_KOHYA_CLIP_PREFIX,
|
||||
FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
@@ -15,13 +17,17 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
|
||||
state_dict_keys as flux_kohya_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
|
||||
state_dict_keys as flux_kohya_te1_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true():
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
@@ -34,11 +40,11 @@ def test_is_state_dict_likely_in_flux_kohya_format_false():
|
||||
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format():
|
||||
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
|
||||
# Construct state_dict from state_dict_keys.
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
|
||||
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
converted_state_dict = _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
|
||||
# .alpha suffixes).
|
||||
@@ -65,29 +71,33 @@ def test_convert_flux_kohya_state_dict_to_invoke_format():
|
||||
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
|
||||
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
|
||||
unexpected keys.
|
||||
def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
|
||||
"""Test that an error is raised by _convert_flux_transformer_kohya_state_dict_to_invoke_format() if the input
|
||||
state_dict contains unexpected keys.
|
||||
"""
|
||||
state_dict = {
|
||||
"unexpected_key.lora_up.weight": torch.empty(1),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
_convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_kohya_state_dict():
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
|
||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
|
||||
|
||||
# Prepare expected layer keys.
|
||||
expected_layer_keys: set[str] = set()
|
||||
for k in flux_kohya_state_dict_keys:
|
||||
k = k.replace("lora_unet_", "")
|
||||
for k in sd_keys:
|
||||
# Replace prefixes.
|
||||
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
|
||||
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
|
||||
# Remove suffixes.
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
k = k.replace(".alpha", "")
|
||||
|
||||
Reference in New Issue
Block a user