From 249da858df4fb480f510dc1053129008a1cfca9f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 26 Sep 2024 21:28:25 +0000 Subject: [PATCH] Add support for FLUX T5 text encoder LoRA models to invocations. --- invokeai/app/invocations/flux_lora_loader.py | 68 ++++++++++++++++---- invokeai/app/invocations/model.py | 3 +- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/invokeai/app/invocations/flux_lora_loader.py b/invokeai/app/invocations/flux_lora_loader.py index 46f593ea9f..e82556c74e 100644 --- a/invokeai/app/invocations/flux_lora_loader.py +++ b/invokeai/app/invocations/flux_lora_loader.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation_output, ) from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType -from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField +from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, T5EncoderField, TransformerField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import BaseModelType @@ -20,6 +20,9 @@ class FluxLoRALoaderOutput(BaseInvocationOutput): transformer: Optional[TransformerField] = OutputField( default=None, description=FieldDescriptions.transformer, title="FLUX Transformer" ) + t5_encoder: Optional[T5EncoderField] = OutputField( + default=None, description=FieldDescriptions.t5_encoder, title="T5Encoder" + ) @invocation( @@ -27,21 +30,28 @@ class FluxLoRALoaderOutput(BaseInvocationOutput): title="FLUX LoRA", tags=["lora", "model", "flux"], category="model", - version="1.0.0", + version="1.1.0", classification=Classification.Prototype, ) class FluxLoRALoaderInvocation(BaseInvocation): - """Apply a LoRA model to a FLUX transformer.""" + """Apply a LoRA model to a FLUX transformer and/or T5 encoder.""" lora: ModelIdentifierField = InputField( description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) - transformer: TransformerField = InputField( + transformer: TransformerField | None = InputField( + default=None, description=FieldDescriptions.transformer, input=Input.Connection, title="FLUX Transformer", ) + t5_encoder: T5EncoderField | None = InputField( + default=None, + title="T5Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput: lora_key = self.lora.key @@ -49,18 +59,33 @@ class FluxLoRALoaderInvocation(BaseInvocation): if not context.models.exists(lora_key): raise ValueError(f"Unknown lora: {lora_key}!") - if any(lora.lora.key == lora_key for lora in self.transformer.loras): + # Check for existing LoRAs with the same key. + if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras): raise ValueError(f'LoRA "{lora_key}" already applied to transformer.') + if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras): + raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.') - transformer = self.transformer.model_copy(deep=True) - transformer.loras.append( - LoRAField( - lora=self.lora, - weight=self.weight, + output = FluxLoRALoaderOutput() + + # Attach LoRA layers to the models. + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + output.transformer.loras.append( + LoRAField( + lora=self.lora, + weight=self.weight, + ) + ) + if self.t5_encoder is not None: + output.t5_encoder = self.t5_encoder.model_copy(deep=True) + output.t5_encoder.loras.append( + LoRAField( + lora=self.lora, + weight=self.weight, + ) ) - ) - return FluxLoRALoaderOutput(transformer=transformer) + return output @invocation( @@ -68,7 +93,7 @@ class FluxLoRALoaderInvocation(BaseInvocation): title="FLUX LoRA Collection Loader", tags=["lora", "model", "flux"], category="model", - version="1.0.0", + version="1.1.0", classification=Classification.Prototype, ) class FLUXLoRACollectionLoader(BaseInvocation): @@ -84,6 +109,18 @@ class FLUXLoRACollectionLoader(BaseInvocation): input=Input.Connection, title="Transformer", ) + transformer: TransformerField | None = InputField( + default=None, + description=FieldDescriptions.transformer, + input=Input.Connection, + title="FLUX Transformer", + ) + t5_encoder: T5EncoderField | None = InputField( + default=None, + title="T5Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput: output = FluxLoRALoaderOutput() @@ -106,4 +143,9 @@ class FLUXLoRACollectionLoader(BaseInvocation): output.transformer = self.transformer.model_copy(deep=True) output.transformer.loras.append(lora) + if self.t5_encoder is not None: + if output.t5_encoder is None: + output.t5_encoder = self.t5_encoder.model_copy(deep=True) + output.t5_encoder.loras.append(lora) + return output diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c0d0a4a7f7..0b87a5cd34 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -75,6 +75,7 @@ class TransformerField(BaseModel): class T5EncoderField(BaseModel): tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") + loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") class VAEField(BaseModel): @@ -205,7 +206,7 @@ class FluxModelLoaderInvocation(BaseInvocation): return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer, loras=[]), clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), - t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder), + t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]), vae=VAEField(vae=vae), max_seq_len=max_seq_lengths[transformer_config.config_path], )