mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-04 03:15:06 -05:00
feat(ui): add FLUX LoRAs to linear UI graph
This commit is contained in:
committed by
Mary Hipp Rogers
parent
d4a7e48109
commit
d651dfe138
@@ -0,0 +1,61 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
|
||||
export const addFLUXLoRAs = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'flux_denoise'>,
|
||||
modelLoader: Invocation<'flux_model_loader'>
|
||||
): void => {
|
||||
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux');
|
||||
const loraCount = enabledLoRAs.length;
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const loraMetadata: S['LoRAMetadataField'][] = [];
|
||||
|
||||
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
|
||||
// each LoRA to the UNet and CLIP.
|
||||
const loraCollector = g.addNode({
|
||||
id: getPrefixedId('lora_collector'),
|
||||
type: 'collect',
|
||||
});
|
||||
const loraCollectionLoader = g.addNode({
|
||||
type: 'flux_lora_collection_loader',
|
||||
id: getPrefixedId('flux_lora_collection_loader'),
|
||||
});
|
||||
|
||||
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
|
||||
// Use model loader as transformer input
|
||||
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
|
||||
// Reroute transformer connections through the LoRA collection loader
|
||||
g.deleteEdgesTo(denoise, ['transformer']);
|
||||
|
||||
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
|
||||
|
||||
for (const lora of enabledLoRAs) {
|
||||
const { weight } = lora;
|
||||
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||
|
||||
const loraSelector = g.addNode({
|
||||
type: 'lora_selector',
|
||||
id: getPrefixedId('lora_selector'),
|
||||
lora: parsedModel,
|
||||
weight,
|
||||
});
|
||||
|
||||
loraMetadata.push({
|
||||
model: parsedModel,
|
||||
weight,
|
||||
});
|
||||
|
||||
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
|
||||
}
|
||||
|
||||
g.upsertMetadata({ loras: loraMetadata });
|
||||
};
|
||||
@@ -18,6 +18,8 @@ import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { addFLUXLoRAs } from './addFLUXLoRAs';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFLUXGraph = async (
|
||||
@@ -84,6 +86,8 @@ export const buildFLUXGraph = async (
|
||||
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
|
||||
addFLUXLoRAs(state, g, noise, modelLoader);
|
||||
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user