tidy(ui): clean up and document CanvasSegmentAnythingModule

This commit is contained in:
psychedelicious
2024-10-23 11:52:00 +10:00
parent 38b09d73e4
commit bc85bd4bd4
2 changed files with 319 additions and 101 deletions

View File

@@ -6,19 +6,14 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import {
addCoords,
floorCoord,
getKonvaNodeDebugAttrs,
getPrefixedId,
offsetCoord,
} from 'features/controlLayers/konva/util';
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
import type {
CanvasImageState,
Coordinate,
RgbaColor,
SAMPoint,
SAMPointLabel,
SAMPointLabelString,
} from 'features/controlLayers/store/types';
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
@@ -29,14 +24,40 @@ import type { Atom } from 'nanostores';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import type { ImageDTO } from 'services/api/types';
type CanvasSegmentAnythingModuleConfig = {
/**
* The radius of the SAM point Konva circle node.
*/
SAM_POINT_RADIUS: number;
/**
* The border width of the SAM point Konva circle node.
*/
SAM_POINT_BORDER_WIDTH: number;
/**
* The border color of the SAM point Konva circle node.
*/
SAM_POINT_BORDER_COLOR: RgbaColor;
/**
* The color of the SAM point Konva circle node when the label is 1.
*/
SAM_POINT_FOREGROUND_COLOR: RgbaColor;
/**
* The color of the SAM point Konva circle node when the label is -1.
*/
SAM_POINT_BACKGROUND_COLOR: RgbaColor;
/**
* The color of the SAM point Konva circle node when the label is 0.
*/
SAM_POINT_NEUTRAL_COLOR: RgbaColor;
/**
* The color to use for the mask preview overlay.
*/
MASK_COLOR: RgbaColor;
/**
* The debounce time in milliseconds for processing the points.
*/
PROCESS_DEBOUNCE_MS: number;
};
@@ -44,12 +65,21 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
SAM_POINT_RADIUS: 8,
SAM_POINT_BORDER_WIDTH: 2,
SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 },
SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // green-ish
SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // light green
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan-ish
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
MASK_COLOR: { r: 0, g: 200, b: 200, a: 0.5 }, // cyan with 50% opacity
PROCESS_DEBOUNCE_MS: 300,
};
/**
* The state of a SAM point.
* @property id - The unique identifier of the point.
* @property label - The label of the point. -1 is background, 0 is neutral, 1 is foreground.
* @property konva - The Konva node state of the point.
* @property konva.circle - The Konva circle node of the point. The x and y coordinates for the point are derived from
* this node.
*/
type SAMPointState = {
id: string;
label: SAMPointLabel;
@@ -75,30 +105,88 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
*/
abortController: AbortController | null = null;
/**
* Whether the module is currently segmenting an entity.
*/
$isSegmenting = atom<boolean>(false);
/**
* Whether the current set of points has been processed.
*/
$hasProcessed = atom<boolean>(false);
/**
* Whether the module is currently processing the points.
*/
$isProcessing = atom<boolean>(false);
/**
* The type of point to create when segmenting. This is a number representation of the SAMPointLabel enum.
*/
$pointType = atom<SAMPointLabel>(1);
$pointTypeEnglish = computed<(typeof SAM_POINT_LABEL_NUMBER_TO_STRING)[SAMPointLabel], Atom<SAMPointLabel>>(
/**
* The type of point to create when segmenting, as a string. This is a computed value based on $pointType.
*/
$pointTypeEnglish = computed<SAMPointLabelString, Atom<SAMPointLabel>>(
this.$pointType,
(pointType) => SAM_POINT_LABEL_NUMBER_TO_STRING[pointType]
);
/**
* Whether a point is currently being dragged. This is used to prevent the point additions and deletions during
* dragging.
*/
$isDraggingPoint = atom<boolean>(false);
/**
* The ephemeral image state of the processed image. Only used while segmenting.
*/
imageState: CanvasImageState | null = null;
/**
* The current input points.
*/
points: SAMPointState[] = [];
/**
* The masked image object, if it exists.
*/
maskedImage: CanvasObjectImage | null = null;
/**
* The Konva nodes for the module.
*/
konva: {
/**
* The main Konva group node for the module.
*/
group: Konva.Group;
/**
* The Konva group node for the SAM points.
*
* This is a child of the main group node, rendered above the mask group.
*/
pointGroup: Konva.Group;
/**
* The Konva group node for the mask image and compositing rect.
*
* This is a child of the main group node, rendered below the point group.
*/
maskGroup: Konva.Group;
/**
* The Konva rect node for compositing the mask image.
*
* It's rendered with a globalCompositeOperation of 'source-atop' to preview the mask as a semi-transparent overlay.
*/
compositingRect: Konva.Rect;
};
KONVA_CIRCLE_NAME = `${this.type}:circle`;
KONVA_GROUP_NAME = `${this.type}:group`;
KONVA_POINT_GROUP_NAME = `${this.type}:point_group`;
KONVA_MASK_GROUP_NAME = `${this.type}:mask_group`;
KONVA_COMPOSITING_RECT_NAME = `${this.type}:compositing_rect`;
constructor(parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer) {
super();
@@ -110,13 +198,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.log.debug('Creating module');
// Create all konva nodes
this.konva = {
group: new Konva.Group({ name: `${this.type}:group` }),
pointGroup: new Konva.Group({ name: `${this.type}:point_group` }),
maskGroup: new Konva.Group({ name: `${this.type}:mask_group` }),
group: new Konva.Group({ name: this.KONVA_GROUP_NAME }),
pointGroup: new Konva.Group({ name: this.KONVA_POINT_GROUP_NAME }),
maskGroup: new Konva.Group({ name: this.KONVA_MASK_GROUP_NAME }),
compositingRect: new Konva.Rect({
name: `${this.type}:compositingRect`,
fill: rgbaColorToString({ r: 0, g: 200, b: 200, a: 0.5 }),
name: this.KONVA_COMPOSITING_RECT_NAME,
fill: rgbaColorToString(this.config.MASK_COLOR),
globalCompositeOperation: 'source-atop',
listening: false,
strokeEnabled: false,
@@ -124,9 +213,16 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
visible: false,
}),
};
// Mask group is below the point group
this.konva.group.add(this.konva.maskGroup);
this.konva.group.add(this.konva.pointGroup);
// Compositing rect is added to the mask group - will also be above the mask image, but that doesn't get created
// until after processing
this.konva.maskGroup.add(this.konva.compositingRect);
// Scale the SAM points when the stage scale changes
this.subscriptions.add(
this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => {
if (stageAttrs.scale !== oldStageAttrs.scale) {
@@ -136,30 +232,43 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
);
}
syncCursorStyle = () => {
/**
* Synchronizes the cursor style to crosshair.
*/
syncCursorStyle = (): void => {
this.manager.stage.setCursor('crosshair');
};
/**
* Creates a SAM point at the given coordinate with the given label. -1 is background, 0 is neutral, 1 is foreground.
* @param coord The coordinate
* @param label The label.
* @returns The SAM point state.
*/
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
const id = getPrefixedId('sam_point');
const circle = new Konva.Circle({
name: this.KONVA_CIRCLE_NAME,
x: Math.round(coord.x),
y: Math.round(coord.y),
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS),
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), // We will scale this as the stage scale changes
fill: rgbaColorToString(this.getSAMPointColor(label)),
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH),
strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH), // We will scale this as the stage scale changes
draggable: true,
perfectDrawEnabled: true,
perfectDrawEnabled: true, // Required for the stroke/fill to draw correctly w/ partial opacity
opacity: 0.6,
dragDistance: 3,
});
// When the point is clicked, remove it
circle.on('pointerup', (e) => {
// Ignore if we are dragging
if (this.$isDraggingPoint.get()) {
return;
}
// This event should not bubble up to the parent, stage or any other nodes
e.cancelBubble = true;
circle.destroy();
this.points = this.points.filter((point) => point.id !== id);
@@ -171,14 +280,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
circle.on('dragend', () => {
this.$isDraggingPoint.set(false);
// Point has changed!
this.$hasProcessed.set(false);
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'SAM point moved'
'Moved SAM point'
);
});
circle.dragBoundFunc((pos) => floorCoord(pos));
this.konva.pointGroup.add(circle);
this.log.trace(
@@ -193,22 +302,29 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
};
}
/**
* Synchronizes the scales of the SAM points to the stage scale.
*
* SAM points are always the same size, regardless of the stage scale.
*/
syncPointScales = () => {
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH);
for (const {
konva: { circle },
} of this.points) {
circle.radius(radius);
circle.strokeWidth(borderWidth);
for (const point of this.points) {
point.konva.circle.radius(radius);
point.konva.circle.strokeWidth(borderWidth);
}
};
/**
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
*/
getSAMPoints = (): SAMPoint[] => {
const points: SAMPoint[] = [];
for (const { konva, label } of this.points) {
points.push({
// Pull out and round the x and y values from Konva
x: Math.round(konva.circle.x()),
y: Math.round(konva.circle.y()),
label,
@@ -218,16 +334,31 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return points;
};
onPointerUp = (e: KonvaEventObject<PointerEvent>) => {
/**
* Handles the pointerup event on the stage. This is used to add a SAM point to the module.
*/
onStagePointerUp = (e: KonvaEventObject<PointerEvent>) => {
// Only handle left-clicks
if (e.evt.button !== 0) {
return;
}
// Ignore if the stage is dragging/panning
if (this.manager.stage.getIsDragging()) {
return;
}
// Ignore if a point is being dragged
if (this.$isDraggingPoint.get()) {
return;
}
// Ignore if we are already processing
if (this.$isProcessing.get()) {
return;
}
// Ignore if the cursor is not within the stage (should never happen)
const cursorPos = this.manager.tool.$cursorPos.get();
if (!cursorPos) {
return;
@@ -235,21 +366,36 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// We need to offset the cursor position by the parent entity's position + pixel rect to get the correct position
const pixelRect = this.parent.transformer.$pixelRect.get();
const position = addCoords(this.parent.state.position, pixelRect);
const parentPosition = addCoords(this.parent.state.position, pixelRect);
const normalizedPoint = offsetCoord(cursorPos.relative, position);
const samPoint = this.createPoint(normalizedPoint, this.$pointType.get());
this.points.push(samPoint);
// Normalize the cursor position to the parent entity's position
const normalizedPoint = offsetCoord(cursorPos.relative, parentPosition);
// Create a SAM point at the normalized position
const point = this.createPoint(normalizedPoint, this.$pointType.get());
this.points.push(point);
// Mark the module as having _not_ processed the points now that they have changed
this.$hasProcessed.set(false);
};
setSegmentingEventListeners = () => {
this.manager.stage.konva.stage.on('pointerup', this.onPointerUp);
/**
* Adds Konva stage event listeners for segmenting the entity.
*/
addStageEventListeners = () => {
this.manager.stage.konva.stage.on('pointerup', this.onStagePointerUp);
};
removeSegmentingEventListeners = () => {
this.manager.stage.konva.stage.off('pointerup', this.onPointerUp);
/**
* Removes Konva stage event listeners for segmenting the entity.
*/
removeStageEventListeners = () => {
this.manager.stage.konva.stage.off('pointerup', this.onStagePointerUp);
};
/**
* Starts the segmenting process.
*/
start = () => {
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
if (segmentingAdapter) {
@@ -257,23 +403,35 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return;
}
this.log.trace('Starting segment anything');
this.$pointType.set(1);
// Reset the module's state
this.resetEphemeralState();
this.$isSegmenting.set(true);
this.manager.stateApi.$segmentingAdapter.set(this.parent);
for (const point of this.points) {
point.konva.circle.destroy();
}
this.points = [];
// Update the konva group's position to match the parent entity
const pixelRect = this.parent.transformer.$pixelRect.get();
const position = addCoords(this.parent.state.position, pixelRect);
this.konva.group.setAttrs(position);
// Add the module's Konva group to the parent adapter's layer so it is rendered
this.parent.konva.layer.add(this.konva.group);
// Enable listening on the parent adapter's layer so the module can receive pointer events
this.parent.konva.layer.listening(true);
this.setSegmentingEventListeners();
// Set up the segmenting event listeners (e.g. window pointerup)
this.addStageEventListeners();
// Set the global segmenting adapter to this module
this.manager.stateApi.$segmentingAdapter.set(this.parent);
// Sync the cursor style to crosshair
this.syncCursorStyle();
};
/**
* Processes the SAM points to segment the entity, updating the module's state and rendering the mask.
*/
process = async () => {
this.log.trace({ points: this.getSAMPoints() }, 'Segmenting');
const rect = this.parent.transformer.getRelativeRect();
@@ -293,25 +451,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
const controller = new AbortController();
this.abortController = controller;
const g = new Graph(getPrefixedId('canvas_segment_anything'));
const segmentAnything = g.addNode({
id: getPrefixedId('segment_anything_object_identifier'),
type: 'segment_anything_object_identifier',
model: 'segment-anything-huge',
image: { image_name: rasterizeResult.value.image_name },
object_identifiers: [{ points: this.getSAMPoints() }],
});
const applyMask = g.addNode({
id: getPrefixedId('apply_tensor_mask_to_image'),
type: 'apply_tensor_mask_to_image',
image: { image_name: rasterizeResult.value.image_name },
});
g.addEdge(segmentAnything, 'mask', applyMask, 'mask');
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value);
const segmentResult = await withResultAsync(() =>
this.manager.stateApi.runGraphAndReturnImageOutput({
graph: g,
outputNodeId: applyMask.id,
graph,
outputNodeId,
prepend: true,
signal: controller.signal,
})
@@ -345,13 +490,20 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.abortController = null;
};
/**
* Applies the segmented image to the entity.
*/
apply = () => {
if (!this.$hasProcessed.get()) {
this.log.error('Cannot apply unprocessed points');
return;
}
const imageState = this.imageState;
if (!imageState) {
this.log.error('No image state to apply');
return;
}
this.log.trace('Applying segment anything');
this.log.trace('Applying');
this.parent.bufferRenderer.commitBuffer();
const rect = this.parent.transformer.getRelativeRect();
this.manager.stateApi.rasterizeEntity({
@@ -363,58 +515,118 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
},
replaceObjects: true,
});
this.imageState = null;
for (const point of this.points) {
point.konva.circle.destroy();
}
this.points = [];
if (this.maskedImage) {
this.maskedImage.destroy();
}
this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache();
this.$hasProcessed.set(false);
this.manager.stateApi.$segmentingAdapter.set(null);
this.konva.group.remove();
this.parent.konva.layer.listening(false);
this.removeSegmentingEventListeners();
this.$isSegmenting.set(false);
this.resetEphemeralState();
this.teardown();
};
/**
* Resets the module (e.g. remove all points and the mask image).
*
* Does not cancel or otherwise complete the segmenting process.
*/
reset = () => {
this.log.trace('Resetting segment anything');
this.log.trace('Resetting');
this.resetEphemeralState();
};
for (const point of this.points) {
point.konva.circle.destroy();
}
this.points = [];
if (this.maskedImage) {
this.maskedImage.destroy();
}
this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache();
/**
* Cancels the segmenting process.
*/
cancel = () => {
this.log.trace('Canceling');
this.resetEphemeralState();
this.teardown();
};
/**
* Performs teardown of the module. This shared logic is used for canceling and applying - when the segmenting is
* complete and the module is deactivated.
*
* This method:
* - Removes the module's main Konva node from the parent adapter's layer
* - Removes segmenting event listeners (e.g. window pointerup)
* - Resets the segmenting state
* - Resets the global segmenting adapter
*/
teardown = () => {
this.konva.group.remove();
this.removeStageEventListeners();
this.$isSegmenting.set(false);
this.manager.stateApi.$segmentingAdapter.set(null);
};
/**
* Resets the module's ephemeral state. This shared logic is used for resetting, canceling, and applying.
*
* This method:
* - Aborts any processing
* - Destroys ephemeral Konva nodes
* - Resets internal module state
* - Resets non-ephemeral Konva nodes
* - Clears the parent module's buffer
*/
resetEphemeralState = () => {
// First we need to bail out of any processing
this.abortController?.abort();
this.abortController = null;
this.parent.bufferRenderer.clearBuffer();
this.parent.transformer.updatePosition();
this.parent.renderer.syncKonvaCache(true);
// Destroy ephemeral konva nodes
for (const point of this.points) {
point.konva.circle.destroy();
}
if (this.maskedImage) {
this.maskedImage.destroy();
}
// Empty internal module state
this.points = [];
this.imageState = null;
this.$pointType.set(1);
this.$hasProcessed.set(false);
};
cancel = () => {
this.log.trace('Stopping segment anything');
this.reset();
this.$isProcessing.set(false);
this.$hasProcessed.set(false);
this.manager.stateApi.$segmentingAdapter.set(null);
this.konva.group.remove();
this.parent.konva.layer.listening(false);
this.removeSegmentingEventListeners();
this.$isSegmenting.set(false);
// Reset non-ephemeral konva nodes
this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache();
// The parent module's buffer should be reset & forcibly sync the cache
this.parent.bufferRenderer.clearBuffer();
this.parent.renderer.syncKonvaCache(true);
};
/**
* Builds a graph for segmenting an image with the given image DTO.
*/
buildGraph = ({ image_name }: ImageDTO): { graph: Graph; outputNodeId: string } => {
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
// TODO(psyche): When SAM2 is available in transformers, use it here
// See: https://github.com/huggingface/transformers/pull/32317
const segmentAnything = graph.addNode({
id: getPrefixedId('segment_anything_object_identifier'),
type: 'segment_anything_object_identifier',
model: 'segment-anything-huge',
image: { image_name },
object_identifiers: [{ points: this.getSAMPoints() }],
});
// Apply the mask to the image, outputting an image w/ alpha transparency
const applyMask = graph.addNode({
id: getPrefixedId('apply_tensor_mask_to_image'),
type: 'apply_tensor_mask_to_image',
image: { image_name },
});
graph.addEdge(segmentAnything, 'mask', applyMask, 'mask');
return {
graph,
outputNodeId: applyMask.id,
};
};
/**
* Gets the color of a SAM point based on its label.
*/
getSAMPointColor(label: SAMPointLabel): RgbaColor {
if (label === 0) {
return this.config.SAM_POINT_NEUTRAL_COLOR;
@@ -451,7 +663,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
destroy = () => {
this.log.debug('Destroying module');
this.subscriptions.forEach((unsubscribe) => unsubscribe());
this.removeSegmentingEventListeners();
this.removeStageEventListeners();
this.konva.group.destroy();
};
}

View File

@@ -108,12 +108,18 @@ export type SAMPointLabel = z.infer<typeof zSAMPointLabel>;
export const zSAMPointLabelString = z.enum(['background', 'neutral', 'foreground']);
export type SAMPointLabelString = z.infer<typeof zSAMPointLabelString>;
/**
* A mapping of SAM point labels (as numbers) to their string representations.
*/
export const SAM_POINT_LABEL_NUMBER_TO_STRING: Record<SAMPointLabel, SAMPointLabelString> = {
'-1': 'background',
0: 'neutral',
1: 'foreground',
};
/**
* A mapping of SAM point labels (as strings) to their numeric representations.
*/
export const SAM_POINT_LABEL_STRING_TO_NUMBER: Record<SAMPointLabelString, SAMPointLabel> = {
background: -1,
neutral: 0,