mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 19:05:13 -05:00
refactor(ui): graph building respects selected tab
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
|
||||
type CanvasStagingAreaState = {
|
||||
@@ -20,26 +19,16 @@ export const canvasSessionSlice = createSlice({
|
||||
name: 'canvasSession',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
generateSessionIdCreated: {
|
||||
reducer: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.generateSessionId = id;
|
||||
},
|
||||
prepare: () => ({
|
||||
payload: { id: getPrefixedId('generate') },
|
||||
}),
|
||||
generateSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.generateSessionId = id;
|
||||
},
|
||||
generateSessionReset: (state) => {
|
||||
state.generateSessionId = null;
|
||||
},
|
||||
canvasSessionIdCreated: {
|
||||
reducer: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.canvasSessionId = id;
|
||||
},
|
||||
prepare: () => ({
|
||||
payload: { id: getPrefixedId('canvas') },
|
||||
}),
|
||||
canvasSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.canvasSessionId = id;
|
||||
},
|
||||
canvasSessionReset: (state) => {
|
||||
state.canvasSessionId = null;
|
||||
@@ -52,7 +41,7 @@ export const canvasSessionSlice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { generateSessionIdCreated, generateSessionReset, canvasSessionIdCreated, canvasSessionReset } =
|
||||
export const { generateSessionIdChanged, generateSessionReset, canvasSessionIdChanged, canvasSessionReset } =
|
||||
canvasSessionSlice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
@@ -8,23 +6,18 @@ import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildChatGPT4oGraph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager | null
|
||||
): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
|
||||
const { generationMode, state } = arg;
|
||||
|
||||
if (generationMode !== 'txt2img' && generationMode !== 'img2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.chatGPT4oIncompatibleGenerationMode'));
|
||||
@@ -86,9 +79,8 @@ export const buildChatGPT4oGraph = async (
|
||||
}
|
||||
|
||||
if (generationMode === 'img2img') {
|
||||
assert(manager, 'Need manager to do img2img');
|
||||
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
|
||||
const adapters = arg.canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const { image_name } = await arg.canvasManager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
@@ -11,15 +9,13 @@ import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChec
|
||||
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
getSizes,
|
||||
selectCanvasOutputFields,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
@@ -27,12 +23,8 @@ import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildCogView4Graph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager | null
|
||||
): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
|
||||
const { generationMode, state } = arg;
|
||||
log.debug({ generationMode }, 'Building CogView4 graph');
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
@@ -112,10 +104,9 @@ export const buildCogView4Graph = async (
|
||||
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
|
||||
g.upsertMetadata({ generation_mode: 'cogview4_txt2img' });
|
||||
} else if (generationMode === 'img2img') {
|
||||
assert(manager, 'Need manager to do img2img');
|
||||
canvasOutput = await addImageToImage({
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'cogview4_i2l',
|
||||
denoise,
|
||||
@@ -128,11 +119,10 @@ export const buildCogView4Graph = async (
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'cogview4_img2img' });
|
||||
} else if (generationMode === 'inpaint') {
|
||||
assert(manager, 'Need manager to do inpaint');
|
||||
canvasOutput = await addInpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'cogview4_i2l',
|
||||
denoise,
|
||||
@@ -146,11 +136,10 @@ export const buildCogView4Graph = async (
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'cogview4_inpaint' });
|
||||
} else if (generationMode === 'outpaint') {
|
||||
assert(manager, 'Need manager to do outpaint');
|
||||
canvasOutput = await addOutpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'cogview4_i2l',
|
||||
denoise,
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
@@ -15,19 +13,14 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
getSizes,
|
||||
selectCanvasOutputFields,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import {
|
||||
type GraphBuilderReturn,
|
||||
type ImageOutputNodes,
|
||||
UnsupportedGenerationModeError,
|
||||
} from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
@@ -38,9 +31,8 @@ import { addIPAdapters } from './addIPAdapters';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | null): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
|
||||
const { generationMode, state } = arg;
|
||||
log.debug({ generationMode }, 'Building FLUX graph');
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
@@ -171,12 +163,11 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
|
||||
|
||||
let canvasOutput: Invocation<ImageOutputNodes> = l2i;
|
||||
|
||||
if (isFLUXFill) {
|
||||
assert(manager, 'Need manager to do FLUX Fill');
|
||||
if (isFLUXFill && (generationMode === 'inpaint' || generationMode === 'outpaint')) {
|
||||
canvasOutput = await addFLUXFill({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
denoise,
|
||||
originalSize,
|
||||
@@ -186,10 +177,9 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
|
||||
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
|
||||
g.upsertMetadata({ generation_mode: 'flux_txt2img' });
|
||||
} else if (generationMode === 'img2img') {
|
||||
assert(manager, 'Need manager to do img2img');
|
||||
canvasOutput = await addImageToImage({
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'flux_vae_encode',
|
||||
denoise,
|
||||
@@ -202,11 +192,10 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'flux_img2img' });
|
||||
} else if (generationMode === 'inpaint') {
|
||||
assert(manager, 'Need manager to do inpaint');
|
||||
canvasOutput = await addInpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'flux_vae_encode',
|
||||
denoise,
|
||||
@@ -220,11 +209,10 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'flux_inpaint' });
|
||||
} else if (generationMode === 'outpaint') {
|
||||
assert(manager, 'Need manager to do outpaint');
|
||||
canvasOutput = await addOutpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'flux_vae_encode',
|
||||
denoise,
|
||||
@@ -241,13 +229,13 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
|
||||
assert<Equals<typeof generationMode, never>>(false);
|
||||
}
|
||||
|
||||
if (manager) {
|
||||
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
|
||||
const controlNetCollector = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
rect: canvas.bbox.rect,
|
||||
@@ -261,7 +249,7 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
|
||||
}
|
||||
|
||||
await addControlLoRA({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
rect: canvas.bbox.rect,
|
||||
@@ -295,9 +283,9 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
|
||||
});
|
||||
let totalReduxesAdded = fluxReduxResult.addedFLUXReduxes;
|
||||
|
||||
if (manager) {
|
||||
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
@@ -11,22 +9,15 @@ import type { ImageField } from 'features/nodes/types/common';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getGenerationMode } from './getGenerationMode';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFluxKontextGraph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager | null
|
||||
): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn => {
|
||||
const { generationMode, state } = arg;
|
||||
|
||||
if (generationMode !== 'txt2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'FLUX Kontext' }));
|
||||
@@ -61,29 +52,25 @@ export const buildFluxKontextGraph = async (
|
||||
};
|
||||
}
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
const g = new Graph(getPrefixedId('flux_kontext_txt2img_graph'));
|
||||
const fluxKontextImage = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: input_image ? 'flux_kontext_edit_image' : 'flux_kontext_generate_image',
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
input_image,
|
||||
prompt_upsampling: true,
|
||||
...selectCanvasOutputFields(state),
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
});
|
||||
return {
|
||||
g,
|
||||
positivePromptFieldIdentifier: { nodeId: fluxKontextImage.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Flux Kontext');
|
||||
const g = new Graph(getPrefixedId('flux_kontext_txt2img_graph'));
|
||||
const fluxKontextImage = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: input_image ? 'flux_kontext_edit_image' : 'flux_kontext_generate_image',
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
input_image,
|
||||
prompt_upsampling: true,
|
||||
...selectCanvasOutputFields(state),
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
});
|
||||
return {
|
||||
g,
|
||||
positivePromptFieldIdentifier: { nodeId: fluxKontextImage.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,28 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildImagen3Graph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager | null
|
||||
): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildImagen3Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
|
||||
const { generationMode, state } = arg;
|
||||
|
||||
if (generationMode !== 'txt2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen3' }));
|
||||
@@ -41,33 +33,30 @@ export const buildImagen3Graph = async (
|
||||
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
|
||||
assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character');
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
const g = new Graph(getPrefixedId('imagen3_txt2img_graph'));
|
||||
const imagen3 = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'google_imagen3_generate_image',
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
// When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed.
|
||||
enhance_prompt: true,
|
||||
...selectCanvasOutputFields(state),
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
...selectCanvasMetadata(state),
|
||||
});
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
const g = new Graph(getPrefixedId('imagen3_txt2img_graph'));
|
||||
const imagen3 = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'google_imagen3_generate_image',
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
// When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed.
|
||||
enhance_prompt: true,
|
||||
...selectCanvasOutputFields(state),
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
...selectCanvasMetadata(state),
|
||||
});
|
||||
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen3');
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,28 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildImagen4Graph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager | null
|
||||
): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildImagen4Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
|
||||
const { generationMode, state } = arg;
|
||||
|
||||
if (generationMode !== 'txt2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen4' }));
|
||||
@@ -41,34 +33,30 @@ export const buildImagen4Graph = async (
|
||||
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen4 does not support this aspect ratio');
|
||||
assert(positivePrompt.length > 0, 'Imagen4 requires positive prompt to have at least one character');
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
const g = new Graph(getPrefixedId('imagen4_txt2img_graph'));
|
||||
const imagen4 = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'google_imagen4_generate_image',
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
// When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed.
|
||||
enhance_prompt: true,
|
||||
...selectCanvasOutputFields(state),
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
...selectCanvasMetadata(state),
|
||||
});
|
||||
const g = new Graph(getPrefixedId('imagen4_txt2img_graph'));
|
||||
const imagen4 = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'google_imagen4_generate_image',
|
||||
model: zModelIdentifierField.parse(model),
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
// When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed.
|
||||
enhance_prompt: true,
|
||||
...selectCanvasOutputFields(state),
|
||||
});
|
||||
g.upsertMetadata({
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
width: bbox.rect.width,
|
||||
height: bbox.rect.height,
|
||||
model: Graph.getModelMetadataField(model),
|
||||
...selectCanvasMetadata(state),
|
||||
});
|
||||
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen4');
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
@@ -16,15 +14,13 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
getSizes,
|
||||
selectCanvasOutputFields,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -33,9 +29,8 @@ import { addRegions } from './addRegions';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildSD1Graph = async (state: RootState, manager: CanvasManager | null): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
|
||||
const { generationMode, state } = arg;
|
||||
log.debug({ generationMode }, 'Building SD1/SD2 graph');
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
@@ -171,10 +166,9 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
|
||||
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
|
||||
g.upsertMetadata({ generation_mode: 'txt2img' });
|
||||
} else if (generationMode === 'img2img') {
|
||||
assert(manager, 'Need manager to do img2img');
|
||||
canvasOutput = await addImageToImage({
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
@@ -187,11 +181,10 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'img2img' });
|
||||
} else if (generationMode === 'inpaint') {
|
||||
assert(manager, 'Need manager to do inpaint');
|
||||
canvasOutput = await addInpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
@@ -205,11 +198,10 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'inpaint' });
|
||||
} else if (generationMode === 'outpaint') {
|
||||
assert(manager, 'Need manager to do outpaint');
|
||||
canvasOutput = await addOutpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
@@ -226,13 +218,13 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
|
||||
assert<Equals<typeof generationMode, never>>(false);
|
||||
}
|
||||
|
||||
if (manager) {
|
||||
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
|
||||
const controlNetCollector = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
rect: canvas.bbox.rect,
|
||||
@@ -250,7 +242,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
|
||||
id: getPrefixedId('t2i_adapter_collector'),
|
||||
});
|
||||
const t2iAdapterResult = await addT2IAdapters({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
rect: canvas.bbox.rect,
|
||||
@@ -276,9 +268,9 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
|
||||
});
|
||||
let totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
|
||||
|
||||
if (manager) {
|
||||
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
@@ -10,24 +8,21 @@ import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChec
|
||||
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
getSizes,
|
||||
selectCanvasOutputFields,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildSD3Graph = async (state: RootState, manager: CanvasManager | null): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
|
||||
const { generationMode, state } = arg;
|
||||
log.debug({ generationMode }, 'Building SD3 graph');
|
||||
|
||||
const model = selectMainModelConfig(state);
|
||||
@@ -133,10 +128,9 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n
|
||||
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
|
||||
g.upsertMetadata({ generation_mode: 'sd3_txt2img' });
|
||||
} else if (generationMode === 'img2img') {
|
||||
assert(manager, 'Need manager to do img2img');
|
||||
canvasOutput = await addImageToImage({
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'sd3_i2l',
|
||||
denoise,
|
||||
@@ -149,11 +143,10 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'sd3_img2img' });
|
||||
} else if (generationMode === 'inpaint') {
|
||||
assert(manager, 'Need manager to do inpaint');
|
||||
canvasOutput = await addInpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'sd3_i2l',
|
||||
denoise,
|
||||
@@ -167,11 +160,10 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'sd3_inpaint' });
|
||||
} else if (generationMode === 'outpaint') {
|
||||
assert(manager, 'Need manager to do outpaint');
|
||||
canvasOutput = await addOutpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'sd3_i2l',
|
||||
denoise,
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
@@ -16,15 +14,13 @@ import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefi
|
||||
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
getSizes,
|
||||
selectCanvasOutputFields,
|
||||
selectPresetModifiedPrompts,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -33,9 +29,8 @@ import { addRegions } from './addRegions';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | null): Promise<GraphBuilderReturn> => {
|
||||
const tab = selectActiveTab(state);
|
||||
const generationMode = await getGenerationMode(manager, tab);
|
||||
export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
|
||||
const { generationMode, state } = arg;
|
||||
log.debug({ generationMode }, 'Building SDXL graph');
|
||||
|
||||
const model = selectMainModelConfig(state);
|
||||
@@ -178,10 +173,9 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
|
||||
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
|
||||
g.upsertMetadata({ generation_mode: 'sdxl_txt2img' });
|
||||
} else if (generationMode === 'img2img') {
|
||||
assert(manager, 'Need manager to do img2img');
|
||||
canvasOutput = await addImageToImage({
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
@@ -194,11 +188,10 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'sdxl_img2img' });
|
||||
} else if (generationMode === 'inpaint') {
|
||||
assert(manager, 'Need manager to do inpaint');
|
||||
canvasOutput = await addInpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
@@ -212,11 +205,10 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
|
||||
});
|
||||
g.upsertMetadata({ generation_mode: 'sdxl_inpaint' });
|
||||
} else if (generationMode === 'outpaint') {
|
||||
assert(manager, 'Need manager to do outpaint');
|
||||
canvasOutput = await addOutpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
@@ -233,13 +225,13 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
|
||||
assert<Equals<typeof generationMode, never>>(false);
|
||||
}
|
||||
|
||||
if (manager) {
|
||||
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
|
||||
const controlNetCollector = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
rect: canvas.bbox.rect,
|
||||
@@ -257,7 +249,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
|
||||
id: getPrefixedId('t2i_adapter_collector'),
|
||||
});
|
||||
const t2iAdapterResult = await addT2IAdapters({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
entities: canvas.controlLayers.entities,
|
||||
g,
|
||||
rect: canvas.bbox.rect,
|
||||
@@ -283,9 +275,9 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
|
||||
});
|
||||
let totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
|
||||
|
||||
if (manager) {
|
||||
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
manager: arg.canvasManager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { pick } from 'es-toolkit/compat';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CanvasState, ParamsState } from 'features/controlLayers/store/types';
|
||||
@@ -13,7 +14,7 @@ import { selectListStylePresetsRequestState } from 'services/api/endpoints/style
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { MainModelLoaderNodes } from './types';
|
||||
import type { GraphBuilderArg, MainModelLoaderNodes } from './types';
|
||||
|
||||
/**
|
||||
* Gets the board field, based on the autoAddBoardId setting.
|
||||
@@ -165,3 +166,19 @@ export const isCanvasOutputNodeId = (nodeId: string) => nodeId.split(':')[0] ===
|
||||
export const isCanvasOutputEvent = (data: S['InvocationCompleteEvent']) => {
|
||||
return isCanvasOutputNodeId(data.invocation_source_id);
|
||||
};
|
||||
|
||||
export const getGraphBuilderArg = async (
|
||||
state: RootState,
|
||||
canvasManager: CanvasManager | null
|
||||
): Promise<GraphBuilderArg> => {
|
||||
const tab = selectActiveTab(state);
|
||||
if (tab === 'generate') {
|
||||
return { generationMode: 'txt2img', state };
|
||||
} else if (tab === 'canvas') {
|
||||
assert(canvasManager !== null, 'CanvasManager should not be null in canvas tab');
|
||||
const generationMode = await canvasManager.compositor.getGenerationMode();
|
||||
return { generationMode, state, canvasManager };
|
||||
} else {
|
||||
assert(false, `Unknown tab: ${tab}`);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import type { GenerationMode } from 'features/controlLayers/store/types';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
|
||||
@@ -27,6 +30,17 @@ export type MainModelLoaderNodes =
|
||||
|
||||
export type VaeSourceNodes = 'seamless' | 'vae_loader';
|
||||
|
||||
export type GraphBuilderArg =
|
||||
| {
|
||||
generationMode: Extract<GenerationMode, 'txt2img'>;
|
||||
state: RootState;
|
||||
}
|
||||
| {
|
||||
generationMode: Exclude<GenerationMode, 'txt2img'>;
|
||||
state: RootState;
|
||||
canvasManager: CanvasManager;
|
||||
};
|
||||
|
||||
export type GraphBuilderReturn = {
|
||||
g: Graph;
|
||||
seedFieldIdentifier?: FieldIdentifier;
|
||||
|
||||
@@ -7,7 +7,8 @@ import { extractMessageFromAssertionError } from 'common/util/extractMessageFrom
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { canvasSessionIdCreated, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { canvasSessionIdChanged, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
@@ -18,6 +19,7 @@ import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildIma
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
@@ -33,40 +35,42 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
|
||||
|
||||
dispatch(enqueueRequestedCanvas());
|
||||
|
||||
let destination = selectCanvasSessionId(getState());
|
||||
if (!destination) {
|
||||
dispatch(canvasSessionIdCreated());
|
||||
destination = selectCanvasSessionId(getState());
|
||||
}
|
||||
assert(destination !== null);
|
||||
|
||||
const state = getState();
|
||||
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
let destination = selectCanvasSessionId(state);
|
||||
if (destination === null) {
|
||||
destination = getPrefixedId('canvas');
|
||||
dispatch(canvasSessionIdChanged({ id: destination }));
|
||||
}
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
|
||||
const generationMode = await canvasManager.compositor.getGenerationMode();
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode, state, canvasManager };
|
||||
|
||||
switch (base) {
|
||||
case 'sdxl':
|
||||
return await buildSDXLGraph(state, canvasManager);
|
||||
return await buildSDXLGraph(graphBuilderArg);
|
||||
case 'sd-1':
|
||||
case `sd-2`:
|
||||
return await buildSD1Graph(state, canvasManager);
|
||||
return await buildSD1Graph(graphBuilderArg);
|
||||
case `sd-3`:
|
||||
return await buildSD3Graph(state, canvasManager);
|
||||
return await buildSD3Graph(graphBuilderArg);
|
||||
case `flux`:
|
||||
return await buildFLUXGraph(state, canvasManager);
|
||||
return await buildFLUXGraph(graphBuilderArg);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(state, canvasManager);
|
||||
return await buildCogView4Graph(graphBuilderArg);
|
||||
case 'imagen3':
|
||||
return await buildImagen3Graph(state, canvasManager);
|
||||
return buildImagen3Graph(graphBuilderArg);
|
||||
case 'imagen4':
|
||||
return await buildImagen4Graph(state, canvasManager);
|
||||
return buildImagen4Graph(graphBuilderArg);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(state, canvasManager);
|
||||
return await buildChatGPT4oGraph(graphBuilderArg);
|
||||
case 'flux-kontext':
|
||||
return await buildFluxKontextGraph(state, canvasManager);
|
||||
return buildFluxKontextGraph(graphBuilderArg);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@ import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { generateSessionIdCreated, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { generateSessionIdChanged, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
@@ -16,6 +17,7 @@ import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildIma
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
@@ -32,40 +34,41 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
|
||||
|
||||
dispatch(enqueueRequestedGenerate());
|
||||
|
||||
let destination = selectGenerateSessionId(getState());
|
||||
if (!destination) {
|
||||
dispatch(generateSessionIdCreated());
|
||||
destination = selectGenerateSessionId(getState());
|
||||
}
|
||||
assert(destination !== null);
|
||||
|
||||
const state = getState();
|
||||
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
let destination = selectGenerateSessionId(state);
|
||||
if (destination === null) {
|
||||
destination = getPrefixedId('generate');
|
||||
dispatch(generateSessionIdChanged({ id: destination }));
|
||||
}
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state };
|
||||
|
||||
switch (base) {
|
||||
case 'sdxl':
|
||||
return await buildSDXLGraph(state, null);
|
||||
return await buildSDXLGraph(graphBuilderArg);
|
||||
case 'sd-1':
|
||||
case `sd-2`:
|
||||
return await buildSD1Graph(state, null);
|
||||
return await buildSD1Graph(graphBuilderArg);
|
||||
case `sd-3`:
|
||||
return await buildSD3Graph(state, null);
|
||||
return await buildSD3Graph(graphBuilderArg);
|
||||
case `flux`:
|
||||
return await buildFLUXGraph(state, null);
|
||||
return await buildFLUXGraph(graphBuilderArg);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(state, null);
|
||||
return await buildCogView4Graph(graphBuilderArg);
|
||||
case 'imagen3':
|
||||
return await buildImagen3Graph(state, null);
|
||||
return buildImagen3Graph(graphBuilderArg);
|
||||
case 'imagen4':
|
||||
return await buildImagen4Graph(state, null);
|
||||
return buildImagen4Graph(graphBuilderArg);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(state, null);
|
||||
return await buildChatGPT4oGraph(graphBuilderArg);
|
||||
case 'flux-kontext':
|
||||
return await buildFluxKontextGraph(state, null);
|
||||
return buildFluxKontextGraph(graphBuilderArg);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user