diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageModel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageModel.tsx index 275772c5d3..e87cff00a1 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageModel.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageModel.tsx @@ -1,7 +1,8 @@ import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; -import { selectBase } from 'features/controlLayers/store/paramsSlice'; +import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice'; +import { areBasesCompatibleForRefImage } from 'features/controlLayers/store/validators'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType'; @@ -21,7 +22,7 @@ type Props = { export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => { const { t } = useTranslation(); - const currentBaseModel = useAppSelector(selectBase); + const mainModelConfig = useAppSelector(selectMainModelConfig); const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels(); const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]); @@ -39,11 +40,9 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => { const getIsDisabled = useCallback( (model: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig): boolean => { - const hasMainModel = Boolean(currentBaseModel); - const hasSameBase = currentBaseModel === model.base; - return !hasMainModel || !hasSameBase; + return !areBasesCompatibleForRefImage(mainModelConfig, model); }, - [currentBaseModel] + [mainModelConfig] ); const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({ @@ -56,7 +55,11 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => { return ( - + { const dispatch = useAppDispatch(); const id = useRefImageIdContext(); const entity = useRefImageEntity(id); + const mainModelConfig = useAppSelector(selectMainModelConfig); const selectedEntityId = useAppSelector(selectSelectedRefEntityId); const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen); const [showWeightDisplay, setShowWeightDisplay] = useState(false); @@ -82,6 +95,10 @@ export const RefImagePreview = memo(() => { }; }, [entity.config]); + const isInvalid = useMemo(() => { + return getGlobalReferenceImageWarnings(entity, mainModelConfig).length > 0; + }, [entity, mainModelConfig]); + const onClick = useCallback(() => { dispatch(refImageSelected({ id })); }, [dispatch, id]); @@ -120,7 +137,7 @@ export const RefImagePreview = memo(() => { flexShrink={0} sx={sx} data-is-open={selectedEntityId === id && isPanelOpen} - data-is-error={!entity.config.model} + data-is-error={isInvalid} data-is-disabled={!entity.isEnabled} role="button" onClick={onClick} @@ -152,18 +169,19 @@ export const RefImagePreview = memo(() => { )} - {!entity.isEnabled ? ( + {!entity.isEnabled && ( - ) : !entity.config.model ? ( + )} + {entity.isEnabled && isInvalid && ( { transform="translateX(-50%) translateY(-50%)" filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))" color="error.500" - boxSize={6} + boxSize={12} as={PiExclamationMarkBold} /> - ) : null} + )} ); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts index 534c9a337a..03ef5404a6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -5,7 +5,8 @@ import type { CanvasRegionalGuidanceState, RefImageState, } from 'features/controlLayers/store/types'; -import type { MainModelConfig } from 'services/api/types'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; +import type { AnyModelConfig, MainModelConfig } from 'services/api/types'; const WARNINGS = { UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel', @@ -77,6 +78,27 @@ export const getRegionalGuidanceWarnings = ( return warnings; }; +export const areBasesCompatibleForRefImage = ( + first?: ModelIdentifierField | AnyModelConfig | null, + second?: ModelIdentifierField | AnyModelConfig | null +): boolean => { + if (!first || !second) { + return false; + } + if (first.base !== second.base) { + return false; + } + if ( + first.base === 'flux' && + (first.name.toLowerCase().includes('kontext') || second.name.toLowerCase().includes('kontext')) && + first.key !== second.key + ) { + // FLUX Kontext requires the main model and the reference image model to be the same model + return false; + } + return true; +}; + export const getGlobalReferenceImageWarnings = ( entity: RefImageState, model: MainModelConfig | null | undefined @@ -95,7 +117,7 @@ export const getGlobalReferenceImageWarnings = ( if (!config.model) { // No model selected warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED); - } else if (config.model.base !== model.base) { + } else if (!areBasesCompatibleForRefImage(config.model, model)) { // Supported model architecture but doesn't match warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL); }