feat(ui): add FLUX LoRAs to linear UI graph

This commit is contained in:
Mary Hipp
2024-09-19 09:27:50 -04:00
committed by Mary Hipp Rogers
parent d4a7e48109
commit d651dfe138
3 changed files with 188 additions and 7 deletions

View File

@@ -0,0 +1,61 @@
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation, S } from 'services/api/types';
export const addFLUXLoRAs = (
state: RootState,
g: Graph,
denoise: Invocation<'flux_denoise'>,
modelLoader: Invocation<'flux_model_loader'>
): void => {
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux');
const loraCount = enabledLoRAs.length;
if (loraCount === 0) {
return;
}
const loraMetadata: S['LoRAMetadataField'][] = [];
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
// each LoRA to the UNet and CLIP.
const loraCollector = g.addNode({
id: getPrefixedId('lora_collector'),
type: 'collect',
});
const loraCollectionLoader = g.addNode({
type: 'flux_lora_collection_loader',
id: getPrefixedId('flux_lora_collection_loader'),
});
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
// Use model loader as transformer input
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
// Reroute transformer connections through the LoRA collection loader
g.deleteEdgesTo(denoise, ['transformer']);
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
for (const lora of enabledLoRAs) {
const { weight } = lora;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraSelector = g.addNode({
type: 'lora_selector',
id: getPrefixedId('lora_selector'),
lora: parsedModel,
weight,
});
loraMetadata.push({
model: parsedModel,
weight,
});
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
}
g.upsertMetadata({ loras: loraMetadata });
};

View File

@@ -18,6 +18,8 @@ import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addFLUXLoRAs } from './addFLUXLoRAs';
const log = logger('system');
export const buildFLUXGraph = async (
@@ -84,6 +86,8 @@ export const buildFLUXGraph = async (
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
addFLUXLoRAs(state, g, noise, modelLoader);
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');

File diff suppressed because one or more lines are too long