mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): support for ref images for chatgpt on canvas
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
modelKey: string | null;
|
||||
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
|
||||
};
|
||||
|
||||
export const GlobalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
|
||||
const _onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
onChangeModel(modelConfig);
|
||||
},
|
||||
[onChangeModel]
|
||||
);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
const hasSameBase = currentBaseModel === model.base;
|
||||
return !hasMainModel || !hasSameBase;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChangeModel,
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
<NavigateToModelManagerButton />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
GlobalReferenceImageModel.displayName = 'GlobalReferenceImageModel';
|
||||
@@ -6,6 +6,7 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
|
||||
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
|
||||
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
|
||||
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
@@ -33,10 +34,9 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold } from 'react-icons/pi';
|
||||
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
|
||||
import { IPAdapterModel } from './IPAdapterModel';
|
||||
|
||||
const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
|
||||
createSelector(
|
||||
@@ -80,7 +80,7 @@ const IPAdapterSettingsContent = memo(() => {
|
||||
);
|
||||
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
|
||||
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
@@ -113,11 +113,7 @@ const IPAdapterSettingsContent = memo(() => {
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<IPAdapterModel
|
||||
isRegionalGuidance={false}
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
/>
|
||||
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
|
||||
{ipAdapter.type === 'ip_adapter' && (
|
||||
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
|
||||
)}
|
||||
|
||||
@@ -5,29 +5,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
|
||||
import { useRegionalReferenceImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
isRegionalGuidance: boolean;
|
||||
modelKey: string | null;
|
||||
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
|
||||
};
|
||||
|
||||
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
|
||||
const filter = (config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
|
||||
if (config.base === 'flux' && config.type === 'ip_adapter') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
export const RegionalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const filter = useCallback(
|
||||
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
|
||||
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[isRegionalGuidance]
|
||||
);
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
|
||||
const [modelConfigs, { isLoading }] = useRegionalReferenceImageModels(filter);
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
|
||||
const _onChangeModel = useCallback(
|
||||
@@ -73,4 +70,4 @@ export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeMode
|
||||
);
|
||||
});
|
||||
|
||||
IPAdapterModel.displayName = 'IPAdapterModel';
|
||||
RegionalReferenceImageModel.displayName = 'RegionalReferenceImageModel';
|
||||
@@ -7,7 +7,7 @@ import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLI
|
||||
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
|
||||
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
|
||||
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
|
||||
import { RegionalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/RegionalReferenceImageModel';
|
||||
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
@@ -140,11 +140,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
|
||||
</Flex>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<IPAdapterModel
|
||||
isRegionalGuidance={true}
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
/>
|
||||
<RegionalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
|
||||
{ipAdapter.type === 'ip_adapter' && (
|
||||
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
|
||||
)}
|
||||
|
||||
@@ -17,16 +17,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
ControlLoRAConfig,
|
||||
ControlNetConfig,
|
||||
IPAdapterConfig,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
|
||||
import {
|
||||
initialChatGPT4oReferenceImage,
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlLayers/store/util';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { useCallback } from 'react';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import {
|
||||
modelConfigsAdapterSelectors,
|
||||
selectMainModelConfig,
|
||||
selectModelConfigsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type {
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
@@ -64,6 +74,35 @@ export const selectDefaultControlAdapter = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
const selectDefaultRefImageConfig = createSelector(
|
||||
selectMainModelConfig,
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
(selectedMainModel, query, base): CanvasReferenceImageState['ipAdapter'] => {
|
||||
if (selectedMainModel?.base === 'chatgpt-4o') {
|
||||
const referenceImage = deepClone(initialChatGPT4oReferenceImage);
|
||||
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
|
||||
return referenceImage;
|
||||
}
|
||||
|
||||
const { data } = query;
|
||||
let model: IPAdapterModelConfig | null = null;
|
||||
if (data) {
|
||||
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
|
||||
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
|
||||
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
|
||||
}
|
||||
const ipAdapter = deepClone(initialIPAdapter);
|
||||
if (model) {
|
||||
ipAdapter.model = zModelIdentifierField.parse(model);
|
||||
if (model.base === 'flux') {
|
||||
ipAdapter.clipVisionModel = 'ViT-L';
|
||||
}
|
||||
}
|
||||
return ipAdapter;
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Selects the default IP adapter configuration based on the model configurations and the base.
|
||||
*
|
||||
@@ -146,11 +185,11 @@ export const useAddRegionalReferenceImage = () => {
|
||||
|
||||
export const useAddGlobalReferenceImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const defaultRefImage = useAppSelector(selectDefaultRefImageConfig);
|
||||
const func = useCallback(() => {
|
||||
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
|
||||
const overrides = { ipAdapter: deepClone(defaultRefImage) };
|
||||
dispatch(referenceImageAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
}, [defaultRefImage, dispatch]);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
@@ -19,7 +19,7 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
|
||||
const isEntityTypeEnabled = useMemo<boolean>(() => {
|
||||
switch (entityType) {
|
||||
case 'reference_image':
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
|
||||
return !isSD3 && !isCogView4 && !isImagen3;
|
||||
case 'regional_guidance':
|
||||
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
|
||||
case 'control_layer':
|
||||
|
||||
@@ -34,9 +34,10 @@ import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/com
|
||||
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
|
||||
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { merge } from 'lodash-es';
|
||||
import { isEqual, merge } from 'lodash-es';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type {
|
||||
ApiModelConfig,
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
FLUXReduxModelConfig,
|
||||
@@ -76,6 +77,7 @@ import {
|
||||
getReferenceImageState,
|
||||
getRegionalGuidanceState,
|
||||
imageDTOToImageWithDims,
|
||||
initialChatGPT4oReferenceImage,
|
||||
initialControlLoRA,
|
||||
initialControlNet,
|
||||
initialFLUXRedux,
|
||||
@@ -644,7 +646,10 @@ export const canvasSlice = createSlice({
|
||||
referenceImageIPAdapterModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
|
||||
EntityIdentifierPayload<
|
||||
{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null },
|
||||
'reference_image'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, modelConfig } = action.payload;
|
||||
@@ -652,14 +657,36 @@ export const canvasSlice = createSlice({
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
|
||||
const oldModel = entity.ipAdapter.model;
|
||||
|
||||
// First set the new model
|
||||
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
|
||||
|
||||
if (!entity.ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
|
||||
// Switching from ip_adapter to flux_redux
|
||||
if (isEqual(oldModel, entity.ipAdapter.model)) {
|
||||
// Nothing changed, so we don't need to do anything
|
||||
return;
|
||||
}
|
||||
|
||||
// The type of ref image depends on the model. When the user switches the model, we rebuild the ref image.
|
||||
// When we switch the model, we keep the image the same, but change the other parameters.
|
||||
|
||||
if (entity.ipAdapter.model.base === 'chatgpt-4o') {
|
||||
// Switching to chatgpt-4o ref image
|
||||
entity.ipAdapter = {
|
||||
...initialChatGPT4oReferenceImage,
|
||||
image: entity.ipAdapter.image,
|
||||
model: entity.ipAdapter.model,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.model.type === 'flux_redux') {
|
||||
// Switching to flux_redux
|
||||
entity.ipAdapter = {
|
||||
...initialFLUXRedux,
|
||||
image: entity.ipAdapter.image,
|
||||
@@ -668,17 +695,13 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
|
||||
// Switching from flux_redux to ip_adapter
|
||||
if (entity.ipAdapter.model.type === 'ip_adapter') {
|
||||
// Switching to ip_adapter
|
||||
entity.ipAdapter = {
|
||||
...initialIPAdapter,
|
||||
image: entity.ipAdapter.image,
|
||||
model: entity.ipAdapter.model,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'ip_adapter') {
|
||||
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
|
||||
if (entity.ipAdapter.model?.base === 'flux') {
|
||||
entity.ipAdapter.clipVisionModel = 'ViT-L';
|
||||
@@ -686,6 +709,7 @@ export const canvasSlice = createSlice({
|
||||
// Fall back to ViT-H (ViT-G would also work)
|
||||
entity.ipAdapter.clipVisionModel = 'ViT-H';
|
||||
}
|
||||
return;
|
||||
}
|
||||
},
|
||||
referenceImageIPAdapterCLIPVisionModelChanged: (
|
||||
|
||||
@@ -245,6 +245,18 @@ const zFLUXReduxConfig = z.object({
|
||||
});
|
||||
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
|
||||
|
||||
const zChatGPT4oReferenceImageConfig = z.object({
|
||||
type: z.literal('chatgpt_4o_reference_image'),
|
||||
image: zImageWithDims.nullable(),
|
||||
/**
|
||||
* TODO(psyche): Technically there is no model for ChatGPT 4o reference images - it's just a field in the API call.
|
||||
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
|
||||
* there will be no way to switch between ref image types.
|
||||
*/
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
});
|
||||
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
|
||||
|
||||
const zCanvasEntityBase = z.object({
|
||||
id: zId,
|
||||
name: zName,
|
||||
@@ -254,15 +266,19 @@ const zCanvasEntityBase = z.object({
|
||||
|
||||
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
|
||||
type: z.literal('reference_image'),
|
||||
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
|
||||
// This should be named `referenceImage` but we need to keep it as `ipAdapter` for backwards compatibility
|
||||
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
|
||||
});
|
||||
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
|
||||
|
||||
export const isIPAdapterConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is IPAdapterConfig =>
|
||||
export const isIPAdapterConfig = (config: CanvasReferenceImageState['ipAdapter']): config is IPAdapterConfig =>
|
||||
config.type === 'ip_adapter';
|
||||
|
||||
export const isFLUXReduxConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is FLUXReduxConfig =>
|
||||
export const isFLUXReduxConfig = (config: CanvasReferenceImageState['ipAdapter']): config is FLUXReduxConfig =>
|
||||
config.type === 'flux_redux';
|
||||
export const isChatGPT4oReferenceImageConfig = (
|
||||
config: CanvasReferenceImageState['ipAdapter']
|
||||
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
|
||||
|
||||
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
|
||||
export type FillStyle = z.infer<typeof zFillStyle>;
|
||||
|
||||
@@ -7,6 +7,7 @@ import type {
|
||||
CanvasRasterLayerState,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
ChatGPT4oReferenceImageConfig,
|
||||
ControlLoRAConfig,
|
||||
ControlNetConfig,
|
||||
FLUXReduxConfig,
|
||||
@@ -77,6 +78,11 @@ export const initialFLUXRedux: FLUXReduxConfig = {
|
||||
model: null,
|
||||
imageInfluence: 'highest',
|
||||
};
|
||||
export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
|
||||
type: 'chatgpt_4o_reference_image',
|
||||
image: null,
|
||||
model: null,
|
||||
};
|
||||
export const initialT2IAdapter: T2IAdapterConfig = {
|
||||
type: 't2i_adapter',
|
||||
model: null,
|
||||
|
||||
@@ -4,7 +4,9 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isChatGPT4oAspectRatioID } from 'features/controlLayers/store/types';
|
||||
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
CANVAS_OUTPUT_PREFIX,
|
||||
@@ -13,6 +15,7 @@ import {
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@@ -21,21 +24,40 @@ const log = logger('system');
|
||||
export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
|
||||
assert(
|
||||
generationMode === 'txt2img' || generationMode === 'img2img',
|
||||
t('toast.gptImageIncompatibleWithInpaintAndOutpaint')
|
||||
);
|
||||
assert(generationMode === 'txt2img', t('toast.chatGPT4oIncompatibleGenerationMode'));
|
||||
|
||||
log.debug({ generationMode }, 'Building GPT Image graph');
|
||||
|
||||
const model = selectMainModelConfig(state);
|
||||
|
||||
const canvas = selectCanvasSlice(state);
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
const { positivePrompt } = selectPresetModifiedPrompts(state);
|
||||
|
||||
assert(model, 'No model found in state');
|
||||
assert(model.base === 'chatgpt-4o', 'Model is not a FLUX model');
|
||||
|
||||
assert(isChatGPT4oAspectRatioID(bbox.aspectRatio.id), 'ChatGPT 4o does not support this aspect ratio');
|
||||
|
||||
const validRefImages = canvas.referenceImages.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => isChatGPT4oReferenceImageConfig(entity.ipAdapter))
|
||||
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
|
||||
|
||||
let reference_images: ImageField[] | undefined = undefined;
|
||||
|
||||
if (validRefImages.length > 0) {
|
||||
reference_images = [];
|
||||
for (const entity of validRefImages) {
|
||||
assert(entity.ipAdapter.image, 'Image is required for reference image');
|
||||
reference_images.push({
|
||||
image_name: entity.ipAdapter.image.image_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const is_intermediate = canvasSettings.sendToCanvas;
|
||||
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
|
||||
|
||||
@@ -47,6 +69,7 @@ export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManag
|
||||
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
|
||||
positive_prompt: positivePrompt,
|
||||
aspect_ratio: bbox.aspectRatio.id,
|
||||
reference_images,
|
||||
use_cache: false,
|
||||
is_intermediate,
|
||||
board,
|
||||
@@ -57,28 +80,5 @@ export const buildChatGPT4oGraph = async (state: RootState, manager: CanvasManag
|
||||
};
|
||||
}
|
||||
|
||||
if (generationMode === 'img2img') {
|
||||
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
const g = new Graph(getPrefixedId('chatgpt_4o_img2img_graph'));
|
||||
const gptImage = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'chatgpt_4o_edit_image',
|
||||
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
|
||||
positive_prompt: positivePrompt,
|
||||
image: { image_name },
|
||||
use_cache: false,
|
||||
is_intermediate,
|
||||
board,
|
||||
});
|
||||
return {
|
||||
g,
|
||||
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for gpt image');
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for ChatGPT 4o');
|
||||
};
|
||||
|
||||
@@ -58,29 +58,5 @@ export const buildImagen3Graph = async (state: RootState, manager: CanvasManager
|
||||
};
|
||||
}
|
||||
|
||||
if (generationMode === 'img2img') {
|
||||
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
const g = new Graph(getPrefixedId('imagen3_img2img_graph'));
|
||||
const imagen3 = g.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'google_imagen3_edit_image',
|
||||
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
base_image: { image_name },
|
||||
is_intermediate,
|
||||
board,
|
||||
});
|
||||
return {
|
||||
g,
|
||||
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
|
||||
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
|
||||
};
|
||||
}
|
||||
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for imagen3');
|
||||
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen3');
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isChatGPT4oModelConfig,
|
||||
isCLIPEmbedModelConfig,
|
||||
isCLIPVisionModelConfig,
|
||||
isCogView4MainModelModelConfig,
|
||||
@@ -81,7 +82,10 @@ export const useFluxVAEModels = (args?: ModelHookArgs) =>
|
||||
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
|
||||
export const useSigLipModels = buildModelsHook(isSigLipModelConfig);
|
||||
export const useFluxReduxModels = buildModelsHook(isFluxReduxModelConfig);
|
||||
export const useIPAdapterOrFLUXReduxModels = buildModelsHook(
|
||||
export const useGlobalReferenceImageModels = buildModelsHook(
|
||||
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config) || isChatGPT4oModelConfig(config)
|
||||
);
|
||||
export const useRegionalReferenceImageModels = buildModelsHook(
|
||||
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config)
|
||||
);
|
||||
export const useLLaVAModels = buildModelsHook(isLLaVAModelConfig);
|
||||
|
||||
@@ -65,7 +65,7 @@ export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
export type SigLipModelConfig = S['SigLIPConfig'];
|
||||
export type FLUXReduxModelConfig = S['FluxReduxConfig'];
|
||||
type ApiModelConfig = S['ApiModelConfig'];
|
||||
export type ApiModelConfig = S['ApiModelConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig | ApiModelConfig;
|
||||
export type AnyModelConfig =
|
||||
| ControlLoRAModelConfig
|
||||
@@ -228,6 +228,10 @@ export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FLUXRe
|
||||
return config.type === 'flux_redux';
|
||||
};
|
||||
|
||||
export const isChatGPT4oModelConfig = (config: AnyModelConfig): config is ApiModelConfig => {
|
||||
return config.type === 'main' && config.base === 'chatgpt-4o';
|
||||
};
|
||||
|
||||
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base !== 'sdxl-refiner';
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user