feat(ui): refactor filter module

- Each entity gets its own `CanvasEntityFilterer`
- Add auto-preview feature to filter, debounced by 1000ms leading + trailing
- Fix flash when preview updates
This commit is contained in:
psychedelicious
2024-09-07 18:14:51 +10:00
parent 0abd81ac80
commit 670e054fe0
18 changed files with 304 additions and 240 deletions

View File

@@ -1656,6 +1656,7 @@
"storeNotInitialized": "Store is not initialized"
},
"controlLayers": {
"autoPreviewFilter": "Auto Preview",
"bookmark": "Bookmark for Quick Switch",
"fitBboxToLayers": "Fit Bbox To Layers",
"removeBookmark": "Remove Bookmark",

View File

@@ -23,8 +23,11 @@ export const EntityListSelectedEntityActionBarFilterButton = memo(() => {
if (!isFilterableEntityIdentifier(selectedEntityIdentifier)) {
return;
}
canvasManager.filter.startFilter(selectedEntityIdentifier);
const adapter = canvasManager.getAdapter(selectedEntityIdentifier);
if (!adapter) {
return;
}
adapter.filterer.startFilter();
}, [canvasManager, selectedEntityIdentifier]);
if (!selectedEntityIdentifier) {

View File

@@ -4,12 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import {
IMAGE_FILTERS,
isControlLayerEntityIdentifier,
isFilterType,
isRasterLayerEntityIdentifier,
} from 'features/controlLayers/store/types';
import { IMAGE_FILTERS, isFilterableEntityIdentifier, isFilterType } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
@@ -43,26 +38,25 @@ export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onCha
}
// Open the filter popup by setting this entity as the filtering entity
if (!canvasManager.filter.$adapter.get()) {
// Can only filter raster and control layers
if (!isRasterLayerEntityIdentifier(entityIdentifier) && !isControlLayerEntityIdentifier(entityIdentifier)) {
if (!canvasManager.stateApi.$isFiltering.get()) {
if (!isFilterableEntityIdentifier(entityIdentifier)) {
return;
}
const adapter = canvasManager.getAdapter(entityIdentifier);
if (!adapter) {
return;
}
// Update the filter, preferring the model's default
if (isFilterType(modelConfig.default_settings?.preprocessor)) {
canvasManager.filter.$config.set(
IMAGE_FILTERS[modelConfig.default_settings.preprocessor].buildDefaults(modelConfig.base)
);
} else {
canvasManager.filter.$config.set(IMAGE_FILTERS.canny_image_processor.buildDefaults(modelConfig.base));
}
const filterConfig = isFilterType(modelConfig.default_settings?.preprocessor)
? IMAGE_FILTERS[modelConfig.default_settings.preprocessor].buildDefaults(modelConfig.base)
: IMAGE_FILTERS.canny_image_processor.buildDefaults(modelConfig.base);
canvasManager.filter.startFilter(entityIdentifier);
canvasManager.filter.previewFilter();
adapter.filterer.startFilter(filterConfig);
adapter.filterer.previewFilter();
}
},
[canvasManager.filter, entityIdentifier, modelKey, onChangeModel]
[canvasManager, entityIdentifier, modelKey, onChangeModel]
);
const getIsDisabled = useCallback(

View File

@@ -1,37 +1,44 @@
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Heading, Spacer, Switch } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { FilterSettings } from 'features/controlLayers/components/Filters/FilterSettings';
import { FilterTypeSelect } from 'features/controlLayers/components/Filters/FilterTypeSelect';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import {
selectAutoPreviewFilter,
settingsAutoPreviewFilterToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { type FilterConfig, IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCheckBold, PiShootingStarBold, PiXBold } from 'react-icons/pi';
export const Filter = memo(() => {
const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const canvasManager = useCanvasManager();
const config = useStore(canvasManager.filter.$config);
const isFiltering = useStore(canvasManager.filter.$isFiltering);
const isProcessing = useStore(canvasManager.filter.$isProcessing);
const dispatch = useAppDispatch();
const config = useStore(adapter.filterer.$filterConfig);
const isProcessing = useStore(adapter.filterer.$isProcessing);
const autoPreviewFilter = useAppSelector(selectAutoPreviewFilter);
const onChangeFilterConfig = useCallback(
(filterConfig: FilterConfig) => {
canvasManager.filter.$config.set(filterConfig);
adapter.filterer.$filterConfig.set(filterConfig);
},
[canvasManager.filter.$config]
[adapter.filterer.$filterConfig]
);
const onChangeFilterType = useCallback(
(filterType: FilterConfig['type']) => {
canvasManager.filter.$config.set(IMAGE_FILTERS[filterType].buildDefaults());
adapter.filterer.$filterConfig.set(IMAGE_FILTERS[filterType].buildDefaults());
},
[canvasManager.filter.$config]
[adapter.filterer.$filterConfig]
);
if (!isFiltering) {
return null;
}
const onChangeAutoPreviewFilter = useCallback(() => {
dispatch(settingsAutoPreviewFilterToggled());
}, [dispatch]);
return (
<Flex
@@ -46,16 +53,23 @@ export const Filter = memo(() => {
transitionProperty="height"
transitionDuration="normal"
>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.filter.filter')}
</Heading>
<Flex w="full">
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.filter.filter')}
</Heading>
<Spacer />
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.autoPreviewFilter')}</FormLabel>
<Switch size="sm" isChecked={autoPreviewFilter} onChange={onChangeAutoPreviewFilter} />
</FormControl>
</Flex>
<FilterTypeSelect filterType={config.type} onChange={onChangeFilterType} />
<FilterSettings filterConfig={config} onChange={onChangeFilterConfig} />
<ButtonGroup isAttached={false} size="sm" w="full">
<Button
variant="ghost"
leftIcon={<PiShootingStarBold />}
onClick={canvasManager.filter.previewFilter}
onClick={adapter.filterer.previewFilter}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.preview')}
>
@@ -65,7 +79,7 @@ export const Filter = memo(() => {
<Button
variant="ghost"
leftIcon={<PiCheckBold />}
onClick={canvasManager.filter.applyFilter}
onClick={adapter.filterer.applyFilter}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.apply')}
>
@@ -74,7 +88,7 @@ export const Filter = memo(() => {
<Button
variant="ghost"
leftIcon={<PiXBold />}
onClick={canvasManager.filter.cancelFilter}
onClick={adapter.filterer.cancelFilter}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.cancel')}
>
@@ -85,4 +99,16 @@ export const Filter = memo(() => {
);
});
FilterBox.displayName = 'FilterBox';
export const Filter = () => {
const canvasManager = useCanvasManager();
const adapter = useStore(canvasManager.stateApi.$filteringAdapter);
if (!adapter) {
return null;
}
return <FilterBox adapter={adapter} />;
};
Filter.displayName = 'Filter';

View File

@@ -23,8 +23,12 @@ export const CanvasEntityMenuItemsFilter = memo(() => {
if (!isFilterableEntityIdentifier(entityIdentifier)) {
return;
}
canvasManager.filter.startFilter(entityIdentifier);
}, [canvasManager.filter, entityIdentifier]);
const adapter = canvasManager.getAdapter(entityIdentifier);
if (!adapter) {
return;
}
adapter.filterer.startFilter();
}, [canvasManager, entityIdentifier]);
return (
<MenuItem onClick={onClick} icon={<PiShootingStarBold />} isDisabled={isBusy || isStaging}>

View File

@@ -1,6 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntityFilterer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import type { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -25,15 +26,23 @@ export abstract class CanvasEntityAdapterBase<
readonly entityIdentifier: CanvasEntityIdentifier<T['type']>;
/**
* The transformer for this entity adapter.
* The transformer for this entity adapter. All entities must have a transformer.
*/
abstract transformer: CanvasEntityTransformer;
/**
* The renderer for this entity adapter.
* The renderer for this entity adapter. All entities must have a renderer.
*/
abstract renderer: CanvasEntityObjectRenderer;
/**
* The filterer for this entity adapter. Entities that support filtering should implement this property.
*/
// TODO(psyche): This is in the ABC and not in the concrete classes to allow all adapters to share the `destroy`
// method. If it wasn't in this ABC, we'd get a TS error in `destroy`. Maybe there's a better way to handle this
// without requiring all adapters to implement this property and their own `destroy`?
abstract filterer?: CanvasEntityFilterer;
/**
* Synchronizes the entity state with the canvas. This includes rendering the entity's objects, handling visibility,
* positioning, opacity, locked state, and any other properties.
@@ -201,8 +210,8 @@ export abstract class CanvasEntityAdapterBase<
this.transformer.stopTransform();
}
this.transformer.destroy();
if (this.manager.filter.$adapter.get()?.id === this.id) {
this.manager.filter.cancelFilter();
if (this.filterer?.$isFiltering.get()) {
this.filterer.cancelFilter();
}
this.konva.layer.destroy();
this.manager.deleteAdapter(this.entityIdentifier);

View File

@@ -1,5 +1,6 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterBase';
import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntityFilterer';
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -12,12 +13,14 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<Can
transformer: CanvasEntityTransformer;
renderer: CanvasEntityObjectRenderer;
filterer: CanvasEntityFilterer;
constructor(entityIdentifier: CanvasEntityIdentifier<'control_layer'>, manager: CanvasManager) {
super(entityIdentifier, manager, CanvasEntityAdapterControlLayer.TYPE);
this.renderer = new CanvasEntityObjectRenderer(this);
this.transformer = new CanvasEntityTransformer(this);
this.filterer = new CanvasEntityFilterer(this);
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectState, this.sync));
}

View File

@@ -12,6 +12,7 @@ export class CanvasEntityAdapterInpaintMask extends CanvasEntityAdapterBase<Canv
transformer: CanvasEntityTransformer;
renderer: CanvasEntityObjectRenderer;
filterer = undefined;
constructor(entityIdentifier: CanvasEntityIdentifier<'inpaint_mask'>, manager: CanvasManager) {
super(entityIdentifier, manager, CanvasEntityAdapterInpaintMask.TYPE);

View File

@@ -1,5 +1,6 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterBase';
import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntityFilterer';
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -12,12 +13,14 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<Canv
transformer: CanvasEntityTransformer;
renderer: CanvasEntityObjectRenderer;
filterer: CanvasEntityFilterer;
constructor(entityIdentifier: CanvasEntityIdentifier<'raster_layer'>, manager: CanvasManager) {
super(entityIdentifier, manager, CanvasEntityAdapterRasterLayer.TYPE);
this.renderer = new CanvasEntityObjectRenderer(this);
this.transformer = new CanvasEntityTransformer(this);
this.filterer = new CanvasEntityFilterer(this);
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectState, this.sync));
}

View File

@@ -12,6 +12,7 @@ export class CanvasEntityAdapterRegionalGuidance extends CanvasEntityAdapterBase
transformer: CanvasEntityTransformer;
renderer: CanvasEntityObjectRenderer;
filterer = undefined;
constructor(entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>, manager: CanvasManager) {
super(entityIdentifier, manager, CanvasEntityAdapterRegionalGuidance.TYPE);

View File

@@ -0,0 +1,174 @@
import type { SerializableObject } from 'common/types';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectAutoPreviewFilter } from 'features/controlLayers/store/canvasSettingsSlice';
import type { CanvasImageState, FilterConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS, imageDTOToImageObject } from 'features/controlLayers/store/types';
import { debounce } from 'lodash-es';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
import { assert } from 'tsafe';
export class CanvasEntityFilterer extends CanvasModuleBase {
readonly type = 'canvas_filterer';
readonly id: string;
readonly path: string[];
readonly parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer;
readonly manager: CanvasManager;
readonly log: Logger;
imageState: CanvasImageState | null = null;
subscriptions = new Set<() => void>();
$isFiltering = atom<boolean>(false);
$isProcessing = atom<boolean>(false);
$filterConfig = atom<FilterConfig>(IMAGE_FILTERS.canny_image_processor.buildDefaults());
constructor(parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating filter module');
this.subscriptions.add(
this.$filterConfig.listen(() => {
if (this.manager.stateApi.getSettings().autoPreviewFilter && this.$isFiltering.get()) {
this.previewFilter();
}
})
);
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectAutoPreviewFilter, (autoPreviewFilter) => {
if (autoPreviewFilter && this.$isFiltering.get()) {
this.previewFilter();
}
})
);
}
startFilter = (config?: FilterConfig) => {
this.log.trace('Initializing filter');
if (config) {
this.$filterConfig.set(config);
}
this.$isFiltering.set(true);
this.manager.stateApi.$filteringAdapter.set(this.parent);
this.previewFilter();
};
previewFilter = debounce(
async () => {
const config = this.$filterConfig.get();
this.log.trace({ config }, 'Previewing filter');
const rect = this.parent.transformer.getRelativeRect();
const imageDTO = await this.parent.renderer.rasterize({ rect, attrs: { filters: [] } });
const nodeId = getPrefixedId('filter_node');
const batch = this.buildBatchConfig(imageDTO, config, nodeId);
// Listen for the filter processing completion event
const listener = async (event: S['InvocationCompleteEvent']) => {
if (event.origin !== this.id || event.invocation_source_id !== nodeId) {
return;
}
this.manager.socket.off('invocation_complete', listener);
this.log.trace({ event } as SerializableObject, 'Handling filter processing completion');
const { result } = event;
assert(result.type === 'image_output', `Processor did not return an image output, got: ${result}`);
const imageDTO = await getImageDTO(result.image.image_name);
assert(imageDTO, "Failed to fetch processor output's image DTO");
this.imageState = imageDTOToImageObject(imageDTO);
await this.parent.renderer.setBuffer(this.imageState, true);
this.parent.renderer.hideObjects();
this.$isProcessing.set(false);
};
this.manager.socket.on('invocation_complete', listener);
this.log.trace({ batch } as SerializableObject, 'Enqueuing filter batch');
this.$isProcessing.set(true);
this.manager.stateApi.enqueueBatch(batch);
},
1000,
{ leading: true, trailing: true }
);
applyFilter = () => {
const imageState = this.imageState;
if (!imageState) {
this.log.warn('No image state to apply filter to');
return;
}
this.log.trace('Applying filter');
this.parent.renderer.commitBuffer();
const rect = this.parent.transformer.getRelativeRect();
this.manager.stateApi.rasterizeEntity({
entityIdentifier: this.parent.entityIdentifier,
imageObject: imageState,
rect: {
x: Math.round(rect.x),
y: Math.round(rect.y),
width: imageState.image.height,
height: imageState.image.width,
},
replaceObjects: true,
});
this.parent.renderer.showObjects();
this.imageState = null;
this.$isFiltering.set(false);
this.manager.stateApi.$filteringAdapter.set(null);
};
cancelFilter = () => {
this.log.trace('Cancelling filter');
this.parent.renderer.clearBuffer();
this.parent.renderer.showObjects();
this.parent.transformer.updatePosition();
this.parent.renderer.syncCache(true);
this.imageState = null;
this.$isProcessing.set(false);
this.$isFiltering.set(false);
this.manager.stateApi.$filteringAdapter.set(null);
};
buildBatchConfig = (imageDTO: ImageDTO, config: FilterConfig, id: string): BatchConfig => {
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const node = IMAGE_FILTERS[config.type].buildNode(imageDTO, config as never);
node.id = id;
const batch: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[node.id]: {
...node,
// filtered images are always intermediate - do not save to gallery
is_intermediate: true,
},
},
edges: [],
},
origin: this.id,
runs: 1,
},
};
return batch;
};
}

View File

@@ -454,7 +454,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
syncInteractionState = () => {
this.log.trace('Syncing interaction state');
if (this.manager.filter.$isFiltering.get()) {
// Not all entities have a filterer - only raster layer and control layer adapters
if (this.parent.filterer?.$isFiltering.get()) {
// May not interact with the entity when the filter is active
this.parent.konva.layer.listening(false);
this._setInteractionMode('off');

View File

@@ -1,178 +0,0 @@
import type { SerializableObject } from 'common/types';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasEntityIdentifier, CanvasImageState, FilterConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS, imageDTOToImageObject } from 'features/controlLayers/store/types';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
import { assert } from 'tsafe';
export class CanvasFilterModule extends CanvasModuleBase {
readonly type = 'canvas_filter';
readonly id: string;
readonly path: string[];
readonly parent: CanvasManager;
readonly manager: CanvasManager;
readonly log: Logger;
imageState: CanvasImageState | null = null;
$adapter = atom<
| CanvasEntityAdapterRasterLayer
| CanvasEntityAdapterControlLayer
| CanvasEntityAdapterInpaintMask
| CanvasEntityAdapterRegionalGuidance
| null
>(null);
$isFiltering = computed(this.$adapter, (adapter) => Boolean(adapter));
$isProcessing = atom<boolean>(false);
$config = atom<FilterConfig>(IMAGE_FILTERS.canny_image_processor.buildDefaults());
constructor(manager: CanvasManager) {
super();
this.id = getPrefixedId(this.type);
this.parent = manager;
this.manager = manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating filter module');
}
startFilter = (entityIdentifier: CanvasEntityIdentifier<'raster_layer' | 'control_layer'>) => {
this.log.trace('Initializing filter');
const adapter = this.manager.getAdapter(entityIdentifier);
if (!adapter) {
this.log.warn({ entityIdentifier }, 'Unable to find entity');
return;
}
if (adapter.entityIdentifier.type !== 'raster_layer' && adapter.entityIdentifier.type !== 'control_layer') {
this.log.warn({ entityIdentifier }, 'Unsupported entity type');
return;
}
this.$adapter.set(adapter);
};
previewFilter = async () => {
const adapter = this.$adapter.get();
if (!adapter) {
this.log.warn('Cannot preview filter without an adapter');
return;
}
const config = this.$config.get();
this.log.trace({ config }, 'Previewing filter');
const rect = adapter.transformer.getRelativeRect();
const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { filters: [] } });
const nodeId = getPrefixedId('filter_node');
const batch = this.buildBatchConfig(imageDTO, config, nodeId);
// Listen for the filter processing completion event
const listener = async (event: S['InvocationCompleteEvent']) => {
if (event.origin !== this.id || event.invocation_source_id !== nodeId) {
return;
}
this.manager.socket.off('invocation_complete', listener);
this.log.trace({ event } as SerializableObject, 'Handling filter processing completion');
const { result } = event;
assert(result.type === 'image_output', `Processor did not return an image output, got: ${result}`);
const imageDTO = await getImageDTO(result.image.image_name);
assert(imageDTO, "Failed to fetch processor output's image DTO");
this.imageState = imageDTOToImageObject(imageDTO);
adapter.renderer.clearBuffer();
await adapter.renderer.setBuffer(this.imageState, true);
adapter.renderer.hideObjects();
this.$isProcessing.set(false);
};
this.manager.socket.on('invocation_complete', listener);
this.log.trace({ batch } as SerializableObject, 'Enqueuing filter batch');
this.$isProcessing.set(true);
this.manager.stateApi.enqueueBatch(batch);
};
applyFilter = () => {
const imageState = this.imageState;
const adapter = this.$adapter.get();
if (!imageState) {
this.log.warn('No image state to apply filter to');
return;
}
if (!adapter) {
this.log.warn('Cannot apply filter without an adapter');
return;
}
this.log.trace('Applying filter');
adapter.renderer.commitBuffer();
const rect = adapter.transformer.getRelativeRect();
this.manager.stateApi.rasterizeEntity({
entityIdentifier: adapter.entityIdentifier,
imageObject: imageState,
rect: {
x: Math.round(rect.x),
y: Math.round(rect.y),
width: imageState.image.height,
height: imageState.image.width,
},
replaceObjects: true,
});
adapter.renderer.showObjects();
this.imageState = null;
this.$adapter.set(null);
};
cancelFilter = () => {
this.log.trace('Cancelling filter');
const adapter = this.$adapter.get();
if (adapter) {
adapter.renderer.clearBuffer();
adapter.renderer.showObjects();
adapter.transformer.updatePosition();
adapter.renderer.syncCache(true);
this.$adapter.set(null);
}
this.imageState = null;
this.$isProcessing.set(false);
};
buildBatchConfig = (imageDTO: ImageDTO, config: FilterConfig, id: string): BatchConfig => {
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const node = IMAGE_FILTERS[config.type].buildNode(imageDTO, config as never);
node.id = id;
const batch: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[node.id]: {
...node,
// filtered images are always intermediate - do not save to gallery
is_intermediate: true,
},
},
edges: [],
},
origin: this.id,
runs: 1,
},
};
return batch;
};
}

View File

@@ -12,7 +12,6 @@ import { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/Can
import { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntityAdapter/types';
import { CanvasEntityRendererModule } from 'features/controlLayers/konva/CanvasEntityRendererModule';
import { CanvasFilterModule } from 'features/controlLayers/konva/CanvasFilterModule';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasProgressImageModule } from 'features/controlLayers/konva/CanvasProgressImageModule';
import { CanvasStageModule } from 'features/controlLayers/konva/CanvasStageModule';
@@ -58,7 +57,6 @@ export class CanvasManager extends CanvasModuleBase {
stateApi: CanvasStateApiModule;
background: CanvasBackgroundModule;
filter: CanvasFilterModule;
stage: CanvasStageModule;
worker: CanvasWorkerModule;
cache: CanvasCacheModule;
@@ -105,7 +103,6 @@ export class CanvasManager extends CanvasModuleBase {
this.worker = new CanvasWorkerModule(this);
this.cache = new CanvasCacheModule(this);
this.entityRenderer = new CanvasEntityRendererModule(this);
this.filter = new CanvasFilterModule(this);
this.compositor = new CanvasCompositorModule(this);
@@ -128,9 +125,12 @@ export class CanvasManager extends CanvasModuleBase {
this.konva.previewLayer.add(this.bbox.konva.group);
this.konva.previewLayer.add(this.tool.konva.group);
this.$isBusy = computed([this.filter.$isFiltering, this.stateApi.$isTranforming], (isFiltering, isTransforming) => {
return isFiltering || isTransforming;
});
this.$isBusy = computed(
[this.stateApi.$isFiltering, this.stateApi.$isTranforming],
(isFiltering, isTransforming) => {
return isFiltering || isTransforming;
}
);
}
getAdapter = <T extends CanvasEntityType = CanvasEntityType>(
@@ -233,7 +233,6 @@ export class CanvasManager extends CanvasModuleBase {
this.progressImage,
this.stateApi,
this.background,
this.filter,
this.worker,
this.entityRenderer,
this.compositor,
@@ -280,7 +279,6 @@ export class CanvasManager extends CanvasModuleBase {
tool: this.tool.repr(),
progressImage: this.progressImage.repr(),
background: this.background.repr(),
filter: this.filter.repr(),
worker: this.worker.repr(),
entityRenderer: this.entityRenderer.repr(),
compositor: this.compositor.repr(),

View File

@@ -1,7 +1,7 @@
import { Mutex } from 'async-mutex';
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntityFilterer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntityObjectRenderer';
import type { CanvasFilterModule } from 'features/controlLayers/konva/CanvasFilterModule';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasStagingAreaModule } from 'features/controlLayers/konva/CanvasStagingAreaModule';
@@ -16,7 +16,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
readonly type = 'object_image';
readonly id: string;
readonly path: string[];
readonly parent: CanvasEntityObjectRenderer | CanvasStagingAreaModule | CanvasFilterModule;
readonly parent: CanvasEntityObjectRenderer | CanvasStagingAreaModule | CanvasEntityFilterer;
readonly manager: CanvasManager;
readonly log: Logger;
@@ -33,7 +33,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
constructor(
state: CanvasImageState,
parent: CanvasEntityObjectRenderer | CanvasStagingAreaModule | CanvasFilterModule
parent: CanvasEntityObjectRenderer | CanvasStagingAreaModule | CanvasEntityFilterer
) {
super();
this.id = state.id;

View File

@@ -1,6 +1,8 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit';
import type { AppStore, RootState } from 'app/store/store';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { SubscriptionHandler } from 'features/controlLayers/konva/util';
@@ -307,6 +309,16 @@ export class CanvasStateApiModule extends CanvasModuleBase {
}
};
/**
* The entity adapter being filtered, if any.
*/
$filteringAdapter = atom<CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer | null>(null);
/**
* Whether an entity is currently being filtered. Derived from `$filteringAdapter`.
*/
$isFiltering = computed(this.$filteringAdapter, (filteringAdapter) => Boolean(filteringAdapter));
/**
* The entity adapter being transformed, if any.
*/

View File

@@ -107,7 +107,7 @@ export class CanvasToolModule extends CanvasModuleBase {
this.subscriptions.add(this.manager.stage.$stageAttrs.listen(this.render));
this.subscriptions.add(this.manager.stateApi.$isTranforming.listen(this.render));
this.subscriptions.add(this.manager.filter.$isFiltering.listen(this.render));
this.subscriptions.add(this.manager.stateApi.$isFiltering.listen(this.render));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSettingsSlice, this.render));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSlice, this.syncCursorStyle));
this.subscriptions.add(
@@ -146,7 +146,7 @@ export class CanvasToolModule extends CanvasModuleBase {
stage.setCursor('not-allowed');
} else if (this.manager.stateApi.$isTranforming.get()) {
stage.setCursor('default');
} else if (this.manager.filter.$isFiltering.get()) {
} else if (this.manager.stateApi.$isFiltering.get()) {
stage.setCursor('not-allowed');
} else if (!this.manager.stateApi.getSelectedEntityAdapter()?.getIsInteractable()) {
stage.setCursor('not-allowed');
@@ -282,7 +282,7 @@ export class CanvasToolModule extends CanvasModuleBase {
return false;
} else if (this.manager.stateApi.$isTranforming.get()) {
return false;
} else if (this.manager.filter.$isFiltering.get()) {
} else if (this.manager.stateApi.$isFiltering.get()) {
return false;
} else if (!this.manager.stateApi.getSelectedEntityAdapter()?.getIsInteractable()) {
return false;

View File

@@ -51,7 +51,10 @@ type CanvasSettingsState = {
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
*/
compositeMaskedRegions: boolean;
/**
* Whether to automatically preview the filter when the filter configuration changes.
*/
autoPreviewFilter: boolean;
// TODO(psyche): These are copied from old canvas state, need to be implemented
// imageSmoothing: boolean;
// preserveMaskedArea: boolean;
@@ -69,6 +72,7 @@ const initialState: CanvasSettingsState = {
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
sendToCanvas: false,
compositeMaskedRegions: false,
autoPreviewFilter: true,
};
export const canvasSettingsSlice = createSlice({
@@ -105,6 +109,9 @@ export const canvasSettingsSlice = createSlice({
settingsCompositeMaskedRegionsChanged: (state, action: PayloadAction<boolean>) => {
state.compositeMaskedRegions = action.payload;
},
settingsAutoPreviewFilterToggled: (state) => {
state.autoPreviewFilter = !state.autoPreviewFilter;
},
},
});
@@ -119,6 +126,7 @@ export const {
settingsInvertScrollForToolWidthChanged,
settingsSendToCanvasChanged,
settingsCompositeMaskedRegionsChanged,
settingsAutoPreviewFilterToggled,
} = canvasSettingsSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -141,3 +149,7 @@ export const selectDynamicGrid = createSelector(
);
export const selectShowHUD = createSelector(selectCanvasSettingsSlice, (canvasSettings) => canvasSettings.showHUD);
export const selectAutoPreviewFilter = createSelector(
selectCanvasSettingsSlice,
(canvasSettings) => canvasSettings.autoPreviewFilter
);