feat(ui): add simple mode filtering

This commit is contained in:
psychedelicious
2024-11-06 15:22:12 +10:00
committed by Kent Keirsey
parent 999809b4c7
commit 4d0837541b
7 changed files with 188 additions and 52 deletions

View File

@@ -1819,6 +1819,9 @@
"process": "Process",
"apply": "Apply",
"cancel": "Cancel",
"advanced": "Advanced",
"processingLayerWith": "Processing layer with the {{type}} filter.",
"forMoreControl": "For more control, click Advanced below.",
"spandrel_filter": {
"label": "Image-to-Image Model",
"description": "Run an image-to-image model on the selected layer.",

View File

@@ -5,7 +5,7 @@ import { atom } from 'nanostores';
* A fallback non-writable atom that always returns `false`, used when a nanostores atom is only conditionally available
* in a hook or component.
*/
// export const $false: ReadableAtom<boolean> = atom(false);
export const $false: ReadableAtom<boolean> = atom(false);
/**
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
* in a hook or component.

View File

@@ -6,6 +6,7 @@ import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginE
import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
@@ -16,6 +17,7 @@ import {
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { getFilterForModel } from 'features/controlLayers/store/filters';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
@@ -44,6 +46,7 @@ export const ControlLayerControlAdapter = memo(() => {
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const filter = useEntityFilter(entityIdentifier);
const isFLUX = useAppSelector(selectIsFLUX);
const adapter = useEntityAdapterContext('control_layer');
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
@@ -69,8 +72,43 @@ export const ControlLayerControlAdapter = memo(() => {
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
// filter config.
const isFiltering = adapter.filterer.$isFiltering.get();
const isSimple = adapter.filterer.$simple.get();
// If we are filtering and _not_ in simple mode, that means the user has clicked Advanced. They want to be in control
// of the settings. Bail early without doing anything else.
if (isFiltering && !isSimple) {
return;
}
// Else, we are in simple mode and will take care of some things for the user.
// First, check if the newly-selected model has a default filter. It may not - for example, Tile controlnet models
// don't have a default filter.
const defaultFilterForNewModel = getFilterForModel(modelConfig);
if (!defaultFilterForNewModel) {
// The user has chosen a model that doesn't have a default filter - cancel any in-progress filtering and bail.
if (isFiltering) {
adapter.filterer.cancel();
}
return;
}
// At this point, we know the user has selected a model that has a default filter. We need to either start filtering
// with that default filter, or update the existing filter config to match the new model's default filter.
const filterConfig = defaultFilterForNewModel.buildDefaults();
if (isFiltering) {
adapter.filterer.$filterConfig.set(filterConfig);
} else {
adapter.filterer.start(filterConfig);
}
// The user may have disabled auto-processing, so we should process the filter manually. This is essentially a
// no-op if auto-processing is already enabled, because the process method is debounced.
adapter.filterer.process();
},
[dispatch, entityIdentifier]
[adapter.filterer, dispatch, entityIdentifier]
);
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);

View File

@@ -9,6 +9,7 @@ import {
MenuList,
Spacer,
Spinner,
Text,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
@@ -28,13 +29,10 @@ import { memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
const FilterContent = memo(
const FilterContentAdvanced = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const config = useStore(adapter.filterer.$filterConfig);
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.filterer.$isProcessing);
const hasImageState = useStore(adapter.filterer.$hasImageState);
const autoProcess = useAppSelector(selectAutoProcess);
@@ -73,36 +71,8 @@ const FilterContent = memo(
adapter.filterer.saveAs('control_layer');
}, [adapter.filterer]);
useRegisteredHotkeys({
id: 'applyFilter',
category: 'canvas',
callback: adapter.filterer.apply,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelFilter',
category: 'canvas',
callback: adapter.filterer.cancel,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
w={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
<>
<Flex w="full" gap={4}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.filter.filter')}
@@ -169,12 +139,67 @@ const FilterContent = memo(
{t('controlLayers.filter.cancel')}
</Button>
</ButtonGroup>
</Flex>
</>
);
}
);
FilterContent.displayName = 'FilterContent';
FilterContentAdvanced.displayName = 'FilterContentAdvanced';
const FilterContentSimple = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const config = useStore(adapter.filterer.$filterConfig);
const isProcessing = useStore(adapter.filterer.$isProcessing);
const hasImageState = useStore(adapter.filterer.$hasImageState);
const isValid = useMemo(() => {
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
}, [config]);
const onClickAdvanced = useCallback(() => {
adapter.filterer.$simple.set(false);
}, [adapter.filterer.$simple]);
return (
<>
<Flex w="full" gap={4}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.filter.filter')}
</Heading>
<Spacer />
</Flex>
<Flex flexDir="column" w="full" gap={2} pb={2}>
<Text color="base.500" textAlign="center">
{t('controlLayers.filter.processingLayerWith', { type: t(`controlLayers.filter.${config.type}.label`) })}
</Text>
<Text color="base.500" textAlign="center">
{t('controlLayers.filter.forMoreControl')}
</Text>
</Flex>
<ButtonGroup isAttached={false} size="sm" w="full">
<Button variant="ghost" onClick={onClickAdvanced}>
{t('controlLayers.filter.advanced')}
</Button>
<Spacer />
<Button
onClick={adapter.filterer.apply}
loadingText={t('controlLayers.filter.apply')}
variant="ghost"
isDisabled={isProcessing || !isValid || !hasImageState}
>
{t('controlLayers.filter.apply')}
</Button>
<Button variant="ghost" onClick={adapter.filterer.cancel} loadingText={t('controlLayers.filter.cancel')}>
{t('controlLayers.filter.cancel')}
</Button>
</ButtonGroup>
</>
);
}
);
FilterContentSimple.displayName = 'FilterContentSimple';
export const Filter = () => {
const canvasManager = useCanvasManager();
@@ -182,8 +207,54 @@ export const Filter = () => {
if (!adapter) {
return null;
}
return <FilterContent adapter={adapter} />;
};
Filter.displayName = 'Filter';
const FilterContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const simplified = useStore(adapter.filterer.$simple);
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.filterer.$isProcessing);
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
useRegisteredHotkeys({
id: 'applyFilter',
category: 'canvas',
callback: adapter.filterer.apply,
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelFilter',
category: 'canvas',
callback: adapter.filterer.cancel,
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
w={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
{simplified && <FilterContentSimple adapter={adapter} />}
{!simplified && <FilterContentAdvanced adapter={adapter} />}
</Flex>
);
}
);
FilterContent.displayName = 'FilterContent';

View File

@@ -4,9 +4,13 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type { CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
import type {
CanvasEntityIdentifier,
CanvasRenderableEntityType,
} from 'features/controlLayers/store/types';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useMemo, useSyncExternalStore } from 'react';
import { createContext, memo, useContext, useMemo, useSyncExternalStore } from 'react';
import { assert } from 'tsafe';
const EntityAdapterContext = createContext<
@@ -95,6 +99,17 @@ export const RegionalGuidanceAdapterGate = memo(({ children }: PropsWithChildren
return <EntityAdapterContext.Provider value={adapter}>{children}</EntityAdapterContext.Provider>;
});
export const useEntityAdapterContext = <T extends CanvasRenderableEntityType | undefined = CanvasRenderableEntityType>(
type?: T
): CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T> => {
const adapter = useContext(EntityAdapterContext);
assert(adapter, 'useEntityIdentifier must be used within a EntityIdentifierProvider');
if (type) {
assert(adapter.entityIdentifier.type === type, 'useEntityIdentifier must be used with the correct type');
}
return adapter as CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T>;
};
RegionalGuidanceAdapterGate.displayName = 'RegionalGuidanceAdapterGate';
export const useEntityAdapterSafe = (

View File

@@ -83,6 +83,13 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
* Whether the module has an image state. This is a computed value based on $imageState.
*/
$hasImageState = computed(this.$imageState, (imageState) => imageState !== null);
/**
* Whether the filter is in simple mode. In simple mode, the filter is started with a default filter config and the
* user is not presented with filter settings.
*/
$simple = atom<boolean>(false);
/**
* The filtered image object module, if it exists.
*/
@@ -147,7 +154,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
/**
* Starts the filter module.
* @param config The filter config to start with. If omitted, the default filter config is used.
* @param config The filter config to use. If omitted, the default filter config is used.
*/
start = (config?: FilterConfig) => {
const filteringAdapter = this.manager.stateApi.$filteringAdapter.get();
@@ -174,12 +181,14 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
// If a config is provided, use it
this.$filterConfig.set(config);
this.$initialFilterConfig.set(config);
this.$simple.set(true);
} else {
this.$filterConfig.set(this.createInitialFilterConfig());
const initialConfig = this.createInitialFilterConfig();
this.$filterConfig.set(initialConfig);
this.$initialFilterConfig.set(initialConfig);
this.$simple.set(false);
}
this.$initialFilterConfig.set(this.$filterConfig.get());
this.subscribe();
this.manager.stateApi.$filteringAdapter.set(this.parent);
@@ -198,7 +207,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
);
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
// This always returns a filter
const filter = getFilterForModel(modelConfig);
const filter = getFilterForModel(modelConfig) ?? IMAGE_FILTERS.canny_edge_detection;
return filter.buildDefaults();
} else {
// Otherwise, used the default filter
@@ -404,7 +413,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.imageModule.destroy();
this.imageModule = null;
}
const initialFilterConfig = this.$initialFilterConfig.get() ?? this.createInitialFilterConfig();
const initialFilterConfig = deepClone(this.$initialFilterConfig.get() ?? this.createInitialFilterConfig());
this.$filterConfig.set(initialFilterConfig);
this.$imageState.set(null);
this.$lastProcessedHash.set('');

View File

@@ -456,14 +456,14 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
*/
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
if (!modelConfig) {
// No model, use the default filter
return IMAGE_FILTERS.canny_edge_detection;
// No model
return null;
}
const preprocessor = modelConfig?.default_settings?.preprocessor;
if (!preprocessor) {
// No preprocessor, use the default filter
return IMAGE_FILTERS.canny_edge_detection;
// No preprocessor
return null;
}
if (isFilterType(preprocessor)) {
@@ -473,8 +473,8 @@ export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapte
const filterName = PROCESSOR_TO_FILTER_MAP[preprocessor];
if (!filterName) {
// No filter found, use the default filter
return IMAGE_FILTERS.canny_edge_detection;
// No filter found
return null;
}
// Found a filter, use it