mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 17:18:11 -05:00
Compare commits
1 Commits
controlnet
...
maryhipp/w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bf8dc7043 |
@@ -2,7 +2,7 @@ import type { RootState } from 'app/store/store';
|
||||
import { generateSeeds } from 'common/util/generateSeeds';
|
||||
import { range } from 'es-toolkit/compat';
|
||||
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { Graph, GraphUIContract, GraphUIInput } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { API_BASE_MODELS, VIDEO_BASE_MODELS } from 'features/parameters/types/constants';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { BaseModelType, Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
|
||||
@@ -37,23 +37,54 @@ export const prepareLinearUIBatch = (arg: {
|
||||
const { iterations, shouldRandomizeSeed, seed } = state.params;
|
||||
const { prompts, seedBehaviour } = state.dynamicPrompts;
|
||||
|
||||
const graphWithContract = g.getGraph();
|
||||
const uiContract = graphWithContract.ui;
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
const firstBatchDatumList: components['schemas']['BatchDatum'][] = [];
|
||||
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
|
||||
|
||||
const resolvePromptInput = (contract?: GraphUIContract): { nodePath: string; fieldName: string } => {
|
||||
if (!contract) {
|
||||
return { nodePath: positivePromptNode.id, fieldName: 'value' };
|
||||
}
|
||||
const preferredId = contract.primary_input;
|
||||
const promptInput = preferredId
|
||||
? contract.inputs.find((i) => i.id === preferredId)
|
||||
: contract.inputs.find((i) => i.kind === 'string');
|
||||
if (promptInput) {
|
||||
return { nodePath: promptInput.node_id, fieldName: promptInput.field };
|
||||
}
|
||||
return { nodePath: positivePromptNode.id, fieldName: 'value' };
|
||||
};
|
||||
|
||||
const resolveSeedInput = (contract?: GraphUIContract): GraphUIInput | undefined => {
|
||||
if (!contract) {
|
||||
return undefined;
|
||||
}
|
||||
const preferredId = contract.inputs.find((i) => i.kind === 'seed');
|
||||
return preferredId;
|
||||
};
|
||||
|
||||
const promptTarget = resolvePromptInput(uiContract);
|
||||
const seedTarget = resolveSeedInput(uiContract);
|
||||
|
||||
const seedNodePath = seedTarget?.node_id ?? seedNode?.id;
|
||||
const seedFieldName = seedTarget?.field ?? 'value';
|
||||
|
||||
// add seeds first to ensure the output order groups the prompts
|
||||
if (seedNode && seedBehaviour === 'PER_PROMPT') {
|
||||
if (seedNodePath && seedBehaviour === 'PER_PROMPT') {
|
||||
const seeds = generateSeeds({
|
||||
count: prompts.length * iterations,
|
||||
start: shouldRandomizeSeed ? undefined : seed,
|
||||
});
|
||||
|
||||
firstBatchDatumList.push({
|
||||
node_path: seedNode.id,
|
||||
field_name: 'value',
|
||||
node_path: seedNodePath,
|
||||
field_name: seedFieldName,
|
||||
items: seeds,
|
||||
});
|
||||
} else if (seedNode && seedBehaviour === 'PER_ITERATION') {
|
||||
} else if (seedNodePath && seedBehaviour === 'PER_ITERATION') {
|
||||
// seedBehaviour = SeedBehaviour.PerRun
|
||||
const seeds = generateSeeds({
|
||||
count: iterations,
|
||||
@@ -61,8 +92,8 @@ export const prepareLinearUIBatch = (arg: {
|
||||
});
|
||||
|
||||
secondBatchDatumList.push({
|
||||
node_path: seedNode.id,
|
||||
field_name: 'value',
|
||||
node_path: seedNodePath,
|
||||
field_name: seedFieldName,
|
||||
items: seeds,
|
||||
});
|
||||
data.push(secondBatchDatumList);
|
||||
@@ -72,8 +103,8 @@ export const prepareLinearUIBatch = (arg: {
|
||||
|
||||
// zipped batch of prompts
|
||||
firstBatchDatumList.push({
|
||||
node_path: positivePromptNode.id,
|
||||
field_name: 'value',
|
||||
node_path: promptTarget.nodePath,
|
||||
field_name: promptTarget.fieldName,
|
||||
items: extendedPrompts,
|
||||
});
|
||||
|
||||
@@ -82,7 +113,7 @@ export const prepareLinearUIBatch = (arg: {
|
||||
const enqueueBatchArg: EnqueueBatchArg = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph: g.getGraph(),
|
||||
graph: graphWithContract,
|
||||
runs: 1,
|
||||
data,
|
||||
origin,
|
||||
|
||||
@@ -30,7 +30,40 @@ type Edge = {
|
||||
};
|
||||
};
|
||||
|
||||
export type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
|
||||
export type GraphUIInputKind = 'string' | 'seed' | 'image' | 'number';
|
||||
|
||||
export type GraphUIInput = {
|
||||
id: string;
|
||||
node_id: string;
|
||||
field: string;
|
||||
kind: GraphUIInputKind;
|
||||
batchable?: boolean;
|
||||
ui?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type GraphUIOutputKind = 'image' | 'latents';
|
||||
|
||||
export type GraphUIOutput = {
|
||||
id: string;
|
||||
node_id: string;
|
||||
field: string;
|
||||
kind: GraphUIOutputKind;
|
||||
};
|
||||
|
||||
export type GraphUIContract = {
|
||||
version: '0.1';
|
||||
inputs: GraphUIInput[];
|
||||
outputs: GraphUIOutput[];
|
||||
primary_input?: string;
|
||||
primary_output?: string;
|
||||
};
|
||||
|
||||
export type GraphType = {
|
||||
id: string;
|
||||
nodes: Record<string, AnyInvocation>;
|
||||
edges: Edge[];
|
||||
ui?: GraphUIContract;
|
||||
};
|
||||
|
||||
export class Graph {
|
||||
_graph: GraphType;
|
||||
@@ -365,6 +398,10 @@ export class Graph {
|
||||
getGraphSafe(): GraphType {
|
||||
return this._graph;
|
||||
}
|
||||
|
||||
setUIContract(contract: GraphUIContract): void {
|
||||
this._graph.ui = contract;
|
||||
}
|
||||
//#endregion
|
||||
|
||||
//#region Metadata
|
||||
|
||||
@@ -6,7 +6,7 @@ import { selectCanvasMetadata } from 'features/controlLayers/store/selectors';
|
||||
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { Graph, type GraphUIContract } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
getOriginalAndScaledSizesForOtherModes,
|
||||
getOriginalAndScaledSizesForTextToImage,
|
||||
@@ -86,6 +86,32 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
|
||||
width: originalSize.width,
|
||||
height: originalSize.height,
|
||||
});
|
||||
|
||||
const uiContract: GraphUIContract = {
|
||||
version: '0.1',
|
||||
primary_input: 'prompt',
|
||||
primary_output: 'image',
|
||||
inputs: [
|
||||
{
|
||||
id: 'prompt',
|
||||
node_id: positivePrompt.id,
|
||||
field: 'value',
|
||||
kind: 'string',
|
||||
batchable: true,
|
||||
ui: { component: 'textarea' },
|
||||
},
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
id: 'image',
|
||||
node_id: gptImage.id,
|
||||
field: 'image',
|
||||
kind: 'image',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
g.setUIContract(uiContract);
|
||||
return {
|
||||
g,
|
||||
positivePrompt,
|
||||
@@ -135,6 +161,32 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
|
||||
|
||||
g.setMetadataReceivingNode(gptImage);
|
||||
|
||||
const uiContract: GraphUIContract = {
|
||||
version: '0.1',
|
||||
primary_input: 'prompt',
|
||||
primary_output: 'image',
|
||||
inputs: [
|
||||
{
|
||||
id: 'prompt',
|
||||
node_id: positivePrompt.id,
|
||||
field: 'value',
|
||||
kind: 'string',
|
||||
batchable: false,
|
||||
ui: { component: 'textarea' },
|
||||
},
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
id: 'image',
|
||||
node_id: gptImage.id,
|
||||
field: 'image',
|
||||
kind: 'image',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
g.setUIContract(uiContract);
|
||||
|
||||
return {
|
||||
g,
|
||||
positivePrompt,
|
||||
|
||||
@@ -16,7 +16,7 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { Graph, type GraphUIContract } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
@@ -357,6 +357,40 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
|
||||
const uiContract: GraphUIContract = {
|
||||
version: '0.1',
|
||||
primary_input: 'prompt',
|
||||
primary_output: 'image',
|
||||
inputs: [
|
||||
{
|
||||
id: 'prompt',
|
||||
node_id: positivePrompt.id,
|
||||
field: 'value',
|
||||
kind: 'string',
|
||||
batchable: true,
|
||||
ui: { component: 'textarea' },
|
||||
},
|
||||
{
|
||||
id: 'seed',
|
||||
node_id: seed.id,
|
||||
field: 'value',
|
||||
kind: 'seed',
|
||||
batchable: true,
|
||||
ui: { component: 'seed' },
|
||||
},
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
id: 'image',
|
||||
node_id: canvasOutput.id,
|
||||
field: 'image',
|
||||
kind: 'image',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
g.setUIContract(uiContract);
|
||||
|
||||
return {
|
||||
g,
|
||||
seed,
|
||||
|
||||
@@ -14,7 +14,7 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { Graph, type GraphUIContract } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
@@ -321,6 +321,40 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
|
||||
const uiContract: GraphUIContract = {
|
||||
version: '0.1',
|
||||
primary_input: 'prompt',
|
||||
primary_output: 'image',
|
||||
inputs: [
|
||||
{
|
||||
id: 'prompt',
|
||||
node_id: positivePrompt.id,
|
||||
field: 'value',
|
||||
kind: 'string',
|
||||
batchable: true,
|
||||
ui: { component: 'textarea' },
|
||||
},
|
||||
{
|
||||
id: 'seed',
|
||||
node_id: seed.id,
|
||||
field: 'value',
|
||||
kind: 'seed',
|
||||
batchable: true,
|
||||
ui: { component: 'seed' },
|
||||
},
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
id: 'image',
|
||||
node_id: canvasOutput.id,
|
||||
field: 'image',
|
||||
kind: 'image',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
g.setUIContract(uiContract);
|
||||
|
||||
return {
|
||||
g,
|
||||
seed,
|
||||
|
||||
Reference in New Issue
Block a user