mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): more resilient auto-masking processing
- Use a hash of the last processed points instead of a `hasProcessed` flag to determine whether or not we should re-process a given set of points. - Store point coords in state instead of pulling them out of the konva node positions. This makes moving a point a more explicit action in code. - Add a `roundCoord` util to round the x and y values of a coordinate. - Ensure we always re-process when $points changes.
This commit is contained in:
committed by
Kent Keirsey
parent
5764e4f7f2
commit
175a9dc28d
@@ -6,15 +6,21 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
|
||||
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
addCoords,
|
||||
getKonvaNodeDebugAttrs,
|
||||
getPrefixedId,
|
||||
offsetCoord,
|
||||
roundCoord,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type {
|
||||
CanvasImageState,
|
||||
Coordinate,
|
||||
RgbaColor,
|
||||
SAMPoint,
|
||||
SAMPointLabel,
|
||||
SAMPointLabelString,
|
||||
SAMPointWithId,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
@@ -27,6 +33,7 @@ import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import stableHash from 'stable-hash';
|
||||
|
||||
type CanvasSegmentAnythingModuleConfig = {
|
||||
/**
|
||||
@@ -85,6 +92,7 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
|
||||
type SAMPointState = {
|
||||
id: string;
|
||||
label: SAMPointLabel;
|
||||
coord: Coordinate;
|
||||
konva: {
|
||||
circle: Konva.Circle;
|
||||
};
|
||||
@@ -113,9 +121,9 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
$isSegmenting = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* Whether the current set of points has been processed.
|
||||
* The hash of the last processed points. This is used to prevent re-processing the same points.
|
||||
*/
|
||||
$hasProcessed = atom<boolean>(false);
|
||||
$lastProcessedHash = atom<string>('');
|
||||
|
||||
/**
|
||||
* Whether the module is currently processing the points.
|
||||
@@ -147,7 +155,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
imageState: CanvasImageState | null = null;
|
||||
|
||||
/**
|
||||
* The current input points.
|
||||
* The current input points. A listener is added to this atom to process the points when they change.
|
||||
*/
|
||||
$points = atom<SAMPointState[]>([]);
|
||||
|
||||
@@ -250,10 +258,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
|
||||
const id = getPrefixedId('sam_point');
|
||||
|
||||
const roundedCoord = roundCoord(coord);
|
||||
|
||||
const circle = new Konva.Circle({
|
||||
name: this.KONVA_CIRCLE_NAME,
|
||||
x: Math.round(coord.x),
|
||||
y: Math.round(coord.y),
|
||||
x: roundedCoord.x,
|
||||
y: roundedCoord.y,
|
||||
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), // We will scale this as the stage scale changes
|
||||
fill: rgbaColorToString(this.getSAMPointColor(label)),
|
||||
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
|
||||
@@ -273,11 +283,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// This event should not bubble up to the parent, stage or any other nodes
|
||||
e.cancelBubble = true;
|
||||
circle.destroy();
|
||||
this.$points.set(this.$points.get().filter((point) => point.id !== id));
|
||||
if (this.$points.get().length === 0) {
|
||||
|
||||
const newPoints = this.$points.get().filter((point) => point.id !== id);
|
||||
if (newPoints.length === 0) {
|
||||
this.resetEphemeralState();
|
||||
} else {
|
||||
this.$hasProcessed.set(false);
|
||||
this.$points.set(newPoints);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -286,25 +297,28 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
});
|
||||
|
||||
circle.on('dragend', () => {
|
||||
const roundedCoord = roundCoord(circle.position());
|
||||
|
||||
this.log.trace({ ...roundedCoord, label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, 'Moved SAM point');
|
||||
this.$isDraggingPoint.set(false);
|
||||
// Point has changed!
|
||||
this.$hasProcessed.set(false);
|
||||
this.$points.notify();
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'Moved SAM point'
|
||||
);
|
||||
|
||||
const newPoints = this.$points.get().map((point) => {
|
||||
if (point.id === id) {
|
||||
return { ...point, coord: roundedCoord };
|
||||
}
|
||||
return point;
|
||||
});
|
||||
|
||||
this.$points.set(newPoints);
|
||||
});
|
||||
|
||||
this.konva.pointGroup.add(circle);
|
||||
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'Created SAM point'
|
||||
);
|
||||
this.log.trace({ ...roundedCoord, label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, 'Created SAM point');
|
||||
|
||||
return {
|
||||
id,
|
||||
coord: roundedCoord,
|
||||
label,
|
||||
konva: { circle },
|
||||
};
|
||||
@@ -327,14 +341,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
/**
|
||||
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
|
||||
*/
|
||||
getSAMPoints = (): SAMPoint[] => {
|
||||
const points: SAMPoint[] = [];
|
||||
getSAMPoints = (): SAMPointWithId[] => {
|
||||
const points: SAMPointWithId[] = [];
|
||||
|
||||
for (const { konva, label } of this.$points.get()) {
|
||||
for (const { id, coord, label } of this.$points.get()) {
|
||||
points.push({
|
||||
// Pull out and round the x and y values from Konva
|
||||
x: Math.round(konva.circle.x()),
|
||||
y: Math.round(konva.circle.y()),
|
||||
id,
|
||||
x: coord.x,
|
||||
y: coord.y,
|
||||
label,
|
||||
});
|
||||
}
|
||||
@@ -381,10 +395,8 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
|
||||
// Create a SAM point at the normalized position
|
||||
const point = this.createPoint(normalizedPoint, this.$pointType.get());
|
||||
this.$points.set([...this.$points.get(), point]);
|
||||
|
||||
// Mark the module as having _not_ processed the points now that they have changed
|
||||
this.$hasProcessed.set(false);
|
||||
const newPoints = [...this.$points.get(), point];
|
||||
this.$points.set(newPoints);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -421,6 +433,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
if (points.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.manager.stateApi.getSettings().autoProcess) {
|
||||
this.process();
|
||||
}
|
||||
@@ -433,7 +446,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
if (this.$points.get().length === 0) {
|
||||
return;
|
||||
}
|
||||
if (autoProcess && !this.$hasProcessed.get()) {
|
||||
if (autoProcess) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
@@ -500,6 +513,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
const hash = stableHash(points);
|
||||
if (hash === this.$lastProcessedHash.get()) {
|
||||
this.log.trace('Already processed points');
|
||||
return;
|
||||
}
|
||||
|
||||
this.$isProcessing.set(true);
|
||||
|
||||
this.log.trace({ points }, 'Segmenting');
|
||||
@@ -521,7 +540,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.abortController = controller;
|
||||
|
||||
// Build the graph for segmenting the image, using the rasterized image DTO
|
||||
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value);
|
||||
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value, points);
|
||||
|
||||
// Run the graph and get the segmented image output
|
||||
const segmentResult = await withResultAsync(() =>
|
||||
@@ -574,12 +593,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// Cache the group to ensure the mask is rendered correctly w/ opacity
|
||||
this.konva.maskGroup.cache();
|
||||
|
||||
this.$lastProcessedHash.set(hash);
|
||||
|
||||
// We are done processing (still segmenting though!)
|
||||
this.$isProcessing.set(false);
|
||||
|
||||
// The current points have been processed
|
||||
this.$hasProcessed.set(true);
|
||||
|
||||
// Clean up the abort controller as needed
|
||||
if (!this.abortController.signal.aborted) {
|
||||
this.abortController.abort();
|
||||
@@ -596,10 +614,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
* Applies the segmented image to the entity.
|
||||
*/
|
||||
apply = () => {
|
||||
if (!this.$hasProcessed.get()) {
|
||||
this.log.error('Cannot apply unprocessed points');
|
||||
return;
|
||||
}
|
||||
const imageState = this.imageState;
|
||||
if (!imageState) {
|
||||
this.log.error('No image state to apply');
|
||||
@@ -691,7 +705,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.$points.set([]);
|
||||
this.imageState = null;
|
||||
this.$pointType.set(1);
|
||||
this.$hasProcessed.set(false);
|
||||
this.$lastProcessedHash.set('');
|
||||
this.$isProcessing.set(false);
|
||||
|
||||
// Reset non-ephemeral konva nodes
|
||||
@@ -706,7 +720,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
/**
|
||||
* Builds a graph for segmenting an image with the given image DTO.
|
||||
*/
|
||||
buildGraph = ({ image_name }: ImageDTO): { graph: Graph; outputNodeId: string } => {
|
||||
buildGraph = ({ image_name }: ImageDTO, points: SAMPointWithId[]): { graph: Graph; outputNodeId: string } => {
|
||||
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
|
||||
|
||||
// TODO(psyche): When SAM2 is available in transformers, use it here
|
||||
@@ -716,7 +730,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
type: 'segment_anything',
|
||||
model: 'segment-anything-huge',
|
||||
image: { image_name },
|
||||
point_lists: [{ points: this.getSAMPoints() }],
|
||||
point_lists: [{ points: points.map(({ x, y, label }) => ({ x, y, label })) }],
|
||||
mask_filter: 'largest',
|
||||
});
|
||||
|
||||
@@ -763,7 +777,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
maskedImage: this.maskedImage?.repr(),
|
||||
config: deepClone(this.config),
|
||||
$isSegmenting: this.$isSegmenting.get(),
|
||||
$hasProcessed: this.$hasProcessed.get(),
|
||||
$lastProcessedHash: this.$lastProcessedHash.get(),
|
||||
$isProcessing: this.$isProcessing.get(),
|
||||
$pointType: this.$pointType.get(),
|
||||
$pointTypeString: this.$pointTypeString.get(),
|
||||
|
||||
@@ -126,6 +126,13 @@ export const floorCoord = (coord: Coordinate): Coordinate => {
|
||||
};
|
||||
};
|
||||
|
||||
export const roundCoord = (coord: Coordinate): Coordinate => {
|
||||
return {
|
||||
x: Math.round(coord.x),
|
||||
y: Math.round(coord.y),
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Snaps a position to the edge of the given rect if within a threshold of the edge
|
||||
* @param pos The position to snap
|
||||
|
||||
@@ -132,6 +132,7 @@ const zSAMPoint = z.object({
|
||||
label: zSAMPointLabel,
|
||||
});
|
||||
export type SAMPoint = z.infer<typeof zSAMPoint>;
|
||||
export type SAMPointWithId = SAMPoint & { id: string };
|
||||
|
||||
const zRect = z.object({
|
||||
x: z.number(),
|
||||
|
||||
Reference in New Issue
Block a user