From 606c4ae88c19fd1edd8021bad545773b32b6ce25 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 22 Oct 2024 19:09:50 +1000 Subject: [PATCH] feat(ui): masking UX (wip - issue w/ positioning) --- .../SegmentAnythingPointType.tsx | 8 +- .../konva/CanvasSegmentAnythingModule.ts | 121 ++++++++++++------ .../controlLayers/konva/CanvasStageModule.ts | 8 +- .../konva/CanvasTool/CanvasBboxToolModule.ts | 5 +- .../konva/CanvasTool/CanvasToolModule.ts | 9 +- .../konva/CanvasTool/CanvasViewToolModule.ts | 2 +- .../src/features/controlLayers/store/types.ts | 23 +++- 7 files changed, 125 insertions(+), 51 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SegmentAnything/SegmentAnythingPointType.tsx b/invokeai/frontend/web/src/features/controlLayers/components/SegmentAnything/SegmentAnythingPointType.tsx index 7a76c76ecf..9f597b885d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/SegmentAnything/SegmentAnythingPointType.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/SegmentAnything/SegmentAnythingPointType.tsx @@ -3,14 +3,14 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; -import { zSAMPointLabel } from 'features/controlLayers/store/types'; +import { SAM_POINT_LABEL_STRING_TO_NUMBER, zSAMPointLabelString } from 'features/controlLayers/store/types'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; export const SegmentAnythingPointType = memo( ({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => { const { t } = useTranslation(); - const pointType = useStore(adapter.segmentAnything.$pointType); + const pointType = useStore(adapter.segmentAnything.$pointTypeEnglish); const options = useMemo(() => { return [ @@ -28,7 +28,9 @@ export const SegmentAnythingPointType = memo( return; } - adapter.segmentAnything.$pointType.set(zSAMPointLabel.parse(v.value)); + const labelAsString = zSAMPointLabelString.parse(v.value); + const labelAsNumber = SAM_POINT_LABEL_STRING_TO_NUMBER[labelAsString]; + adapter.segmentAnything.$pointType.set(labelAsNumber); }, [adapter.segmentAnything.$pointType] ); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasSegmentAnythingModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasSegmentAnythingModule.ts index db47365532..f54c528f46 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasSegmentAnythingModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasSegmentAnythingModule.ts @@ -6,7 +6,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; -import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util'; +import { floorCoord, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util'; import type { CanvasImageState, Coordinate, @@ -14,14 +14,15 @@ import type { SAMPoint, SAMPointLabel, } from 'features/controlLayers/store/types'; +import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types'; import { imageDTOToImageObject } from 'features/controlLayers/store/util'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import Konva from 'konva'; import type { KonvaEventObject } from 'konva/lib/Node'; -import { atom } from 'nanostores'; +import type { Atom } from 'nanostores'; +import { atom, computed } from 'nanostores'; import type { Logger } from 'roarr'; import { serializeError } from 'serialize-error'; -import type { S } from 'services/api/types'; type CanvasSegmentAnythingModuleConfig = { SAM_POINT_RADIUS: number; @@ -34,12 +35,12 @@ type CanvasSegmentAnythingModuleConfig = { }; const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = { - SAM_POINT_RADIUS: 5, + SAM_POINT_RADIUS: 8, SAM_POINT_BORDER_WIDTH: 2, SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 }, - SAM_POINT_FOREGROUND_COLOR: { r: 0, g: 200, b: 0, a: 0.7 }, - SAM_POINT_BACKGROUND_COLOR: { r: 200, g: 0, b: 0, a: 0.7 }, - SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 0, b: 200, a: 0.7 }, + SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // green-ish + SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish + SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan-ish PROCESS_DEBOUNCE_MS: 300, }; @@ -72,7 +73,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { $hasProcessed = atom(false); $isProcessing = atom(false); - $pointType = atom('foreground'); + $pointType = atom(1); + $pointTypeEnglish = computed<(typeof SAM_POINT_LABEL_NUMBER_TO_STRING)[SAMPointLabel], Atom>( + this.$pointType, + (pointType) => SAM_POINT_LABEL_NUMBER_TO_STRING[pointType] + ); $isDraggingPoint = atom(false); imageState: CanvasImageState | null = null; @@ -116,20 +121,33 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { this.konva.group.add(this.konva.maskGroup); this.konva.group.add(this.konva.pointGroup); this.konva.maskGroup.add(this.konva.compositingRect); + this.subscriptions.add( + this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => { + if (stageAttrs.scale !== oldStageAttrs.scale) { + this.syncPointScales(); + } + }) + ); } + syncCursorStyle = () => { + this.manager.stage.setCursor('crosshair'); + }; + createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState { const id = getPrefixedId('sam_point'); const circle = new Konva.Circle({ name: this.KONVA_CIRCLE_NAME, x: Math.round(coord.x), y: Math.round(coord.y), - radius: this.config.SAM_POINT_RADIUS, + radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), fill: rgbaColorToString(this.getSAMPointColor(label)), stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR), - strokeWidth: this.config.SAM_POINT_BORDER_WIDTH, + strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH), draggable: true, - perfectDrawEnabled: false, + perfectDrawEnabled: true, + opacity: 0.6, + dragDistance: 3, }); circle.on('pointerup', (e) => { @@ -147,35 +165,60 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { circle.on('dragend', () => { this.$isDraggingPoint.set(false); + this.log.trace( + { x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, + 'SAM point moved' + ); }); - circle.dragBoundFunc(({ x, y }) => ({ - x: Math.round(x), - y: Math.round(y), - })); + circle.dragBoundFunc((pos) => floorCoord(pos)); this.konva.pointGroup.add(circle); - const state: SAMPointState = { + + this.log.trace( + { x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, + 'Created SAM point' + ); + + return { id, label, konva: { circle }, }; - - return state; } + syncPointScales = () => { + const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS); + const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH); + for (const { + konva: { circle }, + } of this.points) { + circle.radius(radius); + circle.strokeWidth(borderWidth); + } + }; + getSAMPoints = (): SAMPoint[] => { - return this.points.map(({ konva: { circle }, label }) => ({ - x: circle.x(), - y: circle.y(), - label, - })); + const points: SAMPoint[] = []; + + for (const { konva, label } of this.points) { + points.push({ + x: Math.round(konva.circle.x()), + y: Math.round(konva.circle.y()), + label, + }); + } + + return points; }; onPointerUp = (e: KonvaEventObject) => { if (e.evt.button !== 0) { return; } + if (this.manager.stage.getIsDragging()) { + return; + } if (this.$isDraggingPoint.get()) { return; } @@ -184,7 +227,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { return; } - this.points.push(this.createPoint(cursorPos.relative, this.$pointType.get())); + const pixelRect = this.parent.transformer.$pixelRect.get(); + + const normalizedPoint = offsetCoord(cursorPos.relative, { x: pixelRect.x, y: pixelRect.y }); + const samPoint = this.createPoint(normalizedPoint, this.$pointType.get()); + this.points.push(samPoint); }; setSegmentingEventListeners = () => { @@ -202,6 +249,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { return; } this.log.trace('Starting segment anything'); + this.$pointType.set(1); this.$isSegmenting.set(true); this.manager.stateApi.$segmentingAdapter.set(this.parent); for (const point of this.points) { @@ -209,6 +257,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { } this.points = []; this.parent.konva.layer.add(this.konva.group); + console.log({ + position: this.parent.state.position, + pixelRect: this.parent.transformer.$pixelRect.get(), + nodeRect: this.parent.transformer.$nodeRect.get(), + getRelativeRect: this.parent.transformer.getRelativeRect(), + }); + const pixelRect = this.parent.transformer.$pixelRect.get(); + this.konva.group.setAttrs({ x: pixelRect.x, y: pixelRect.y }); this.parent.konva.layer.listening(true); this.setSegmentingEventListeners(); @@ -239,15 +295,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { type: 'segment_anything_object_identifier', model: 'segment-anything-huge', image: { image_name: rasterizeResult.value.image_name }, - object_identifiers: [ - { - points: this.getSAMPoints().map(({ x, y, label }): S['SAMPoint'] => ({ - x, - y, - label: label === 'foreground' ? 1 : -1, - })), - }, - ], + object_identifiers: [{ points: this.getSAMPoints() }], }); const applyMask = g.addNode({ id: getPrefixedId('apply_tensor_mask_to_image'), @@ -321,7 +369,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { } this.konva.compositingRect.visible(false); this.konva.maskGroup.clearCache(); - this.$pointType.set('foreground'); + this.$pointType.set(1); this.$isSegmenting.set(false); this.$hasProcessed.set(false); this.manager.stateApi.$segmentingAdapter.set(null); @@ -364,11 +412,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { }; getSAMPointColor(label: SAMPointLabel): RgbaColor { - if (label === 'neutral') { + if (label === 0) { return this.config.SAM_POINT_NEUTRAL_COLOR; - } else if (label === 'foreground') { + } else if (label === 1) { return this.config.SAM_POINT_FOREGROUND_COLOR; } else { + // label === -1 return this.config.SAM_POINT_BACKGROUND_COLOR; } } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStageModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStageModule.ts index 0468b9a2b4..9c5351d6e7 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStageModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStageModule.ts @@ -311,7 +311,7 @@ export class CanvasStageModule extends CanvasModuleBase { this.setIsDraggable(true); // Then start dragging the stage if it's not already being dragged - if (!this.konva.stage.isDragging()) { + if (!this.getIsDragging()) { this.konva.stage.startDrag(); } @@ -328,7 +328,7 @@ export class CanvasStageModule extends CanvasModuleBase { this.setIsDraggable(this.manager.tool.$tool.get() === 'view'); // Stop dragging the stage if it's being dragged - if (this.konva.stage.isDragging()) { + if (this.getIsDragging()) { this.konva.stage.stopDrag(); } @@ -404,6 +404,10 @@ export class CanvasStageModule extends CanvasModuleBase { this.konva.stage.draggable(isDraggable); }; + getIsDragging = () => { + return this.konva.stage.isDragging(); + }; + addLayer = (layer: Konva.Layer) => { this.konva.stage.add(layer); }; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts index 8d0ffee675..7d0b2cf518 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts @@ -203,12 +203,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase { * Renders the bbox. The bbox is only visible when the tool is set to 'bbox'. */ render = () => { - this.log.trace('Rendering'); - - const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect; const tool = this.manager.tool.$tool.get(); - this.konva.group.visible(true); + const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect; // We need to reach up to the preview layer to enable/disable listening so that the bbox can be interacted with. // If the mangaer is busy, we disable listening so the bbox cannot be interacted with. diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts index 5268483cdc..811205dbf1 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts @@ -158,11 +158,12 @@ export class CanvasToolModule extends CanvasModuleBase { syncCursorStyle = () => { const stage = this.manager.stage; const tool = this.$tool.get(); + const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get(); - if (this.manager.stage.konva.stage.isDragging() || tool === 'view') { + if (this.manager.stage.getIsDragging() || tool === 'view') { this.tools.view.syncCursorStyle(); - } else if (this.manager.stateApi.$isSegmenting.get()) { - stage.setCursor('default'); + } else if (segmentingAdapter) { + segmentingAdapter.segmentAnything.syncCursorStyle(); } else if (this.manager.stateApi.$isFiltering.get()) { stage.setCursor('not-allowed'); } else if (this.manager.stagingArea.$isStaging.get()) { @@ -284,7 +285,7 @@ export class CanvasToolModule extends CanvasModuleBase { return false; } - if (this.manager.stage.konva.stage.isDragging()) { + if (this.manager.stage.getIsDragging()) { return false; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasViewToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasViewToolModule.ts index d796831b9d..211178e0d5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasViewToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasViewToolModule.ts @@ -24,6 +24,6 @@ export class CanvasViewToolModule extends CanvasModuleBase { } syncCursorStyle = () => { - this.manager.stage.setCursor(this.manager.stage.konva.stage.isDragging() ? 'grabbing' : 'grab'); + this.manager.stage.setCursor(this.manager.stage.getIsDragging() ? 'grabbing' : 'grab'); }; } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 12b81ac225..15f4e63563 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -96,9 +96,30 @@ const zCoordinateWithPressure = z.object({ }); export type CoordinateWithPressure = z.infer; -export const zSAMPointLabel = z.enum(['foreground', 'background', 'neutral']); +const SAM_POINT_LABELS = { + background: -1, + neutral: 0, + foreground: 1, +} as const; + +export const zSAMPointLabel = z.nativeEnum(SAM_POINT_LABELS); export type SAMPointLabel = z.infer; +export const zSAMPointLabelString = z.enum(['background', 'neutral', 'foreground']); +export type SAMPointLabelString = z.infer; + +export const SAM_POINT_LABEL_NUMBER_TO_STRING: Record = { + '-1': 'background', + 0: 'neutral', + 1: 'foreground', +}; + +export const SAM_POINT_LABEL_STRING_TO_NUMBER: Record = { + background: -1, + neutral: 0, + foreground: 1, +}; + export const zSAMPoint = z.object({ x: z.number().int().gte(0), y: z.number().int().gte(0),