mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Update frontend graph building logic to support FLUX LoRAs that modify the T5 encoder weights.
This commit is contained in:
@@ -83,22 +83,33 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
# Apply LoRA models to the T5 encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if t5_encoder_config.format == ModelFormat.T5Encoder:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=t5_text_encoder,
|
||||
patches=self._t5_lora_iterator(context),
|
||||
prefix=FLUX_LORA_T5_PREFIX,
|
||||
dtype=t5_text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
# Determine if the model is quantized.
|
||||
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
|
||||
# slower inference than direct patching, but is agnostic to the quantization format.
|
||||
if t5_encoder_config.format in [ModelFormat.T5Encoder, ModelFormat.Diffusers]:
|
||||
model_is_quantized = False
|
||||
elif t5_encoder_config.format in [
|
||||
ModelFormat.BnbQuantizedLlmInt8b,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
model_is_quantized = True
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {t5_encoder_config.format}")
|
||||
|
||||
# Apply LoRA models to the T5 encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=t5_text_encoder,
|
||||
patches=self._t5_lora_iterator(context),
|
||||
prefix=FLUX_LORA_T5_PREFIX,
|
||||
dtype=t5_text_encoder.dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
|
||||
@@ -94,7 +94,7 @@ class DoRALayer(LoRALayerBase):
|
||||
# If any of the original parameters are on the 'meta' device, we assume this is because the base model is in
|
||||
# a quantization format that doesn't allow easy dequantization.
|
||||
raise RuntimeError(
|
||||
"The base model quantization format (likely bitsandbytes) is not supported for DoRA patches."
|
||||
"The base model quantization format (likely bitsandbytes) is not compatible with DoRA patches."
|
||||
)
|
||||
|
||||
scale = self.scale()
|
||||
|
||||
@@ -35,11 +35,13 @@ export const addFLUXLoRAs = (
|
||||
// Use model loader as transformer input
|
||||
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
|
||||
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
|
||||
g.addEdge(modelLoader, 't5_encoder', loraCollectionLoader, 't5_encoder');
|
||||
// Reroute model connections through the LoRA collection loader
|
||||
g.deleteEdgesTo(denoise, ['transformer']);
|
||||
g.deleteEdgesTo(fluxTextEncoder, ['clip']);
|
||||
g.deleteEdgesTo(fluxTextEncoder, ['clip', 't5_encoder']);
|
||||
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
|
||||
g.addEdge(loraCollectionLoader, 'clip', fluxTextEncoder, 'clip');
|
||||
g.addEdge(loraCollectionLoader, 't5_encoder', fluxTextEncoder, 't5_encoder');
|
||||
|
||||
for (const lora of enabledLoRAs) {
|
||||
const { weight } = lora;
|
||||
|
||||
@@ -6194,6 +6194,12 @@ export type components = {
|
||||
* @default null
|
||||
*/
|
||||
clip?: components["schemas"]["CLIPField"] | null;
|
||||
/**
|
||||
* T5 Encoder
|
||||
* @description T5 tokenizer and text encoder
|
||||
* @default null
|
||||
*/
|
||||
t5_encoder?: components["schemas"]["T5EncoderField"] | null;
|
||||
/**
|
||||
* type
|
||||
* @default flux_lora_collection_loader
|
||||
@@ -7336,6 +7342,12 @@ export type components = {
|
||||
* @default null
|
||||
*/
|
||||
clip?: components["schemas"]["CLIPField"] | null;
|
||||
/**
|
||||
* T5 Encoder
|
||||
* @description T5 tokenizer and text encoder
|
||||
* @default null
|
||||
*/
|
||||
t5_encoder?: components["schemas"]["T5EncoderField"] | null;
|
||||
/**
|
||||
* type
|
||||
* @default flux_lora_loader
|
||||
@@ -7361,6 +7373,12 @@ export type components = {
|
||||
* @default null
|
||||
*/
|
||||
clip: components["schemas"]["CLIPField"] | null;
|
||||
/**
|
||||
* T5 Encoder
|
||||
* @description T5 tokenizer and text encoder
|
||||
* @default null
|
||||
*/
|
||||
t5_encoder: components["schemas"]["T5EncoderField"] | null;
|
||||
/**
|
||||
* type
|
||||
* @default flux_lora_loader_output
|
||||
@@ -18345,6 +18363,11 @@ export type components = {
|
||||
tokenizer: components["schemas"]["ModelIdentifierField"];
|
||||
/** @description Info to load text_encoder submodel */
|
||||
text_encoder: components["schemas"]["ModelIdentifierField"];
|
||||
/**
|
||||
* Loras
|
||||
* @description LoRAs to apply on model loading
|
||||
*/
|
||||
loras: components["schemas"]["LoRAField"][];
|
||||
};
|
||||
/** TBLR */
|
||||
TBLR: {
|
||||
|
||||
Reference in New Issue
Block a user