mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): support inverted selection in auto-mask
This commit is contained in:
committed by
Kent Keirsey
parent
ab898a7180
commit
75f605ba1a
@@ -1889,6 +1889,7 @@
|
||||
"segment": {
|
||||
"autoMask": "Auto Mask",
|
||||
"pointType": "Point Type",
|
||||
"invertSelection": "Invert Selection",
|
||||
"include": "Include",
|
||||
"exclude": "Exclude",
|
||||
"neutral": "Neutral",
|
||||
@@ -1896,8 +1897,9 @@
|
||||
"saveAs": "Save As",
|
||||
"cancel": "Cancel",
|
||||
"process": "Process",
|
||||
"help1": "Auto-mask creates a mask for a single target object. Add <Bold>Include</Bold> and <Bold>Exclude</Bold> points to indicate which parts of the layer are part of the target object.",
|
||||
"help2": "Start with one <Bold>Include</Bold> point within the target object. Add more points to refine the mask. Fewer points typically produce better results.",
|
||||
"help1": "Select a single target object. Add <Bold>Include</Bold> and <Bold>Exclude</Bold> points to indicate which parts of the layer are part of the target object.",
|
||||
"help2": "Start with one <Bold>Include</Bold> point within the target object. Add more points to refine the selection. Fewer points typically produce better results.",
|
||||
"help3": "Invert the selection to select everything except the target object.",
|
||||
"clickToAdd": "Click on the layer to add a point",
|
||||
"dragToMove": "Drag a point to move it",
|
||||
"clickToRemove": "Click on a point to remove it"
|
||||
|
||||
@@ -19,6 +19,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { CanvasAutoProcessSwitch } from 'features/controlLayers/components/CanvasAutoProcessSwitch';
|
||||
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
|
||||
import { SegmentAnythingInvert } from 'features/controlLayers/components/SegmentAnything/SegmentAnythingInvert';
|
||||
import { SegmentAnythingPointType } from 'features/controlLayers/components/SegmentAnything/SegmentAnythingPointType';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
@@ -103,7 +104,10 @@ const SegmentAnythingContent = memo(
|
||||
<CanvasOperationIsolatedLayerPreviewSwitch />
|
||||
</Flex>
|
||||
|
||||
<SegmentAnythingPointType adapter={adapter} />
|
||||
<Flex w="full" justifyContent="space-between" py={2}>
|
||||
<SegmentAnythingPointType adapter={adapter} />
|
||||
<SegmentAnythingInvert adapter={adapter} />
|
||||
</Flex>
|
||||
|
||||
<ButtonGroup isAttached={false} size="sm" w="full">
|
||||
<Button
|
||||
@@ -197,6 +201,9 @@ const SegmentAnythingHelpTooltipContent = memo(() => {
|
||||
<Text>
|
||||
<Trans i18nKey="controlLayers.segment.help2" components={{ Bold: <Bold /> }} />
|
||||
</Text>
|
||||
<Text>
|
||||
<Trans i18nKey="controlLayers.segment.help3" />
|
||||
</Text>
|
||||
<UnorderedList>
|
||||
<ListItem>{t('controlLayers.segment.clickToAdd')}</ListItem>
|
||||
<ListItem>{t('controlLayers.segment.dragToMove')}</ListItem>
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const SegmentAnythingInvert = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const invert = useStore(adapter.segmentAnything.$invert);
|
||||
|
||||
const onChange = useCallback(() => {
|
||||
adapter.segmentAnything.$invert.set(!adapter.segmentAnything.$invert.get());
|
||||
}, [adapter.segmentAnything.$invert]);
|
||||
|
||||
return (
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.segment.invertSelection')}</FormLabel>
|
||||
<Switch size="sm" isChecked={invert} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SegmentAnythingInvert.displayName = 'SegmentAnythingInvert';
|
||||
@@ -21,8 +21,8 @@ export const SegmentAnythingPointType = memo(
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel>{t('controlLayers.segment.pointType')}</FormLabel>
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.segment.pointType')}</FormLabel>
|
||||
<RadioGroup value={pointType} onChange={onChange} w="full" size="md">
|
||||
<Flex alignItems="center" w="full" gap={4} fontWeight="semibold" color="base.300">
|
||||
<Radio value="foreground">
|
||||
|
||||
@@ -172,6 +172,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
*/
|
||||
$hasPoints = computed(this.$points, (points) => points.length > 0);
|
||||
|
||||
/**
|
||||
* Whether the module should invert the mask image.
|
||||
*/
|
||||
$invert = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* The masked image object, if it exists.
|
||||
*/
|
||||
@@ -456,6 +461,19 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
})
|
||||
);
|
||||
|
||||
// When the invert flag changes, process if autoProcess is enabled
|
||||
this.subscriptions.add(
|
||||
this.$invert.listen(() => {
|
||||
if (this.$points.get().length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.manager.stateApi.getSettings().autoProcess) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
// When auto-process is enabled, process the points if they have not been processed
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.createStoreSubscription(selectAutoProcess, (autoProcess) => {
|
||||
@@ -529,7 +547,9 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
const hash = stableHash(points);
|
||||
const invert = this.$invert.get();
|
||||
|
||||
const hash = stableHash({ points, invert });
|
||||
if (hash === this.$lastProcessedHash.get()) {
|
||||
this.log.trace('Already processed points');
|
||||
return;
|
||||
@@ -556,7 +576,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.abortController = controller;
|
||||
|
||||
// Build the graph for segmenting the image, using the rasterized image DTO
|
||||
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value, points);
|
||||
const { graph, outputNodeId } = CanvasSegmentAnythingModule.buildGraph(rasterizeResult.value, points, invert);
|
||||
|
||||
// Run the graph and get the segmented image output
|
||||
const segmentResult = await withResultAsync(() =>
|
||||
@@ -793,6 +813,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.$points.set([]);
|
||||
this.$imageState.set(null);
|
||||
this.$pointType.set(1);
|
||||
this.$invert.set(false);
|
||||
this.$lastProcessedHash.set('');
|
||||
this.$isProcessing.set(false);
|
||||
|
||||
@@ -808,7 +829,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
/**
|
||||
* Builds a graph for segmenting an image with the given image DTO.
|
||||
*/
|
||||
buildGraph = ({ image_name }: ImageDTO, points: SAMPointWithId[]): { graph: Graph; outputNodeId: string } => {
|
||||
static buildGraph = (
|
||||
{ image_name }: ImageDTO,
|
||||
points: SAMPointWithId[],
|
||||
invert: boolean
|
||||
): { graph: Graph; outputNodeId: string } => {
|
||||
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
|
||||
|
||||
// TODO(psyche): When SAM2 is available in transformers, use it here
|
||||
@@ -827,6 +852,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
id: getPrefixedId('apply_tensor_mask_to_image'),
|
||||
type: 'apply_tensor_mask_to_image',
|
||||
image: { image_name },
|
||||
invert,
|
||||
});
|
||||
graph.addEdge(segmentAnything, 'mask', applyMask, 'mask');
|
||||
|
||||
|
||||
Reference in New Issue
Block a user