Fix frontend FLUX graph construction for FLUX control LoRAs.

This commit is contained in:
Ryan Dick
2024-12-16 19:57:28 +00:00
committed by Kent Keirsey
parent 516ffa641c
commit 5fcd76a712
2 changed files with 27 additions and 47 deletions

View File

@@ -109,53 +109,42 @@ export const addT2IAdapters = async ({
return result;
};
type AddControlLoRAsArg = {
type AddControlLoRAArg = {
manager: CanvasManager;
entities: CanvasControlLayerState[];
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
denoise: Invocation<'flux_denoise'>;
};
type AddControlLoRAsResult = {
addedControlLoRAs: number;
};
export const addControlLoRAs = async ({
manager,
entities,
g,
rect,
collector,
model,
}: AddControlLoRAsArg): Promise<AddControlLoRAsResult> => {
export const addControlLoRA = async ({ manager, entities, g, rect, model, denoise }: AddControlLoRAArg) => {
const validControlLayers = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => entity.controlAdapter.type === 'control_lora')
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
const result: AddControlLoRAsResult = {
addedControlLoRAs: 0,
};
for (const layer of validControlLayers) {
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
});
if (getImageDTOResult.isErr()) {
log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer');
continue;
}
const imageDTO = getImageDTOResult.value;
addControlLoRAToGraph(g, layer, imageDTO, collector);
result.addedControlLoRAs++;
const validControlLayer = validControlLayers[0];
if (validControlLayer === undefined) {
// No valid control LoRA found
return;
}
if (validControlLayers.length > 1) {
throw new Error('Cannot add more than one FLUX control LoRA.');
}
return result;
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(validControlLayer.id);
assert(adapter, 'Adapter not found');
return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
});
if (getImageDTOResult.isErr()) {
log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer');
return;
}
const imageDTO = getImageDTOResult.value;
addControlLoRAToGraph(g, validControlLayer, imageDTO, denoise);
};
const addControlNetToGraph = (
@@ -214,7 +203,7 @@ const addControlLoRAToGraph = (
g: Graph,
layer: CanvasControlLayerState,
imageDTO: ImageDTO,
collector: Invocation<'collect'>
denoise: Invocation<'flux_denoise'>
) => {
const { id, controlAdapter } = layer;
assert(controlAdapter.type === 'control_lora');
@@ -229,5 +218,5 @@ const addControlLoRAToGraph = (
image: { image_name },
});
g.addEdge(controlLoRA, 'control_lora', collector, 'item');
g.addEdge(controlLoRA, 'control_lora', denoise, 'control_lora');
};

View File

@@ -26,7 +26,7 @@ import { isNonRefinerMainModelConfig } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { addControlLoRAs, addControlNets } from './addControlAdapters';
import { addControlLoRA, addControlNets } from './addControlAdapters';
import { addIPAdapters } from './addIPAdapters';
const log = logger('system');
@@ -213,23 +213,14 @@ export const buildFLUXGraph = async (
g.deleteNode(controlNetCollector.id);
}
const controlLoRACollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_lora_collector'),
});
const controlLoRAResult = await addControlLoRAs({
await addControlLoRA({
manager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
collector: controlLoRACollector,
denoise,
model: modelConfig,
});
if (controlLoRAResult.addedControlLoRAs > 0) {
g.addEdge(controlLoRACollector, 'collection', denoise, 'control_lora');
} else {
g.deleteNode(controlLoRACollector.id);
}
const ipAdapterCollect = g.addNode({
type: 'collect',