feat(ui): masking UX (wip - issue w/ positioning)

This commit is contained in:
psychedelicious
2024-10-22 19:09:50 +10:00
parent f666bac77f
commit 606c4ae88c
7 changed files with 125 additions and 51 deletions

View File

@@ -3,14 +3,14 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { zSAMPointLabel } from 'features/controlLayers/store/types';
import { SAM_POINT_LABEL_STRING_TO_NUMBER, zSAMPointLabelString } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const SegmentAnythingPointType = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const pointType = useStore(adapter.segmentAnything.$pointType);
const pointType = useStore(adapter.segmentAnything.$pointTypeEnglish);
const options = useMemo(() => {
return [
@@ -28,7 +28,9 @@ export const SegmentAnythingPointType = memo(
return;
}
adapter.segmentAnything.$pointType.set(zSAMPointLabel.parse(v.value));
const labelAsString = zSAMPointLabelString.parse(v.value);
const labelAsNumber = SAM_POINT_LABEL_STRING_TO_NUMBER[labelAsString];
adapter.segmentAnything.$pointType.set(labelAsNumber);
},
[adapter.segmentAnything.$pointType]
);

View File

@@ -6,7 +6,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { floorCoord, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
import type {
CanvasImageState,
Coordinate,
@@ -14,14 +14,15 @@ import type {
SAMPoint,
SAMPointLabel,
} from 'features/controlLayers/store/types';
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import { atom } from 'nanostores';
import type { Atom } from 'nanostores';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import type { S } from 'services/api/types';
type CanvasSegmentAnythingModuleConfig = {
SAM_POINT_RADIUS: number;
@@ -34,12 +35,12 @@ type CanvasSegmentAnythingModuleConfig = {
};
const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
SAM_POINT_RADIUS: 5,
SAM_POINT_RADIUS: 8,
SAM_POINT_BORDER_WIDTH: 2,
SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 },
SAM_POINT_FOREGROUND_COLOR: { r: 0, g: 200, b: 0, a: 0.7 },
SAM_POINT_BACKGROUND_COLOR: { r: 200, g: 0, b: 0, a: 0.7 },
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 0, b: 200, a: 0.7 },
SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // green-ish
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan-ish
PROCESS_DEBOUNCE_MS: 300,
};
@@ -72,7 +73,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
$hasProcessed = atom<boolean>(false);
$isProcessing = atom<boolean>(false);
$pointType = atom<SAMPointLabel>('foreground');
$pointType = atom<SAMPointLabel>(1);
$pointTypeEnglish = computed<(typeof SAM_POINT_LABEL_NUMBER_TO_STRING)[SAMPointLabel], Atom<SAMPointLabel>>(
this.$pointType,
(pointType) => SAM_POINT_LABEL_NUMBER_TO_STRING[pointType]
);
$isDraggingPoint = atom<boolean>(false);
imageState: CanvasImageState | null = null;
@@ -116,20 +121,33 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.konva.group.add(this.konva.maskGroup);
this.konva.group.add(this.konva.pointGroup);
this.konva.maskGroup.add(this.konva.compositingRect);
this.subscriptions.add(
this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => {
if (stageAttrs.scale !== oldStageAttrs.scale) {
this.syncPointScales();
}
})
);
}
syncCursorStyle = () => {
this.manager.stage.setCursor('crosshair');
};
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
const id = getPrefixedId('sam_point');
const circle = new Konva.Circle({
name: this.KONVA_CIRCLE_NAME,
x: Math.round(coord.x),
y: Math.round(coord.y),
radius: this.config.SAM_POINT_RADIUS,
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS),
fill: rgbaColorToString(this.getSAMPointColor(label)),
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
strokeWidth: this.config.SAM_POINT_BORDER_WIDTH,
strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH),
draggable: true,
perfectDrawEnabled: false,
perfectDrawEnabled: true,
opacity: 0.6,
dragDistance: 3,
});
circle.on('pointerup', (e) => {
@@ -147,35 +165,60 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
circle.on('dragend', () => {
this.$isDraggingPoint.set(false);
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'SAM point moved'
);
});
circle.dragBoundFunc(({ x, y }) => ({
x: Math.round(x),
y: Math.round(y),
}));
circle.dragBoundFunc((pos) => floorCoord(pos));
this.konva.pointGroup.add(circle);
const state: SAMPointState = {
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Created SAM point'
);
return {
id,
label,
konva: { circle },
};
return state;
}
syncPointScales = () => {
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH);
for (const {
konva: { circle },
} of this.points) {
circle.radius(radius);
circle.strokeWidth(borderWidth);
}
};
getSAMPoints = (): SAMPoint[] => {
return this.points.map(({ konva: { circle }, label }) => ({
x: circle.x(),
y: circle.y(),
label,
}));
const points: SAMPoint[] = [];
for (const { konva, label } of this.points) {
points.push({
x: Math.round(konva.circle.x()),
y: Math.round(konva.circle.y()),
label,
});
}
return points;
};
onPointerUp = (e: KonvaEventObject<PointerEvent>) => {
if (e.evt.button !== 0) {
return;
}
if (this.manager.stage.getIsDragging()) {
return;
}
if (this.$isDraggingPoint.get()) {
return;
}
@@ -184,7 +227,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return;
}
this.points.push(this.createPoint(cursorPos.relative, this.$pointType.get()));
const pixelRect = this.parent.transformer.$pixelRect.get();
const normalizedPoint = offsetCoord(cursorPos.relative, { x: pixelRect.x, y: pixelRect.y });
const samPoint = this.createPoint(normalizedPoint, this.$pointType.get());
this.points.push(samPoint);
};
setSegmentingEventListeners = () => {
@@ -202,6 +249,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return;
}
this.log.trace('Starting segment anything');
this.$pointType.set(1);
this.$isSegmenting.set(true);
this.manager.stateApi.$segmentingAdapter.set(this.parent);
for (const point of this.points) {
@@ -209,6 +257,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
}
this.points = [];
this.parent.konva.layer.add(this.konva.group);
console.log({
position: this.parent.state.position,
pixelRect: this.parent.transformer.$pixelRect.get(),
nodeRect: this.parent.transformer.$nodeRect.get(),
getRelativeRect: this.parent.transformer.getRelativeRect(),
});
const pixelRect = this.parent.transformer.$pixelRect.get();
this.konva.group.setAttrs({ x: pixelRect.x, y: pixelRect.y });
this.parent.konva.layer.listening(true);
this.setSegmentingEventListeners();
@@ -239,15 +295,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
type: 'segment_anything_object_identifier',
model: 'segment-anything-huge',
image: { image_name: rasterizeResult.value.image_name },
object_identifiers: [
{
points: this.getSAMPoints().map(({ x, y, label }): S['SAMPoint'] => ({
x,
y,
label: label === 'foreground' ? 1 : -1,
})),
},
],
object_identifiers: [{ points: this.getSAMPoints() }],
});
const applyMask = g.addNode({
id: getPrefixedId('apply_tensor_mask_to_image'),
@@ -321,7 +369,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
}
this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache();
this.$pointType.set('foreground');
this.$pointType.set(1);
this.$isSegmenting.set(false);
this.$hasProcessed.set(false);
this.manager.stateApi.$segmentingAdapter.set(null);
@@ -364,11 +412,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
};
getSAMPointColor(label: SAMPointLabel): RgbaColor {
if (label === 'neutral') {
if (label === 0) {
return this.config.SAM_POINT_NEUTRAL_COLOR;
} else if (label === 'foreground') {
} else if (label === 1) {
return this.config.SAM_POINT_FOREGROUND_COLOR;
} else {
// label === -1
return this.config.SAM_POINT_BACKGROUND_COLOR;
}
}

View File

@@ -311,7 +311,7 @@ export class CanvasStageModule extends CanvasModuleBase {
this.setIsDraggable(true);
// Then start dragging the stage if it's not already being dragged
if (!this.konva.stage.isDragging()) {
if (!this.getIsDragging()) {
this.konva.stage.startDrag();
}
@@ -328,7 +328,7 @@ export class CanvasStageModule extends CanvasModuleBase {
this.setIsDraggable(this.manager.tool.$tool.get() === 'view');
// Stop dragging the stage if it's being dragged
if (this.konva.stage.isDragging()) {
if (this.getIsDragging()) {
this.konva.stage.stopDrag();
}
@@ -404,6 +404,10 @@ export class CanvasStageModule extends CanvasModuleBase {
this.konva.stage.draggable(isDraggable);
};
getIsDragging = () => {
return this.konva.stage.isDragging();
};
addLayer = (layer: Konva.Layer) => {
this.konva.stage.add(layer);
};

View File

@@ -203,12 +203,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
* Renders the bbox. The bbox is only visible when the tool is set to 'bbox'.
*/
render = () => {
this.log.trace('Rendering');
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
const tool = this.manager.tool.$tool.get();
this.konva.group.visible(true);
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
// We need to reach up to the preview layer to enable/disable listening so that the bbox can be interacted with.
// If the mangaer is busy, we disable listening so the bbox cannot be interacted with.

View File

@@ -158,11 +158,12 @@ export class CanvasToolModule extends CanvasModuleBase {
syncCursorStyle = () => {
const stage = this.manager.stage;
const tool = this.$tool.get();
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
if (this.manager.stage.konva.stage.isDragging() || tool === 'view') {
if (this.manager.stage.getIsDragging() || tool === 'view') {
this.tools.view.syncCursorStyle();
} else if (this.manager.stateApi.$isSegmenting.get()) {
stage.setCursor('default');
} else if (segmentingAdapter) {
segmentingAdapter.segmentAnything.syncCursorStyle();
} else if (this.manager.stateApi.$isFiltering.get()) {
stage.setCursor('not-allowed');
} else if (this.manager.stagingArea.$isStaging.get()) {
@@ -284,7 +285,7 @@ export class CanvasToolModule extends CanvasModuleBase {
return false;
}
if (this.manager.stage.konva.stage.isDragging()) {
if (this.manager.stage.getIsDragging()) {
return false;
}

View File

@@ -24,6 +24,6 @@ export class CanvasViewToolModule extends CanvasModuleBase {
}
syncCursorStyle = () => {
this.manager.stage.setCursor(this.manager.stage.konva.stage.isDragging() ? 'grabbing' : 'grab');
this.manager.stage.setCursor(this.manager.stage.getIsDragging() ? 'grabbing' : 'grab');
};
}

View File

@@ -96,9 +96,30 @@ const zCoordinateWithPressure = z.object({
});
export type CoordinateWithPressure = z.infer<typeof zCoordinateWithPressure>;
export const zSAMPointLabel = z.enum(['foreground', 'background', 'neutral']);
const SAM_POINT_LABELS = {
background: -1,
neutral: 0,
foreground: 1,
} as const;
export const zSAMPointLabel = z.nativeEnum(SAM_POINT_LABELS);
export type SAMPointLabel = z.infer<typeof zSAMPointLabel>;
export const zSAMPointLabelString = z.enum(['background', 'neutral', 'foreground']);
export type SAMPointLabelString = z.infer<typeof zSAMPointLabelString>;
export const SAM_POINT_LABEL_NUMBER_TO_STRING: Record<SAMPointLabel, SAMPointLabelString> = {
'-1': 'background',
0: 'neutral',
1: 'foreground',
};
export const SAM_POINT_LABEL_STRING_TO_NUMBER: Record<SAMPointLabelString, SAMPointLabel> = {
background: -1,
neutral: 0,
foreground: 1,
};
export const zSAMPoint = z.object({
x: z.number().int().gte(0),
y: z.number().int().gte(0),