feat(ui): working segment anything flow

This commit is contained in:
psychedelicious
2024-10-18 17:11:08 +10:00
parent afa0661e55
commit 86a8476d97
3 changed files with 83 additions and 16 deletions

View File

@@ -1850,7 +1850,8 @@
"neutral": "Neutral",
"reset": "Reset",
"apply": "Apply",
"cancel": "Cancel"
"cancel": "Cancel",
"process": "Process"
},
"settings": {
"snapToGrid": {

View File

@@ -8,7 +8,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiXBold } from 'react-icons/pi';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiStarBold, PiXBold } from 'react-icons/pi';
const SegmentAnythingContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
@@ -58,6 +58,15 @@ const SegmentAnythingContent = memo(
<ButtonGroup isAttached={false} size="sm" w="full">
<Spacer />
<Button
leftIcon={<PiStarBold />}
onClick={adapter.segmentAnything.process}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.process')}
variant="ghost"
>
{t('controlLayers.segment.process')}
</Button>
<Button
leftIcon={<PiArrowsCounterClockwiseBold />}
onClick={adapter.segmentAnything.reset}

View File

@@ -6,7 +6,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasImageState, RgbaColor, SAMPoint, SAMPointLabel } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -15,6 +15,7 @@ import type { KonvaEventObject } from 'konva/lib/Node';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import type { S } from 'services/api/types';
type CanvasSegmentAnythingModuleConfig = {
SAM_POINT_RADIUS: number;
@@ -65,7 +66,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
$hasProcessed = atom<boolean>(false);
$isProcessing = atom<boolean>(false);
$pointType = atom<SAMPointLabel>('background');
$pointType = atom<SAMPointLabel>('foreground');
imageState: CanvasImageState | null = null;
@@ -74,6 +75,8 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
konva: {
group: Konva.Group;
pointGroup: Konva.Group;
maskGroup: Konva.Group;
compositingRect: Konva.Rect;
};
@@ -91,23 +94,29 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.konva = {
group: new Konva.Group({ name: `${this.type}:group` }),
pointGroup: new Konva.Group({ name: `${this.type}:point_group` }),
maskGroup: new Konva.Group({ name: `${this.type}:mask_group` }),
compositingRect: new Konva.Rect({
name: `${this.type}:compositingRect`,
fill: rgbaColorToString({ r: 0, g: 0, b: 0, a: 0.5 }),
globalCompositeOperation: 'source-in',
fill: rgbaColorToString({ r: 0, g: 200, b: 200, a: 0.5 }),
globalCompositeOperation: 'source-atop',
listening: false,
strokeEnabled: false,
perfectDrawEnabled: false,
visible: false,
}),
};
this.konva.group.add(this.konva.maskGroup);
this.konva.group.add(this.konva.pointGroup);
this.konva.maskGroup.add(this.konva.compositingRect);
}
createPoint(x: number, y: number, label: SAMPointLabel): SAMPointState {
const id = getPrefixedId('sam_point');
const circle = new Konva.Circle({
name: this.KONVA_CIRCLE_NAME,
x,
y,
x: Math.round(x),
y: Math.round(y),
radius: this.config.SAM_POINT_RADIUS,
fill: rgbaColorToString(this.getSAMPointColor(label)),
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
@@ -130,7 +139,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// y: Math.round(y),
// }));
this.konva.group.add(circle);
this.konva.pointGroup.add(circle);
const state: SAMPointState = {
id,
label,
@@ -154,7 +163,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return;
}
this.createPoint(cursorPos.relative.x, cursorPos.relative.y, this.$pointType.get());
this.points.push(this.createPoint(cursorPos.relative.x, cursorPos.relative.y, this.$pointType.get()));
};
setSegmentingEventListeners = () => {
@@ -203,10 +212,33 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
const controller = new AbortController();
this.abortController = controller;
const g = new Graph(getPrefixedId('canvas_segment_anything'));
const segmentAnything = g.addNode({
id: getPrefixedId('segment_anything_object_identifier'),
type: 'segment_anything_object_identifier',
model: 'segment-anything-huge',
image: { image_name: rasterizeResult.value.image_name },
object_identifiers: [
{
points: this.getSAMPoints().map(({ x, y, label }): S['SAMPoint'] => ({
x,
y,
label: label === 'foreground' ? 1 : -1,
})),
},
],
});
const applyMask = g.addNode({
id: getPrefixedId('apply_tensor_mask_to_image'),
type: 'apply_tensor_mask_to_image',
image: { image_name: rasterizeResult.value.image_name },
});
g.addEdge(segmentAnything, 'mask', applyMask, 'mask');
const segmentResult = await withResultAsync(() =>
this.manager.stateApi.runGraphAndReturnImageOutput({
graph: new Graph(),
outputNodeId: 'TODO',
graph: g,
outputNodeId: applyMask.id,
prepend: true,
signal: controller.signal,
})
@@ -226,10 +258,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
}
this.maskedImage = new CanvasObjectImage(this.imageState, this);
await this.maskedImage.update(this.imageState, true);
this.konva.compositingRect.width(this.imageState.image.width);
this.konva.compositingRect.height(this.imageState.image.height);
this.konva.group.add(this.maskedImage.konva.group);
this.konva.compositingRect.setAttrs({
width: this.imageState.image.width,
height: this.imageState.image.height,
visible: true,
});
this.konva.maskGroup.add(this.maskedImage.konva.group);
this.konva.compositingRect.moveToTop();
this.konva.maskGroup.cache();
this.$isProcessing.set(false);
this.$hasProcessed.set(true);
@@ -255,6 +291,16 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
replaceObjects: true,
});
this.imageState = null;
for (const point of this.points) {
point.konva.circle.destroy();
}
this.points = [];
if (this.maskedImage) {
this.maskedImage.destroy();
}
this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache();
this.$pointType.set('foreground');
this.$isSegmenting.set(false);
this.$hasProcessed.set(false);
this.manager.stateApi.$segmentingAdapter.set(null);
@@ -265,8 +311,15 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
reset = () => {
this.log.trace('Resetting segment anything');
for (const point of this.points) {
point.konva.circle.destroy();
}
this.points = [];
this.konva.group.destroyChildren();
if (this.maskedImage) {
this.maskedImage.destroy();
}
this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache();
this.abortController?.abort();
this.abortController = null;
@@ -308,6 +361,10 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
points: this.getSAMPoints(),
config: deepClone(this.config),
isSegmenting: this.$isSegmenting.get(),
konva: {
group: getKonvaNodeDebugAttrs(this.konva.group),
compositingRect: getKonvaNodeDebugAttrs(this.konva.compositingRect),
},
};
};