mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): auto-process for segment anything
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Heading, Spacer, Switch } from '@invoke-ai/ui-library';
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
|
||||
import { FilterAutoProcessSwitch } from 'features/controlLayers/components/Filters/FilterAutoProcessSwitch';
|
||||
import { FilterSettings } from 'features/controlLayers/components/Filters/FilterSettings';
|
||||
import { FilterTypeSelect } from 'features/controlLayers/components/Filters/FilterTypeSelect';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
@@ -10,7 +11,6 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import {
|
||||
selectAutoProcessFilter,
|
||||
settingsAutoProcessFilterToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type { FilterConfig } from 'features/controlLayers/store/filters';
|
||||
import { IMAGE_FILTERS } from 'features/controlLayers/store/filters';
|
||||
@@ -22,7 +22,6 @@ import { PiArrowsCounterClockwiseBold, PiCheckBold, PiShootingStarBold, PiXBold
|
||||
const FilterContent = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
const config = useStore(adapter.filterer.$filterConfig);
|
||||
@@ -45,10 +44,6 @@ const FilterContent = memo(
|
||||
[adapter.filterer.$filterConfig]
|
||||
);
|
||||
|
||||
const onChangeAutoProcessFilter = useCallback(() => {
|
||||
dispatch(settingsAutoProcessFilterToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
const isValid = useMemo(() => {
|
||||
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
|
||||
}, [config]);
|
||||
@@ -88,10 +83,7 @@ const FilterContent = memo(
|
||||
{t('controlLayers.filter.filter')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
|
||||
<Switch size="sm" isChecked={autoProcessFilter} onChange={onChangeAutoProcessFilter} />
|
||||
</FormControl>
|
||||
<FilterAutoProcessSwitch />
|
||||
<CanvasOperationIsolatedLayerPreviewSwitch />
|
||||
</Flex>
|
||||
<FilterTypeSelect filterType={config.type} onChange={onChangeFilterType} />
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectAutoProcessFilter,
|
||||
settingsAutoProcessFilterToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const FilterAutoProcessSwitch = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const autoProcessFilter = useAppSelector(selectAutoProcessFilter);
|
||||
|
||||
const onChangeAutoProcessFilter = useCallback(() => {
|
||||
dispatch(settingsAutoProcessFilterToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
|
||||
<Switch size="sm" isChecked={autoProcessFilter} onChange={onChangeAutoProcessFilter} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
FilterAutoProcessSwitch.displayName = 'FilterAutoProcessSwitch';
|
||||
@@ -1,11 +1,14 @@
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
|
||||
import { FilterAutoProcessSwitch } from 'features/controlLayers/components/Filters/FilterAutoProcessSwitch';
|
||||
import { SegmentAnythingPointType } from 'features/controlLayers/components/SegmentAnything/SegmentAnythingPointType';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import { selectAutoProcessFilter } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { memo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -18,6 +21,8 @@ const SegmentAnythingContent = memo(
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.segmentAnything.$isProcessing);
|
||||
const hasPoints = useStore(adapter.segmentAnything.$hasPoints);
|
||||
const autoProcessFilter = useAppSelector(selectAutoProcessFilter);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'applySegmentAnything',
|
||||
@@ -54,22 +59,24 @@ const SegmentAnythingContent = memo(
|
||||
{t('controlLayers.segment.autoMask')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
<FilterAutoProcessSwitch />
|
||||
<CanvasOperationIsolatedLayerPreviewSwitch />
|
||||
</Flex>
|
||||
|
||||
<SegmentAnythingPointType adapter={adapter} />
|
||||
|
||||
<ButtonGroup isAttached={false} size="sm" w="full">
|
||||
<Spacer />
|
||||
<Button
|
||||
leftIcon={<PiStarBold />}
|
||||
onClick={adapter.segmentAnything.process}
|
||||
onClick={adapter.segmentAnything.processImmediate}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.segment.process')}
|
||||
variant="ghost"
|
||||
isDisabled={!hasPoints || autoProcessFilter}
|
||||
>
|
||||
{t('controlLayers.segment.process')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
<Button
|
||||
leftIcon={<PiArrowsCounterClockwiseBold />}
|
||||
onClick={adapter.segmentAnything.reset}
|
||||
|
||||
@@ -7,6 +7,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
|
||||
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
|
||||
import { selectAutoProcessFilter } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type {
|
||||
CanvasImageState,
|
||||
Coordinate,
|
||||
@@ -20,6 +21,7 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { debounce } from 'lodash-es';
|
||||
import type { Atom } from 'nanostores';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
@@ -69,7 +71,7 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
|
||||
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
|
||||
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
|
||||
MASK_COLOR: { r: 0, g: 200, b: 200, a: 0.5 }, // cyan with 50% opacity
|
||||
PROCESS_DEBOUNCE_MS: 300,
|
||||
PROCESS_DEBOUNCE_MS: 1000,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -147,7 +149,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
/**
|
||||
* The current input points.
|
||||
*/
|
||||
points: SAMPointState[] = [];
|
||||
$points = atom<SAMPointState[]>([]);
|
||||
|
||||
/**
|
||||
* Whether the module has points.
|
||||
*/
|
||||
$hasPoints = computed(this.$points, (points) => points.length > 0);
|
||||
|
||||
/**
|
||||
* The masked image object, if it exists.
|
||||
@@ -222,6 +229,15 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// until after processing
|
||||
this.konva.maskGroup.add(this.konva.compositingRect);
|
||||
|
||||
this.subscriptions.add(
|
||||
this.$isProcessing.listen((isProcessing) => {
|
||||
this.syncCursorStyle();
|
||||
if (this.$isSegmenting.get()) {
|
||||
this.parent.konva.layer.listening(!isProcessing);
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
// Scale the SAM points when the stage scale changes
|
||||
this.subscriptions.add(
|
||||
this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => {
|
||||
@@ -230,13 +246,41 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
// When the points change, process them if autoProcessFilter is enabled
|
||||
this.subscriptions.add(
|
||||
this.$points.listen((points) => {
|
||||
if (points.length === 0) {
|
||||
return;
|
||||
}
|
||||
if (this.manager.stateApi.getSettings().autoProcessFilter && this.$isSegmenting.get()) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
// When auto-process is enabled, process the points if they have not been processed
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.createStoreSubscription(selectAutoProcessFilter, (autoProcessFilter) => {
|
||||
if (this.$points.get().length === 0) {
|
||||
return;
|
||||
}
|
||||
if (autoProcessFilter && this.$isSegmenting.get() && !this.$hasProcessed.get()) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Synchronizes the cursor style to crosshair.
|
||||
*/
|
||||
syncCursorStyle = (): void => {
|
||||
this.manager.stage.setCursor('crosshair');
|
||||
if (this.$isProcessing.get()) {
|
||||
this.manager.stage.setCursor('wait');
|
||||
} else if (this.$isSegmenting.get()) {
|
||||
this.manager.stage.setCursor('crosshair');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -271,7 +315,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// This event should not bubble up to the parent, stage or any other nodes
|
||||
e.cancelBubble = true;
|
||||
circle.destroy();
|
||||
this.points = this.points.filter((point) => point.id !== id);
|
||||
this.$points.set(this.$points.get().filter((point) => point.id !== id));
|
||||
});
|
||||
|
||||
circle.on('dragstart', () => {
|
||||
@@ -282,6 +326,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.$isDraggingPoint.set(false);
|
||||
// Point has changed!
|
||||
this.$hasProcessed.set(false);
|
||||
this.$points.notify();
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'Moved SAM point'
|
||||
@@ -310,7 +355,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
syncPointScales = () => {
|
||||
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
|
||||
const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH);
|
||||
for (const point of this.points) {
|
||||
for (const point of this.$points.get()) {
|
||||
point.konva.circle.radius(radius);
|
||||
point.konva.circle.strokeWidth(borderWidth);
|
||||
}
|
||||
@@ -322,7 +367,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
getSAMPoints = (): SAMPoint[] => {
|
||||
const points: SAMPoint[] = [];
|
||||
|
||||
for (const { konva, label } of this.points) {
|
||||
for (const { konva, label } of this.$points.get()) {
|
||||
points.push({
|
||||
// Pull out and round the x and y values from Konva
|
||||
x: Math.round(konva.circle.x()),
|
||||
@@ -373,7 +418,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
|
||||
// Create a SAM point at the normalized position
|
||||
const point = this.createPoint(normalizedPoint, this.$pointType.get());
|
||||
this.points.push(point);
|
||||
this.$points.set([...this.$points.get(), point]);
|
||||
|
||||
// Mark the module as having _not_ processed the points now that they have changed
|
||||
this.$hasProcessed.set(false);
|
||||
@@ -432,8 +477,22 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
/**
|
||||
* Processes the SAM points to segment the entity, updating the module's state and rendering the mask.
|
||||
*/
|
||||
process = async () => {
|
||||
this.log.trace({ points: this.getSAMPoints() }, 'Segmenting');
|
||||
processImmediate = async () => {
|
||||
if (this.$isProcessing.get()) {
|
||||
this.log.warn('Already processing');
|
||||
return;
|
||||
}
|
||||
|
||||
const points = this.getSAMPoints();
|
||||
|
||||
if (points.length === 0) {
|
||||
this.log.trace('No points to segment');
|
||||
return;
|
||||
}
|
||||
|
||||
this.$isProcessing.set(true);
|
||||
|
||||
this.log.trace({ points }, 'Segmenting');
|
||||
const rect = this.parent.transformer.getRelativeRect();
|
||||
|
||||
const rasterizeResult = await withResultAsync(() =>
|
||||
@@ -446,8 +505,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
this.$isProcessing.set(true);
|
||||
|
||||
const controller = new AbortController();
|
||||
this.abortController = controller;
|
||||
|
||||
@@ -490,6 +547,8 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.abortController = null;
|
||||
};
|
||||
|
||||
process = debounce(this.processImmediate, this.config.PROCESS_DEBOUNCE_MS);
|
||||
|
||||
/**
|
||||
* Applies the segmented image to the entity.
|
||||
*/
|
||||
@@ -571,7 +630,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.abortController = null;
|
||||
|
||||
// Destroy ephemeral konva nodes
|
||||
for (const point of this.points) {
|
||||
for (const point of this.$points.get()) {
|
||||
point.konva.circle.destroy();
|
||||
}
|
||||
if (this.maskedImage) {
|
||||
@@ -579,7 +638,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
}
|
||||
|
||||
// Empty internal module state
|
||||
this.points = [];
|
||||
this.$points.set([]);
|
||||
this.imageState = null;
|
||||
this.$pointType.set(1);
|
||||
this.$hasProcessed.set(false);
|
||||
@@ -645,7 +704,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
type: this.type,
|
||||
path: this.path,
|
||||
parent: this.parent.id,
|
||||
points: this.points.map(({ id, konva, label }) => ({
|
||||
points: this.$points.get().map(({ id, konva, label }) => ({
|
||||
id,
|
||||
label,
|
||||
circle: getKonvaNodeDebugAttrs(konva.circle),
|
||||
|
||||
Reference in New Issue
Block a user