mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 23:08:05 -05:00
feat(ui): add IP adapters to FLUX in linear UI
This commit is contained in:
@@ -81,7 +81,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G", "ViT-L"] = InputField(
|
||||
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||
default="ViT-H",
|
||||
ui_order=2,
|
||||
|
||||
@@ -34,7 +34,6 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addGlobalReferenceImage}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</Button>
|
||||
|
||||
@@ -40,7 +40,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuGroup title={t('controlLayers.global')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isFLUX}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage}>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBase, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
|
||||
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -11,9 +11,13 @@ import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
|
||||
const FLUX_CLIP_VISION = 'ViT-L';
|
||||
|
||||
const CLIP_VISION_OPTIONS = [
|
||||
{ label: 'ViT-H', value: 'ViT-H' },
|
||||
{ label: 'ViT-G', value: 'ViT-G' },
|
||||
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
|
||||
];
|
||||
|
||||
type Props = {
|
||||
@@ -47,6 +51,8 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
[onChangeCLIPVisionModel]
|
||||
);
|
||||
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base;
|
||||
@@ -64,10 +70,20 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const clipVisionModelValue = useMemo(
|
||||
() => CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel),
|
||||
[clipVisionModel]
|
||||
);
|
||||
const clipVisionOptions = useMemo(() => {
|
||||
if (isFLUX) {
|
||||
return CLIP_VISION_OPTIONS.map((option) => ({ ...option, isDisabled: option.value !== FLUX_CLIP_VISION }));
|
||||
} else {
|
||||
return CLIP_VISION_OPTIONS;
|
||||
}
|
||||
}, [isFLUX]);
|
||||
|
||||
const clipVisionModelValue = useMemo(() => {
|
||||
if (isFLUX) {
|
||||
return CLIP_VISION_OPTIONS.find((o) => o.value === FLUX_CLIP_VISION);
|
||||
}
|
||||
return CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel);
|
||||
}, [clipVisionModel, isFLUX]);
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
@@ -85,7 +101,7 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
{selectedModel?.format === 'checkpoint' && (
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
|
||||
<Combobox
|
||||
options={CLIP_VISION_OPTIONS}
|
||||
options={clipVisionOptions}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
value={clipVisionModelValue}
|
||||
onChange={_onChangeCLIPVisionModel}
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
referenceImageIPAdapterModelChanged,
|
||||
referenceImageIPAdapterWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
|
||||
import type { IPAImageDropData } from 'features/dnd/types';
|
||||
@@ -90,6 +91,8 @@ export const IPAdapterSettings = memo(() => {
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
return (
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
@@ -113,7 +116,7 @@ export const IPAdapterSettings = memo(() => {
|
||||
</Flex>
|
||||
<Flex gap={2} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
|
||||
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
|
||||
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
|
||||
@@ -46,7 +46,7 @@ const zControlModeV2 = z.enum(['balanced', 'more_prompt', 'more_control', 'unbal
|
||||
export type ControlModeV2 = z.infer<typeof zControlModeV2>;
|
||||
export const isControlModeV2 = (v: unknown): v is ControlModeV2 => zControlModeV2.safeParse(v).success;
|
||||
|
||||
const zCLIPVisionModelV2 = z.enum(['ViT-H', 'ViT-G']);
|
||||
const zCLIPVisionModelV2 = z.enum(['ViT-H', 'ViT-G', 'ViT-L']);
|
||||
export type CLIPVisionModelV2 = z.infer<typeof zCLIPVisionModelV2>;
|
||||
export const isCLIPVisionModelV2 = (v: unknown): v is CLIPVisionModelV2 => zCLIPVisionModelV2.safeParse(v).success;
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { addControlNets } from './addControlAdapters';
|
||||
import { addIPAdapters } from './addIPAdapters';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
@@ -198,6 +199,40 @@ export const buildFLUXGraph = async (
|
||||
g.deleteNode(controlNetCollector.id);
|
||||
}
|
||||
|
||||
const ipAdapterCollector = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
|
||||
|
||||
const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
|
||||
if (totalIPAdaptersAdded > 0) {
|
||||
assert(steps > 2);
|
||||
const cfg_scale_start_step = 1;
|
||||
const cfg_scale_end_step = Math.ceil(steps / 2);
|
||||
assert(cfg_scale_end_step > cfg_scale_start_step);
|
||||
|
||||
const negCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('flux_text_encoder'),
|
||||
prompt: '',
|
||||
});
|
||||
|
||||
g.addEdge(modelLoader, 'clip', negCond, 'clip');
|
||||
g.addEdge(modelLoader, 't5_encoder', negCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 'max_seq_len', negCond, 't5_max_seq_len');
|
||||
g.addEdge(negCond, 'conditioning', noise, 'negative_text_conditioning');
|
||||
|
||||
g.updateNode(noise, {
|
||||
cfg_scale: 3,
|
||||
cfg_scale_start_step,
|
||||
cfg_scale_end_step,
|
||||
});
|
||||
g.addEdge(ipAdapterCollector, 'collection', noise, 'ip_adapter');
|
||||
} else {
|
||||
g.deleteNode(ipAdapterCollector.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
canvasOutput = addNSFWChecker(g, canvasOutput);
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user