feat(ui): auto-process for segment anything

This commit is contained in:
psychedelicious
2024-10-23 14:27:14 +10:00
parent b044f31a61
commit 116d32fbbe
4 changed files with 113 additions and 28 deletions

View File

@@ -1,8 +1,9 @@
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Heading, Spacer, Switch } from '@invoke-ai/ui-library';
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
import { FilterAutoProcessSwitch } from 'features/controlLayers/components/Filters/FilterAutoProcessSwitch';
import { FilterSettings } from 'features/controlLayers/components/Filters/FilterSettings';
import { FilterTypeSelect } from 'features/controlLayers/components/Filters/FilterTypeSelect';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
@@ -10,7 +11,6 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import {
selectAutoProcessFilter,
settingsAutoProcessFilterToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import type { FilterConfig } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS } from 'features/controlLayers/store/filters';
@@ -22,7 +22,6 @@ import { PiArrowsCounterClockwiseBold, PiCheckBold, PiShootingStarBold, PiXBold
const FilterContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const config = useStore(adapter.filterer.$filterConfig);
@@ -45,10 +44,6 @@ const FilterContent = memo(
[adapter.filterer.$filterConfig]
);
const onChangeAutoProcessFilter = useCallback(() => {
dispatch(settingsAutoProcessFilterToggled());
}, [dispatch]);
const isValid = useMemo(() => {
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
}, [config]);
@@ -88,10 +83,7 @@ const FilterContent = memo(
{t('controlLayers.filter.filter')}
</Heading>
<Spacer />
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
<Switch size="sm" isChecked={autoProcessFilter} onChange={onChangeAutoProcessFilter} />
</FormControl>
<FilterAutoProcessSwitch />
<CanvasOperationIsolatedLayerPreviewSwitch />
</Flex>
<FilterTypeSelect filterType={config.type} onChange={onChangeFilterType} />

View File

@@ -0,0 +1,27 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectAutoProcessFilter,
settingsAutoProcessFilterToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const FilterAutoProcessSwitch = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoProcessFilter = useAppSelector(selectAutoProcessFilter);
const onChangeAutoProcessFilter = useCallback(() => {
dispatch(settingsAutoProcessFilterToggled());
}, [dispatch]);
return (
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
<Switch size="sm" isChecked={autoProcessFilter} onChange={onChangeAutoProcessFilter} />
</FormControl>
);
});
FilterAutoProcessSwitch.displayName = 'FilterAutoProcessSwitch';

View File

@@ -1,11 +1,14 @@
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
import { FilterAutoProcessSwitch } from 'features/controlLayers/components/Filters/FilterAutoProcessSwitch';
import { SegmentAnythingPointType } from 'features/controlLayers/components/SegmentAnything/SegmentAnythingPointType';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { selectAutoProcessFilter } from 'features/controlLayers/store/canvasSettingsSlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
@@ -18,6 +21,8 @@ const SegmentAnythingContent = memo(
useFocusRegion('canvas', ref, { focusOnMount: true });
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.segmentAnything.$isProcessing);
const hasPoints = useStore(adapter.segmentAnything.$hasPoints);
const autoProcessFilter = useAppSelector(selectAutoProcessFilter);
useRegisteredHotkeys({
id: 'applySegmentAnything',
@@ -54,22 +59,24 @@ const SegmentAnythingContent = memo(
{t('controlLayers.segment.autoMask')}
</Heading>
<Spacer />
<FilterAutoProcessSwitch />
<CanvasOperationIsolatedLayerPreviewSwitch />
</Flex>
<SegmentAnythingPointType adapter={adapter} />
<ButtonGroup isAttached={false} size="sm" w="full">
<Spacer />
<Button
leftIcon={<PiStarBold />}
onClick={adapter.segmentAnything.process}
onClick={adapter.segmentAnything.processImmediate}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.process')}
variant="ghost"
isDisabled={!hasPoints || autoProcessFilter}
>
{t('controlLayers.segment.process')}
</Button>
<Spacer />
<Button
leftIcon={<PiArrowsCounterClockwiseBold />}
onClick={adapter.segmentAnything.reset}

View File

@@ -7,6 +7,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
import { selectAutoProcessFilter } from 'features/controlLayers/store/canvasSettingsSlice';
import type {
CanvasImageState,
Coordinate,
@@ -20,6 +21,7 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import { debounce } from 'lodash-es';
import type { Atom } from 'nanostores';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
@@ -69,7 +71,7 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
MASK_COLOR: { r: 0, g: 200, b: 200, a: 0.5 }, // cyan with 50% opacity
PROCESS_DEBOUNCE_MS: 300,
PROCESS_DEBOUNCE_MS: 1000,
};
/**
@@ -147,7 +149,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
/**
* The current input points.
*/
points: SAMPointState[] = [];
$points = atom<SAMPointState[]>([]);
/**
* Whether the module has points.
*/
$hasPoints = computed(this.$points, (points) => points.length > 0);
/**
* The masked image object, if it exists.
@@ -222,6 +229,15 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// until after processing
this.konva.maskGroup.add(this.konva.compositingRect);
this.subscriptions.add(
this.$isProcessing.listen((isProcessing) => {
this.syncCursorStyle();
if (this.$isSegmenting.get()) {
this.parent.konva.layer.listening(!isProcessing);
}
})
);
// Scale the SAM points when the stage scale changes
this.subscriptions.add(
this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => {
@@ -230,13 +246,41 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
}
})
);
// When the points change, process them if autoProcessFilter is enabled
this.subscriptions.add(
this.$points.listen((points) => {
if (points.length === 0) {
return;
}
if (this.manager.stateApi.getSettings().autoProcessFilter && this.$isSegmenting.get()) {
this.process();
}
})
);
// When auto-process is enabled, process the points if they have not been processed
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectAutoProcessFilter, (autoProcessFilter) => {
if (this.$points.get().length === 0) {
return;
}
if (autoProcessFilter && this.$isSegmenting.get() && !this.$hasProcessed.get()) {
this.process();
}
})
);
}
/**
* Synchronizes the cursor style to crosshair.
*/
syncCursorStyle = (): void => {
this.manager.stage.setCursor('crosshair');
if (this.$isProcessing.get()) {
this.manager.stage.setCursor('wait');
} else if (this.$isSegmenting.get()) {
this.manager.stage.setCursor('crosshair');
}
};
/**
@@ -271,7 +315,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// This event should not bubble up to the parent, stage or any other nodes
e.cancelBubble = true;
circle.destroy();
this.points = this.points.filter((point) => point.id !== id);
this.$points.set(this.$points.get().filter((point) => point.id !== id));
});
circle.on('dragstart', () => {
@@ -282,6 +326,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.$isDraggingPoint.set(false);
// Point has changed!
this.$hasProcessed.set(false);
this.$points.notify();
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Moved SAM point'
@@ -310,7 +355,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
syncPointScales = () => {
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH);
for (const point of this.points) {
for (const point of this.$points.get()) {
point.konva.circle.radius(radius);
point.konva.circle.strokeWidth(borderWidth);
}
@@ -322,7 +367,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
getSAMPoints = (): SAMPoint[] => {
const points: SAMPoint[] = [];
for (const { konva, label } of this.points) {
for (const { konva, label } of this.$points.get()) {
points.push({
// Pull out and round the x and y values from Konva
x: Math.round(konva.circle.x()),
@@ -373,7 +418,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// Create a SAM point at the normalized position
const point = this.createPoint(normalizedPoint, this.$pointType.get());
this.points.push(point);
this.$points.set([...this.$points.get(), point]);
// Mark the module as having _not_ processed the points now that they have changed
this.$hasProcessed.set(false);
@@ -432,8 +477,22 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
/**
* Processes the SAM points to segment the entity, updating the module's state and rendering the mask.
*/
process = async () => {
this.log.trace({ points: this.getSAMPoints() }, 'Segmenting');
processImmediate = async () => {
if (this.$isProcessing.get()) {
this.log.warn('Already processing');
return;
}
const points = this.getSAMPoints();
if (points.length === 0) {
this.log.trace('No points to segment');
return;
}
this.$isProcessing.set(true);
this.log.trace({ points }, 'Segmenting');
const rect = this.parent.transformer.getRelativeRect();
const rasterizeResult = await withResultAsync(() =>
@@ -446,8 +505,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return;
}
this.$isProcessing.set(true);
const controller = new AbortController();
this.abortController = controller;
@@ -490,6 +547,8 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.abortController = null;
};
process = debounce(this.processImmediate, this.config.PROCESS_DEBOUNCE_MS);
/**
* Applies the segmented image to the entity.
*/
@@ -571,7 +630,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.abortController = null;
// Destroy ephemeral konva nodes
for (const point of this.points) {
for (const point of this.$points.get()) {
point.konva.circle.destroy();
}
if (this.maskedImage) {
@@ -579,7 +638,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
}
// Empty internal module state
this.points = [];
this.$points.set([]);
this.imageState = null;
this.$pointType.set(1);
this.$hasProcessed.set(false);
@@ -645,7 +704,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
type: this.type,
path: this.path,
parent: this.parent.id,
points: this.points.map(({ id, konva, label }) => ({
points: this.$points.get().map(({ id, konva, label }) => ({
id,
label,
circle: getKonvaNodeDebugAttrs(konva.circle),