From bc85bd4bd42903fe5f3799a09d6a1d52acca2bca Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 23 Oct 2024 11:52:00 +1000 Subject: [PATCH] tidy(ui): clean up and document CanvasSegmentAnythingModule --- .../konva/CanvasSegmentAnythingModule.ts | 414 +++++++++++++----- .../src/features/controlLayers/store/types.ts | 6 + 2 files changed, 319 insertions(+), 101 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasSegmentAnythingModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasSegmentAnythingModule.ts index 072893f271..88cd46ca78 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasSegmentAnythingModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasSegmentAnythingModule.ts @@ -6,19 +6,14 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; -import { - addCoords, - floorCoord, - getKonvaNodeDebugAttrs, - getPrefixedId, - offsetCoord, -} from 'features/controlLayers/konva/util'; +import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util'; import type { CanvasImageState, Coordinate, RgbaColor, SAMPoint, SAMPointLabel, + SAMPointLabelString, } from 'features/controlLayers/store/types'; import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types'; import { imageDTOToImageObject } from 'features/controlLayers/store/util'; @@ -29,14 +24,40 @@ import type { Atom } from 'nanostores'; import { atom, computed } from 'nanostores'; import type { Logger } from 'roarr'; import { serializeError } from 'serialize-error'; +import type { ImageDTO } from 'services/api/types'; type CanvasSegmentAnythingModuleConfig = { + /** + * The radius of the SAM point Konva circle node. + */ SAM_POINT_RADIUS: number; + /** + * The border width of the SAM point Konva circle node. + */ SAM_POINT_BORDER_WIDTH: number; + /** + * The border color of the SAM point Konva circle node. + */ SAM_POINT_BORDER_COLOR: RgbaColor; + /** + * The color of the SAM point Konva circle node when the label is 1. + */ SAM_POINT_FOREGROUND_COLOR: RgbaColor; + /** + * The color of the SAM point Konva circle node when the label is -1. + */ SAM_POINT_BACKGROUND_COLOR: RgbaColor; + /** + * The color of the SAM point Konva circle node when the label is 0. + */ SAM_POINT_NEUTRAL_COLOR: RgbaColor; + /** + * The color to use for the mask preview overlay. + */ + MASK_COLOR: RgbaColor; + /** + * The debounce time in milliseconds for processing the points. + */ PROCESS_DEBOUNCE_MS: number; }; @@ -44,12 +65,21 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = { SAM_POINT_RADIUS: 8, SAM_POINT_BORDER_WIDTH: 2, SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 }, - SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // green-ish + SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // light green SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish - SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan-ish + SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan + MASK_COLOR: { r: 0, g: 200, b: 200, a: 0.5 }, // cyan with 50% opacity PROCESS_DEBOUNCE_MS: 300, }; +/** + * The state of a SAM point. + * @property id - The unique identifier of the point. + * @property label - The label of the point. -1 is background, 0 is neutral, 1 is foreground. + * @property konva - The Konva node state of the point. + * @property konva.circle - The Konva circle node of the point. The x and y coordinates for the point are derived from + * this node. + */ type SAMPointState = { id: string; label: SAMPointLabel; @@ -75,30 +105,88 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { */ abortController: AbortController | null = null; + /** + * Whether the module is currently segmenting an entity. + */ $isSegmenting = atom(false); + + /** + * Whether the current set of points has been processed. + */ $hasProcessed = atom(false); + + /** + * Whether the module is currently processing the points. + */ $isProcessing = atom(false); + /** + * The type of point to create when segmenting. This is a number representation of the SAMPointLabel enum. + */ $pointType = atom(1); - $pointTypeEnglish = computed<(typeof SAM_POINT_LABEL_NUMBER_TO_STRING)[SAMPointLabel], Atom>( + + /** + * The type of point to create when segmenting, as a string. This is a computed value based on $pointType. + */ + $pointTypeEnglish = computed>( this.$pointType, (pointType) => SAM_POINT_LABEL_NUMBER_TO_STRING[pointType] ); + + /** + * Whether a point is currently being dragged. This is used to prevent the point additions and deletions during + * dragging. + */ $isDraggingPoint = atom(false); + /** + * The ephemeral image state of the processed image. Only used while segmenting. + */ imageState: CanvasImageState | null = null; + /** + * The current input points. + */ points: SAMPointState[] = []; + + /** + * The masked image object, if it exists. + */ maskedImage: CanvasObjectImage | null = null; + /** + * The Konva nodes for the module. + */ konva: { + /** + * The main Konva group node for the module. + */ group: Konva.Group; + /** + * The Konva group node for the SAM points. + * + * This is a child of the main group node, rendered above the mask group. + */ pointGroup: Konva.Group; + /** + * The Konva group node for the mask image and compositing rect. + * + * This is a child of the main group node, rendered below the point group. + */ maskGroup: Konva.Group; + /** + * The Konva rect node for compositing the mask image. + * + * It's rendered with a globalCompositeOperation of 'source-atop' to preview the mask as a semi-transparent overlay. + */ compositingRect: Konva.Rect; }; KONVA_CIRCLE_NAME = `${this.type}:circle`; + KONVA_GROUP_NAME = `${this.type}:group`; + KONVA_POINT_GROUP_NAME = `${this.type}:point_group`; + KONVA_MASK_GROUP_NAME = `${this.type}:mask_group`; + KONVA_COMPOSITING_RECT_NAME = `${this.type}:compositing_rect`; constructor(parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer) { super(); @@ -110,13 +198,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { this.log.debug('Creating module'); + // Create all konva nodes this.konva = { - group: new Konva.Group({ name: `${this.type}:group` }), - pointGroup: new Konva.Group({ name: `${this.type}:point_group` }), - maskGroup: new Konva.Group({ name: `${this.type}:mask_group` }), + group: new Konva.Group({ name: this.KONVA_GROUP_NAME }), + pointGroup: new Konva.Group({ name: this.KONVA_POINT_GROUP_NAME }), + maskGroup: new Konva.Group({ name: this.KONVA_MASK_GROUP_NAME }), compositingRect: new Konva.Rect({ - name: `${this.type}:compositingRect`, - fill: rgbaColorToString({ r: 0, g: 200, b: 200, a: 0.5 }), + name: this.KONVA_COMPOSITING_RECT_NAME, + fill: rgbaColorToString(this.config.MASK_COLOR), globalCompositeOperation: 'source-atop', listening: false, strokeEnabled: false, @@ -124,9 +213,16 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { visible: false, }), }; + + // Mask group is below the point group this.konva.group.add(this.konva.maskGroup); this.konva.group.add(this.konva.pointGroup); + + // Compositing rect is added to the mask group - will also be above the mask image, but that doesn't get created + // until after processing this.konva.maskGroup.add(this.konva.compositingRect); + + // Scale the SAM points when the stage scale changes this.subscriptions.add( this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => { if (stageAttrs.scale !== oldStageAttrs.scale) { @@ -136,30 +232,43 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { ); } - syncCursorStyle = () => { + /** + * Synchronizes the cursor style to crosshair. + */ + syncCursorStyle = (): void => { this.manager.stage.setCursor('crosshair'); }; + /** + * Creates a SAM point at the given coordinate with the given label. -1 is background, 0 is neutral, 1 is foreground. + * @param coord The coordinate + * @param label The label. + * @returns The SAM point state. + */ createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState { const id = getPrefixedId('sam_point'); + const circle = new Konva.Circle({ name: this.KONVA_CIRCLE_NAME, x: Math.round(coord.x), y: Math.round(coord.y), - radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), + radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), // We will scale this as the stage scale changes fill: rgbaColorToString(this.getSAMPointColor(label)), stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR), - strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH), + strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH), // We will scale this as the stage scale changes draggable: true, - perfectDrawEnabled: true, + perfectDrawEnabled: true, // Required for the stroke/fill to draw correctly w/ partial opacity opacity: 0.6, dragDistance: 3, }); + // When the point is clicked, remove it circle.on('pointerup', (e) => { + // Ignore if we are dragging if (this.$isDraggingPoint.get()) { return; } + // This event should not bubble up to the parent, stage or any other nodes e.cancelBubble = true; circle.destroy(); this.points = this.points.filter((point) => point.id !== id); @@ -171,14 +280,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { circle.on('dragend', () => { this.$isDraggingPoint.set(false); + // Point has changed! + this.$hasProcessed.set(false); this.log.trace( { x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, - 'SAM point moved' + 'Moved SAM point' ); }); - circle.dragBoundFunc((pos) => floorCoord(pos)); - this.konva.pointGroup.add(circle); this.log.trace( @@ -193,22 +302,29 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { }; } + /** + * Synchronizes the scales of the SAM points to the stage scale. + * + * SAM points are always the same size, regardless of the stage scale. + */ syncPointScales = () => { const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS); const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH); - for (const { - konva: { circle }, - } of this.points) { - circle.radius(radius); - circle.strokeWidth(borderWidth); + for (const point of this.points) { + point.konva.circle.radius(radius); + point.konva.circle.strokeWidth(borderWidth); } }; + /** + * Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers. + */ getSAMPoints = (): SAMPoint[] => { const points: SAMPoint[] = []; for (const { konva, label } of this.points) { points.push({ + // Pull out and round the x and y values from Konva x: Math.round(konva.circle.x()), y: Math.round(konva.circle.y()), label, @@ -218,16 +334,31 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { return points; }; - onPointerUp = (e: KonvaEventObject) => { + /** + * Handles the pointerup event on the stage. This is used to add a SAM point to the module. + */ + onStagePointerUp = (e: KonvaEventObject) => { + // Only handle left-clicks if (e.evt.button !== 0) { return; } + + // Ignore if the stage is dragging/panning if (this.manager.stage.getIsDragging()) { return; } + + // Ignore if a point is being dragged if (this.$isDraggingPoint.get()) { return; } + + // Ignore if we are already processing + if (this.$isProcessing.get()) { + return; + } + + // Ignore if the cursor is not within the stage (should never happen) const cursorPos = this.manager.tool.$cursorPos.get(); if (!cursorPos) { return; @@ -235,21 +366,36 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { // We need to offset the cursor position by the parent entity's position + pixel rect to get the correct position const pixelRect = this.parent.transformer.$pixelRect.get(); - const position = addCoords(this.parent.state.position, pixelRect); + const parentPosition = addCoords(this.parent.state.position, pixelRect); - const normalizedPoint = offsetCoord(cursorPos.relative, position); - const samPoint = this.createPoint(normalizedPoint, this.$pointType.get()); - this.points.push(samPoint); + // Normalize the cursor position to the parent entity's position + const normalizedPoint = offsetCoord(cursorPos.relative, parentPosition); + + // Create a SAM point at the normalized position + const point = this.createPoint(normalizedPoint, this.$pointType.get()); + this.points.push(point); + + // Mark the module as having _not_ processed the points now that they have changed + this.$hasProcessed.set(false); }; - setSegmentingEventListeners = () => { - this.manager.stage.konva.stage.on('pointerup', this.onPointerUp); + /** + * Adds Konva stage event listeners for segmenting the entity. + */ + addStageEventListeners = () => { + this.manager.stage.konva.stage.on('pointerup', this.onStagePointerUp); }; - removeSegmentingEventListeners = () => { - this.manager.stage.konva.stage.off('pointerup', this.onPointerUp); + /** + * Removes Konva stage event listeners for segmenting the entity. + */ + removeStageEventListeners = () => { + this.manager.stage.konva.stage.off('pointerup', this.onStagePointerUp); }; + /** + * Starts the segmenting process. + */ start = () => { const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get(); if (segmentingAdapter) { @@ -257,23 +403,35 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { return; } this.log.trace('Starting segment anything'); - this.$pointType.set(1); + + // Reset the module's state + this.resetEphemeralState(); this.$isSegmenting.set(true); - this.manager.stateApi.$segmentingAdapter.set(this.parent); - for (const point of this.points) { - point.konva.circle.destroy(); - } - this.points = []; + // Update the konva group's position to match the parent entity const pixelRect = this.parent.transformer.$pixelRect.get(); const position = addCoords(this.parent.state.position, pixelRect); this.konva.group.setAttrs(position); + + // Add the module's Konva group to the parent adapter's layer so it is rendered this.parent.konva.layer.add(this.konva.group); + + // Enable listening on the parent adapter's layer so the module can receive pointer events this.parent.konva.layer.listening(true); - this.setSegmentingEventListeners(); + // Set up the segmenting event listeners (e.g. window pointerup) + this.addStageEventListeners(); + + // Set the global segmenting adapter to this module + this.manager.stateApi.$segmentingAdapter.set(this.parent); + + // Sync the cursor style to crosshair + this.syncCursorStyle(); }; + /** + * Processes the SAM points to segment the entity, updating the module's state and rendering the mask. + */ process = async () => { this.log.trace({ points: this.getSAMPoints() }, 'Segmenting'); const rect = this.parent.transformer.getRelativeRect(); @@ -293,25 +451,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { const controller = new AbortController(); this.abortController = controller; - const g = new Graph(getPrefixedId('canvas_segment_anything')); - const segmentAnything = g.addNode({ - id: getPrefixedId('segment_anything_object_identifier'), - type: 'segment_anything_object_identifier', - model: 'segment-anything-huge', - image: { image_name: rasterizeResult.value.image_name }, - object_identifiers: [{ points: this.getSAMPoints() }], - }); - const applyMask = g.addNode({ - id: getPrefixedId('apply_tensor_mask_to_image'), - type: 'apply_tensor_mask_to_image', - image: { image_name: rasterizeResult.value.image_name }, - }); - g.addEdge(segmentAnything, 'mask', applyMask, 'mask'); + const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value); const segmentResult = await withResultAsync(() => this.manager.stateApi.runGraphAndReturnImageOutput({ - graph: g, - outputNodeId: applyMask.id, + graph, + outputNodeId, prepend: true, signal: controller.signal, }) @@ -345,13 +490,20 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { this.abortController = null; }; + /** + * Applies the segmented image to the entity. + */ apply = () => { + if (!this.$hasProcessed.get()) { + this.log.error('Cannot apply unprocessed points'); + return; + } const imageState = this.imageState; if (!imageState) { this.log.error('No image state to apply'); return; } - this.log.trace('Applying segment anything'); + this.log.trace('Applying'); this.parent.bufferRenderer.commitBuffer(); const rect = this.parent.transformer.getRelativeRect(); this.manager.stateApi.rasterizeEntity({ @@ -363,58 +515,118 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { }, replaceObjects: true, }); - this.imageState = null; - for (const point of this.points) { - point.konva.circle.destroy(); - } - this.points = []; - if (this.maskedImage) { - this.maskedImage.destroy(); - } - this.konva.compositingRect.visible(false); - this.konva.maskGroup.clearCache(); - this.$hasProcessed.set(false); - this.manager.stateApi.$segmentingAdapter.set(null); - this.konva.group.remove(); - this.parent.konva.layer.listening(false); - this.removeSegmentingEventListeners(); - this.$isSegmenting.set(false); + this.resetEphemeralState(); + this.teardown(); }; + /** + * Resets the module (e.g. remove all points and the mask image). + * + * Does not cancel or otherwise complete the segmenting process. + */ reset = () => { - this.log.trace('Resetting segment anything'); + this.log.trace('Resetting'); + this.resetEphemeralState(); + }; - for (const point of this.points) { - point.konva.circle.destroy(); - } - this.points = []; - if (this.maskedImage) { - this.maskedImage.destroy(); - } - this.konva.compositingRect.visible(false); - this.konva.maskGroup.clearCache(); + /** + * Cancels the segmenting process. + */ + cancel = () => { + this.log.trace('Canceling'); + this.resetEphemeralState(); + this.teardown(); + }; + /** + * Performs teardown of the module. This shared logic is used for canceling and applying - when the segmenting is + * complete and the module is deactivated. + * + * This method: + * - Removes the module's main Konva node from the parent adapter's layer + * - Removes segmenting event listeners (e.g. window pointerup) + * - Resets the segmenting state + * - Resets the global segmenting adapter + */ + teardown = () => { + this.konva.group.remove(); + this.removeStageEventListeners(); + this.$isSegmenting.set(false); + this.manager.stateApi.$segmentingAdapter.set(null); + }; + + /** + * Resets the module's ephemeral state. This shared logic is used for resetting, canceling, and applying. + * + * This method: + * - Aborts any processing + * - Destroys ephemeral Konva nodes + * - Resets internal module state + * - Resets non-ephemeral Konva nodes + * - Clears the parent module's buffer + */ + resetEphemeralState = () => { + // First we need to bail out of any processing this.abortController?.abort(); this.abortController = null; - this.parent.bufferRenderer.clearBuffer(); - this.parent.transformer.updatePosition(); - this.parent.renderer.syncKonvaCache(true); + + // Destroy ephemeral konva nodes + for (const point of this.points) { + point.konva.circle.destroy(); + } + if (this.maskedImage) { + this.maskedImage.destroy(); + } + + // Empty internal module state + this.points = []; this.imageState = null; + this.$pointType.set(1); this.$hasProcessed.set(false); - }; - - cancel = () => { - this.log.trace('Stopping segment anything'); - this.reset(); this.$isProcessing.set(false); - this.$hasProcessed.set(false); - this.manager.stateApi.$segmentingAdapter.set(null); - this.konva.group.remove(); - this.parent.konva.layer.listening(false); - this.removeSegmentingEventListeners(); - this.$isSegmenting.set(false); + + // Reset non-ephemeral konva nodes + this.konva.compositingRect.visible(false); + this.konva.maskGroup.clearCache(); + + // The parent module's buffer should be reset & forcibly sync the cache + this.parent.bufferRenderer.clearBuffer(); + this.parent.renderer.syncKonvaCache(true); }; + /** + * Builds a graph for segmenting an image with the given image DTO. + */ + buildGraph = ({ image_name }: ImageDTO): { graph: Graph; outputNodeId: string } => { + const graph = new Graph(getPrefixedId('canvas_segment_anything')); + + // TODO(psyche): When SAM2 is available in transformers, use it here + // See: https://github.com/huggingface/transformers/pull/32317 + const segmentAnything = graph.addNode({ + id: getPrefixedId('segment_anything_object_identifier'), + type: 'segment_anything_object_identifier', + model: 'segment-anything-huge', + image: { image_name }, + object_identifiers: [{ points: this.getSAMPoints() }], + }); + + // Apply the mask to the image, outputting an image w/ alpha transparency + const applyMask = graph.addNode({ + id: getPrefixedId('apply_tensor_mask_to_image'), + type: 'apply_tensor_mask_to_image', + image: { image_name }, + }); + graph.addEdge(segmentAnything, 'mask', applyMask, 'mask'); + + return { + graph, + outputNodeId: applyMask.id, + }; + }; + + /** + * Gets the color of a SAM point based on its label. + */ getSAMPointColor(label: SAMPointLabel): RgbaColor { if (label === 0) { return this.config.SAM_POINT_NEUTRAL_COLOR; @@ -451,7 +663,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase { destroy = () => { this.log.debug('Destroying module'); this.subscriptions.forEach((unsubscribe) => unsubscribe()); - this.removeSegmentingEventListeners(); + this.removeStageEventListeners(); this.konva.group.destroy(); }; } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 15f4e63563..2a06ac3241 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -108,12 +108,18 @@ export type SAMPointLabel = z.infer; export const zSAMPointLabelString = z.enum(['background', 'neutral', 'foreground']); export type SAMPointLabelString = z.infer; +/** + * A mapping of SAM point labels (as numbers) to their string representations. + */ export const SAM_POINT_LABEL_NUMBER_TO_STRING: Record = { '-1': 'background', 0: 'neutral', 1: 'foreground', }; +/** + * A mapping of SAM point labels (as strings) to their numeric representations. + */ export const SAM_POINT_LABEL_STRING_TO_NUMBER: Record = { background: -1, neutral: 0,