From 8b969053e78ff64892a071759656fa2fc4893afa Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 27 Sep 2023 04:14:32 +0530 Subject: [PATCH] fix: SDXL Refiner using the incorrect node during inpainting --- .../graphBuilders/addSDXLRefinerToGraph.ts | 52 ++++++++++++++----- .../buildCanvasSDXLInpaintGraph.ts | 3 +- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts index 6bd44db197..a6ee6a091d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts @@ -32,7 +32,8 @@ export const addSDXLRefinerToGraph = ( graph: NonNullableGraph, baseNodeId: string, modelLoaderNodeId?: string, - canvasInitImage?: ImageDTO + canvasInitImage?: ImageDTO, + canvasMaskImage?: ImageDTO ): void => { const { refinerModel, @@ -257,8 +258,30 @@ export const addSDXLRefinerToGraph = ( }; } - graph.edges.push( - { + if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) { + if (isUsingScaledDimensions) { + graph.edges.push({ + source: { + node_id: MASK_RESIZE_UP, + field: 'image', + }, + destination: { + node_id: SDXL_REFINER_INPAINT_CREATE_MASK, + field: 'mask', + }, + }); + } else { + graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = { + ...(graph.nodes[ + SDXL_REFINER_INPAINT_CREATE_MASK + ] as CreateDenoiseMaskInvocation), + mask: canvasMaskImage, + }; + } + } + + if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) { + graph.edges.push({ source: { node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE, field: 'image', @@ -267,18 +290,19 @@ export const addSDXLRefinerToGraph = ( node_id: SDXL_REFINER_INPAINT_CREATE_MASK, field: 'mask', }, + }); + } + + graph.edges.push({ + source: { + node_id: SDXL_REFINER_INPAINT_CREATE_MASK, + field: 'denoise_mask', }, - { - source: { - node_id: SDXL_REFINER_INPAINT_CREATE_MASK, - field: 'denoise_mask', - }, - destination: { - node_id: SDXL_REFINER_DENOISE_LATENTS, - field: 'denoise_mask', - }, - } - ); + destination: { + node_id: SDXL_REFINER_DENOISE_LATENTS, + field: 'denoise_mask', + }, + }); } if ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts index 389d510ac7..a245953c8e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts @@ -663,7 +663,8 @@ export const buildCanvasSDXLInpaintGraph = ( graph, CANVAS_COHERENCE_DENOISE_LATENTS, modelLoaderNodeId, - canvasInitImage + canvasInitImage, + canvasMaskImage ); if (seamlessXAxis || seamlessYAxis) { modelLoaderNodeId = SDXL_REFINER_SEAMLESS;