feat(ui): more resilient auto-masking processing

- Use a hash of the last processed points instead of a `hasProcessed` flag to determine whether or not we should re-process a given set of points.
- Store point coords in state instead of pulling them out of the konva node positions. This makes moving a point a more explicit action in code.
- Add a `roundCoord` util to round the x and y values of a coordinate.
- Ensure we always re-process when $points changes.
This commit is contained in:
psychedelicious
2024-10-24 07:50:31 +10:00
committed by Kent Keirsey
parent 5764e4f7f2
commit 175a9dc28d
3 changed files with 66 additions and 44 deletions

View File

@@ -6,15 +6,21 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
import {
addCoords,
getKonvaNodeDebugAttrs,
getPrefixedId,
offsetCoord,
roundCoord,
} from 'features/controlLayers/konva/util';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import type {
CanvasImageState,
Coordinate,
RgbaColor,
SAMPoint,
SAMPointLabel,
SAMPointLabelString,
SAMPointWithId,
} from 'features/controlLayers/store/types';
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
@@ -27,6 +33,7 @@ import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
type CanvasSegmentAnythingModuleConfig = {
/**
@@ -85,6 +92,7 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
type SAMPointState = {
id: string;
label: SAMPointLabel;
coord: Coordinate;
konva: {
circle: Konva.Circle;
};
@@ -113,9 +121,9 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
$isSegmenting = atom<boolean>(false);
/**
* Whether the current set of points has been processed.
* The hash of the last processed points. This is used to prevent re-processing the same points.
*/
$hasProcessed = atom<boolean>(false);
$lastProcessedHash = atom<string>('');
/**
* Whether the module is currently processing the points.
@@ -147,7 +155,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
imageState: CanvasImageState | null = null;
/**
* The current input points.
* The current input points. A listener is added to this atom to process the points when they change.
*/
$points = atom<SAMPointState[]>([]);
@@ -250,10 +258,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
const id = getPrefixedId('sam_point');
const roundedCoord = roundCoord(coord);
const circle = new Konva.Circle({
name: this.KONVA_CIRCLE_NAME,
x: Math.round(coord.x),
y: Math.round(coord.y),
x: roundedCoord.x,
y: roundedCoord.y,
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), // We will scale this as the stage scale changes
fill: rgbaColorToString(this.getSAMPointColor(label)),
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
@@ -273,11 +283,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// This event should not bubble up to the parent, stage or any other nodes
e.cancelBubble = true;
circle.destroy();
this.$points.set(this.$points.get().filter((point) => point.id !== id));
if (this.$points.get().length === 0) {
const newPoints = this.$points.get().filter((point) => point.id !== id);
if (newPoints.length === 0) {
this.resetEphemeralState();
} else {
this.$hasProcessed.set(false);
this.$points.set(newPoints);
}
});
@@ -286,25 +297,28 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
});
circle.on('dragend', () => {
const roundedCoord = roundCoord(circle.position());
this.log.trace({ ...roundedCoord, label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, 'Moved SAM point');
this.$isDraggingPoint.set(false);
// Point has changed!
this.$hasProcessed.set(false);
this.$points.notify();
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Moved SAM point'
);
const newPoints = this.$points.get().map((point) => {
if (point.id === id) {
return { ...point, coord: roundedCoord };
}
return point;
});
this.$points.set(newPoints);
});
this.konva.pointGroup.add(circle);
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Created SAM point'
);
this.log.trace({ ...roundedCoord, label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, 'Created SAM point');
return {
id,
coord: roundedCoord,
label,
konva: { circle },
};
@@ -327,14 +341,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
/**
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
*/
getSAMPoints = (): SAMPoint[] => {
const points: SAMPoint[] = [];
getSAMPoints = (): SAMPointWithId[] => {
const points: SAMPointWithId[] = [];
for (const { konva, label } of this.$points.get()) {
for (const { id, coord, label } of this.$points.get()) {
points.push({
// Pull out and round the x and y values from Konva
x: Math.round(konva.circle.x()),
y: Math.round(konva.circle.y()),
id,
x: coord.x,
y: coord.y,
label,
});
}
@@ -381,10 +395,8 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// Create a SAM point at the normalized position
const point = this.createPoint(normalizedPoint, this.$pointType.get());
this.$points.set([...this.$points.get(), point]);
// Mark the module as having _not_ processed the points now that they have changed
this.$hasProcessed.set(false);
const newPoints = [...this.$points.get(), point];
this.$points.set(newPoints);
};
/**
@@ -421,6 +433,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
if (points.length === 0) {
return;
}
if (this.manager.stateApi.getSettings().autoProcess) {
this.process();
}
@@ -433,7 +446,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
if (this.$points.get().length === 0) {
return;
}
if (autoProcess && !this.$hasProcessed.get()) {
if (autoProcess) {
this.process();
}
})
@@ -500,6 +513,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return;
}
const hash = stableHash(points);
if (hash === this.$lastProcessedHash.get()) {
this.log.trace('Already processed points');
return;
}
this.$isProcessing.set(true);
this.log.trace({ points }, 'Segmenting');
@@ -521,7 +540,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.abortController = controller;
// Build the graph for segmenting the image, using the rasterized image DTO
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value);
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value, points);
// Run the graph and get the segmented image output
const segmentResult = await withResultAsync(() =>
@@ -574,12 +593,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// Cache the group to ensure the mask is rendered correctly w/ opacity
this.konva.maskGroup.cache();
this.$lastProcessedHash.set(hash);
// We are done processing (still segmenting though!)
this.$isProcessing.set(false);
// The current points have been processed
this.$hasProcessed.set(true);
// Clean up the abort controller as needed
if (!this.abortController.signal.aborted) {
this.abortController.abort();
@@ -596,10 +614,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
* Applies the segmented image to the entity.
*/
apply = () => {
if (!this.$hasProcessed.get()) {
this.log.error('Cannot apply unprocessed points');
return;
}
const imageState = this.imageState;
if (!imageState) {
this.log.error('No image state to apply');
@@ -691,7 +705,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.$points.set([]);
this.imageState = null;
this.$pointType.set(1);
this.$hasProcessed.set(false);
this.$lastProcessedHash.set('');
this.$isProcessing.set(false);
// Reset non-ephemeral konva nodes
@@ -706,7 +720,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
/**
* Builds a graph for segmenting an image with the given image DTO.
*/
buildGraph = ({ image_name }: ImageDTO): { graph: Graph; outputNodeId: string } => {
buildGraph = ({ image_name }: ImageDTO, points: SAMPointWithId[]): { graph: Graph; outputNodeId: string } => {
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
// TODO(psyche): When SAM2 is available in transformers, use it here
@@ -716,7 +730,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
type: 'segment_anything',
model: 'segment-anything-huge',
image: { image_name },
point_lists: [{ points: this.getSAMPoints() }],
point_lists: [{ points: points.map(({ x, y, label }) => ({ x, y, label })) }],
mask_filter: 'largest',
});
@@ -763,7 +777,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
maskedImage: this.maskedImage?.repr(),
config: deepClone(this.config),
$isSegmenting: this.$isSegmenting.get(),
$hasProcessed: this.$hasProcessed.get(),
$lastProcessedHash: this.$lastProcessedHash.get(),
$isProcessing: this.$isProcessing.get(),
$pointType: this.$pointType.get(),
$pointTypeString: this.$pointTypeString.get(),

View File

@@ -126,6 +126,13 @@ export const floorCoord = (coord: Coordinate): Coordinate => {
};
};
export const roundCoord = (coord: Coordinate): Coordinate => {
return {
x: Math.round(coord.x),
y: Math.round(coord.y),
};
};
/**
* Snaps a position to the edge of the given rect if within a threshold of the edge
* @param pos The position to snap

View File

@@ -132,6 +132,7 @@ const zSAMPoint = z.object({
label: zSAMPointLabel,
});
export type SAMPoint = z.infer<typeof zSAMPoint>;
export type SAMPointWithId = SAMPoint & { id: string };
const zRect = z.object({
x: z.number(),