Compare commits

...

8 Commits

10 changed files with 1718 additions and 56 deletions

View File

@@ -30,6 +30,7 @@ from invokeai.backend.flux.sampling_utils import (
pack, pack,
unpack, unpack,
) )
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.model_manager.config import ModelFormat
@@ -208,7 +209,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
LoRAPatcher.apply_lora_patches( LoRAPatcher.apply_lora_patches(
model=transformer, model=transformer,
patches=self._lora_iterator(context), patches=self._lora_iterator(context),
prefix="", prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
cached_weights=cached_weights, cached_weights=cached_weights,
) )
) )
@@ -219,7 +220,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
LoRAPatcher.apply_lora_sidecar_patches( LoRAPatcher.apply_lora_sidecar_patches(
model=transformer, model=transformer,
patches=self._lora_iterator(context), patches=self._lora_iterator(context),
prefix="", prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
dtype=inference_dtype, dtype=inference_dtype,
) )
) )

View File

@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType from invokeai.backend.model_manager.config import BaseModelType
@@ -20,6 +20,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
transformer: Optional[TransformerField] = OutputField( transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer" default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
) )
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation( @invocation(
@@ -27,21 +28,28 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
title="FLUX LoRA", title="FLUX LoRA",
tags=["lora", "model", "flux"], tags=["lora", "model", "flux"],
category="model", category="model",
version="1.0.0", version="1.1.0",
classification=Classification.Prototype, classification=Classification.Prototype,
) )
class FluxLoRALoaderInvocation(BaseInvocation): class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer.""" """Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField( transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer, description=FieldDescriptions.transformer,
input=Input.Connection, input=Input.Connection,
title="FLUX Transformer", title="FLUX Transformer",
) )
clip: CLIPField | None = InputField(
default=None,
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput: def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key lora_key = self.lora.key
@@ -49,18 +57,33 @@ class FluxLoRALoaderInvocation(BaseInvocation):
if not context.models.exists(lora_key): if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!") raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras): # Check for existing LoRAs with the same key.
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.') raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
transformer = self.transformer.model_copy(deep=True) output = FluxLoRALoaderOutput()
transformer.loras.append(
LoRAField( # Attach LoRA layers to the models.
lora=self.lora, if self.transformer is not None:
weight=self.weight, output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
) )
)
return FluxLoRALoaderOutput(transformer=transformer) return output
@invocation( @invocation(
@@ -68,7 +91,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
title="FLUX LoRA Collection Loader", title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"], tags=["lora", "model", "flux"],
category="model", category="model",
version="1.0.0", version="1.1.0",
classification=Classification.Prototype, classification=Classification.Prototype,
) )
class FLUXLoRACollectionLoader(BaseInvocation): class FLUXLoRACollectionLoader(BaseInvocation):
@@ -84,6 +107,12 @@ class FLUXLoRACollectionLoader(BaseInvocation):
input=Input.Connection, input=Input.Connection,
title="Transformer", title="Transformer",
) )
clip: CLIPField | None = InputField(
default=None,
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput: def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput() output = FluxLoRALoaderOutput()
@@ -106,4 +135,9 @@ class FLUXLoRACollectionLoader(BaseInvocation):
output.transformer = self.transformer.model_copy(deep=True) output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(lora) output.transformer.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
return output return output

View File

@@ -1,4 +1,5 @@
from typing import Literal from contextlib import ExitStack
from typing import Iterator, Literal, Tuple
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
@@ -9,6 +10,10 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@@ -17,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
title="FLUX Text Encoding", title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"], tags=["prompt", "conditioning", "flux"],
category="conditioning", category="conditioning",
version="1.0.0", version="1.1.0",
classification=Classification.Prototype, classification=Classification.Prototype,
) )
class FluxTextEncoderInvocation(BaseInvocation): class FluxTextEncoderInvocation(BaseInvocation):
@@ -78,15 +83,42 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt = [self.prompt] prompt = [self.prompt]
with ( with (
clip_text_encoder_info as clip_text_encoder, clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer, clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
): ):
assert isinstance(clip_text_encoder, CLIPTextModel) assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer) assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_KOHYA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
else:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77) clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt) pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(pooled_prompt_embeds, torch.Tensor) assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds return pooled_prompt_embeds
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -7,14 +7,25 @@ from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format. # A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
# Example keys: # Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha # lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight # lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight # lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = ( FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)" r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
) )
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
@@ -23,7 +34,10 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
""" """
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys()) return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw: def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
@@ -35,13 +49,27 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
grouped_state_dict[layer_name] = {} grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format. # Split the grouped state dict into transformer and CLIP state dicts.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict) transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
# Create LoRA layers. # Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {} layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items(): for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict) layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw. # Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers) return LoRAModelRaw(layers=layers)
@@ -50,16 +78,34 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
T = TypeVar("T") T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]: def _convert_flux_clip_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI. """Converts a CLIP LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by
InvokeAI.
Example key conversions:
"lora_te1_text_model_encoder_layers_0_mlp_fc1" -> "text_model.encoder.layers.0.mlp.fc1",
"lora_te1_text_model_encoder_layers_0_self_attn_k_proj" -> "text_model.encoder.layers.0.self_attn.k_proj"
"""
converted_sd: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
if match:
new_key = f"text_model.encoder.layers.{match.group(1)}.{match.group(2)}.{match.group(3)}"
converted_sd[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_sd
def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a FLUX tranformer LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally
by InvokeAI.
Example key conversions: Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj" "lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv" "lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
""" """
def replace_func(match: re.Match[str]) -> str: def replace_func(match: re.Match[str]) -> str:
@@ -70,9 +116,9 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) ->
converted_dict: dict[str, T] = {} converted_dict: dict[str, T] = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k) match = re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
if match: if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k) new_key = re.sub(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v converted_dict[new_key] = v
else: else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.") raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")

View File

@@ -8,7 +8,8 @@ export const addFLUXLoRAs = (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'flux_denoise'>, denoise: Invocation<'flux_denoise'>,
modelLoader: Invocation<'flux_model_loader'> modelLoader: Invocation<'flux_model_loader'>,
fluxTextEncoder: Invocation<'flux_text_encoder'>
): void => { ): void => {
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux'); const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux');
const loraCount = enabledLoRAs.length; const loraCount = enabledLoRAs.length;
@@ -20,7 +21,7 @@ export const addFLUXLoRAs = (
const loraMetadata: S['LoRAMetadataField'][] = []; const loraMetadata: S['LoRAMetadataField'][] = [];
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
// each LoRA to the UNet and CLIP. // each LoRA to the transformer and text encoders.
const loraCollector = g.addNode({ const loraCollector = g.addNode({
id: getPrefixedId('lora_collector'), id: getPrefixedId('lora_collector'),
type: 'collect', type: 'collect',
@@ -33,10 +34,12 @@ export const addFLUXLoRAs = (
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
// Use model loader as transformer input // Use model loader as transformer input
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer'); g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
// Reroute transformer connections through the LoRA collection loader g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
// Reroute model connections through the LoRA collection loader
g.deleteEdgesTo(denoise, ['transformer']); g.deleteEdgesTo(denoise, ['transformer']);
g.deleteEdgesTo(fluxTextEncoder, ['clip']);
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer'); g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
g.addEdge(loraCollectionLoader, 'clip', fluxTextEncoder, 'clip');
for (const lora of enabledLoRAs) { for (const lora of enabledLoRAs) {
const { weight } = lora; const { weight } = lora;

View File

@@ -6,6 +6,7 @@ import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSe
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
@@ -18,8 +19,6 @@ import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { addFLUXLoRAs } from './addFLUXLoRAs';
const log = logger('system'); const log = logger('system');
export const buildFLUXGraph = async ( export const buildFLUXGraph = async (
@@ -96,12 +95,12 @@ export const buildFLUXGraph = async (
g.addEdge(modelLoader, 'transformer', noise, 'transformer'); g.addEdge(modelLoader, 'transformer', noise, 'transformer');
g.addEdge(modelLoader, 'vae', l2i, 'vae'); g.addEdge(modelLoader, 'vae', l2i, 'vae');
addFLUXLoRAs(state, g, noise, modelLoader);
g.addEdge(modelLoader, 'clip', posCond, 'clip'); g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder'); g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len'); g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
addFLUXLoRAs(state, g, noise, modelLoader, posCond);
g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning'); g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning');
g.addEdge(noise, 'latents', l2i, 'latents'); g.addEdge(noise, 'latents', l2i, 'latents');

View File

@@ -5707,6 +5707,12 @@ export type components = {
* @default null * @default null
*/ */
transformer?: components["schemas"]["TransformerField"] | null; transformer?: components["schemas"]["TransformerField"] | null;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
* @default null
*/
clip?: components["schemas"]["CLIPField"] | null;
/** /**
* type * type
* @default flux_lora_collection_loader * @default flux_lora_collection_loader
@@ -6391,7 +6397,7 @@ export type components = {
}; };
/** /**
* FLUX LoRA * FLUX LoRA
* @description Apply a LoRA model to a FLUX transformer. * @description Apply a LoRA model to a FLUX transformer and/or text encoder.
*/ */
FluxLoRALoaderInvocation: { FluxLoRALoaderInvocation: {
/** /**
@@ -6428,7 +6434,13 @@ export type components = {
* @description Transformer * @description Transformer
* @default null * @default null
*/ */
transformer?: components["schemas"]["TransformerField"]; transformer?: components["schemas"]["TransformerField"] | null;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
* @default null
*/
clip?: components["schemas"]["CLIPField"] | null;
/** /**
* type * type
* @default flux_lora_loader * @default flux_lora_loader
@@ -6448,6 +6460,12 @@ export type components = {
* @default null * @default null
*/ */
transformer: components["schemas"]["TransformerField"] | null; transformer: components["schemas"]["TransformerField"] | null;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
* @default null
*/
clip: components["schemas"]["CLIPField"] | null;
/** /**
* type * type
* @default flux_lora_loader_output * @default flux_lora_loader_output

View File

@@ -0,0 +1,386 @@
# A sample state dict in the Diffusers FLUX LoRA format with non-standard PEFT weight keys.
# I.e. key suffixes of `.lora.down.weight` and `.lora.up.weight` instead of `.lora_A.weight` and `.lora_B.weight`.
# These keys are based on the LoRA model here:
# https://civitai.com/models/684810/flux1-dev-cctv-mania
state_dict_keys = [
"transformer.single_transformer_blocks.0.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.0.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.0.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.0.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.1.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.1.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.1.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.1.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.1.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.1.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.10.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.10.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.10.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.10.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.10.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.10.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.11.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.11.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.11.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.11.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.11.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.11.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.12.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.12.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.12.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.12.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.12.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.12.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.13.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.13.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.13.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.13.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.13.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.13.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.14.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.14.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.14.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.14.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.14.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.14.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.15.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.15.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.15.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.15.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.15.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.15.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.16.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.16.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.16.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.16.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.16.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.16.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.17.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.17.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.17.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.17.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.17.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.17.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.18.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.18.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.18.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.18.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.18.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.18.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.19.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.19.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.19.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.19.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.19.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.19.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.2.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.2.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.2.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.2.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.2.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.2.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.20.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.20.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.20.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.20.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.20.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.20.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.21.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.21.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.21.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.21.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.21.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.21.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.22.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.22.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.22.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.22.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.22.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.22.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.23.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.23.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.23.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.23.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.23.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.23.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.24.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.24.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.24.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.24.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.24.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.24.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.25.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.25.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.25.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.25.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.25.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.25.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.26.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.26.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.26.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.26.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.26.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.26.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.27.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.27.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.27.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.27.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.27.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.27.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.28.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.28.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.28.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.28.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.28.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.28.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.29.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.29.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.29.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.29.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.29.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.29.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.3.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.3.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.3.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.3.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.3.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.3.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.30.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.30.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.30.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.30.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.30.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.30.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.31.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.31.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.31.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.31.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.31.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.31.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.32.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.32.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.32.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.32.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.32.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.32.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.33.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.33.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.33.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.33.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.33.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.33.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.34.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.34.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.34.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.34.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.34.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.34.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.35.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.35.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.35.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.35.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.35.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.35.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.36.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.36.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.36.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.36.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.36.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.36.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.37.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.37.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.37.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.37.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.37.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.37.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.4.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.4.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.4.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.4.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.4.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.4.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.5.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.5.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.5.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.5.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.5.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.5.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.6.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.6.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.6.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.6.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.6.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.6.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.7.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.7.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.7.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.7.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.7.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.7.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.8.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.8.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.8.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.8.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.8.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.8.attn.to_v.lora.up.weight",
"transformer.single_transformer_blocks.9.attn.to_k.lora.down.weight",
"transformer.single_transformer_blocks.9.attn.to_k.lora.up.weight",
"transformer.single_transformer_blocks.9.attn.to_q.lora.down.weight",
"transformer.single_transformer_blocks.9.attn.to_q.lora.up.weight",
"transformer.single_transformer_blocks.9.attn.to_v.lora.down.weight",
"transformer.single_transformer_blocks.9.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.0.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.0.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.0.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.0.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.0.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.0.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.0.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.0.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.1.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.1.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.1.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.1.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.1.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.1.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.1.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.1.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.10.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.10.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.10.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.10.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.10.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.10.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.10.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.10.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.11.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.11.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.11.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.11.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.11.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.11.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.11.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.11.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.12.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.12.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.12.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.12.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.12.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.12.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.12.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.12.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.13.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.13.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.13.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.13.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.13.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.13.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.13.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.13.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.14.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.14.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.14.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.14.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.14.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.14.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.14.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.14.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.15.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.15.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.15.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.15.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.15.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.15.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.15.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.15.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.16.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.16.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.16.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.16.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.16.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.16.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.16.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.16.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.17.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.17.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.17.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.17.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.17.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.17.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.17.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.17.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.18.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.18.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.18.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.18.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.18.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.18.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.18.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.18.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.2.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.2.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.2.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.2.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.2.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.2.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.2.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.2.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.3.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.3.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.3.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.3.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.3.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.3.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.3.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.3.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.4.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.4.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.4.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.4.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.4.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.4.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.4.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.4.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.5.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.5.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.5.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.5.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.5.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.5.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.5.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.5.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.6.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.6.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.6.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.6.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.6.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.6.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.6.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.6.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.7.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.7.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.7.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.7.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.7.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.7.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.7.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.7.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.8.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.8.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.8.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.8.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.8.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.8.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.8.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.8.attn.to_v.lora.up.weight",
"transformer.transformer_blocks.9.attn.to_k.lora.down.weight",
"transformer.transformer_blocks.9.attn.to_k.lora.up.weight",
"transformer.transformer_blocks.9.attn.to_out.0.lora.down.weight",
"transformer.transformer_blocks.9.attn.to_out.0.lora.up.weight",
"transformer.transformer_blocks.9.attn.to_q.lora.down.weight",
"transformer.transformer_blocks.9.attn.to_q.lora.up.weight",
"transformer.transformer_blocks.9.attn.to_v.lora.down.weight",
"transformer.transformer_blocks.9.attn.to_v.lora.up.weight",
]

View File

@@ -5,7 +5,9 @@ import torch
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import ( from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
convert_flux_kohya_state_dict_to_invoke_format, FLUX_KOHYA_CLIP_PREFIX,
FLUX_KOHYA_TRANFORMER_PREFIX,
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format, is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict, lora_model_from_flux_kohya_state_dict,
) )
@@ -15,13 +17,17 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import ( from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys, state_dict_keys as flux_kohya_state_dict_keys,
) )
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import (
state_dict_keys as flux_kohya_te1_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_kohya_format_true(): @pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format.""" """Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format. # Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys) state_dict = keys_to_mock_state_dict(sd_keys)
assert is_state_dict_likely_in_flux_kohya_format(state_dict) assert is_state_dict_likely_in_flux_kohya_format(state_dict)
@@ -34,11 +40,11 @@ def test_is_state_dict_likely_in_flux_kohya_format_false():
assert not is_state_dict_likely_in_flux_kohya_format(state_dict) assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_convert_flux_kohya_state_dict_to_invoke_format(): def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
# Construct state_dict from state_dict_keys. # Construct state_dict from state_dict_keys.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys) state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict) converted_state_dict = _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and # Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
# .alpha suffixes). # .alpha suffixes).
@@ -65,29 +71,33 @@ def test_convert_flux_kohya_state_dict_to_invoke_format():
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}") raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
def test_convert_flux_kohya_state_dict_to_invoke_format_error(): def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains """Test that an error is raised by _convert_flux_transformer_kohya_state_dict_to_invoke_format() if the input
unexpected keys. state_dict contains unexpected keys.
""" """
state_dict = { state_dict = {
"unexpected_key.lora_up.weight": torch.empty(1), "unexpected_key.lora_up.weight": torch.empty(1),
} }
with pytest.raises(ValueError): with pytest.raises(ValueError):
convert_flux_kohya_state_dict_to_invoke_format(state_dict) _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict)
def test_lora_model_from_flux_kohya_state_dict(): @pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format.""" """Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format. # Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys) state_dict = keys_to_mock_state_dict(sd_keys)
lora_model = lora_model_from_flux_kohya_state_dict(state_dict) lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
# Prepare expected layer keys. # Prepare expected layer keys.
expected_layer_keys: set[str] = set() expected_layer_keys: set[str] = set()
for k in flux_kohya_state_dict_keys: for k in sd_keys:
k = k.replace("lora_unet_", "") # Replace prefixes.
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
# Remove suffixes.
k = k.replace(".lora_up.weight", "") k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "") k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "") k = k.replace(".alpha", "")