feat(ui): masking UX (wip - issue w/ positioning)

This commit is contained in:
psychedelicious
2024-10-22 19:09:50 +10:00
parent f666bac77f
commit 606c4ae88c
7 changed files with 125 additions and 51 deletions

View File

@@ -3,14 +3,14 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { zSAMPointLabel } from 'features/controlLayers/store/types'; import { SAM_POINT_LABEL_STRING_TO_NUMBER, zSAMPointLabelString } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
export const SegmentAnythingPointType = memo( export const SegmentAnythingPointType = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => { ({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation(); const { t } = useTranslation();
const pointType = useStore(adapter.segmentAnything.$pointType); const pointType = useStore(adapter.segmentAnything.$pointTypeEnglish);
const options = useMemo(() => { const options = useMemo(() => {
return [ return [
@@ -28,7 +28,9 @@ export const SegmentAnythingPointType = memo(
return; return;
} }
adapter.segmentAnything.$pointType.set(zSAMPointLabel.parse(v.value)); const labelAsString = zSAMPointLabelString.parse(v.value);
const labelAsNumber = SAM_POINT_LABEL_STRING_TO_NUMBER[labelAsString];
adapter.segmentAnything.$pointType.set(labelAsNumber);
}, },
[adapter.segmentAnything.$pointType] [adapter.segmentAnything.$pointType]
); );

View File

@@ -6,7 +6,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util'; import { floorCoord, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
import type { import type {
CanvasImageState, CanvasImageState,
Coordinate, Coordinate,
@@ -14,14 +14,15 @@ import type {
SAMPoint, SAMPoint,
SAMPointLabel, SAMPointLabel,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util'; import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import Konva from 'konva'; import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node'; import type { KonvaEventObject } from 'konva/lib/Node';
import { atom } from 'nanostores'; import type { Atom } from 'nanostores';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr'; import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error'; import { serializeError } from 'serialize-error';
import type { S } from 'services/api/types';
type CanvasSegmentAnythingModuleConfig = { type CanvasSegmentAnythingModuleConfig = {
SAM_POINT_RADIUS: number; SAM_POINT_RADIUS: number;
@@ -34,12 +35,12 @@ type CanvasSegmentAnythingModuleConfig = {
}; };
const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = { const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
SAM_POINT_RADIUS: 5, SAM_POINT_RADIUS: 8,
SAM_POINT_BORDER_WIDTH: 2, SAM_POINT_BORDER_WIDTH: 2,
SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 }, SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 },
SAM_POINT_FOREGROUND_COLOR: { r: 0, g: 200, b: 0, a: 0.7 }, SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // green-ish
SAM_POINT_BACKGROUND_COLOR: { r: 200, g: 0, b: 0, a: 0.7 }, SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 0, b: 200, a: 0.7 }, SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan-ish
PROCESS_DEBOUNCE_MS: 300, PROCESS_DEBOUNCE_MS: 300,
}; };
@@ -72,7 +73,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
$hasProcessed = atom<boolean>(false); $hasProcessed = atom<boolean>(false);
$isProcessing = atom<boolean>(false); $isProcessing = atom<boolean>(false);
$pointType = atom<SAMPointLabel>('foreground'); $pointType = atom<SAMPointLabel>(1);
$pointTypeEnglish = computed<(typeof SAM_POINT_LABEL_NUMBER_TO_STRING)[SAMPointLabel], Atom<SAMPointLabel>>(
this.$pointType,
(pointType) => SAM_POINT_LABEL_NUMBER_TO_STRING[pointType]
);
$isDraggingPoint = atom<boolean>(false); $isDraggingPoint = atom<boolean>(false);
imageState: CanvasImageState | null = null; imageState: CanvasImageState | null = null;
@@ -116,20 +121,33 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.konva.group.add(this.konva.maskGroup); this.konva.group.add(this.konva.maskGroup);
this.konva.group.add(this.konva.pointGroup); this.konva.group.add(this.konva.pointGroup);
this.konva.maskGroup.add(this.konva.compositingRect); this.konva.maskGroup.add(this.konva.compositingRect);
this.subscriptions.add(
this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => {
if (stageAttrs.scale !== oldStageAttrs.scale) {
this.syncPointScales();
}
})
);
} }
syncCursorStyle = () => {
this.manager.stage.setCursor('crosshair');
};
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState { createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
const id = getPrefixedId('sam_point'); const id = getPrefixedId('sam_point');
const circle = new Konva.Circle({ const circle = new Konva.Circle({
name: this.KONVA_CIRCLE_NAME, name: this.KONVA_CIRCLE_NAME,
x: Math.round(coord.x), x: Math.round(coord.x),
y: Math.round(coord.y), y: Math.round(coord.y),
radius: this.config.SAM_POINT_RADIUS, radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS),
fill: rgbaColorToString(this.getSAMPointColor(label)), fill: rgbaColorToString(this.getSAMPointColor(label)),
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR), stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
strokeWidth: this.config.SAM_POINT_BORDER_WIDTH, strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH),
draggable: true, draggable: true,
perfectDrawEnabled: false, perfectDrawEnabled: true,
opacity: 0.6,
dragDistance: 3,
}); });
circle.on('pointerup', (e) => { circle.on('pointerup', (e) => {
@@ -147,35 +165,60 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
circle.on('dragend', () => { circle.on('dragend', () => {
this.$isDraggingPoint.set(false); this.$isDraggingPoint.set(false);
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'SAM point moved'
);
}); });
circle.dragBoundFunc(({ x, y }) => ({ circle.dragBoundFunc((pos) => floorCoord(pos));
x: Math.round(x),
y: Math.round(y),
}));
this.konva.pointGroup.add(circle); this.konva.pointGroup.add(circle);
const state: SAMPointState = {
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Created SAM point'
);
return {
id, id,
label, label,
konva: { circle }, konva: { circle },
}; };
return state;
} }
syncPointScales = () => {
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH);
for (const {
konva: { circle },
} of this.points) {
circle.radius(radius);
circle.strokeWidth(borderWidth);
}
};
getSAMPoints = (): SAMPoint[] => { getSAMPoints = (): SAMPoint[] => {
return this.points.map(({ konva: { circle }, label }) => ({ const points: SAMPoint[] = [];
x: circle.x(),
y: circle.y(), for (const { konva, label } of this.points) {
label, points.push({
})); x: Math.round(konva.circle.x()),
y: Math.round(konva.circle.y()),
label,
});
}
return points;
}; };
onPointerUp = (e: KonvaEventObject<PointerEvent>) => { onPointerUp = (e: KonvaEventObject<PointerEvent>) => {
if (e.evt.button !== 0) { if (e.evt.button !== 0) {
return; return;
} }
if (this.manager.stage.getIsDragging()) {
return;
}
if (this.$isDraggingPoint.get()) { if (this.$isDraggingPoint.get()) {
return; return;
} }
@@ -184,7 +227,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return; return;
} }
this.points.push(this.createPoint(cursorPos.relative, this.$pointType.get())); const pixelRect = this.parent.transformer.$pixelRect.get();
const normalizedPoint = offsetCoord(cursorPos.relative, { x: pixelRect.x, y: pixelRect.y });
const samPoint = this.createPoint(normalizedPoint, this.$pointType.get());
this.points.push(samPoint);
}; };
setSegmentingEventListeners = () => { setSegmentingEventListeners = () => {
@@ -202,6 +249,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return; return;
} }
this.log.trace('Starting segment anything'); this.log.trace('Starting segment anything');
this.$pointType.set(1);
this.$isSegmenting.set(true); this.$isSegmenting.set(true);
this.manager.stateApi.$segmentingAdapter.set(this.parent); this.manager.stateApi.$segmentingAdapter.set(this.parent);
for (const point of this.points) { for (const point of this.points) {
@@ -209,6 +257,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
} }
this.points = []; this.points = [];
this.parent.konva.layer.add(this.konva.group); this.parent.konva.layer.add(this.konva.group);
console.log({
position: this.parent.state.position,
pixelRect: this.parent.transformer.$pixelRect.get(),
nodeRect: this.parent.transformer.$nodeRect.get(),
getRelativeRect: this.parent.transformer.getRelativeRect(),
});
const pixelRect = this.parent.transformer.$pixelRect.get();
this.konva.group.setAttrs({ x: pixelRect.x, y: pixelRect.y });
this.parent.konva.layer.listening(true); this.parent.konva.layer.listening(true);
this.setSegmentingEventListeners(); this.setSegmentingEventListeners();
@@ -239,15 +295,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
type: 'segment_anything_object_identifier', type: 'segment_anything_object_identifier',
model: 'segment-anything-huge', model: 'segment-anything-huge',
image: { image_name: rasterizeResult.value.image_name }, image: { image_name: rasterizeResult.value.image_name },
object_identifiers: [ object_identifiers: [{ points: this.getSAMPoints() }],
{
points: this.getSAMPoints().map(({ x, y, label }): S['SAMPoint'] => ({
x,
y,
label: label === 'foreground' ? 1 : -1,
})),
},
],
}); });
const applyMask = g.addNode({ const applyMask = g.addNode({
id: getPrefixedId('apply_tensor_mask_to_image'), id: getPrefixedId('apply_tensor_mask_to_image'),
@@ -321,7 +369,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
} }
this.konva.compositingRect.visible(false); this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache(); this.konva.maskGroup.clearCache();
this.$pointType.set('foreground'); this.$pointType.set(1);
this.$isSegmenting.set(false); this.$isSegmenting.set(false);
this.$hasProcessed.set(false); this.$hasProcessed.set(false);
this.manager.stateApi.$segmentingAdapter.set(null); this.manager.stateApi.$segmentingAdapter.set(null);
@@ -364,11 +412,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
}; };
getSAMPointColor(label: SAMPointLabel): RgbaColor { getSAMPointColor(label: SAMPointLabel): RgbaColor {
if (label === 'neutral') { if (label === 0) {
return this.config.SAM_POINT_NEUTRAL_COLOR; return this.config.SAM_POINT_NEUTRAL_COLOR;
} else if (label === 'foreground') { } else if (label === 1) {
return this.config.SAM_POINT_FOREGROUND_COLOR; return this.config.SAM_POINT_FOREGROUND_COLOR;
} else { } else {
// label === -1
return this.config.SAM_POINT_BACKGROUND_COLOR; return this.config.SAM_POINT_BACKGROUND_COLOR;
} }
} }

View File

@@ -311,7 +311,7 @@ export class CanvasStageModule extends CanvasModuleBase {
this.setIsDraggable(true); this.setIsDraggable(true);
// Then start dragging the stage if it's not already being dragged // Then start dragging the stage if it's not already being dragged
if (!this.konva.stage.isDragging()) { if (!this.getIsDragging()) {
this.konva.stage.startDrag(); this.konva.stage.startDrag();
} }
@@ -328,7 +328,7 @@ export class CanvasStageModule extends CanvasModuleBase {
this.setIsDraggable(this.manager.tool.$tool.get() === 'view'); this.setIsDraggable(this.manager.tool.$tool.get() === 'view');
// Stop dragging the stage if it's being dragged // Stop dragging the stage if it's being dragged
if (this.konva.stage.isDragging()) { if (this.getIsDragging()) {
this.konva.stage.stopDrag(); this.konva.stage.stopDrag();
} }
@@ -404,6 +404,10 @@ export class CanvasStageModule extends CanvasModuleBase {
this.konva.stage.draggable(isDraggable); this.konva.stage.draggable(isDraggable);
}; };
getIsDragging = () => {
return this.konva.stage.isDragging();
};
addLayer = (layer: Konva.Layer) => { addLayer = (layer: Konva.Layer) => {
this.konva.stage.add(layer); this.konva.stage.add(layer);
}; };

View File

@@ -203,12 +203,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
* Renders the bbox. The bbox is only visible when the tool is set to 'bbox'. * Renders the bbox. The bbox is only visible when the tool is set to 'bbox'.
*/ */
render = () => { render = () => {
this.log.trace('Rendering');
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
const tool = this.manager.tool.$tool.get(); const tool = this.manager.tool.$tool.get();
this.konva.group.visible(true); const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
// We need to reach up to the preview layer to enable/disable listening so that the bbox can be interacted with. // We need to reach up to the preview layer to enable/disable listening so that the bbox can be interacted with.
// If the mangaer is busy, we disable listening so the bbox cannot be interacted with. // If the mangaer is busy, we disable listening so the bbox cannot be interacted with.

View File

@@ -158,11 +158,12 @@ export class CanvasToolModule extends CanvasModuleBase {
syncCursorStyle = () => { syncCursorStyle = () => {
const stage = this.manager.stage; const stage = this.manager.stage;
const tool = this.$tool.get(); const tool = this.$tool.get();
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
if (this.manager.stage.konva.stage.isDragging() || tool === 'view') { if (this.manager.stage.getIsDragging() || tool === 'view') {
this.tools.view.syncCursorStyle(); this.tools.view.syncCursorStyle();
} else if (this.manager.stateApi.$isSegmenting.get()) { } else if (segmentingAdapter) {
stage.setCursor('default'); segmentingAdapter.segmentAnything.syncCursorStyle();
} else if (this.manager.stateApi.$isFiltering.get()) { } else if (this.manager.stateApi.$isFiltering.get()) {
stage.setCursor('not-allowed'); stage.setCursor('not-allowed');
} else if (this.manager.stagingArea.$isStaging.get()) { } else if (this.manager.stagingArea.$isStaging.get()) {
@@ -284,7 +285,7 @@ export class CanvasToolModule extends CanvasModuleBase {
return false; return false;
} }
if (this.manager.stage.konva.stage.isDragging()) { if (this.manager.stage.getIsDragging()) {
return false; return false;
} }

View File

@@ -24,6 +24,6 @@ export class CanvasViewToolModule extends CanvasModuleBase {
} }
syncCursorStyle = () => { syncCursorStyle = () => {
this.manager.stage.setCursor(this.manager.stage.konva.stage.isDragging() ? 'grabbing' : 'grab'); this.manager.stage.setCursor(this.manager.stage.getIsDragging() ? 'grabbing' : 'grab');
}; };
} }

View File

@@ -96,9 +96,30 @@ const zCoordinateWithPressure = z.object({
}); });
export type CoordinateWithPressure = z.infer<typeof zCoordinateWithPressure>; export type CoordinateWithPressure = z.infer<typeof zCoordinateWithPressure>;
export const zSAMPointLabel = z.enum(['foreground', 'background', 'neutral']); const SAM_POINT_LABELS = {
background: -1,
neutral: 0,
foreground: 1,
} as const;
export const zSAMPointLabel = z.nativeEnum(SAM_POINT_LABELS);
export type SAMPointLabel = z.infer<typeof zSAMPointLabel>; export type SAMPointLabel = z.infer<typeof zSAMPointLabel>;
export const zSAMPointLabelString = z.enum(['background', 'neutral', 'foreground']);
export type SAMPointLabelString = z.infer<typeof zSAMPointLabelString>;
export const SAM_POINT_LABEL_NUMBER_TO_STRING: Record<SAMPointLabel, SAMPointLabelString> = {
'-1': 'background',
0: 'neutral',
1: 'foreground',
};
export const SAM_POINT_LABEL_STRING_TO_NUMBER: Record<SAMPointLabelString, SAMPointLabel> = {
background: -1,
neutral: 0,
foreground: 1,
};
export const zSAMPoint = z.object({ export const zSAMPoint = z.object({
x: z.number().int().gte(0), x: z.number().int().gte(0),
y: z.number().int().gte(0), y: z.number().int().gte(0),