fix(ui): flux kontext special handlign for ref image models

This commit is contained in:
psychedelicious
2025-07-07 18:42:41 +10:00
parent 2e8db3cce3
commit 702cb2cb1e
3 changed files with 60 additions and 17 deletions

View File

@@ -1,7 +1,8 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { areBasesCompatibleForRefImage } from 'features/controlLayers/store/validators';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
@@ -21,7 +22,7 @@ type Props = {
export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
@@ -39,11 +40,9 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const getIsDisabled = useCallback(
(model: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig): boolean => {
const hasMainModel = Boolean(currentBaseModel);
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
return !areBasesCompatibleForRefImage(mainModelConfig, model);
},
[currentBaseModel]
[mainModelConfig]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
@@ -56,7 +55,11 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
return (
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full" minW={0}>
<FormControl
isInvalid={!value || !areBasesCompatibleForRefImage(mainModelConfig, selectedModel)}
w="full"
minW={0}
>
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}

View File

@@ -5,12 +5,14 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { round } from 'es-toolkit/compat';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import {
refImageSelected,
selectIsRefImagePanelOpen,
selectSelectedRefEntityId,
} from 'features/controlLayers/store/refImagesSlice';
import { isIPAdapterConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
@@ -20,7 +22,17 @@ const baseSx: SystemStyleObject = {
borderColor: 'invokeBlue.300',
},
'&[data-is-disabled="true"]': {
opacity: 0.4,
img: {
opacity: 0.4,
filter: 'grayscale(100%)',
},
},
'&[data-is-error="true"]': {
borderColor: 'error.500',
img: {
opacity: 0.4,
filter: 'grayscale(100%)',
},
},
};
@@ -57,6 +69,7 @@ export const RefImagePreview = memo(() => {
const dispatch = useAppDispatch();
const id = useRefImageIdContext();
const entity = useRefImageEntity(id);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
@@ -82,6 +95,10 @@ export const RefImagePreview = memo(() => {
};
}, [entity.config]);
const isInvalid = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig).length > 0;
}, [entity, mainModelConfig]);
const onClick = useCallback(() => {
dispatch(refImageSelected({ id }));
}, [dispatch, id]);
@@ -120,7 +137,7 @@ export const RefImagePreview = memo(() => {
flexShrink={0}
sx={sx}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={!entity.config.model}
data-is-error={isInvalid}
data-is-disabled={!entity.isEnabled}
role="button"
onClick={onClick}
@@ -152,18 +169,19 @@ export const RefImagePreview = memo(() => {
</Text>
</Flex>
)}
{!entity.isEnabled ? (
{!entity.isEnabled && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="base.100"
boxSize={6}
color="base.300"
boxSize={8}
as={PiEyeSlashBold}
/>
) : !entity.config.model ? (
)}
{entity.isEnabled && isInvalid && (
<Icon
position="absolute"
top="50%"
@@ -171,10 +189,10 @@ export const RefImagePreview = memo(() => {
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="error.500"
boxSize={6}
boxSize={12}
as={PiExclamationMarkBold}
/>
) : null}
)}
</Flex>
);
});

View File

@@ -5,7 +5,8 @@ import type {
CanvasRegionalGuidanceState,
RefImageState,
} from 'features/controlLayers/store/types';
import type { MainModelConfig } from 'services/api/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { AnyModelConfig, MainModelConfig } from 'services/api/types';
const WARNINGS = {
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
@@ -77,6 +78,27 @@ export const getRegionalGuidanceWarnings = (
return warnings;
};
export const areBasesCompatibleForRefImage = (
first?: ModelIdentifierField | AnyModelConfig | null,
second?: ModelIdentifierField | AnyModelConfig | null
): boolean => {
if (!first || !second) {
return false;
}
if (first.base !== second.base) {
return false;
}
if (
first.base === 'flux' &&
(first.name.toLowerCase().includes('kontext') || second.name.toLowerCase().includes('kontext')) &&
first.key !== second.key
) {
// FLUX Kontext requires the main model and the reference image model to be the same model
return false;
}
return true;
};
export const getGlobalReferenceImageWarnings = (
entity: RefImageState,
model: MainModelConfig | null | undefined
@@ -95,7 +117,7 @@ export const getGlobalReferenceImageWarnings = (
if (!config.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (config.model.base !== model.base) {
} else if (!areBasesCompatibleForRefImage(config.model, model)) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}