mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-14 01:14:56 -05:00
fix(ui): flux kontext special handlign for ref image models
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import { areBasesCompatibleForRefImage } from 'features/controlLayers/store/validators';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -21,7 +22,7 @@ type Props = {
|
||||
|
||||
export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const mainModelConfig = useAppSelector(selectMainModelConfig);
|
||||
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
|
||||
@@ -39,11 +40,9 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig): boolean => {
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
const hasSameBase = currentBaseModel === model.base;
|
||||
return !hasMainModel || !hasSameBase;
|
||||
return !areBasesCompatibleForRefImage(mainModelConfig, model);
|
||||
},
|
||||
[currentBaseModel]
|
||||
[mainModelConfig]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
@@ -56,7 +55,11 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
|
||||
|
||||
return (
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full" minW={0}>
|
||||
<FormControl
|
||||
isInvalid={!value || !areBasesCompatibleForRefImage(mainModelConfig, selectedModel)}
|
||||
w="full"
|
||||
minW={0}
|
||||
>
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
|
||||
@@ -5,12 +5,14 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { round } from 'es-toolkit/compat';
|
||||
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
|
||||
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
|
||||
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
refImageSelected,
|
||||
selectIsRefImagePanelOpen,
|
||||
selectSelectedRefEntityId,
|
||||
} from 'features/controlLayers/store/refImagesSlice';
|
||||
import { isIPAdapterConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
@@ -20,7 +22,17 @@ const baseSx: SystemStyleObject = {
|
||||
borderColor: 'invokeBlue.300',
|
||||
},
|
||||
'&[data-is-disabled="true"]': {
|
||||
opacity: 0.4,
|
||||
img: {
|
||||
opacity: 0.4,
|
||||
filter: 'grayscale(100%)',
|
||||
},
|
||||
},
|
||||
'&[data-is-error="true"]': {
|
||||
borderColor: 'error.500',
|
||||
img: {
|
||||
opacity: 0.4,
|
||||
filter: 'grayscale(100%)',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -57,6 +69,7 @@ export const RefImagePreview = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const id = useRefImageIdContext();
|
||||
const entity = useRefImageEntity(id);
|
||||
const mainModelConfig = useAppSelector(selectMainModelConfig);
|
||||
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
|
||||
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
|
||||
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
|
||||
@@ -82,6 +95,10 @@ export const RefImagePreview = memo(() => {
|
||||
};
|
||||
}, [entity.config]);
|
||||
|
||||
const isInvalid = useMemo(() => {
|
||||
return getGlobalReferenceImageWarnings(entity, mainModelConfig).length > 0;
|
||||
}, [entity, mainModelConfig]);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(refImageSelected({ id }));
|
||||
}, [dispatch, id]);
|
||||
@@ -120,7 +137,7 @@ export const RefImagePreview = memo(() => {
|
||||
flexShrink={0}
|
||||
sx={sx}
|
||||
data-is-open={selectedEntityId === id && isPanelOpen}
|
||||
data-is-error={!entity.config.model}
|
||||
data-is-error={isInvalid}
|
||||
data-is-disabled={!entity.isEnabled}
|
||||
role="button"
|
||||
onClick={onClick}
|
||||
@@ -152,18 +169,19 @@ export const RefImagePreview = memo(() => {
|
||||
</Text>
|
||||
</Flex>
|
||||
)}
|
||||
{!entity.isEnabled ? (
|
||||
{!entity.isEnabled && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="base.100"
|
||||
boxSize={6}
|
||||
color="base.300"
|
||||
boxSize={8}
|
||||
as={PiEyeSlashBold}
|
||||
/>
|
||||
) : !entity.config.model ? (
|
||||
)}
|
||||
{entity.isEnabled && isInvalid && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
@@ -171,10 +189,10 @@ export const RefImagePreview = memo(() => {
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="error.500"
|
||||
boxSize={6}
|
||||
boxSize={12}
|
||||
as={PiExclamationMarkBold}
|
||||
/>
|
||||
) : null}
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -5,7 +5,8 @@ import type {
|
||||
CanvasRegionalGuidanceState,
|
||||
RefImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { AnyModelConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
const WARNINGS = {
|
||||
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
|
||||
@@ -77,6 +78,27 @@ export const getRegionalGuidanceWarnings = (
|
||||
return warnings;
|
||||
};
|
||||
|
||||
export const areBasesCompatibleForRefImage = (
|
||||
first?: ModelIdentifierField | AnyModelConfig | null,
|
||||
second?: ModelIdentifierField | AnyModelConfig | null
|
||||
): boolean => {
|
||||
if (!first || !second) {
|
||||
return false;
|
||||
}
|
||||
if (first.base !== second.base) {
|
||||
return false;
|
||||
}
|
||||
if (
|
||||
first.base === 'flux' &&
|
||||
(first.name.toLowerCase().includes('kontext') || second.name.toLowerCase().includes('kontext')) &&
|
||||
first.key !== second.key
|
||||
) {
|
||||
// FLUX Kontext requires the main model and the reference image model to be the same model
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
export const getGlobalReferenceImageWarnings = (
|
||||
entity: RefImageState,
|
||||
model: MainModelConfig | null | undefined
|
||||
@@ -95,7 +117,7 @@ export const getGlobalReferenceImageWarnings = (
|
||||
if (!config.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (config.model.base !== model.base) {
|
||||
} else if (!areBasesCompatibleForRefImage(config.model, model)) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user