mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-14 11:14:57 -05:00
feat(ui): masking UX (wip - issue w/ positioning)
This commit is contained in:
@@ -3,14 +3,14 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import { zSAMPointLabel } from 'features/controlLayers/store/types';
|
||||
import { SAM_POINT_LABEL_STRING_TO_NUMBER, zSAMPointLabelString } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const SegmentAnythingPointType = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const pointType = useStore(adapter.segmentAnything.$pointType);
|
||||
const pointType = useStore(adapter.segmentAnything.$pointTypeEnglish);
|
||||
|
||||
const options = useMemo(() => {
|
||||
return [
|
||||
@@ -28,7 +28,9 @@ export const SegmentAnythingPointType = memo(
|
||||
return;
|
||||
}
|
||||
|
||||
adapter.segmentAnything.$pointType.set(zSAMPointLabel.parse(v.value));
|
||||
const labelAsString = zSAMPointLabelString.parse(v.value);
|
||||
const labelAsNumber = SAM_POINT_LABEL_STRING_TO_NUMBER[labelAsString];
|
||||
adapter.segmentAnything.$pointType.set(labelAsNumber);
|
||||
},
|
||||
[adapter.segmentAnything.$pointType]
|
||||
);
|
||||
|
||||
@@ -6,7 +6,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
|
||||
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { floorCoord, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
CanvasImageState,
|
||||
Coordinate,
|
||||
@@ -14,14 +14,15 @@ import type {
|
||||
SAMPoint,
|
||||
SAMPointLabel,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Atom } from 'nanostores';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
type CanvasSegmentAnythingModuleConfig = {
|
||||
SAM_POINT_RADIUS: number;
|
||||
@@ -34,12 +35,12 @@ type CanvasSegmentAnythingModuleConfig = {
|
||||
};
|
||||
|
||||
const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
|
||||
SAM_POINT_RADIUS: 5,
|
||||
SAM_POINT_RADIUS: 8,
|
||||
SAM_POINT_BORDER_WIDTH: 2,
|
||||
SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 },
|
||||
SAM_POINT_FOREGROUND_COLOR: { r: 0, g: 200, b: 0, a: 0.7 },
|
||||
SAM_POINT_BACKGROUND_COLOR: { r: 200, g: 0, b: 0, a: 0.7 },
|
||||
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 0, b: 200, a: 0.7 },
|
||||
SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // green-ish
|
||||
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
|
||||
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan-ish
|
||||
PROCESS_DEBOUNCE_MS: 300,
|
||||
};
|
||||
|
||||
@@ -72,7 +73,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
$hasProcessed = atom<boolean>(false);
|
||||
$isProcessing = atom<boolean>(false);
|
||||
|
||||
$pointType = atom<SAMPointLabel>('foreground');
|
||||
$pointType = atom<SAMPointLabel>(1);
|
||||
$pointTypeEnglish = computed<(typeof SAM_POINT_LABEL_NUMBER_TO_STRING)[SAMPointLabel], Atom<SAMPointLabel>>(
|
||||
this.$pointType,
|
||||
(pointType) => SAM_POINT_LABEL_NUMBER_TO_STRING[pointType]
|
||||
);
|
||||
$isDraggingPoint = atom<boolean>(false);
|
||||
|
||||
imageState: CanvasImageState | null = null;
|
||||
@@ -116,20 +121,33 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.konva.group.add(this.konva.maskGroup);
|
||||
this.konva.group.add(this.konva.pointGroup);
|
||||
this.konva.maskGroup.add(this.konva.compositingRect);
|
||||
this.subscriptions.add(
|
||||
this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => {
|
||||
if (stageAttrs.scale !== oldStageAttrs.scale) {
|
||||
this.syncPointScales();
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
syncCursorStyle = () => {
|
||||
this.manager.stage.setCursor('crosshair');
|
||||
};
|
||||
|
||||
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
|
||||
const id = getPrefixedId('sam_point');
|
||||
const circle = new Konva.Circle({
|
||||
name: this.KONVA_CIRCLE_NAME,
|
||||
x: Math.round(coord.x),
|
||||
y: Math.round(coord.y),
|
||||
radius: this.config.SAM_POINT_RADIUS,
|
||||
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS),
|
||||
fill: rgbaColorToString(this.getSAMPointColor(label)),
|
||||
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
|
||||
strokeWidth: this.config.SAM_POINT_BORDER_WIDTH,
|
||||
strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH),
|
||||
draggable: true,
|
||||
perfectDrawEnabled: false,
|
||||
perfectDrawEnabled: true,
|
||||
opacity: 0.6,
|
||||
dragDistance: 3,
|
||||
});
|
||||
|
||||
circle.on('pointerup', (e) => {
|
||||
@@ -147,35 +165,60 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
|
||||
circle.on('dragend', () => {
|
||||
this.$isDraggingPoint.set(false);
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'SAM point moved'
|
||||
);
|
||||
});
|
||||
|
||||
circle.dragBoundFunc(({ x, y }) => ({
|
||||
x: Math.round(x),
|
||||
y: Math.round(y),
|
||||
}));
|
||||
circle.dragBoundFunc((pos) => floorCoord(pos));
|
||||
|
||||
this.konva.pointGroup.add(circle);
|
||||
const state: SAMPointState = {
|
||||
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'Created SAM point'
|
||||
);
|
||||
|
||||
return {
|
||||
id,
|
||||
label,
|
||||
konva: { circle },
|
||||
};
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
syncPointScales = () => {
|
||||
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
|
||||
const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH);
|
||||
for (const {
|
||||
konva: { circle },
|
||||
} of this.points) {
|
||||
circle.radius(radius);
|
||||
circle.strokeWidth(borderWidth);
|
||||
}
|
||||
};
|
||||
|
||||
getSAMPoints = (): SAMPoint[] => {
|
||||
return this.points.map(({ konva: { circle }, label }) => ({
|
||||
x: circle.x(),
|
||||
y: circle.y(),
|
||||
label,
|
||||
}));
|
||||
const points: SAMPoint[] = [];
|
||||
|
||||
for (const { konva, label } of this.points) {
|
||||
points.push({
|
||||
x: Math.round(konva.circle.x()),
|
||||
y: Math.round(konva.circle.y()),
|
||||
label,
|
||||
});
|
||||
}
|
||||
|
||||
return points;
|
||||
};
|
||||
|
||||
onPointerUp = (e: KonvaEventObject<PointerEvent>) => {
|
||||
if (e.evt.button !== 0) {
|
||||
return;
|
||||
}
|
||||
if (this.manager.stage.getIsDragging()) {
|
||||
return;
|
||||
}
|
||||
if (this.$isDraggingPoint.get()) {
|
||||
return;
|
||||
}
|
||||
@@ -184,7 +227,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
this.points.push(this.createPoint(cursorPos.relative, this.$pointType.get()));
|
||||
const pixelRect = this.parent.transformer.$pixelRect.get();
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, { x: pixelRect.x, y: pixelRect.y });
|
||||
const samPoint = this.createPoint(normalizedPoint, this.$pointType.get());
|
||||
this.points.push(samPoint);
|
||||
};
|
||||
|
||||
setSegmentingEventListeners = () => {
|
||||
@@ -202,6 +249,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
this.log.trace('Starting segment anything');
|
||||
this.$pointType.set(1);
|
||||
this.$isSegmenting.set(true);
|
||||
this.manager.stateApi.$segmentingAdapter.set(this.parent);
|
||||
for (const point of this.points) {
|
||||
@@ -209,6 +257,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
}
|
||||
this.points = [];
|
||||
this.parent.konva.layer.add(this.konva.group);
|
||||
console.log({
|
||||
position: this.parent.state.position,
|
||||
pixelRect: this.parent.transformer.$pixelRect.get(),
|
||||
nodeRect: this.parent.transformer.$nodeRect.get(),
|
||||
getRelativeRect: this.parent.transformer.getRelativeRect(),
|
||||
});
|
||||
const pixelRect = this.parent.transformer.$pixelRect.get();
|
||||
this.konva.group.setAttrs({ x: pixelRect.x, y: pixelRect.y });
|
||||
this.parent.konva.layer.listening(true);
|
||||
|
||||
this.setSegmentingEventListeners();
|
||||
@@ -239,15 +295,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
type: 'segment_anything_object_identifier',
|
||||
model: 'segment-anything-huge',
|
||||
image: { image_name: rasterizeResult.value.image_name },
|
||||
object_identifiers: [
|
||||
{
|
||||
points: this.getSAMPoints().map(({ x, y, label }): S['SAMPoint'] => ({
|
||||
x,
|
||||
y,
|
||||
label: label === 'foreground' ? 1 : -1,
|
||||
})),
|
||||
},
|
||||
],
|
||||
object_identifiers: [{ points: this.getSAMPoints() }],
|
||||
});
|
||||
const applyMask = g.addNode({
|
||||
id: getPrefixedId('apply_tensor_mask_to_image'),
|
||||
@@ -321,7 +369,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
}
|
||||
this.konva.compositingRect.visible(false);
|
||||
this.konva.maskGroup.clearCache();
|
||||
this.$pointType.set('foreground');
|
||||
this.$pointType.set(1);
|
||||
this.$isSegmenting.set(false);
|
||||
this.$hasProcessed.set(false);
|
||||
this.manager.stateApi.$segmentingAdapter.set(null);
|
||||
@@ -364,11 +412,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
};
|
||||
|
||||
getSAMPointColor(label: SAMPointLabel): RgbaColor {
|
||||
if (label === 'neutral') {
|
||||
if (label === 0) {
|
||||
return this.config.SAM_POINT_NEUTRAL_COLOR;
|
||||
} else if (label === 'foreground') {
|
||||
} else if (label === 1) {
|
||||
return this.config.SAM_POINT_FOREGROUND_COLOR;
|
||||
} else {
|
||||
// label === -1
|
||||
return this.config.SAM_POINT_BACKGROUND_COLOR;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -311,7 +311,7 @@ export class CanvasStageModule extends CanvasModuleBase {
|
||||
this.setIsDraggable(true);
|
||||
|
||||
// Then start dragging the stage if it's not already being dragged
|
||||
if (!this.konva.stage.isDragging()) {
|
||||
if (!this.getIsDragging()) {
|
||||
this.konva.stage.startDrag();
|
||||
}
|
||||
|
||||
@@ -328,7 +328,7 @@ export class CanvasStageModule extends CanvasModuleBase {
|
||||
this.setIsDraggable(this.manager.tool.$tool.get() === 'view');
|
||||
|
||||
// Stop dragging the stage if it's being dragged
|
||||
if (this.konva.stage.isDragging()) {
|
||||
if (this.getIsDragging()) {
|
||||
this.konva.stage.stopDrag();
|
||||
}
|
||||
|
||||
@@ -404,6 +404,10 @@ export class CanvasStageModule extends CanvasModuleBase {
|
||||
this.konva.stage.draggable(isDraggable);
|
||||
};
|
||||
|
||||
getIsDragging = () => {
|
||||
return this.konva.stage.isDragging();
|
||||
};
|
||||
|
||||
addLayer = (layer: Konva.Layer) => {
|
||||
this.konva.stage.add(layer);
|
||||
};
|
||||
|
||||
@@ -203,12 +203,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
* Renders the bbox. The bbox is only visible when the tool is set to 'bbox'.
|
||||
*/
|
||||
render = () => {
|
||||
this.log.trace('Rendering');
|
||||
|
||||
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
|
||||
const tool = this.manager.tool.$tool.get();
|
||||
|
||||
this.konva.group.visible(true);
|
||||
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
|
||||
|
||||
// We need to reach up to the preview layer to enable/disable listening so that the bbox can be interacted with.
|
||||
// If the mangaer is busy, we disable listening so the bbox cannot be interacted with.
|
||||
|
||||
@@ -158,11 +158,12 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
syncCursorStyle = () => {
|
||||
const stage = this.manager.stage;
|
||||
const tool = this.$tool.get();
|
||||
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
|
||||
|
||||
if (this.manager.stage.konva.stage.isDragging() || tool === 'view') {
|
||||
if (this.manager.stage.getIsDragging() || tool === 'view') {
|
||||
this.tools.view.syncCursorStyle();
|
||||
} else if (this.manager.stateApi.$isSegmenting.get()) {
|
||||
stage.setCursor('default');
|
||||
} else if (segmentingAdapter) {
|
||||
segmentingAdapter.segmentAnything.syncCursorStyle();
|
||||
} else if (this.manager.stateApi.$isFiltering.get()) {
|
||||
stage.setCursor('not-allowed');
|
||||
} else if (this.manager.stagingArea.$isStaging.get()) {
|
||||
@@ -284,7 +285,7 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (this.manager.stage.konva.stage.isDragging()) {
|
||||
if (this.manager.stage.getIsDragging()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,6 @@ export class CanvasViewToolModule extends CanvasModuleBase {
|
||||
}
|
||||
|
||||
syncCursorStyle = () => {
|
||||
this.manager.stage.setCursor(this.manager.stage.konva.stage.isDragging() ? 'grabbing' : 'grab');
|
||||
this.manager.stage.setCursor(this.manager.stage.getIsDragging() ? 'grabbing' : 'grab');
|
||||
};
|
||||
}
|
||||
|
||||
@@ -96,9 +96,30 @@ const zCoordinateWithPressure = z.object({
|
||||
});
|
||||
export type CoordinateWithPressure = z.infer<typeof zCoordinateWithPressure>;
|
||||
|
||||
export const zSAMPointLabel = z.enum(['foreground', 'background', 'neutral']);
|
||||
const SAM_POINT_LABELS = {
|
||||
background: -1,
|
||||
neutral: 0,
|
||||
foreground: 1,
|
||||
} as const;
|
||||
|
||||
export const zSAMPointLabel = z.nativeEnum(SAM_POINT_LABELS);
|
||||
export type SAMPointLabel = z.infer<typeof zSAMPointLabel>;
|
||||
|
||||
export const zSAMPointLabelString = z.enum(['background', 'neutral', 'foreground']);
|
||||
export type SAMPointLabelString = z.infer<typeof zSAMPointLabelString>;
|
||||
|
||||
export const SAM_POINT_LABEL_NUMBER_TO_STRING: Record<SAMPointLabel, SAMPointLabelString> = {
|
||||
'-1': 'background',
|
||||
0: 'neutral',
|
||||
1: 'foreground',
|
||||
};
|
||||
|
||||
export const SAM_POINT_LABEL_STRING_TO_NUMBER: Record<SAMPointLabelString, SAMPointLabel> = {
|
||||
background: -1,
|
||||
neutral: 0,
|
||||
foreground: 1,
|
||||
};
|
||||
|
||||
export const zSAMPoint = z.object({
|
||||
x: z.number().int().gte(0),
|
||||
y: z.number().int().gte(0),
|
||||
|
||||
Reference in New Issue
Block a user