mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 20:58:04 -05:00
Add support for FLUX T5 text encoder LoRA models to invocations.
This commit is contained in:
@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
|
||||
@@ -20,6 +20,9 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
t5_encoder: Optional[T5EncoderField] = OutputField(
|
||||
default=None, description=FieldDescriptions.t5_encoder, title="T5Encoder"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -27,21 +30,28 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer."""
|
||||
"""Apply a LoRA model to a FLUX transformer and/or T5 encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField = InputField(
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
@@ -49,18 +59,33 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')
|
||||
|
||||
transformer = self.transformer.model_copy(deep=True)
|
||||
transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
output = FluxLoRALoaderOutput()
|
||||
|
||||
# Attach LoRA layers to the models.
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
if self.t5_encoder is not None:
|
||||
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
|
||||
output.t5_encoder.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return FluxLoRALoaderOutput(transformer=transformer)
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -68,7 +93,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
title="FLUX LoRA Collection Loader",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
@@ -84,6 +109,18 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
output = FluxLoRALoaderOutput()
|
||||
@@ -106,4 +143,9 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(lora)
|
||||
|
||||
if self.t5_encoder is not None:
|
||||
if output.t5_encoder is None:
|
||||
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
|
||||
output.t5_encoder.loras.append(lora)
|
||||
|
||||
return output
|
||||
|
||||
@@ -75,6 +75,7 @@ class TransformerField(BaseModel):
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
@@ -205,7 +206,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user