feat(ui): support for ref images for chatgpt on canvas

This commit is contained in:
psychedelicious
2025-04-30 12:41:41 +10:00
parent 7b446ee40d
commit 56cd839d5b
14 changed files with 229 additions and 105 deletions

View File

@@ -0,0 +1,65 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
};
export const GlobalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
if (!modelConfig) {
return;
}
onChangeModel(modelConfig);
},
[onChangeModel]
);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const hasMainModel = Boolean(currentBaseModel);
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
},
[currentBaseModel]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChangeModel,
selectedModel,
getIsDisabled,
isLoading,
});
return (
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
<NavigateToModelManagerButton />
</FormControl>
</Tooltip>
);
});
GlobalReferenceImageModel.displayName = 'GlobalReferenceImageModel';

View File

@@ -6,6 +6,7 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -33,10 +34,9 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
import { IPAdapterModel } from './IPAdapterModel';
const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
createSelector(
@@ -80,7 +80,7 @@ const IPAdapterSettingsContent = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
},
[dispatch, entityIdentifier]
@@ -113,11 +113,7 @@ const IPAdapterSettingsContent = memo(() => {
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<IPAdapterModel
isRegionalGuidance={false}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}

View File

@@ -5,29 +5,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
import { useRegionalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
isRegionalGuidance: boolean;
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
};
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
const filter = (config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
};
export const RegionalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const filter = useCallback(
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
},
[isRegionalGuidance]
);
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
const [modelConfigs, { isLoading }] = useRegionalReferenceImageModels(filter);
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
@@ -73,4 +70,4 @@ export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeMode
);
});
IPAdapterModel.displayName = 'IPAdapterModel';
RegionalReferenceImageModel.displayName = 'RegionalReferenceImageModel';

View File

@@ -7,7 +7,7 @@ import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLI
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
import { RegionalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/RegionalReferenceImageModel';
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
@@ -140,11 +140,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
</Flex>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<IPAdapterModel
isRegionalGuidance={true}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
<RegionalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}

View File

@@ -17,16 +17,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
import {
initialChatGPT4oReferenceImage,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import {
modelConfigsAdapterSelectors,
selectMainModelConfig,
selectModelConfigsQuery,
} from 'services/api/endpoints/models';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
@@ -64,6 +74,35 @@ export const selectDefaultControlAdapter = createSelector(
}
);
const selectDefaultRefImageConfig = createSelector(
selectMainModelConfig,
selectModelConfigsQuery,
selectBase,
(selectedMainModel, query, base): CanvasReferenceImageState['ipAdapter'] => {
if (selectedMainModel?.base === 'chatgpt-4o') {
const referenceImage = deepClone(initialChatGPT4oReferenceImage);
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
return referenceImage;
}
const { data } = query;
let model: IPAdapterModelConfig | null = null;
if (data) {
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
}
const ipAdapter = deepClone(initialIPAdapter);
if (model) {
ipAdapter.model = zModelIdentifierField.parse(model);
if (model.base === 'flux') {
ipAdapter.clipVisionModel = 'ViT-L';
}
}
return ipAdapter;
}
);
/**
* Selects the default IP adapter configuration based on the model configurations and the base.
*
@@ -146,11 +185,11 @@ export const useAddRegionalReferenceImage = () => {
export const useAddGlobalReferenceImage = () => {
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const defaultRefImage = useAppSelector(selectDefaultRefImageConfig);
const func = useCallback(() => {
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
const overrides = { ipAdapter: deepClone(defaultRefImage) };
dispatch(referenceImageAdded({ isSelected: true, overrides }));
}, [defaultIPAdapter, dispatch]);
}, [defaultRefImage, dispatch]);
return func;
};

View File

@@ -19,7 +19,7 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'control_layer':

View File

@@ -34,9 +34,10 @@ import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/com
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import { merge } from 'lodash-es';
import { isEqual, merge } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type {
ApiModelConfig,
ControlLoRAModelConfig,
ControlNetModelConfig,
FLUXReduxModelConfig,
@@ -76,6 +77,7 @@ import {
getReferenceImageState,
getRegionalGuidanceState,
imageDTOToImageWithDims,
initialChatGPT4oReferenceImage,
initialControlLoRA,
initialControlNet,
initialFLUXRedux,
@@ -644,7 +646,10 @@ export const canvasSlice = createSlice({
referenceImageIPAdapterModelChanged: (
state,
action: PayloadAction<
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
EntityIdentifierPayload<
{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null },
'reference_image'
>
>
) => {
const { entityIdentifier, modelConfig } = action.payload;
@@ -652,14 +657,36 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
const oldModel = entity.ipAdapter.model;
// First set the new model
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
if (!entity.ipAdapter.model) {
return;
}
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
// Switching from ip_adapter to flux_redux
if (isEqual(oldModel, entity.ipAdapter.model)) {
// Nothing changed, so we don't need to do anything
return;
}
// The type of ref image depends on the model. When the user switches the model, we rebuild the ref image.
// When we switch the model, we keep the image the same, but change the other parameters.
if (entity.ipAdapter.model.base === 'chatgpt-4o') {
// Switching to chatgpt-4o ref image
entity.ipAdapter = {
...initialChatGPT4oReferenceImage,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.model.type === 'flux_redux') {
// Switching to flux_redux
entity.ipAdapter = {
...initialFLUXRedux,
image: entity.ipAdapter.image,
@@ -668,17 +695,13 @@ export const canvasSlice = createSlice({
return;
}
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
// Switching from flux_redux to ip_adapter
if (entity.ipAdapter.model.type === 'ip_adapter') {
// Switching to ip_adapter
entity.ipAdapter = {
...initialIPAdapter,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.type === 'ip_adapter') {
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
@@ -686,6 +709,7 @@ export const canvasSlice = createSlice({
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
}
return;
}
},
referenceImageIPAdapterCLIPVisionModelChanged: (

View File

@@ -245,6 +245,18 @@ const zFLUXReduxConfig = z.object({
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
const zChatGPT4oReferenceImageConfig = z.object({
type: z.literal('chatgpt_4o_reference_image'),
image: zImageWithDims.nullable(),
/**
* TODO(psyche): Technically there is no model for ChatGPT 4o reference images - it's just a field in the API call.
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
* there will be no way to switch between ref image types.
*/
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zCanvasEntityBase = z.object({
id: zId,
name: zName,
@@ -254,15 +266,19 @@ const zCanvasEntityBase = z.object({
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
type: z.literal('reference_image'),
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
// This should be named `referenceImage` but we need to keep it as `ipAdapter` for backwards compatibility
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
});
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
export const isIPAdapterConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is IPAdapterConfig =>
export const isIPAdapterConfig = (config: CanvasReferenceImageState['ipAdapter']): config is IPAdapterConfig =>
config.type === 'ip_adapter';
export const isFLUXReduxConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is FLUXReduxConfig =>
export const isFLUXReduxConfig = (config: CanvasReferenceImageState['ipAdapter']): config is FLUXReduxConfig =>
config.type === 'flux_redux';
export const isChatGPT4oReferenceImageConfig = (
config: CanvasReferenceImageState['ipAdapter']
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;

View File

@@ -7,6 +7,7 @@ import type {
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ChatGPT4oReferenceImageConfig,
ControlLoRAConfig,
ControlNetConfig,
FLUXReduxConfig,
@@ -77,6 +78,11 @@ export const initialFLUXRedux: FLUXReduxConfig = {
model: null,
imageInfluence: 'highest',
};
export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
type: 'chatgpt_4o_reference_image',
image: null,
model: null,
};
export const initialT2IAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
model: null,

View File

@@ -4,7 +4,9 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isChatGPT4oAspectRatioID } from 'features/controlLayers/store/types';
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { ImageField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
@@ -13,6 +15,7 @@ import {
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
@@ -21,21 +24,40 @@ const log = logger('system');
export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
assert(
generationMode === 'txt2img' || generationMode === 'img2img',
t('toast.gptImageIncompatibleWithInpaintAndOutpaint')
);
assert(generationMode === 'txt2img', t('toast.chatGPT4oIncompatibleGenerationMode'));
log.debug({ generationMode }, 'Building GPT Image graph');
const model = selectMainModelConfig(state);
const canvas = selectCanvasSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const { bbox } = canvas;
const { positivePrompt } = selectPresetModifiedPrompts(state);
assert(model, 'No model found in state');
assert(model.base === 'chatgpt-4o', 'Model is not a FLUX model');
assert(isChatGPT4oAspectRatioID(bbox.aspectRatio.id), 'ChatGPT 4o does not support this aspect ratio');
const validRefImages = canvas.referenceImages.entities
.filter((entity) => entity.isEnabled)
.filter((entity) => isChatGPT4oReferenceImageConfig(entity.ipAdapter))
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
let reference_images: ImageField[] | undefined = undefined;
if (validRefImages.length > 0) {
reference_images = [];
for (const entity of validRefImages) {
assert(entity.ipAdapter.image, 'Image is required for reference image');
reference_images.push({
image_name: entity.ipAdapter.image.image_name,
});
}
}
const is_intermediate = canvasSettings.sendToCanvas;
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
@@ -47,6 +69,7 @@ export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManag
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
reference_images,
use_cache: false,
is_intermediate,
board,
@@ -57,28 +80,5 @@ export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManag
};
}
if (generationMode === 'img2img') {
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const g = new Graph(getPrefixedId('chatgpt_4o_img2img_graph'));
const gptImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'chatgpt_4o_edit_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
positive_prompt: positivePrompt,
image: { image_name },
use_cache: false,
is_intermediate,
board,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for gpt image');
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for ChatGPT 4o');
};

View File

@@ -58,29 +58,5 @@ export const buildImagen3Graph = async (state: RootState, manager: CanvasManager
};
}
if (generationMode === 'img2img') {
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const g = new Graph(getPrefixedId('imagen3_img2img_graph'));
const imagen3 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen3_edit_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
base_image: { image_name },
is_intermediate,
board,
});
return {
g,
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for imagen3');
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen3');
};

View File

@@ -9,6 +9,7 @@ import {
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isChatGPT4oModelConfig,
isCLIPEmbedModelConfig,
isCLIPVisionModelConfig,
isCogView4MainModelModelConfig,
@@ -81,7 +82,10 @@ export const useFluxVAEModels = (args?: ModelHookArgs) =>
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
export const useSigLipModels = buildModelsHook(isSigLipModelConfig);
export const useFluxReduxModels = buildModelsHook(isFluxReduxModelConfig);
export const useIPAdapterOrFLUXReduxModels = buildModelsHook(
export const useGlobalReferenceImageModels = buildModelsHook(
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config) || isChatGPT4oModelConfig(config)
);
export const useRegionalReferenceImageModels = buildModelsHook(
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config)
);
export const useLLaVAModels = buildModelsHook(isLLaVAModelConfig);

View File

@@ -65,7 +65,7 @@ export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type SigLipModelConfig = S['SigLIPConfig'];
export type FLUXReduxModelConfig = S['FluxReduxConfig'];
type ApiModelConfig = S['ApiModelConfig'];
export type ApiModelConfig = S['ApiModelConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig | ApiModelConfig;
export type AnyModelConfig =
| ControlLoRAModelConfig
@@ -228,6 +228,10 @@ export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FLUXRe
return config.type === 'flux_redux';
};
export const isChatGPT4oModelConfig = (config: AnyModelConfig): config is ApiModelConfig => {
return config.type === 'main' && config.base === 'chatgpt-4o';
};
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner';
};