feat(ui): combine points and bbox in visual mode for SAM

Revised the Select Object feature to support two input modes:
- Visual mode: Combined points and bounding box input for paired SAM inputs
- Prompt mode: Text-based object selection (unchanged)

Key changes:
- Replaced three input types (points, prompt, bbox) with two (visual, prompt)
- Visual mode supports both point and bbox inputs simultaneously
- Click to add include points, Shift+click for exclude points
- Click and drag to draw bounding box
- Fixed bbox visibility issues when adding points
- Fixed coordinate system issues for proper bbox positioning
- Added proper event handling and interaction controls

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
psychedelicious
2025-09-10 17:41:47 +10:00
parent e1e964bf0e
commit a0232b0e63
2 changed files with 236 additions and 194 deletions

View File

@@ -63,18 +63,14 @@ const SelectObjectContent = memo(
adapter.segmentAnything.saveAs('control_layer');
}, [adapter.segmentAnything]);
const setInputToPoints = useCallback(() => {
adapter.segmentAnything.setInputType('points');
const setInputToVisual = useCallback(() => {
adapter.segmentAnything.setInputType('visual');
}, [adapter.segmentAnything]);
const setInputToPrompt = useCallback(() => {
adapter.segmentAnything.setInputType('prompt');
}, [adapter.segmentAnything]);
const setInputToBoundingBox = useCallback(() => {
adapter.segmentAnything.setInputType('bounding_box');
}, [adapter.segmentAnything]);
useRegisteredHotkeys({
id: 'applySegmentAnything',
category: 'canvas',
@@ -123,11 +119,10 @@ const SelectObjectContent = memo(
<Flex w="full" justifyContent="space-between" py={2}>
<ButtonGroup>
<Button onClick={setInputToPoints}>Points</Button>
<Button onClick={setInputToVisual}>Visual</Button>
<Button onClick={setInputToPrompt}>Prompt</Button>
<Button onClick={setInputToBoundingBox}>Bounding Box</Button>
</ButtonGroup>
{inputData.type === 'points' && <SelectObjectPointType adapter={adapter} />}
{inputData.type === 'visual' && <SelectObjectPointType adapter={adapter} />}
<SelectObjectInvert adapter={adapter} />
</Flex>
@@ -232,20 +227,21 @@ const SelectObjectHelpTooltipContent = memo(() => {
<Text>
<Trans i18nKey="controlLayers.selectObject.help2" components={{ Bold: <Bold /> }} />
</Text>
<Text>
<Trans i18nKey="controlLayers.selectObject.help3" />
</Text>
<Text fontWeight="semibold">Visual Mode:</Text>
<UnorderedList>
<ListItem>{t('controlLayers.selectObject.clickToAdd')}</ListItem>
<ListItem>{t('controlLayers.selectObject.dragToMove')}</ListItem>
<ListItem>{t('controlLayers.selectObject.clickToRemove')}</ListItem>
</UnorderedList>
<Text fontWeight="semibold">Bounding Box Mode:</Text>
<UnorderedList>
<ListItem>Click and drag to draw a bounding box around an object</ListItem>
<ListItem>Click to add include points (green)</ListItem>
<ListItem>Shift + Click to add exclude points (red)</ListItem>
<ListItem>Click and drag to draw a bounding box</ListItem>
<ListItem>Click on points to remove them</ListItem>
<ListItem>Drag points to reposition them</ListItem>
<ListItem>Resize the box using the corner handles</ListItem>
<ListItem>Drag the box to reposition it</ListItem>
</UnorderedList>
<Text fontWeight="semibold">Prompt Mode:</Text>
<UnorderedList>
<ListItem>Type a text description of the object to select</ListItem>
<ListItem>The AI will find and segment matching objects</ListItem>
</UnorderedList>
</Flex>
);
});

View File

@@ -104,18 +104,14 @@ type SAMPointState = {
};
};
type PointsInputData = {
type: 'points';
points: SAMPointState[];
};
type PromptInputData = {
type: 'prompt';
prompt: string;
};
type BoundingBoxInputData = {
type: 'bounding_box';
type VisualInputData = {
type: 'visual';
points: SAMPointState[];
bbox: {
x: number;
y: number;
@@ -124,20 +120,19 @@ type BoundingBoxInputData = {
} | null;
};
const hasInputData = (data: PointsInputData | PromptInputData | BoundingBoxInputData): boolean => {
if (data.type === 'points') {
return data.points.length > 0;
} else if (data.type === 'prompt') {
const hasInputData = (data: PromptInputData | VisualInputData): boolean => {
if (data.type === 'prompt') {
return data.prompt.trim() !== '';
} else {
return data.bbox !== null;
// Visual mode has input if there are points OR a bbox
return data.points.length > 0 || data.bbox !== null;
}
};
/**
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
*/
const getSAMPoints = (data: PointsInputData): SAMPointWithId[] => {
const getSAMPoints = (data: VisualInputData): SAMPointWithId[] => {
const points: SAMPointWithId[] = [];
for (const { id, coord, label } of data.points) {
@@ -152,13 +147,11 @@ const getSAMPoints = (data: PointsInputData): SAMPointWithId[] => {
return points;
};
const getHashableInputData = (data: PointsInputData | PromptInputData | BoundingBoxInputData) => {
if (data.type === 'points') {
return { type: 'points', points: getSAMPoints(data) } as const;
} else if (data.type === 'prompt') {
const getHashableInputData = (data: PromptInputData | VisualInputData) => {
if (data.type === 'prompt') {
return { type: 'prompt', prompt: data.prompt } as const;
} else {
return { type: 'bounding_box', bbox: data.bbox } as const;
return { type: 'visual', points: getSAMPoints(data), bbox: data.bbox } as const;
}
};
@@ -228,7 +221,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
*/
$hasImageState = computed(this.$imageState, (imageState) => imageState !== null);
$inputData = atom<PointsInputData | PromptInputData | BoundingBoxInputData>({ type: 'points', points: [] });
$inputData = atom<PromptInputData | VisualInputData>({ type: 'visual', points: [], bbox: null });
/**
* Whether the module has points. This is a computed value based on $points.
@@ -376,7 +369,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// Add event handlers for bbox transformer
this.konva.bboxTransformer.on('transformend', () => {
const data = this.$inputData.get();
if (data.type !== 'bounding_box') {
if (data.type !== 'visual') {
return;
}
@@ -406,6 +399,29 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
});
});
// Handle bbox dragging
this.konva.bboxRect.on('dragend', () => {
const data = this.$inputData.get();
if (data.type !== 'visual') {
return;
}
const x = this.konva.bboxRect.x();
const y = this.konva.bboxRect.y();
const width = this.konva.bboxRect.width() * this.konva.bboxRect.scaleX();
const height = this.konva.bboxRect.height() * this.konva.bboxRect.scaleY();
this.$inputData.set({
...data,
bbox: {
x: x,
y: y,
width: width,
height: height,
},
});
});
// Stop event propagation when interacting with the bbox
this.konva.bboxRect.on('mousedown touchstart', (e) => {
e.cancelBubble = true;
@@ -418,7 +434,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// Add event handler for bbox rect drag
this.konva.bboxRect.on('dragend', () => {
const data = this.$inputData.get();
if (data.type !== 'bounding_box') {
if (data.type !== 'visual') {
return;
}
@@ -493,16 +509,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
circle.destroy();
const data = this.$inputData.get();
if (data.type !== 'points') {
if (data.type !== 'visual') {
return;
}
const newPoints = data.points.filter((point) => point.id !== id);
if (newPoints.length === 0) {
this.resetEphemeralState();
} else {
this.$inputData.set({ ...data, points: newPoints });
}
this.$inputData.set({ ...data, points: newPoints });
});
circle.on('dragstart', () => {
@@ -516,7 +528,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.$isDraggingPoint.set(false);
const data = this.$inputData.get();
if (data.type !== 'points') {
if (data.type !== 'visual') {
return;
}
@@ -549,7 +561,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
*/
syncPointScales = () => {
const data = this.$inputData.get();
if (data.type !== 'points') {
if (data.type !== 'visual') {
return;
}
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
@@ -561,11 +573,38 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
};
/**
* Handles the pointerdown event on the stage. This is used to start drawing a bounding box.
* Synchronizes the bbox visibility based on the current input data.
*/
syncBboxVisibility = () => {
const data = this.$inputData.get();
if (data.type !== 'visual') {
return;
}
if (data.bbox) {
// Update bbox position and size
this.konva.bboxRect.setAttrs({
x: data.bbox.x,
y: data.bbox.y,
width: data.bbox.width,
height: data.bbox.height,
visible: true,
});
this.konva.bboxTransformer.visible(true);
} else {
// Hide bbox if there's no bbox data
this.konva.bboxRect.visible(false);
this.konva.bboxTransformer.visible(false);
}
};
/**
* Handles the pointerdown event on the stage. This is used to start drawing a bounding box in visual mode.
* We'll start tracking the position but won't decide if it's a bbox or point until pointerup.
*/
onStagePointerDown = (e: KonvaEventObject<PointerEvent>) => {
const data = this.$inputData.get();
if (data.type !== 'bounding_box') {
if (data.type !== 'visual') {
return;
}
@@ -602,19 +641,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// Normalize the cursor position to the parent entity's position
const normalizedPoint = offsetCoord(cursorPos.relative, parentPosition);
// Start drawing the bounding box
// Start potential bbox drawing (we'll decide in pointerup if it's actually a bbox or a point)
this.$isBboxDrawing.set(true);
this.$bboxStartCoord.set(normalizedPoint);
// Reset the bbox rect position and size
this.konva.bboxRect.setAttrs({
x: normalizedPoint.x,
y: normalizedPoint.y,
width: 0,
height: 0,
visible: true,
});
this.konva.bboxTransformer.visible(false);
// Prepare for potential new bbox but don't hide existing one yet
// We'll only update visibility during drag if it's actually a new bbox
};
/**
@@ -622,7 +654,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
*/
onStagePointerMove = (e: KonvaEventObject<PointerEvent>) => {
const data = this.$inputData.get();
if (data.type !== 'bounding_box') {
if (data.type !== 'visual') {
return;
}
@@ -654,13 +686,20 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
const width = Math.abs(currentPoint.x - startCoord.x);
const height = Math.abs(currentPoint.y - startCoord.y);
// Update the bbox rect
this.konva.bboxRect.setAttrs({
x,
y,
width,
height,
});
// Only show the bbox and hide transformer if we've dragged more than a threshold (5 pixels)
if (width > 5 || height > 5) {
// Now we know it's a drag for a new bbox, hide the transformer
this.konva.bboxTransformer.visible(false);
// Update and show the new bbox rect
this.konva.bboxRect.setAttrs({
x,
y,
width,
height,
visible: true,
});
}
};
/**
@@ -669,109 +708,105 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
onStagePointerUp = (e: KonvaEventObject<PointerEvent>) => {
const data = this.$inputData.get();
// Handle bounding box mode
if (data.type === 'bounding_box') {
if (!this.$isBboxDrawing.get()) {
return;
}
// Handle visual mode
if (data.type === 'visual') {
// Only handle left-clicks
if (e.evt.button !== 0) {
return;
}
// Stop drawing
this.$isBboxDrawing.set(false);
this.$bboxStartCoord.set(null);
// Check if we started a potential bbox draw
if (this.$isBboxDrawing.get()) {
const startCoord = this.$bboxStartCoord.get();
// Check if we actually dragged by calculating from start position
const cursorPos = this.manager.tool.$cursorPos.get();
if (!cursorPos || !startCoord) {
// Stop tracking even if we don't have valid coords
this.$isBboxDrawing.set(false);
this.$bboxStartCoord.set(null);
return;
}
// Stop tracking (after we've used the values)
this.$isBboxDrawing.set(false);
this.$bboxStartCoord.set(null);
const pixelRect = this.parent.transformer.$pixelRect.get();
const parentPosition = addCoords(this.parent.state.position, pixelRect);
const currentPoint = offsetCoord(cursorPos.relative, parentPosition);
const dragWidth = Math.abs(currentPoint.x - startCoord.x);
const dragHeight = Math.abs(currentPoint.y - startCoord.y);
// Get the final bbox dimensions from the rect's attributes (not getClientRect which gives screen coords)
const x = this.konva.bboxRect.x();
const y = this.konva.bboxRect.y();
const width = this.konva.bboxRect.width() * this.konva.bboxRect.scaleX();
const height = this.konva.bboxRect.height() * this.konva.bboxRect.scaleY();
// Check if we actually dragged (moved more than threshold)
if (dragWidth > 5 || dragHeight > 5) {
// Get the final bbox dimensions from the rect's attributes
const x = this.konva.bboxRect.x();
const y = this.konva.bboxRect.y();
const width = this.konva.bboxRect.width() * this.konva.bboxRect.scaleX();
const height = this.konva.bboxRect.height() * this.konva.bboxRect.scaleY();
// It was a drag - save the bbox
this.$inputData.set({
...data,
bbox: {
x: x,
y: y,
width: width,
height: height,
},
});
// Only save the bbox if it has a reasonable size (not just a click)
if (width > 5 && height > 5) {
this.$inputData.set({
...data,
bbox: {
x: x,
y: y,
width: width,
height: height,
},
});
// Show the transformer for resizing
this.konva.bboxTransformer.visible(true);
} else {
// It was just a click, not a drag - add a point instead
// Make sure existing bbox stays visible
this.syncBboxVisibility();
// Ignore if the stage is dragging/panning
if (this.manager.stage.getIsDragging()) {
return;
}
// Show the transformer for resizing
this.konva.bboxTransformer.visible(true);
} else {
// Too small, hide the bbox
this.konva.bboxRect.visible(false);
this.konva.bboxTransformer.visible(false);
this.$inputData.set({
...data,
bbox: null,
});
// Ignore if a point is being dragged
if (this.$isDraggingPoint.get()) {
return;
}
// Ignore if we are already processing
if (this.$isProcessing.get()) {
return;
}
if (!startCoord) {
return;
}
// Create a SAM point at the click position
let pointType: -1 | 0 | 1;
// If shift key is held, create an exclude point
if (e.evt.shiftKey) {
pointType = -1;
} else {
// Default to include point
pointType = 1;
}
const point = this.createPoint(startCoord, pointType);
const newPoints = [...data.points, point];
this.$inputData.set({ ...data, points: newPoints });
// Ensure bbox remains visible if it exists
this.syncBboxVisibility();
}
return;
}
return;
}
// Handle points mode
if (data.type !== 'points') {
return;
}
// Only handle left-clicks
if (e.evt.button !== 0) {
return;
}
// Ignore if the stage is dragging/panning
if (this.manager.stage.getIsDragging()) {
return;
}
// Ignore if a point is being dragged
if (this.$isDraggingPoint.get()) {
return;
}
// Ignore if we are already processing
if (this.$isProcessing.get()) {
return;
}
// Ignore if the cursor is not within the stage (should never happen)
const cursorPos = this.manager.tool.$cursorPos.get();
if (!cursorPos) {
return;
}
// We need to offset the cursor position by the parent entity's position + pixel rect to get the correct position
const pixelRect = this.parent.transformer.$pixelRect.get();
const parentPosition = addCoords(this.parent.state.position, pixelRect);
// Normalize the cursor position to the parent entity's position
const normalizedPoint = offsetCoord(cursorPos.relative, parentPosition);
// Create a SAM point at the normalized position
let pointType: -1 | 0 | 1;
// If the shift key is held, invert the point type
if (e.evt.shiftKey) {
if (this.$pointType.get() === 1) {
pointType = -1;
} else if (this.$pointType.get() === -1) {
pointType = 1;
} else {
pointType = 0;
}
} else {
pointType = this.$pointType.get();
}
const point = this.createPoint(normalizedPoint, pointType);
const newPoints = [...data.points, point];
this.$inputData.set({ ...data, points: newPoints });
// Handle prompt mode - nothing to do on pointer up
};
/**
@@ -809,9 +844,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
})
);
// When the points change, process them if autoProcess is enabled
// When the input data changes, sync bbox visibility and process if autoProcess is enabled
this.subscriptions.add(
this.$inputData.listen((inputData) => {
// Always sync bbox visibility when input data changes
this.syncBboxVisibility();
if (!hasInputData(inputData)) {
return;
}
@@ -1133,19 +1171,17 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.resetEphemeralState();
};
setInputType = (type: PointsInputData['type'] | PromptInputData['type'] | BoundingBoxInputData['type']) => {
setInputType = (type: PromptInputData['type'] | VisualInputData['type']) => {
const data = this.$inputData.get();
if (data.type === type) {
return;
}
this.reset();
if (type === 'points') {
this.$inputData.set({ type: 'points', points: [] });
} else if (type === 'prompt') {
if (type === 'prompt') {
this.$inputData.set({ type: 'prompt', prompt: '' });
} else {
this.$inputData.set({ type: 'bounding_box', bbox: null });
// Hide bbox nodes when switching to bbox mode (they'll be shown when drawing)
this.$inputData.set({ type: 'visual', points: [], bbox: null });
// Hide bbox nodes when switching to visual mode (they'll be shown when drawing)
this.konva.bboxRect.visible(false);
this.konva.bboxTransformer.visible(false);
}
@@ -1197,11 +1233,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// Destroy ephemeral konva nodes
const data = this.$inputData.get();
if (data.type === 'points') {
if (data.type === 'visual') {
// Destroy all points
for (const point of data.points) {
point.konva.circle.destroy();
}
} else if (data.type === 'bounding_box') {
// Hide bounding box nodes
this.konva.bboxRect.visible(false);
this.konva.bboxTransformer.visible(false);
@@ -1220,8 +1256,8 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.konva.maskTween = null;
}
// Empty internal module state
this.$inputData.set({ type: 'points', points: [] });
// Empty internal module state - default to visual mode
this.$inputData.set({ type: 'visual', points: [], bbox: null });
this.$imageState.set(null);
this.$pointType.set(1);
this.$invert.set(false);
@@ -1238,7 +1274,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
*/
static buildGraph = (
{ image_name }: ImageDTO,
inputData: PointsInputData | PromptInputData | BoundingBoxInputData,
inputData: PromptInputData | VisualInputData,
invert: boolean,
model: SAMModel
): { graph: Graph; outputNodeId: string } => {
@@ -1250,25 +1286,33 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
image: { image_name },
});
// For visual mode, we may have points, bbox, or both
let pointLists = undefined;
let boundingBoxes = undefined;
if (inputData.type === 'visual') {
// If we have points, add them
if (inputData.points.length > 0) {
pointLists = [{ points: getSAMPoints(inputData).map(({ x, y, label }) => ({ x, y, label })) }];
}
// If we have a bbox, add it
if (inputData.bbox) {
boundingBoxes = [{
x_min: Math.round(inputData.bbox.x),
y_min: Math.round(inputData.bbox.y),
x_max: Math.round(inputData.bbox.x + inputData.bbox.width),
y_max: Math.round(inputData.bbox.y + inputData.bbox.height),
}];
}
}
const segmentAnything = graph.addNode({
id: getPrefixedId('segment_anything'),
type: 'segment_anything',
model: model === 'SAM1' ? 'segment-anything-huge' : 'segment-anything-2-large',
point_lists:
inputData.type === 'points'
? [{ points: getSAMPoints(inputData).map(({ x, y, label }) => ({ x, y, label })) }]
: undefined,
bounding_boxes:
inputData.type === 'bounding_box' && inputData.bbox
? [
{
x_min: Math.round(inputData.bbox.x),
y_min: Math.round(inputData.bbox.y),
x_max: Math.round(inputData.bbox.x + inputData.bbox.width),
y_max: Math.round(inputData.bbox.y + inputData.bbox.height),
},
]
: undefined,
point_lists: pointLists,
bounding_boxes: boundingBoxes,
mask_filter: 'largest',
apply_polygon_refinement: false,
});
@@ -1318,16 +1362,18 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
repr = () => {
const data = this.$inputData.get();
let inputData: any;
if (data.type === 'points') {
inputData = data.points.map(({ id, konva, label }) => ({
id,
label,
circle: getKonvaNodeDebugAttrs(konva.circle),
}));
} else if (data.type === 'prompt') {
if (data.type === 'prompt') {
inputData = { type: 'prompt', prompt: data.prompt };
} else {
inputData = { type: 'bounding_box', bbox: data.bbox || null };
inputData = {
type: 'visual',
points: data.points.map(({ id, konva, label }) => ({
id,
label,
circle: getKonvaNodeDebugAttrs(konva.circle),
})),
bbox: data.bbox || null,
};
}
return {