From 1e904d281a42cfd929bd2d16955e5986ed173c22 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 20 Apr 2024 14:44:00 +1000 Subject: [PATCH] feat(ui): hook up sd1.5 t2i graph to regional prompts --- .../util/graph/addRegionalPromptsToGraph.ts | 56 ++++++++++++------- .../util/graph/buildLinearTextToImageGraph.ts | 3 + 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts index b7903d9d01..cb96923a99 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts @@ -24,7 +24,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull } const { dispatch } = getStore(); // TODO: Handle non-SDXL - // const isSDXL = state.generation.model?.base === 'sdxl'; + const isSDXL = state.generation.model?.base === 'sdxl'; const layers = state.regionalPrompts.present.layers .filter(isRPLayer) // We only want the prompt region layers .filter((l) => l.isVisible) // Only visible layers are rendered on the canvas @@ -125,12 +125,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull if (layer.positivePrompt) { // The main positive conditioning node - const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] = { - type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`, - prompt: layer.positivePrompt, - style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields? - }; + const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL + ? { + type: 'sdxl_compel_prompt', + id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`, + prompt: layer.positivePrompt, + style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields? + } + : { + type: 'compel', + id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`, + prompt: layer.positivePrompt, + }; graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode; // Connect the mask to the conditioning @@ -158,12 +164,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull if (layer.negativePrompt) { // The main negative conditioning node - const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] = { - type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`, - prompt: layer.negativePrompt, - style: layer.negativePrompt, - }; + const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL + ? { + type: 'sdxl_compel_prompt', + id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`, + prompt: layer.negativePrompt, + style: layer.negativePrompt, + } + : { + type: 'compel', + id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`, + prompt: layer.negativePrompt, + }; graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode; // Connect the mask to the conditioning @@ -212,12 +224,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull // Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the // positive prompt - const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] = { - type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`, - prompt: layer.positivePrompt, - style: layer.positivePrompt, - }; + const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL + ? { + type: 'sdxl_compel_prompt', + id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`, + prompt: layer.positivePrompt, + style: layer.positivePrompt, + } + : { + type: 'compel', + id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`, + prompt: layer.positivePrompt, + }; graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode; // Connect the inverted mask to the conditioning graph.edges.push({ diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts index aac1270e0d..90101add6d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; @@ -255,6 +256,8 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise