mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): working segment anything flow
This commit is contained in:
@@ -1850,7 +1850,8 @@
|
||||
"neutral": "Neutral",
|
||||
"reset": "Reset",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel"
|
||||
"cancel": "Cancel",
|
||||
"process": "Process"
|
||||
},
|
||||
"settings": {
|
||||
"snapToGrid": {
|
||||
|
||||
@@ -8,7 +8,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { memo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiStarBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
const SegmentAnythingContent = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
@@ -58,6 +58,15 @@ const SegmentAnythingContent = memo(
|
||||
|
||||
<ButtonGroup isAttached={false} size="sm" w="full">
|
||||
<Spacer />
|
||||
<Button
|
||||
leftIcon={<PiStarBold />}
|
||||
onClick={adapter.segmentAnything.process}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.segment.process')}
|
||||
variant="ghost"
|
||||
>
|
||||
{t('controlLayers.segment.process')}
|
||||
</Button>
|
||||
<Button
|
||||
leftIcon={<PiArrowsCounterClockwiseBold />}
|
||||
onClick={adapter.segmentAnything.reset}
|
||||
|
||||
@@ -6,7 +6,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { CanvasImageState, RgbaColor, SAMPoint, SAMPointLabel } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
@@ -15,6 +15,7 @@ import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
type CanvasSegmentAnythingModuleConfig = {
|
||||
SAM_POINT_RADIUS: number;
|
||||
@@ -65,7 +66,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
$hasProcessed = atom<boolean>(false);
|
||||
$isProcessing = atom<boolean>(false);
|
||||
|
||||
$pointType = atom<SAMPointLabel>('background');
|
||||
$pointType = atom<SAMPointLabel>('foreground');
|
||||
|
||||
imageState: CanvasImageState | null = null;
|
||||
|
||||
@@ -74,6 +75,8 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
|
||||
konva: {
|
||||
group: Konva.Group;
|
||||
pointGroup: Konva.Group;
|
||||
maskGroup: Konva.Group;
|
||||
compositingRect: Konva.Rect;
|
||||
};
|
||||
|
||||
@@ -91,23 +94,29 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: `${this.type}:group` }),
|
||||
pointGroup: new Konva.Group({ name: `${this.type}:point_group` }),
|
||||
maskGroup: new Konva.Group({ name: `${this.type}:mask_group` }),
|
||||
compositingRect: new Konva.Rect({
|
||||
name: `${this.type}:compositingRect`,
|
||||
fill: rgbaColorToString({ r: 0, g: 0, b: 0, a: 0.5 }),
|
||||
globalCompositeOperation: 'source-in',
|
||||
fill: rgbaColorToString({ r: 0, g: 200, b: 200, a: 0.5 }),
|
||||
globalCompositeOperation: 'source-atop',
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
visible: false,
|
||||
}),
|
||||
};
|
||||
this.konva.group.add(this.konva.maskGroup);
|
||||
this.konva.group.add(this.konva.pointGroup);
|
||||
this.konva.maskGroup.add(this.konva.compositingRect);
|
||||
}
|
||||
|
||||
createPoint(x: number, y: number, label: SAMPointLabel): SAMPointState {
|
||||
const id = getPrefixedId('sam_point');
|
||||
const circle = new Konva.Circle({
|
||||
name: this.KONVA_CIRCLE_NAME,
|
||||
x,
|
||||
y,
|
||||
x: Math.round(x),
|
||||
y: Math.round(y),
|
||||
radius: this.config.SAM_POINT_RADIUS,
|
||||
fill: rgbaColorToString(this.getSAMPointColor(label)),
|
||||
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
|
||||
@@ -130,7 +139,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// y: Math.round(y),
|
||||
// }));
|
||||
|
||||
this.konva.group.add(circle);
|
||||
this.konva.pointGroup.add(circle);
|
||||
const state: SAMPointState = {
|
||||
id,
|
||||
label,
|
||||
@@ -154,7 +163,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
this.createPoint(cursorPos.relative.x, cursorPos.relative.y, this.$pointType.get());
|
||||
this.points.push(this.createPoint(cursorPos.relative.x, cursorPos.relative.y, this.$pointType.get()));
|
||||
};
|
||||
|
||||
setSegmentingEventListeners = () => {
|
||||
@@ -203,10 +212,33 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
const controller = new AbortController();
|
||||
this.abortController = controller;
|
||||
|
||||
const g = new Graph(getPrefixedId('canvas_segment_anything'));
|
||||
const segmentAnything = g.addNode({
|
||||
id: getPrefixedId('segment_anything_object_identifier'),
|
||||
type: 'segment_anything_object_identifier',
|
||||
model: 'segment-anything-huge',
|
||||
image: { image_name: rasterizeResult.value.image_name },
|
||||
object_identifiers: [
|
||||
{
|
||||
points: this.getSAMPoints().map(({ x, y, label }): S['SAMPoint'] => ({
|
||||
x,
|
||||
y,
|
||||
label: label === 'foreground' ? 1 : -1,
|
||||
})),
|
||||
},
|
||||
],
|
||||
});
|
||||
const applyMask = g.addNode({
|
||||
id: getPrefixedId('apply_tensor_mask_to_image'),
|
||||
type: 'apply_tensor_mask_to_image',
|
||||
image: { image_name: rasterizeResult.value.image_name },
|
||||
});
|
||||
g.addEdge(segmentAnything, 'mask', applyMask, 'mask');
|
||||
|
||||
const segmentResult = await withResultAsync(() =>
|
||||
this.manager.stateApi.runGraphAndReturnImageOutput({
|
||||
graph: new Graph(),
|
||||
outputNodeId: 'TODO',
|
||||
graph: g,
|
||||
outputNodeId: applyMask.id,
|
||||
prepend: true,
|
||||
signal: controller.signal,
|
||||
})
|
||||
@@ -226,10 +258,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
}
|
||||
this.maskedImage = new CanvasObjectImage(this.imageState, this);
|
||||
await this.maskedImage.update(this.imageState, true);
|
||||
this.konva.compositingRect.width(this.imageState.image.width);
|
||||
this.konva.compositingRect.height(this.imageState.image.height);
|
||||
this.konva.group.add(this.maskedImage.konva.group);
|
||||
this.konva.compositingRect.setAttrs({
|
||||
width: this.imageState.image.width,
|
||||
height: this.imageState.image.height,
|
||||
visible: true,
|
||||
});
|
||||
this.konva.maskGroup.add(this.maskedImage.konva.group);
|
||||
this.konva.compositingRect.moveToTop();
|
||||
this.konva.maskGroup.cache();
|
||||
|
||||
this.$isProcessing.set(false);
|
||||
this.$hasProcessed.set(true);
|
||||
@@ -255,6 +291,16 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
replaceObjects: true,
|
||||
});
|
||||
this.imageState = null;
|
||||
for (const point of this.points) {
|
||||
point.konva.circle.destroy();
|
||||
}
|
||||
this.points = [];
|
||||
if (this.maskedImage) {
|
||||
this.maskedImage.destroy();
|
||||
}
|
||||
this.konva.compositingRect.visible(false);
|
||||
this.konva.maskGroup.clearCache();
|
||||
this.$pointType.set('foreground');
|
||||
this.$isSegmenting.set(false);
|
||||
this.$hasProcessed.set(false);
|
||||
this.manager.stateApi.$segmentingAdapter.set(null);
|
||||
@@ -265,8 +311,15 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
reset = () => {
|
||||
this.log.trace('Resetting segment anything');
|
||||
|
||||
for (const point of this.points) {
|
||||
point.konva.circle.destroy();
|
||||
}
|
||||
this.points = [];
|
||||
this.konva.group.destroyChildren();
|
||||
if (this.maskedImage) {
|
||||
this.maskedImage.destroy();
|
||||
}
|
||||
this.konva.compositingRect.visible(false);
|
||||
this.konva.maskGroup.clearCache();
|
||||
|
||||
this.abortController?.abort();
|
||||
this.abortController = null;
|
||||
@@ -308,6 +361,10 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
points: this.getSAMPoints(),
|
||||
config: deepClone(this.config),
|
||||
isSegmenting: this.$isSegmenting.get(),
|
||||
konva: {
|
||||
group: getKonvaNodeDebugAttrs(this.konva.group),
|
||||
compositingRect: getKonvaNodeDebugAttrs(this.konva.compositingRect),
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user