diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageModel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageModel.tsx index 8fb8beb065..275772c5d3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageModel.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageModel.tsx @@ -5,12 +5,18 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType'; - -type RefImageModelConfig = ReturnType[0][number]; +import type { + ChatGPT4oModelConfig, + FLUXKontextModelConfig, + FLUXReduxModelConfig, + IPAdapterModelConfig, +} from 'services/api/types'; type Props = { modelKey: string | null; - onChangeModel: (modelConfig: RefImageModelConfig) => void; + onChangeModel: ( + modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig + ) => void; }; export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => { @@ -20,7 +26,9 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => { const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]); const _onChangeModel = useCallback( - (modelConfig: RefImageModelConfig | null) => { + ( + modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig | null + ) => { if (!modelConfig) { return; } @@ -30,7 +38,7 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => { ); const getIsDisabled = useCallback( - (model: RefImageModelConfig): boolean => { + (model: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig): boolean => { const hasMainModel = Boolean(currentBaseModel); const hasSameBase = currentBaseModel === model.base; return !hasMainModel || !hasSameBase; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageSettings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageSettings.tsx index 540a9723e0..539ee47f89 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageSettings.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageSettings.tsx @@ -38,7 +38,13 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd'; import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { memo, useCallback, useMemo } from 'react'; -import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types'; +import type { + ChatGPT4oModelConfig, + FLUXKontextModelConfig, + FLUXReduxModelConfig, + ImageDTO, + IPAdapterModelConfig, +} from 'services/api/types'; import { RefImageImage } from './RefImageImage'; @@ -84,7 +90,7 @@ const RefImageSettingsContent = memo(() => { ); const onChangeModel = useCallback( - (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => { + (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig) => { dispatch(refImageModelChanged({ id, modelConfig })); }, [dispatch, id] diff --git a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts index 8c5a64a806..34525692cc 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts @@ -7,7 +7,13 @@ import { clamp } from 'es-toolkit/compat'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; -import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types'; +import type { + ChatGPT4oModelConfig, + FLUXKontextModelConfig, + FLUXReduxModelConfig, + ImageDTO, + IPAdapterModelConfig, +} from 'services/api/types'; import { assert } from 'tsafe'; import type { PartialDeep } from 'type-fest'; @@ -86,7 +92,9 @@ export const refImagesSlice = createSlice({ }, refImageModelChanged: ( state, - action: PayloadActionWithId<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null }> + action: PayloadActionWithId<{ + modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig | null; + }> ) => { const { id, modelConfig } = action.payload; const entity = selectRefImageEntity(state, id); diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index b7467724bc..7f1e992da3 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -69,6 +69,8 @@ export type SigLipModelConfig = S['SigLIPConfig']; export type FLUXReduxModelConfig = S['FluxReduxConfig']; export type ApiModelConfig = S['ApiModelConfig']; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig | ApiModelConfig; +export type FLUXKontextModelConfig = MainModelConfig; +export type ChatGPT4oModelConfig = ApiModelConfig; export type AnyModelConfig = | ControlLoRAModelConfig | LoRAModelConfig @@ -230,7 +232,7 @@ export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FLUXRe return config.type === 'flux_redux'; }; -export const isChatGPT4oModelConfig = (config: AnyModelConfig): config is ApiModelConfig => { +export const isChatGPT4oModelConfig = (config: AnyModelConfig): config is ChatGPT4oModelConfig => { return config.type === 'main' && config.base === 'chatgpt-4o'; }; @@ -246,7 +248,7 @@ export const isFluxKontextApiModelConfig = (config: AnyModelConfig): config is A return config.type === 'main' && config.base === 'flux-kontext'; }; -export const isFluxKontextModelConfig = (config: AnyModelConfig): config is MainModelConfig => { +export const isFluxKontextModelConfig = (config: AnyModelConfig): config is FLUXKontextModelConfig => { return config.type === 'main' && config.base === 'flux' && config.name?.toLowerCase().includes('kontext'); };