Compare commits

...

1 Commits

Author SHA1 Message Date
Mary Hipp
3bf8dc7043 POC to call workflows from generate tab 2025-09-17 15:31:58 -04:00
5 changed files with 202 additions and 14 deletions

View File

@@ -2,7 +2,7 @@ import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import { range } from 'es-toolkit/compat';
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Graph, GraphUIContract, GraphUIInput } from 'features/nodes/util/graph/generation/Graph';
import { API_BASE_MODELS, VIDEO_BASE_MODELS } from 'features/parameters/types/constants';
import type { components } from 'services/api/schema';
import type { BaseModelType, Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
@@ -37,23 +37,54 @@ export const prepareLinearUIBatch = (arg: {
const { iterations, shouldRandomizeSeed, seed } = state.params;
const { prompts, seedBehaviour } = state.dynamicPrompts;
const graphWithContract = g.getGraph();
const uiContract = graphWithContract.ui;
const data: Batch['data'] = [];
const firstBatchDatumList: components['schemas']['BatchDatum'][] = [];
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
const resolvePromptInput = (contract?: GraphUIContract): { nodePath: string; fieldName: string } => {
if (!contract) {
return { nodePath: positivePromptNode.id, fieldName: 'value' };
}
const preferredId = contract.primary_input;
const promptInput = preferredId
? contract.inputs.find((i) => i.id === preferredId)
: contract.inputs.find((i) => i.kind === 'string');
if (promptInput) {
return { nodePath: promptInput.node_id, fieldName: promptInput.field };
}
return { nodePath: positivePromptNode.id, fieldName: 'value' };
};
const resolveSeedInput = (contract?: GraphUIContract): GraphUIInput | undefined => {
if (!contract) {
return undefined;
}
const preferredId = contract.inputs.find((i) => i.kind === 'seed');
return preferredId;
};
const promptTarget = resolvePromptInput(uiContract);
const seedTarget = resolveSeedInput(uiContract);
const seedNodePath = seedTarget?.node_id ?? seedNode?.id;
const seedFieldName = seedTarget?.field ?? 'value';
// add seeds first to ensure the output order groups the prompts
if (seedNode && seedBehaviour === 'PER_PROMPT') {
if (seedNodePath && seedBehaviour === 'PER_PROMPT') {
const seeds = generateSeeds({
count: prompts.length * iterations,
start: shouldRandomizeSeed ? undefined : seed,
});
firstBatchDatumList.push({
node_path: seedNode.id,
field_name: 'value',
node_path: seedNodePath,
field_name: seedFieldName,
items: seeds,
});
} else if (seedNode && seedBehaviour === 'PER_ITERATION') {
} else if (seedNodePath && seedBehaviour === 'PER_ITERATION') {
// seedBehaviour = SeedBehaviour.PerRun
const seeds = generateSeeds({
count: iterations,
@@ -61,8 +92,8 @@ export const prepareLinearUIBatch = (arg: {
});
secondBatchDatumList.push({
node_path: seedNode.id,
field_name: 'value',
node_path: seedNodePath,
field_name: seedFieldName,
items: seeds,
});
data.push(secondBatchDatumList);
@@ -72,8 +103,8 @@ export const prepareLinearUIBatch = (arg: {
// zipped batch of prompts
firstBatchDatumList.push({
node_path: positivePromptNode.id,
field_name: 'value',
node_path: promptTarget.nodePath,
field_name: promptTarget.fieldName,
items: extendedPrompts,
});
@@ -82,7 +113,7 @@ export const prepareLinearUIBatch = (arg: {
const enqueueBatchArg: EnqueueBatchArg = {
prepend,
batch: {
graph: g.getGraph(),
graph: graphWithContract,
runs: 1,
data,
origin,

View File

@@ -30,7 +30,40 @@ type Edge = {
};
};
export type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
export type GraphUIInputKind = 'string' | 'seed' | 'image' | 'number';
export type GraphUIInput = {
id: string;
node_id: string;
field: string;
kind: GraphUIInputKind;
batchable?: boolean;
ui?: Record<string, unknown>;
};
export type GraphUIOutputKind = 'image' | 'latents';
export type GraphUIOutput = {
id: string;
node_id: string;
field: string;
kind: GraphUIOutputKind;
};
export type GraphUIContract = {
version: '0.1';
inputs: GraphUIInput[];
outputs: GraphUIOutput[];
primary_input?: string;
primary_output?: string;
};
export type GraphType = {
id: string;
nodes: Record<string, AnyInvocation>;
edges: Edge[];
ui?: GraphUIContract;
};
export class Graph {
_graph: GraphType;
@@ -365,6 +398,10 @@ export class Graph {
getGraphSafe(): GraphType {
return this._graph;
}
setUIContract(contract: GraphUIContract): void {
this._graph.ui = contract;
}
//#endregion
//#region Metadata

View File

@@ -6,7 +6,7 @@ import { selectCanvasMetadata } from 'features/controlLayers/store/selectors';
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { Graph, type GraphUIContract } from 'features/nodes/util/graph/generation/Graph';
import {
getOriginalAndScaledSizesForOtherModes,
getOriginalAndScaledSizesForTextToImage,
@@ -86,6 +86,32 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
width: originalSize.width,
height: originalSize.height,
});
const uiContract: GraphUIContract = {
version: '0.1',
primary_input: 'prompt',
primary_output: 'image',
inputs: [
{
id: 'prompt',
node_id: positivePrompt.id,
field: 'value',
kind: 'string',
batchable: true,
ui: { component: 'textarea' },
},
],
outputs: [
{
id: 'image',
node_id: gptImage.id,
field: 'image',
kind: 'image',
},
],
};
g.setUIContract(uiContract);
return {
g,
positivePrompt,
@@ -135,6 +161,32 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
g.setMetadataReceivingNode(gptImage);
const uiContract: GraphUIContract = {
version: '0.1',
primary_input: 'prompt',
primary_output: 'image',
inputs: [
{
id: 'prompt',
node_id: positivePrompt.id,
field: 'value',
kind: 'string',
batchable: false,
ui: { component: 'textarea' },
},
],
outputs: [
{
id: 'image',
node_id: gptImage.id,
field: 'image',
kind: 'image',
},
],
};
g.setUIContract(uiContract);
return {
g,
positivePrompt,

View File

@@ -16,7 +16,7 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { Graph, type GraphUIContract } from 'features/nodes/util/graph/generation/Graph';
import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
@@ -357,6 +357,40 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
g.setMetadataReceivingNode(canvasOutput);
const uiContract: GraphUIContract = {
version: '0.1',
primary_input: 'prompt',
primary_output: 'image',
inputs: [
{
id: 'prompt',
node_id: positivePrompt.id,
field: 'value',
kind: 'string',
batchable: true,
ui: { component: 'textarea' },
},
{
id: 'seed',
node_id: seed.id,
field: 'value',
kind: 'seed',
batchable: true,
ui: { component: 'seed' },
},
],
outputs: [
{
id: 'image',
node_id: canvasOutput.id,
field: 'image',
kind: 'image',
},
],
};
g.setUIContract(uiContract);
return {
g,
seed,

View File

@@ -14,7 +14,7 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { Graph, type GraphUIContract } from 'features/nodes/util/graph/generation/Graph';
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
@@ -321,6 +321,40 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
g.setMetadataReceivingNode(canvasOutput);
const uiContract: GraphUIContract = {
version: '0.1',
primary_input: 'prompt',
primary_output: 'image',
inputs: [
{
id: 'prompt',
node_id: positivePrompt.id,
field: 'value',
kind: 'string',
batchable: true,
ui: { component: 'textarea' },
},
{
id: 'seed',
node_id: seed.id,
field: 'value',
kind: 'seed',
batchable: true,
ui: { component: 'seed' },
},
],
outputs: [
{
id: 'image',
node_id: canvasOutput.id,
field: 'image',
kind: 'image',
},
],
};
g.setUIContract(uiContract);
return {
g,
seed,