Add support for FLUX T5 text encoder LoRA models to invocations.

This commit is contained in:
Ryan Dick
2024-09-26 21:28:25 +00:00
committed by Kent Keirsey
parent d332d81866
commit 249da858df
2 changed files with 57 additions and 14 deletions

View File

@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@@ -20,6 +20,9 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
t5_encoder: Optional[T5EncoderField] = OutputField(
default=None, description=FieldDescriptions.t5_encoder, title="T5Encoder"
)
@invocation(
@@ -27,21 +30,28 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
"""Apply a LoRA model to a FLUX transformer and/or T5 encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
@@ -49,18 +59,33 @@ class FluxLoRALoaderInvocation(BaseInvocation):
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
# Check for existing LoRAs with the same key.
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
output = FluxLoRALoaderOutput()
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.t5_encoder is not None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
)
return FluxLoRALoaderOutput(transformer=transformer)
return output
@invocation(
@@ -68,7 +93,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
@@ -84,6 +109,18 @@ class FLUXLoRACollectionLoader(BaseInvocation):
input=Input.Connection,
title="Transformer",
)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput()
@@ -106,4 +143,9 @@ class FLUXLoRACollectionLoader(BaseInvocation):
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(lora)
if self.t5_encoder is not None:
if output.t5_encoder is None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(lora)
return output

View File

@@ -75,6 +75,7 @@ class TransformerField(BaseModel):
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VAEField(BaseModel):
@@ -205,7 +206,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)