feat(ui): hook up sd1.5 t2i graph to regional prompts

This commit is contained in:
psychedelicious
2024-04-20 14:44:00 +10:00
parent 03d9a75720
commit 1e904d281a
2 changed files with 40 additions and 19 deletions

View File

@@ -24,7 +24,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
} }
const { dispatch } = getStore(); const { dispatch } = getStore();
// TODO: Handle non-SDXL // TODO: Handle non-SDXL
// const isSDXL = state.generation.model?.base === 'sdxl'; const isSDXL = state.generation.model?.base === 'sdxl';
const layers = state.regionalPrompts.present.layers const layers = state.regionalPrompts.present.layers
.filter(isRPLayer) // We only want the prompt region layers .filter(isRPLayer) // We only want the prompt region layers
.filter((l) => l.isVisible) // Only visible layers are rendered on the canvas .filter((l) => l.isVisible) // Only visible layers are rendered on the canvas
@@ -125,12 +125,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
if (layer.positivePrompt) { if (layer.positivePrompt) {
// The main positive conditioning node // The main positive conditioning node
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] = { const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
type: 'sdxl_compel_prompt', ? {
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`, type: 'sdxl_compel_prompt',
prompt: layer.positivePrompt, id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields? prompt: layer.positivePrompt,
}; style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode; graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
// Connect the mask to the conditioning // Connect the mask to the conditioning
@@ -158,12 +164,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
if (layer.negativePrompt) { if (layer.negativePrompt) {
// The main negative conditioning node // The main negative conditioning node
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] = { const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
type: 'sdxl_compel_prompt', ? {
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`, type: 'sdxl_compel_prompt',
prompt: layer.negativePrompt, id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
style: layer.negativePrompt, prompt: layer.negativePrompt,
}; style: layer.negativePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
};
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode; graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
// Connect the mask to the conditioning // Connect the mask to the conditioning
@@ -212,12 +224,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the // Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
// positive prompt // positive prompt
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] = { const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
type: 'sdxl_compel_prompt', ? {
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`, type: 'sdxl_compel_prompt',
prompt: layer.positivePrompt, id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
style: layer.positivePrompt, prompt: layer.positivePrompt,
}; style: layer.positivePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode; graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
// Connect the inverted mask to the conditioning // Connect the inverted mask to the conditioning
graph.edges.push({ graph.edges.push({

View File

@@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
@@ -255,6 +256,8 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addRegionalPromptsToGraph(state, graph, DENOISE_LATENTS);
// High resolution fix. // High resolution fix.
if (state.hrf.hrfEnabled) { if (state.hrf.hrfEnabled) {
addHrfToGraph(state, graph); addHrfToGraph(state, graph);