Minor tidy of FLUX control LoRA implementation. (mostly documentation)

This commit is contained in:
Ryan Dick
2024-12-17 00:43:13 +00:00
committed by Kent Keirsey
parent 5fcd76a712
commit a4bed7aee3
5 changed files with 9 additions and 18 deletions

View File

@@ -36,10 +36,8 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
image: ImageField = InputField(description="The image to encode.")
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if not context.models.exists(self.lora.key):
raise ValueError(f"Unknown lora: {self.lora.key}!")
return FluxControlLoRALoaderOutput(
control_lora=ControlLoRAField(

View File

@@ -718,6 +718,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
if self.control_lora:
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
# applied last.
loras.append(self.control_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)

View File

@@ -82,9 +82,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(
tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0
),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],

View File

@@ -23,20 +23,14 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(
re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k)
for k in state_dict.keys()
)
return all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys())
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
key_props = key.split(".")
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
# Leaving this in since it doesn't hurt anything and may be better
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
layer_name = ".".join(key_props[:layer_prop_size])
param_name = ".".join(key_props[layer_prop_size:])
@@ -47,7 +41,6 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
# Convert to a full layer diff
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
layers[prefixed_key] = LoRALayer(
@@ -60,6 +53,6 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
else:
raise AssertionError(f"{layer_key} not expected")
# Create and return the LoRAModelRaw.
raise ValueError(f"{layer_key} not expected")
return LoRAModelRaw(layers=layers)