mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Fix frontend FLUX graph construction for FLUX control LoRAs.
This commit is contained in:
@@ -109,53 +109,42 @@ export const addT2IAdapters = async ({
|
||||
return result;
|
||||
};
|
||||
|
||||
type AddControlLoRAsArg = {
|
||||
type AddControlLoRAArg = {
|
||||
manager: CanvasManager;
|
||||
entities: CanvasControlLayerState[];
|
||||
g: Graph;
|
||||
rect: Rect;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
denoise: Invocation<'flux_denoise'>;
|
||||
};
|
||||
|
||||
type AddControlLoRAsResult = {
|
||||
addedControlLoRAs: number;
|
||||
};
|
||||
|
||||
export const addControlLoRAs = async ({
|
||||
manager,
|
||||
entities,
|
||||
g,
|
||||
rect,
|
||||
collector,
|
||||
model,
|
||||
}: AddControlLoRAsArg): Promise<AddControlLoRAsResult> => {
|
||||
export const addControlLoRA = async ({ manager, entities, g, rect, model, denoise }: AddControlLoRAArg) => {
|
||||
const validControlLayers = entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => entity.controlAdapter.type === 'control_lora')
|
||||
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddControlLoRAsResult = {
|
||||
addedControlLoRAs: 0,
|
||||
};
|
||||
|
||||
for (const layer of validControlLayers) {
|
||||
const getImageDTOResult = await withResultAsync(() => {
|
||||
const adapter = manager.adapters.controlLayers.get(layer.id);
|
||||
assert(adapter, 'Adapter not found');
|
||||
return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
|
||||
});
|
||||
if (getImageDTOResult.isErr()) {
|
||||
log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer');
|
||||
continue;
|
||||
}
|
||||
|
||||
const imageDTO = getImageDTOResult.value;
|
||||
addControlLoRAToGraph(g, layer, imageDTO, collector);
|
||||
result.addedControlLoRAs++;
|
||||
const validControlLayer = validControlLayers[0];
|
||||
if (validControlLayer === undefined) {
|
||||
// No valid control LoRA found
|
||||
return;
|
||||
}
|
||||
if (validControlLayers.length > 1) {
|
||||
throw new Error('Cannot add more than one FLUX control LoRA.');
|
||||
}
|
||||
|
||||
return result;
|
||||
const getImageDTOResult = await withResultAsync(() => {
|
||||
const adapter = manager.adapters.controlLayers.get(validControlLayer.id);
|
||||
assert(adapter, 'Adapter not found');
|
||||
return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' });
|
||||
});
|
||||
if (getImageDTOResult.isErr()) {
|
||||
log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer');
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTO = getImageDTOResult.value;
|
||||
addControlLoRAToGraph(g, validControlLayer, imageDTO, denoise);
|
||||
};
|
||||
|
||||
const addControlNetToGraph = (
|
||||
@@ -214,7 +203,7 @@ const addControlLoRAToGraph = (
|
||||
g: Graph,
|
||||
layer: CanvasControlLayerState,
|
||||
imageDTO: ImageDTO,
|
||||
collector: Invocation<'collect'>
|
||||
denoise: Invocation<'flux_denoise'>
|
||||
) => {
|
||||
const { id, controlAdapter } = layer;
|
||||
assert(controlAdapter.type === 'control_lora');
|
||||
@@ -229,5 +218,5 @@ const addControlLoRAToGraph = (
|
||||
image: { image_name },
|
||||
});
|
||||
|
||||
g.addEdge(controlLoRA, 'control_lora', collector, 'item');
|
||||
g.addEdge(controlLoRA, 'control_lora', denoise, 'control_lora');
|
||||
};
|
||||
|
||||
@@ -26,7 +26,7 @@ import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { addControlLoRAs, addControlNets } from './addControlAdapters';
|
||||
import { addControlLoRA, addControlNets } from './addControlAdapters';
|
||||
import { addIPAdapters } from './addIPAdapters';
|
||||
|
||||
const log = logger('system');
|
||||
@@ -213,23 +213,14 @@ export const buildFLUXGraph = async (
|
||||
g.deleteNode(controlNetCollector.id);
|
||||
}
|
||||
|
||||
const controlLoRACollector = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_lora_collector'),
|
||||
});
|
||||
const controlLoRAResult = await addControlLoRAs({
|
||||
await addControlLoRA({
|
||||
manager,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
rect: canvas.bbox.rect,
|
||||
collector: controlLoRACollector,
|
||||
denoise,
|
||||
model: modelConfig,
|
||||
});
|
||||
if (controlLoRAResult.addedControlLoRAs > 0) {
|
||||
g.addEdge(controlLoRACollector, 'collection', denoise, 'control_lora');
|
||||
} else {
|
||||
g.deleteNode(controlLoRACollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
|
||||
Reference in New Issue
Block a user