mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Minor tidy of FLUX control LoRA implementation. (mostly documentation)
This commit is contained in:
@@ -36,10 +36,8 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
if not context.models.exists(self.lora.key):
|
||||
raise ValueError(f"Unknown lora: {self.lora.key}!")
|
||||
|
||||
return FluxControlLoRALoaderOutput(
|
||||
control_lora=ControlLoRAField(
|
||||
|
||||
@@ -718,6 +718,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
|
||||
if self.control_lora:
|
||||
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
|
||||
# applied last.
|
||||
loras.append(self.control_lora)
|
||||
for lora in loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
|
||||
@@ -82,9 +82,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(
|
||||
tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0
|
||||
),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
|
||||
@@ -23,20 +23,14 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(
|
||||
re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
)
|
||||
return all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys())
|
||||
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
key_props = key.split(".")
|
||||
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
|
||||
# Leaving this in since it doesn't hurt anything and may be better
|
||||
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
|
||||
layer_name = ".".join(key_props[:layer_prop_size])
|
||||
param_name = ".".join(key_props[layer_prop_size:])
|
||||
@@ -47,7 +41,6 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
# Convert to a full layer diff
|
||||
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
|
||||
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
|
||||
layers[prefixed_key] = LoRALayer(
|
||||
@@ -60,6 +53,6 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
else:
|
||||
raise AssertionError(f"{layer_key} not expected")
|
||||
# Create and return the LoRAModelRaw.
|
||||
raise ValueError(f"{layer_key} not expected")
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
Reference in New Issue
Block a user