Maryhipp/chatgpt UI (#7969)

* add GPTimage1 as allowed base model

* fix for non-disabled inpaint layers

* lots of boilerplate for adding gpt-image base model and disabling things along with imagen

* handle gpt-image dimensions

* build graph for gpt-image

* lint

* feat(ui): make chatgpt model naming consistent

* feat(ui): graph builder naming

* feat(ui): disable img2img for imagen3

* feat(ui): more naming

* feat(ui): support presigned url prefetch

* feat(ui): disable neg prompt for chatgpt

* docs(ui): update docstring

* feat(ui): fix graph building issues for chatgpt

* fix(ui): node ids for chatgpt/imagen

* chore(ui): typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Mary Hipp Rogers
2025-04-29 09:38:03 -04:00
committed by GitHub
parent 13d44f47ce
commit 17027c4070
31 changed files with 282 additions and 443 deletions

View File

@@ -1322,7 +1322,7 @@
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
"unableToCopyDesc_theseSteps": "these steps",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"image3IncompatibleWithInpaintAndOutpaint": "Imagen3 does not support Inpainting or Outpainting. Use other models for these tasks.",
"imagen3IncompatibleGenerationMode": "Imagen3 only supports Text to Image. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished"

View File

@@ -6,6 +6,7 @@ import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
@@ -51,6 +52,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
@@ -76,15 +79,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch(
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
'canvas',
destination
)
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {

View File

@@ -20,15 +20,15 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(
const batchConfig = prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
'upscaling',
'gallery'
);
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
try {

View File

@@ -24,6 +24,7 @@ export const CanvasAddEntityButtons = memo(() => {
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
@@ -52,6 +53,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addInpaintMask}
isDisabled={!isInpaintLayerEnabled}
>
{t('controlLayers.inpaintMask')}
</Button>

View File

@@ -25,6 +25,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
return (
<Menu>
@@ -46,7 +47,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.regional')}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask} isDisabled={!isInpaintLayerEnabled}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={!isRegionalGuidanceEnabled}>

View File

@@ -1,5 +1,10 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsCogView4, selectIsImagen3, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
import type { Equals } from 'tsafe';
@@ -9,23 +14,24 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'inpaint_mask':
return !isImagen3;
return !isImagen3 && !isChatGPT4o;
case 'raster_layer':
return !isImagen3;
return !isImagen3 && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3]);
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
return isEntityTypeEnabled;
};

View File

@@ -112,7 +112,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
return;
}
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url));
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url, true));
if (imageElementResult.isErr()) {
// Image loading failed (e.g. the URL to the "physical" image is invalid)
this.onFailedToLoadImage(t('controlLayers.unableToLoadImage', 'Unable to load image'));

View File

@@ -235,8 +235,8 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
if (tool !== 'bbox') {
return NO_ANCHORS;
}
if (model?.base === 'imagen3') {
// The bbox is not resizable in imagen3 mode
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
// The bbox is not resizable in these modes
return NO_ANCHORS;
}
return ALL_ANCHORS;

View File

@@ -476,15 +476,24 @@ export function getImageDataTransparency(imageData: ImageData): Transparency {
/**
* Loads an image from a URL and returns a promise that resolves with the loaded image element.
* @param src The image source URL
* @param fetchUrlFirst Whether to fetch the image's URL first, assuming the provided `src` will redirect to a different URL. This addresses an issue where CORS headers are dropped during a redirect.
* @returns A promise that resolves with the loaded image element
*/
export function loadImage(src: string): Promise<HTMLImageElement> {
export async function loadImage(src: string, fetchUrlFirst?: boolean): Promise<HTMLImageElement> {
const authToken = $authToken.get();
let url = src;
if (authToken && fetchUrlFirst) {
const response = await fetch(`${src}?url_only=true`, { credentials: 'include' });
const data = await response.json();
url = data.url;
}
return new Promise((resolve, reject) => {
const imageElement = new Image();
imageElement.onload = () => resolve(imageElement);
imageElement.onerror = (error) => reject(error);
imageElement.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
imageElement.src = src;
imageElement.src = url;
});
}

View File

@@ -67,7 +67,7 @@ import type {
IPMethodV2,
T2IAdapterConfig,
} from './types';
import { getEntityIdentifier, isImagen3AspectRatioID, isRenderableEntity } from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
import {
converters,
getControlLayerState,
@@ -1232,6 +1232,20 @@ export const canvasSlice = createSlice({
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'chatgpt-4o' && isChatGPT4oAspectRatioID(id)) {
// gpt-image has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '3:2') {
state.bbox.rect.width = 1536;
state.bbox.rect.height = 1024;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
} else if (id === '2:3') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1536;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
@@ -1704,7 +1718,7 @@ export const canvasSlice = createSlice({
const base = model?.base;
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
state.bbox.modelBase = base;
if (base === 'imagen3') {
if (base === 'imagen3' || base === 'chatgpt-4o') {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = 1;
state.bbox.aspectRatio.id = '1:1';
@@ -1843,7 +1857,7 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
};
const syncScaledSize = (state: CanvasState) => {
if (state.bbox.modelBase === 'imagen3') {
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
// Imagen3 has fixed sizes. Scaled bbox is not supported.
return;
}

View File

@@ -381,6 +381,7 @@ export const selectIsFLUX = createParamsSelector((params) => params.model?.base
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);

View File

@@ -388,9 +388,15 @@ export type StagingAreaImage = {
};
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
zImagen3AspectRatioID.safeParse(v).success;
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
zChatGPT4oAspectRatioID.safeParse(v).success;
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;

View File

@@ -4,7 +4,7 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export const isSeedBehaviour = (v: unknown): v is SeedBehaviour => zSeedBehaviour.safeParse(v).success;
export interface DynamicPromptsState {

View File

@@ -123,6 +123,8 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'metadata_to_t2i_adapters',
'google_imagen3_generate',
'google_imagen3_edit',
'chatgpt_create_image',
'chatgpt_edit_image',
];
export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => {

View File

@@ -2,35 +2,57 @@ import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import randomInt from 'common/util/randomInt';
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { range } from 'lodash-es';
import type { components } from 'services/api/schema';
import type { Batch, EnqueueBatchArg } from 'services/api/types';
import { assert } from 'tsafe';
export const prepareLinearUIBatch = (
state: RootState,
g: Graph,
prepend: boolean,
seedFieldIdentifier: FieldIdentifier,
positivePromptFieldIdentifier: FieldIdentifier,
origin: 'canvas' | 'workflows' | 'upscaling',
destination: 'canvas' | 'gallery'
): EnqueueBatchArg => {
const getExtendedPrompts = (arg: {
seedBehaviour: SeedBehaviour;
iterations: number;
prompts: string[];
model: ModelIdentifierField;
}): string[] => {
const { seedBehaviour, iterations, prompts, model } = arg;
// Normally, the seed behaviour implicity determines the batch size. But when we use models without seeds (like
// ChatGPT 4o) in conjunction with the per-prompt seed behaviour, we lose out on that implicit batch size. To rectify
// this, we need to create a batch of the right size by repeating the prompts.
if (seedBehaviour === 'PER_PROMPT' || model.base === 'chatgpt-4o') {
return range(iterations).flatMap(() => prompts);
}
return prompts;
};
export const prepareLinearUIBatch = (arg: {
state: RootState;
g: Graph;
prepend: boolean;
seedFieldIdentifier?: FieldIdentifier;
positivePromptFieldIdentifier: FieldIdentifier;
origin: 'canvas' | 'workflows' | 'upscaling';
destination: 'canvas' | 'gallery';
}): EnqueueBatchArg => {
const { state, g, prepend, seedFieldIdentifier, positivePromptFieldIdentifier, origin, destination } = arg;
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;
const { prompts, seedBehaviour } = state.dynamicPrompts;
assert(model, 'No model found in state when preparing batch');
const data: Batch['data'] = [];
const firstBatchDatumList: components['schemas']['BatchDatum'][] = [];
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
// add seeds first to ensure the output order groups the prompts
if (seedBehaviour === 'PER_PROMPT') {
if (seedFieldIdentifier && seedBehaviour === 'PER_PROMPT') {
const seeds = generateSeeds({
count: prompts.length * iterations,
// Imagen3's support for seeded generation is iffy, we are just not going too use in in linear UI generations.
// Imagen3's support for seeded generation is iffy, we are just not going too use it in linear UI generations.
start:
model?.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
});
firstBatchDatumList.push({
@@ -46,13 +68,13 @@ export const prepareLinearUIBatch = (
field_name: 'seed',
items: seeds,
});
} else {
} else if (seedFieldIdentifier && seedBehaviour === 'PER_ITERATION') {
// seedBehaviour = SeedBehaviour.PerRun
const seeds = generateSeeds({
count: iterations,
// Imagen3's support for seeded generation is iffy, we are just not going too use in in linear UI generations.
start:
model?.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
});
secondBatchDatumList.push({
@@ -71,7 +93,7 @@ export const prepareLinearUIBatch = (
data.push(secondBatchDatumList);
}
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
const extendedPrompts = getExtendedPrompts({ seedBehaviour, iterations, prompts, model });
// zipped batch of prompts
firstBatchDatumList.push({
@@ -88,7 +110,7 @@ export const prepareLinearUIBatch = (
items: extendedPrompts,
});
if (shouldConcatPrompts && model?.base === 'sdxl') {
if (shouldConcatPrompts && model.base === 'sdxl') {
firstBatchDatumList.push({
node_path: positivePromptFieldIdentifier.nodeId,
field_name: 'style',

View File

@@ -0,0 +1,84 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isChatGPT4oAspectRatioID } from 'features/controlLayers/store/types';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
assert(
generationMode === 'txt2img' || generationMode === 'img2img',
t('toast.gptImageIncompatibleWithInpaintAndOutpaint')
);
log.debug({ generationMode }, 'Building GPT Image graph');
const canvas = selectCanvasSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const { bbox } = canvas;
const { positivePrompt } = selectPresetModifiedPrompts(state);
assert(isChatGPT4oAspectRatioID(bbox.aspectRatio.id), 'ChatGPT 4o does not support this aspect ratio');
const is_intermediate = canvasSettings.sendToCanvas;
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('chatgpt_4o_txt2img_graph'));
const gptImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'chatgpt_4o_generate_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
use_cache: false,
is_intermediate,
board,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
};
}
if (generationMode === 'img2img') {
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const g = new Graph(getPrefixedId('chatgpt_4o_img2img_graph'));
const gptImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'chatgpt_4o_edit_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
positive_prompt: positivePrompt,
image: { image_name },
use_cache: false,
is_intermediate,
board,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for gpt image');
};

View File

@@ -6,7 +6,6 @@ import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSe
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
@@ -20,7 +19,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -28,10 +27,7 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildCogView4Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; seedFieldIdentifier: FieldIdentifier; positivePromptFieldIdentifier: FieldIdentifier }> => {
export const buildCogView4Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building CogView4 graph');

View File

@@ -5,7 +5,6 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill';
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
import { addFLUXReduxes } from 'features/nodes/util/graph/generation/addFLUXRedux';
@@ -23,7 +22,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Invocation } from 'services/api/types';
@@ -35,10 +34,7 @@ import { addIPAdapters } from './addIPAdapters';
const log = logger('system');
export const buildFLUXGraph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; seedFieldIdentifier: FieldIdentifier; positivePromptFieldIdentifier: FieldIdentifier }> => {
export const buildFLUXGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building FLUX graph');

View File

@@ -5,29 +5,23 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isImagen3AspectRatioID } from 'features/controlLayers/store/types';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildImagen3Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; seedFieldIdentifier: FieldIdentifier; positivePromptFieldIdentifier: FieldIdentifier }> => {
export const buildImagen3Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
assert(
generationMode === 'txt2img' || generationMode === 'img2img',
t('toast.image3IncompatibleWithInpaintAndOutpaint')
);
assert(generationMode === 'txt2img', t('toast.imagen3IncompatibleGenerationMode'));
log.debug({ generationMode }, 'Building Imagen3 graph');
@@ -46,7 +40,7 @@ export const buildImagen3Graph = async (
const g = new Graph(getPrefixedId('imagen3_txt2img_graph'));
const imagen3 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen3_generate',
type: 'google_imagen3_generate_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
@@ -73,7 +67,7 @@ export const buildImagen3Graph = async (
const g = new Graph(getPrefixedId('imagen3_img2img_graph'));
const imagen3 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen3_edit',
type: 'google_imagen3_edit_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,

View File

@@ -5,7 +5,6 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
@@ -24,7 +23,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -34,10 +33,7 @@ import { addRegions } from './addRegions';
const log = logger('system');
export const buildSD1Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; seedFieldIdentifier: FieldIdentifier; positivePromptFieldIdentifier: FieldIdentifier }> => {
export const buildSD1Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SD1/SD2 graph');

View File

@@ -5,7 +5,6 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
@@ -19,7 +18,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -27,10 +26,7 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildSD3Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; seedFieldIdentifier: FieldIdentifier; positivePromptFieldIdentifier: FieldIdentifier }> => {
export const buildSD3Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SD3 graph');

View File

@@ -5,7 +5,6 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
@@ -24,7 +23,7 @@ import {
getSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -34,10 +33,7 @@ import { addRegions } from './addRegions';
const log = logger('system');
export const buildSDXLGraph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; seedFieldIdentifier: FieldIdentifier; positivePromptFieldIdentifier: FieldIdentifier }> => {
export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SDXL graph');

View File

@@ -1,3 +1,6 @@
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
export type ImageOutputNodes =
| 'l2i'
| 'img_nsfw'
@@ -23,3 +26,9 @@ export type MainModelLoaderNodes =
| 'cogview4_model_loader';
export type VaeSourceNodes = 'seamless' | 'vae_loader';
export type GraphBuilderReturn = {
g: Graph;
seedFieldIdentifier?: FieldIdentifier;
positivePromptFieldIdentifier: FieldIdentifier;
};

View File

@@ -3,9 +3,14 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsChatGTP4o, selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
import { isAspectRatioID, zAspectRatioID, zImagen3AspectRatioID } from 'features/controlLayers/store/types';
import {
isAspectRatioID,
zAspectRatioID,
zChatGPT4oAspectRatioID,
zImagen3AspectRatioID,
} from 'features/controlLayers/store/types';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -17,13 +22,19 @@ export const BboxAspectRatioSelect = memo(() => {
const id = useAppSelector(selectAspectRatioID);
const isStaging = useAppSelector(selectIsStaging);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const options = useMemo(() => {
if (!isImagen3) {
return zAspectRatioID.options;
// Imagen3 and ChatGPT4o have different aspect ratio options, and do not support freeform sizes
if (isImagen3) {
return zImagen3AspectRatioID.options;
}
return zImagen3AspectRatioID.options;
}, [isImagen3]);
if (isChatGPT4o) {
return zChatGPT4oAspectRatioID.options;
}
// All other models
return zAspectRatioID.options;
}, [isImagen3, isChatGPT4o]);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {

View File

@@ -1,9 +1,10 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsChatGTP4o, selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
export const useIsBboxSizeLocked = () => {
const isStaging = useAppSelector(selectIsStaging);
const isImagen3 = useAppSelector(selectIsImagen3);
return isImagen3 || isStaging;
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
return isImagen3 || isChatGPT4o || isStaging;
};

View File

@@ -1,6 +1,6 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { createParamsSelector, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { createParamsSelector, selectIsChatGTP4o, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt';
import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt';
import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt';
@@ -16,11 +16,12 @@ const selectWithStylePrompts = createParamsSelector((params) => {
export const Prompts = memo(() => {
const withStylePrompts = useAppSelector(selectWithStylePrompts);
const isFLUX = useAppSelector(selectIsFLUX);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
return (
<Flex flexDir="column" gap={2}>
<ParamPositivePrompt />
{withStylePrompts && <ParamSDXLPositiveStylePrompt />}
{!isFLUX && <ParamNegativePrompt />}
{!isFLUX && !isChatGPT4o && <ParamNegativePrompt />}
{withStylePrompts && <ParamSDXLNegativeStylePrompt />}
</Flex>
);

View File

@@ -19,6 +19,7 @@ export const getOptimalDimension = (base?: BaseModelType | null): number => {
case 'sd-3':
case 'cogview4':
case 'imagen3':
case 'chatgpt-4o':
default:
return 1024;
}
@@ -44,6 +45,7 @@ export const getGridSize = (base?: BaseModelType | null): number => {
case 'sd-2':
case 'sdxl':
case 'imagen3':
case 'chatgpt-4o':
default:
return 8;
}

View File

@@ -4,7 +4,13 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectIsCogView4, selectIsFLUX, selectIsImagen3, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsFLUX,
selectIsImagen3,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
@@ -34,6 +40,11 @@ export const GenerationSettingsAccordion = memo(() => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
const isUpscaling = useMemo(() => {
return activeTabName === 'upscaling';
@@ -44,7 +55,7 @@ export const GenerationSettingsAccordion = memo(() => {
const enabledLoRAsCount = loras.loras.filter((l) => l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
const accordionBadges =
modelConfig?.base === 'imagen3'
modelConfig?.base === 'imagen3' || modelConfig?.base === 'chatgpt-4o'
? [modelConfig.name]
: modelConfig
? [modelConfig.name, modelConfig.base]
@@ -71,12 +82,12 @@ export const GenerationSettingsAccordion = memo(() => {
onToggle={onToggleAccordion}
>
<Box px={4} pt={4} data-testid="generation-accordion">
<Flex gap={4} flexDir="column" pb={isImagen3 ? 4 : 0}>
<Flex gap={4} flexDir="column" pb={isApiModel ? 4 : 0}>
<MainModelPicker />
{!isImagen3 && <LoRASelect />}
{!isImagen3 && <LoRAList />}
{!isApiModel && <LoRASelect />}
{!isApiModel && <LoRAList />}
</Flex>
{!isImagen3 && (
{!isApiModel && (
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} flexDir="column" pb={4}>
<FormControlGroup formLabelProps={formLabelProps}>

View File

@@ -4,6 +4,7 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectIsChatGTP4o,
selectIsFLUX,
selectIsImagen3,
selectIsSD3,
@@ -18,7 +19,7 @@ import { BboxSettings } from 'features/parameters/components/Bbox/BboxSettings';
import { ParamSeed } from 'features/parameters/components/Seed/ParamSeed';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const selectBadges = createMemoizedSelector([selectCanvasSlice, selectParamsSlice], (canvas, params) => {
@@ -65,6 +66,11 @@ export const ImageSettingsAccordion = memo(() => {
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
return (
<StandaloneAccordion
@@ -76,15 +82,15 @@ export const ImageSettingsAccordion = memo(() => {
<Flex
px={4}
pt={4}
pb={isImagen3 ? 4 : 0}
pb={isApiModel ? 4 : 0}
w="full"
h="full"
flexDir="column"
data-testid="image-settings-accordion"
>
<BboxSettings />
{!isImagen3 && <ParamSeed py={3} />}
{!isImagen3 && (
{!isApiModel && <ParamSeed py={3} />}
{!isApiModel && (
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} pb={4} flexDir="column">
{(isFLUX || isSD3) && <ParamOptimizedDenoisingToggle />}

View File

@@ -2,7 +2,12 @@ import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import { selectIsCogView4, selectIsImagen3, selectIsSDXL } from 'features/controlLayers/store/paramsSlice';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsSDXL,
} from 'features/controlLayers/store/paramsSlice';
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion';
import { CompositingSettingsAccordion } from 'features/settingsAccordions/components/CompositingSettingsAccordion/CompositingSettingsAccordion';
@@ -14,7 +19,7 @@ import { StylePresetMenuTrigger } from 'features/stylePresets/components/StylePr
import { $isStylePresetsMenuOpen } from 'features/stylePresets/store/stylePresetSlice';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo } from 'react';
import { memo, useMemo } from 'react';
const overlayScrollbarsStyles: CSSProperties = {
height: '100%',
@@ -25,8 +30,13 @@ const ParametersPanelTextToImage = () => {
const isSDXL = useAppSelector(selectIsSDXL);
const isCogview4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isStylePresetsMenuOpen = useStore($isStylePresetsMenuOpen);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
return (
<Flex w="full" h="full" flexDir="column" gap={2}>
<StylePresetMenuTrigger />
@@ -44,9 +54,9 @@ const ParametersPanelTextToImage = () => {
<Prompts />
<ImageSettingsAccordion />
<GenerationSettingsAccordion />
{!isImagen3 && <CompositingSettingsAccordion />}
{!isApiModel && <CompositingSettingsAccordion />}
{isSDXL && <RefinerSettingsAccordion />}
{!isCogview4 && !isImagen3 && <AdvancedSettingsAccordion />}
{!isCogview4 && !isApiModel && <AdvancedSettingsAccordion />}
</Flex>
</OverlayScrollbarsComponent>
</Box>

File diff suppressed because it is too large Load Diff