mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): combine points and bbox in visual mode for SAM
Revised the Select Object feature to support two input modes: - Visual mode: Combined points and bounding box input for paired SAM inputs - Prompt mode: Text-based object selection (unchanged) Key changes: - Replaced three input types (points, prompt, bbox) with two (visual, prompt) - Visual mode supports both point and bbox inputs simultaneously - Click to add include points, Shift+click for exclude points - Click and drag to draw bounding box - Fixed bbox visibility issues when adding points - Fixed coordinate system issues for proper bbox positioning - Added proper event handling and interaction controls 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -63,18 +63,14 @@ const SelectObjectContent = memo(
|
||||
adapter.segmentAnything.saveAs('control_layer');
|
||||
}, [adapter.segmentAnything]);
|
||||
|
||||
const setInputToPoints = useCallback(() => {
|
||||
adapter.segmentAnything.setInputType('points');
|
||||
const setInputToVisual = useCallback(() => {
|
||||
adapter.segmentAnything.setInputType('visual');
|
||||
}, [adapter.segmentAnything]);
|
||||
|
||||
const setInputToPrompt = useCallback(() => {
|
||||
adapter.segmentAnything.setInputType('prompt');
|
||||
}, [adapter.segmentAnything]);
|
||||
|
||||
const setInputToBoundingBox = useCallback(() => {
|
||||
adapter.segmentAnything.setInputType('bounding_box');
|
||||
}, [adapter.segmentAnything]);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'applySegmentAnything',
|
||||
category: 'canvas',
|
||||
@@ -123,11 +119,10 @@ const SelectObjectContent = memo(
|
||||
|
||||
<Flex w="full" justifyContent="space-between" py={2}>
|
||||
<ButtonGroup>
|
||||
<Button onClick={setInputToPoints}>Points</Button>
|
||||
<Button onClick={setInputToVisual}>Visual</Button>
|
||||
<Button onClick={setInputToPrompt}>Prompt</Button>
|
||||
<Button onClick={setInputToBoundingBox}>Bounding Box</Button>
|
||||
</ButtonGroup>
|
||||
{inputData.type === 'points' && <SelectObjectPointType adapter={adapter} />}
|
||||
{inputData.type === 'visual' && <SelectObjectPointType adapter={adapter} />}
|
||||
<SelectObjectInvert adapter={adapter} />
|
||||
</Flex>
|
||||
|
||||
@@ -232,20 +227,21 @@ const SelectObjectHelpTooltipContent = memo(() => {
|
||||
<Text>
|
||||
<Trans i18nKey="controlLayers.selectObject.help2" components={{ Bold: <Bold /> }} />
|
||||
</Text>
|
||||
<Text>
|
||||
<Trans i18nKey="controlLayers.selectObject.help3" />
|
||||
</Text>
|
||||
<Text fontWeight="semibold">Visual Mode:</Text>
|
||||
<UnorderedList>
|
||||
<ListItem>{t('controlLayers.selectObject.clickToAdd')}</ListItem>
|
||||
<ListItem>{t('controlLayers.selectObject.dragToMove')}</ListItem>
|
||||
<ListItem>{t('controlLayers.selectObject.clickToRemove')}</ListItem>
|
||||
</UnorderedList>
|
||||
<Text fontWeight="semibold">Bounding Box Mode:</Text>
|
||||
<UnorderedList>
|
||||
<ListItem>Click and drag to draw a bounding box around an object</ListItem>
|
||||
<ListItem>Click to add include points (green)</ListItem>
|
||||
<ListItem>Shift + Click to add exclude points (red)</ListItem>
|
||||
<ListItem>Click and drag to draw a bounding box</ListItem>
|
||||
<ListItem>Click on points to remove them</ListItem>
|
||||
<ListItem>Drag points to reposition them</ListItem>
|
||||
<ListItem>Resize the box using the corner handles</ListItem>
|
||||
<ListItem>Drag the box to reposition it</ListItem>
|
||||
</UnorderedList>
|
||||
<Text fontWeight="semibold">Prompt Mode:</Text>
|
||||
<UnorderedList>
|
||||
<ListItem>Type a text description of the object to select</ListItem>
|
||||
<ListItem>The AI will find and segment matching objects</ListItem>
|
||||
</UnorderedList>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -104,18 +104,14 @@ type SAMPointState = {
|
||||
};
|
||||
};
|
||||
|
||||
type PointsInputData = {
|
||||
type: 'points';
|
||||
points: SAMPointState[];
|
||||
};
|
||||
|
||||
type PromptInputData = {
|
||||
type: 'prompt';
|
||||
prompt: string;
|
||||
};
|
||||
|
||||
type BoundingBoxInputData = {
|
||||
type: 'bounding_box';
|
||||
type VisualInputData = {
|
||||
type: 'visual';
|
||||
points: SAMPointState[];
|
||||
bbox: {
|
||||
x: number;
|
||||
y: number;
|
||||
@@ -124,20 +120,19 @@ type BoundingBoxInputData = {
|
||||
} | null;
|
||||
};
|
||||
|
||||
const hasInputData = (data: PointsInputData | PromptInputData | BoundingBoxInputData): boolean => {
|
||||
if (data.type === 'points') {
|
||||
return data.points.length > 0;
|
||||
} else if (data.type === 'prompt') {
|
||||
const hasInputData = (data: PromptInputData | VisualInputData): boolean => {
|
||||
if (data.type === 'prompt') {
|
||||
return data.prompt.trim() !== '';
|
||||
} else {
|
||||
return data.bbox !== null;
|
||||
// Visual mode has input if there are points OR a bbox
|
||||
return data.points.length > 0 || data.bbox !== null;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
|
||||
*/
|
||||
const getSAMPoints = (data: PointsInputData): SAMPointWithId[] => {
|
||||
const getSAMPoints = (data: VisualInputData): SAMPointWithId[] => {
|
||||
const points: SAMPointWithId[] = [];
|
||||
|
||||
for (const { id, coord, label } of data.points) {
|
||||
@@ -152,13 +147,11 @@ const getSAMPoints = (data: PointsInputData): SAMPointWithId[] => {
|
||||
return points;
|
||||
};
|
||||
|
||||
const getHashableInputData = (data: PointsInputData | PromptInputData | BoundingBoxInputData) => {
|
||||
if (data.type === 'points') {
|
||||
return { type: 'points', points: getSAMPoints(data) } as const;
|
||||
} else if (data.type === 'prompt') {
|
||||
const getHashableInputData = (data: PromptInputData | VisualInputData) => {
|
||||
if (data.type === 'prompt') {
|
||||
return { type: 'prompt', prompt: data.prompt } as const;
|
||||
} else {
|
||||
return { type: 'bounding_box', bbox: data.bbox } as const;
|
||||
return { type: 'visual', points: getSAMPoints(data), bbox: data.bbox } as const;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -228,7 +221,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
*/
|
||||
$hasImageState = computed(this.$imageState, (imageState) => imageState !== null);
|
||||
|
||||
$inputData = atom<PointsInputData | PromptInputData | BoundingBoxInputData>({ type: 'points', points: [] });
|
||||
$inputData = atom<PromptInputData | VisualInputData>({ type: 'visual', points: [], bbox: null });
|
||||
|
||||
/**
|
||||
* Whether the module has points. This is a computed value based on $points.
|
||||
@@ -376,7 +369,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// Add event handlers for bbox transformer
|
||||
this.konva.bboxTransformer.on('transformend', () => {
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'bounding_box') {
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -406,6 +399,29 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
});
|
||||
});
|
||||
|
||||
// Handle bbox dragging
|
||||
this.konva.bboxRect.on('dragend', () => {
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
|
||||
const x = this.konva.bboxRect.x();
|
||||
const y = this.konva.bboxRect.y();
|
||||
const width = this.konva.bboxRect.width() * this.konva.bboxRect.scaleX();
|
||||
const height = this.konva.bboxRect.height() * this.konva.bboxRect.scaleY();
|
||||
|
||||
this.$inputData.set({
|
||||
...data,
|
||||
bbox: {
|
||||
x: x,
|
||||
y: y,
|
||||
width: width,
|
||||
height: height,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
// Stop event propagation when interacting with the bbox
|
||||
this.konva.bboxRect.on('mousedown touchstart', (e) => {
|
||||
e.cancelBubble = true;
|
||||
@@ -418,7 +434,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// Add event handler for bbox rect drag
|
||||
this.konva.bboxRect.on('dragend', () => {
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'bounding_box') {
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -493,16 +509,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
circle.destroy();
|
||||
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'points') {
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
|
||||
const newPoints = data.points.filter((point) => point.id !== id);
|
||||
if (newPoints.length === 0) {
|
||||
this.resetEphemeralState();
|
||||
} else {
|
||||
this.$inputData.set({ ...data, points: newPoints });
|
||||
}
|
||||
this.$inputData.set({ ...data, points: newPoints });
|
||||
});
|
||||
|
||||
circle.on('dragstart', () => {
|
||||
@@ -516,7 +528,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.$isDraggingPoint.set(false);
|
||||
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'points') {
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -549,7 +561,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
*/
|
||||
syncPointScales = () => {
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'points') {
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
|
||||
@@ -561,11 +573,38 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointerdown event on the stage. This is used to start drawing a bounding box.
|
||||
* Synchronizes the bbox visibility based on the current input data.
|
||||
*/
|
||||
syncBboxVisibility = () => {
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.bbox) {
|
||||
// Update bbox position and size
|
||||
this.konva.bboxRect.setAttrs({
|
||||
x: data.bbox.x,
|
||||
y: data.bbox.y,
|
||||
width: data.bbox.width,
|
||||
height: data.bbox.height,
|
||||
visible: true,
|
||||
});
|
||||
this.konva.bboxTransformer.visible(true);
|
||||
} else {
|
||||
// Hide bbox if there's no bbox data
|
||||
this.konva.bboxRect.visible(false);
|
||||
this.konva.bboxTransformer.visible(false);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointerdown event on the stage. This is used to start drawing a bounding box in visual mode.
|
||||
* We'll start tracking the position but won't decide if it's a bbox or point until pointerup.
|
||||
*/
|
||||
onStagePointerDown = (e: KonvaEventObject<PointerEvent>) => {
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'bounding_box') {
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -602,19 +641,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// Normalize the cursor position to the parent entity's position
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, parentPosition);
|
||||
|
||||
// Start drawing the bounding box
|
||||
// Start potential bbox drawing (we'll decide in pointerup if it's actually a bbox or a point)
|
||||
this.$isBboxDrawing.set(true);
|
||||
this.$bboxStartCoord.set(normalizedPoint);
|
||||
|
||||
// Reset the bbox rect position and size
|
||||
this.konva.bboxRect.setAttrs({
|
||||
x: normalizedPoint.x,
|
||||
y: normalizedPoint.y,
|
||||
width: 0,
|
||||
height: 0,
|
||||
visible: true,
|
||||
});
|
||||
this.konva.bboxTransformer.visible(false);
|
||||
// Prepare for potential new bbox but don't hide existing one yet
|
||||
// We'll only update visibility during drag if it's actually a new bbox
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -622,7 +654,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
*/
|
||||
onStagePointerMove = (e: KonvaEventObject<PointerEvent>) => {
|
||||
const data = this.$inputData.get();
|
||||
if (data.type !== 'bounding_box') {
|
||||
if (data.type !== 'visual') {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -654,13 +686,20 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
const width = Math.abs(currentPoint.x - startCoord.x);
|
||||
const height = Math.abs(currentPoint.y - startCoord.y);
|
||||
|
||||
// Update the bbox rect
|
||||
this.konva.bboxRect.setAttrs({
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
});
|
||||
// Only show the bbox and hide transformer if we've dragged more than a threshold (5 pixels)
|
||||
if (width > 5 || height > 5) {
|
||||
// Now we know it's a drag for a new bbox, hide the transformer
|
||||
this.konva.bboxTransformer.visible(false);
|
||||
|
||||
// Update and show the new bbox rect
|
||||
this.konva.bboxRect.setAttrs({
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
visible: true,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -669,109 +708,105 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
onStagePointerUp = (e: KonvaEventObject<PointerEvent>) => {
|
||||
const data = this.$inputData.get();
|
||||
|
||||
// Handle bounding box mode
|
||||
if (data.type === 'bounding_box') {
|
||||
if (!this.$isBboxDrawing.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle visual mode
|
||||
if (data.type === 'visual') {
|
||||
// Only handle left-clicks
|
||||
if (e.evt.button !== 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Stop drawing
|
||||
this.$isBboxDrawing.set(false);
|
||||
this.$bboxStartCoord.set(null);
|
||||
// Check if we started a potential bbox draw
|
||||
if (this.$isBboxDrawing.get()) {
|
||||
const startCoord = this.$bboxStartCoord.get();
|
||||
|
||||
// Check if we actually dragged by calculating from start position
|
||||
const cursorPos = this.manager.tool.$cursorPos.get();
|
||||
if (!cursorPos || !startCoord) {
|
||||
// Stop tracking even if we don't have valid coords
|
||||
this.$isBboxDrawing.set(false);
|
||||
this.$bboxStartCoord.set(null);
|
||||
return;
|
||||
}
|
||||
|
||||
// Stop tracking (after we've used the values)
|
||||
this.$isBboxDrawing.set(false);
|
||||
this.$bboxStartCoord.set(null);
|
||||
|
||||
const pixelRect = this.parent.transformer.$pixelRect.get();
|
||||
const parentPosition = addCoords(this.parent.state.position, pixelRect);
|
||||
const currentPoint = offsetCoord(cursorPos.relative, parentPosition);
|
||||
|
||||
const dragWidth = Math.abs(currentPoint.x - startCoord.x);
|
||||
const dragHeight = Math.abs(currentPoint.y - startCoord.y);
|
||||
|
||||
// Get the final bbox dimensions from the rect's attributes (not getClientRect which gives screen coords)
|
||||
const x = this.konva.bboxRect.x();
|
||||
const y = this.konva.bboxRect.y();
|
||||
const width = this.konva.bboxRect.width() * this.konva.bboxRect.scaleX();
|
||||
const height = this.konva.bboxRect.height() * this.konva.bboxRect.scaleY();
|
||||
// Check if we actually dragged (moved more than threshold)
|
||||
if (dragWidth > 5 || dragHeight > 5) {
|
||||
// Get the final bbox dimensions from the rect's attributes
|
||||
const x = this.konva.bboxRect.x();
|
||||
const y = this.konva.bboxRect.y();
|
||||
const width = this.konva.bboxRect.width() * this.konva.bboxRect.scaleX();
|
||||
const height = this.konva.bboxRect.height() * this.konva.bboxRect.scaleY();
|
||||
// It was a drag - save the bbox
|
||||
this.$inputData.set({
|
||||
...data,
|
||||
bbox: {
|
||||
x: x,
|
||||
y: y,
|
||||
width: width,
|
||||
height: height,
|
||||
},
|
||||
});
|
||||
|
||||
// Only save the bbox if it has a reasonable size (not just a click)
|
||||
if (width > 5 && height > 5) {
|
||||
this.$inputData.set({
|
||||
...data,
|
||||
bbox: {
|
||||
x: x,
|
||||
y: y,
|
||||
width: width,
|
||||
height: height,
|
||||
},
|
||||
});
|
||||
// Show the transformer for resizing
|
||||
this.konva.bboxTransformer.visible(true);
|
||||
} else {
|
||||
// It was just a click, not a drag - add a point instead
|
||||
// Make sure existing bbox stays visible
|
||||
this.syncBboxVisibility();
|
||||
|
||||
// Ignore if the stage is dragging/panning
|
||||
if (this.manager.stage.getIsDragging()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Show the transformer for resizing
|
||||
this.konva.bboxTransformer.visible(true);
|
||||
} else {
|
||||
// Too small, hide the bbox
|
||||
this.konva.bboxRect.visible(false);
|
||||
this.konva.bboxTransformer.visible(false);
|
||||
this.$inputData.set({
|
||||
...data,
|
||||
bbox: null,
|
||||
});
|
||||
// Ignore if a point is being dragged
|
||||
if (this.$isDraggingPoint.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if we are already processing
|
||||
if (this.$isProcessing.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!startCoord) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a SAM point at the click position
|
||||
let pointType: -1 | 0 | 1;
|
||||
// If shift key is held, create an exclude point
|
||||
if (e.evt.shiftKey) {
|
||||
pointType = -1;
|
||||
} else {
|
||||
// Default to include point
|
||||
pointType = 1;
|
||||
}
|
||||
|
||||
const point = this.createPoint(startCoord, pointType);
|
||||
const newPoints = [...data.points, point];
|
||||
this.$inputData.set({ ...data, points: newPoints });
|
||||
|
||||
// Ensure bbox remains visible if it exists
|
||||
this.syncBboxVisibility();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle points mode
|
||||
if (data.type !== 'points') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only handle left-clicks
|
||||
if (e.evt.button !== 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if the stage is dragging/panning
|
||||
if (this.manager.stage.getIsDragging()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if a point is being dragged
|
||||
if (this.$isDraggingPoint.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if we are already processing
|
||||
if (this.$isProcessing.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if the cursor is not within the stage (should never happen)
|
||||
const cursorPos = this.manager.tool.$cursorPos.get();
|
||||
if (!cursorPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We need to offset the cursor position by the parent entity's position + pixel rect to get the correct position
|
||||
const pixelRect = this.parent.transformer.$pixelRect.get();
|
||||
const parentPosition = addCoords(this.parent.state.position, pixelRect);
|
||||
|
||||
// Normalize the cursor position to the parent entity's position
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, parentPosition);
|
||||
|
||||
// Create a SAM point at the normalized position
|
||||
let pointType: -1 | 0 | 1;
|
||||
// If the shift key is held, invert the point type
|
||||
if (e.evt.shiftKey) {
|
||||
if (this.$pointType.get() === 1) {
|
||||
pointType = -1;
|
||||
} else if (this.$pointType.get() === -1) {
|
||||
pointType = 1;
|
||||
} else {
|
||||
pointType = 0;
|
||||
}
|
||||
} else {
|
||||
pointType = this.$pointType.get();
|
||||
}
|
||||
const point = this.createPoint(normalizedPoint, pointType);
|
||||
const newPoints = [...data.points, point];
|
||||
this.$inputData.set({ ...data, points: newPoints });
|
||||
// Handle prompt mode - nothing to do on pointer up
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -809,9 +844,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
})
|
||||
);
|
||||
|
||||
// When the points change, process them if autoProcess is enabled
|
||||
// When the input data changes, sync bbox visibility and process if autoProcess is enabled
|
||||
this.subscriptions.add(
|
||||
this.$inputData.listen((inputData) => {
|
||||
// Always sync bbox visibility when input data changes
|
||||
this.syncBboxVisibility();
|
||||
|
||||
if (!hasInputData(inputData)) {
|
||||
return;
|
||||
}
|
||||
@@ -1133,19 +1171,17 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.resetEphemeralState();
|
||||
};
|
||||
|
||||
setInputType = (type: PointsInputData['type'] | PromptInputData['type'] | BoundingBoxInputData['type']) => {
|
||||
setInputType = (type: PromptInputData['type'] | VisualInputData['type']) => {
|
||||
const data = this.$inputData.get();
|
||||
if (data.type === type) {
|
||||
return;
|
||||
}
|
||||
this.reset();
|
||||
if (type === 'points') {
|
||||
this.$inputData.set({ type: 'points', points: [] });
|
||||
} else if (type === 'prompt') {
|
||||
if (type === 'prompt') {
|
||||
this.$inputData.set({ type: 'prompt', prompt: '' });
|
||||
} else {
|
||||
this.$inputData.set({ type: 'bounding_box', bbox: null });
|
||||
// Hide bbox nodes when switching to bbox mode (they'll be shown when drawing)
|
||||
this.$inputData.set({ type: 'visual', points: [], bbox: null });
|
||||
// Hide bbox nodes when switching to visual mode (they'll be shown when drawing)
|
||||
this.konva.bboxRect.visible(false);
|
||||
this.konva.bboxTransformer.visible(false);
|
||||
}
|
||||
@@ -1197,11 +1233,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
|
||||
// Destroy ephemeral konva nodes
|
||||
const data = this.$inputData.get();
|
||||
if (data.type === 'points') {
|
||||
if (data.type === 'visual') {
|
||||
// Destroy all points
|
||||
for (const point of data.points) {
|
||||
point.konva.circle.destroy();
|
||||
}
|
||||
} else if (data.type === 'bounding_box') {
|
||||
// Hide bounding box nodes
|
||||
this.konva.bboxRect.visible(false);
|
||||
this.konva.bboxTransformer.visible(false);
|
||||
@@ -1220,8 +1256,8 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.konva.maskTween = null;
|
||||
}
|
||||
|
||||
// Empty internal module state
|
||||
this.$inputData.set({ type: 'points', points: [] });
|
||||
// Empty internal module state - default to visual mode
|
||||
this.$inputData.set({ type: 'visual', points: [], bbox: null });
|
||||
this.$imageState.set(null);
|
||||
this.$pointType.set(1);
|
||||
this.$invert.set(false);
|
||||
@@ -1238,7 +1274,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
*/
|
||||
static buildGraph = (
|
||||
{ image_name }: ImageDTO,
|
||||
inputData: PointsInputData | PromptInputData | BoundingBoxInputData,
|
||||
inputData: PromptInputData | VisualInputData,
|
||||
invert: boolean,
|
||||
model: SAMModel
|
||||
): { graph: Graph; outputNodeId: string } => {
|
||||
@@ -1250,25 +1286,33 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
image: { image_name },
|
||||
});
|
||||
|
||||
// For visual mode, we may have points, bbox, or both
|
||||
let pointLists = undefined;
|
||||
let boundingBoxes = undefined;
|
||||
|
||||
if (inputData.type === 'visual') {
|
||||
// If we have points, add them
|
||||
if (inputData.points.length > 0) {
|
||||
pointLists = [{ points: getSAMPoints(inputData).map(({ x, y, label }) => ({ x, y, label })) }];
|
||||
}
|
||||
|
||||
// If we have a bbox, add it
|
||||
if (inputData.bbox) {
|
||||
boundingBoxes = [{
|
||||
x_min: Math.round(inputData.bbox.x),
|
||||
y_min: Math.round(inputData.bbox.y),
|
||||
x_max: Math.round(inputData.bbox.x + inputData.bbox.width),
|
||||
y_max: Math.round(inputData.bbox.y + inputData.bbox.height),
|
||||
}];
|
||||
}
|
||||
}
|
||||
|
||||
const segmentAnything = graph.addNode({
|
||||
id: getPrefixedId('segment_anything'),
|
||||
type: 'segment_anything',
|
||||
model: model === 'SAM1' ? 'segment-anything-huge' : 'segment-anything-2-large',
|
||||
point_lists:
|
||||
inputData.type === 'points'
|
||||
? [{ points: getSAMPoints(inputData).map(({ x, y, label }) => ({ x, y, label })) }]
|
||||
: undefined,
|
||||
bounding_boxes:
|
||||
inputData.type === 'bounding_box' && inputData.bbox
|
||||
? [
|
||||
{
|
||||
x_min: Math.round(inputData.bbox.x),
|
||||
y_min: Math.round(inputData.bbox.y),
|
||||
x_max: Math.round(inputData.bbox.x + inputData.bbox.width),
|
||||
y_max: Math.round(inputData.bbox.y + inputData.bbox.height),
|
||||
},
|
||||
]
|
||||
: undefined,
|
||||
point_lists: pointLists,
|
||||
bounding_boxes: boundingBoxes,
|
||||
mask_filter: 'largest',
|
||||
apply_polygon_refinement: false,
|
||||
});
|
||||
@@ -1318,16 +1362,18 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
repr = () => {
|
||||
const data = this.$inputData.get();
|
||||
let inputData: any;
|
||||
if (data.type === 'points') {
|
||||
inputData = data.points.map(({ id, konva, label }) => ({
|
||||
id,
|
||||
label,
|
||||
circle: getKonvaNodeDebugAttrs(konva.circle),
|
||||
}));
|
||||
} else if (data.type === 'prompt') {
|
||||
if (data.type === 'prompt') {
|
||||
inputData = { type: 'prompt', prompt: data.prompt };
|
||||
} else {
|
||||
inputData = { type: 'bounding_box', bbox: data.bbox || null };
|
||||
inputData = {
|
||||
type: 'visual',
|
||||
points: data.points.map(({ id, konva, label }) => ({
|
||||
id,
|
||||
label,
|
||||
circle: getKonvaNodeDebugAttrs(konva.circle),
|
||||
})),
|
||||
bbox: data.bbox || null,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user