From c276b60af965a0d4e53b34f1ab5fb5cd3bcd3468 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 28 Nov 2024 14:26:24 +1000 Subject: [PATCH] tidy(ui): use object for addRegions graph builder util arg --- .../nodes/util/graph/generation/addRegions.ts | 39 +++++++++++++------ .../util/graph/generation/buildFLUXGraph.ts | 24 ++++++------ .../util/graph/generation/buildSD1Graph.ts | 20 +++++----- .../util/graph/generation/buildSDXLGraph.ts | 20 +++++----- 4 files changed, 59 insertions(+), 44 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index cdeb30a6f6..1c058c5f4c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -30,10 +30,25 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => return isEnabled && (hasTextPrompt || hasIPAdapter); }; +type AddRegionsArg = { + manager: CanvasManager; + regions: CanvasRegionalGuidanceState[]; + g: Graph; + bbox: Rect; + base: BaseModelType; + posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null; + posCondCollect: Invocation<'collect'>; + negCondCollect: Invocation<'collect'> | null; + ipAdapterCollect: Invocation<'collect'>; +}; + /** * Adds regional guidance to the graph + * @param manager The canvas manager * @param regions Array of regions to add * @param g The graph to add the layers to + * @param bbox The bounding box * @param base The base model type * @param posCond The positive conditioning node * @param negCond The negative conditioning node @@ -43,18 +58,18 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => * @returns A promise that resolves to the regions that were successfully added to the graph */ -export const addRegions = async ( - manager: CanvasManager, - regions: CanvasRegionalGuidanceState[], - g: Graph, - bbox: Rect, - base: BaseModelType, - posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>, - negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null, - posCondCollect: Invocation<'collect'>, - negCondCollect: Invocation<'collect'> | null, - ipAdapterCollect: Invocation<'collect'> -): Promise => { +export const addRegions = async ({ + manager, + regions, + g, + bbox, + base, + posCond, + negCond, + posCondCollect, + negCondCollect, + ipAdapterCollect, +}: AddRegionsArg): Promise => { const isSDXL = base === 'sdxl'; const isFLUX = base === 'flux'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index b1e5c60c2a..4b3cad0774 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -213,31 +213,31 @@ export const buildFLUXGraph = async ( g.deleteNode(controlNetCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); + const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); - const regionsResult = await addRegions( + const regionsResult = await addRegions({ manager, - canvas.regionalGuidance.entities, + regions: canvas.regionalGuidance.entities, g, - canvas.bbox.rect, - modelConfig.base, + bbox: canvas.bbox.rect, + base: modelConfig.base, posCond, - null, + negCond: null, posCondCollect, - null, - ipAdapterCollector - ); + negCondCollect: null, + ipAdapterCollect, + }); const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 69e145cbe2..ab38035b4a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -259,31 +259,31 @@ export const buildSD1Graph = async ( g.deleteNode(t2iAdapterCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); + const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); - const regionsResult = await addRegions( + const regionsResult = await addRegions({ manager, - canvas.regionalGuidance.entities, + regions: canvas.regionalGuidance.entities, g, - canvas.bbox.rect, - modelConfig.base, + bbox: canvas.bbox.rect, + base: modelConfig.base, posCond, negCond, posCondCollect, negCondCollect, - ipAdapterCollector - ); + ipAdapterCollect, + }); const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 8d7fb67c10..4d84c025ec 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -264,31 +264,31 @@ export const buildSDXLGraph = async ( g.deleteNode(t2iAdapterCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); + const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); - const regionsResult = await addRegions( + const regionsResult = await addRegions({ manager, - canvas.regionalGuidance.entities, + regions: canvas.regionalGuidance.entities, g, - canvas.bbox.rect, - modelConfig.base, + bbox: canvas.bbox.rect, + base: modelConfig.base, posCond, negCond, posCondCollect, negCondCollect, - ipAdapterCollector - ); + ipAdapterCollect, + }); const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) {