refactor(backend): simplify segment anything APIs

There was a really confusing aspect of the SAM pipeline classes where
they accepted deeply nested lists of different dimensions (bbox, points,
and labels).

The lengths of the lists are related; each point must have a
corresponding label, and if bboxes are provided with points, they must
be same length.

I've refactored the backend API to take a single list of SAMInput
objects. This class has a bbox and/or a list of points, making it much
simpler to provide the right shape of inputs.

Internally, the pipeline classes take rejigger these input classes to
have the correct nesting.

The Nodes still have an awkward API where you can provide both bboxes
and points of different lengths, so I added a pydantic validator that
enforces correct lenghts.
This commit is contained in:
psychedelicious
2025-09-10 16:24:46 +10:00
parent 7a073b6de7
commit d828502bc8
5 changed files with 220 additions and 185 deletions

View File

@@ -1,11 +1,12 @@
from enum import Enum
from typing import Any, Callable, Optional, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.util.metaenum import MetaEnum
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@@ -331,14 +332,9 @@ class ConditioningField(BaseModel):
)
class BoundingBoxField(BaseModel):
class BoundingBoxField(BoundingBox):
"""A bounding box primitive value."""
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
score: Optional[float] = Field(
default=None,
ge=0.0,
@@ -347,21 +343,6 @@ class BoundingBoxField(BaseModel):
"when the bounding box was produced by a detector and has an associated confidence score.",
)
@model_validator(mode="after")
def check_coords(self):
if self.x_min > self.x_max:
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
if self.y_min > self.y_max:
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
def tuple(self) -> Tuple[int, int, int, int]:
"""
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
"""
return (self.x_min, self.y_min, self.x_max, self.y_max)
class MetadataField(RootModel[dict[str, Any]]):
"""

View File

@@ -1,15 +1,15 @@
from enum import Enum
from itertools import zip_longest
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field
from transformers import AutoProcessor
from pydantic import BaseModel, Field, model_validator
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from transformers.models.sam2 import Sam2Model
from transformers.models.sam2.processing_sam2 import Sam2Processor
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
@@ -18,6 +18,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_2_pipeline import SegmentAnything2Pipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.image_util.segment_anything.shared import SAMInput, SAMPoint
SegmentAnythingModelKey = Literal[
"segment-anything-base",
@@ -39,22 +40,10 @@ SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
}
class SAMPointLabel(Enum):
negative = -1
neutral = 0
positive = 1
class SAMPoint(BaseModel):
x: int = Field(..., description="The x-coordinate of the point")
y: int = Field(..., description="The y-coordinate of the point")
label: SAMPointLabel = Field(..., description="The label of the point")
class SAMPointsField(BaseModel):
points: list[SAMPoint] = Field(..., description="The points of the object")
points: list[SAMPoint] = Field(..., description="The points of the object", min_length=1)
def to_list(self) -> list[list[int]]:
def to_list(self) -> list[list[float]]:
return [[point.x, point.y, point.label.value] for point in self.points]
@@ -91,14 +80,18 @@ class SegmentAnythingInvocation(BaseInvocation):
default="all",
)
@model_validator(mode="after")
def validate_points_and_boxes_len(self):
if self.point_lists is not None and self.bounding_boxes is not None:
if len(self.point_lists) != len(self.bounding_boxes):
raise ValueError("If both point_lists and bounding_boxes are provided, they must have the same length.")
return self
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
not self.point_lists or len(self.point_lists) == 0
):
@@ -118,91 +111,86 @@ class SegmentAnythingInvocation(BaseInvocation):
@staticmethod
def _load_sam_model(model_path: Path):
"""Load either SAM or SAM2 model based on the model path."""
model_path_str = str(model_path).lower()
sam_model = SamModel.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
sam_processor = SamProcessor.from_pretrained(model_path, local_files_only=True)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
if "sam2" in model_path_str:
# Load SAM2 model
try:
sam2_model = Sam2Model.from_pretrained(
model_path,
local_files_only=True,
# TODO: Investigate whether fp16 is supported by SAM2
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model from {model_path}. Error: {str(e)}")
@staticmethod
def _load_sam_2_model(model_path: Path):
sam2_model = Sam2Model.from_pretrained(model_path, local_files_only=True)
sam2_processor = Sam2Processor.from_pretrained(model_path, local_files_only=True)
return SegmentAnything2Pipeline(sam2_model=sam2_model, sam2_processor=sam2_processor)
try:
sam2_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
# Log what type of processor we got for debugging
processor_type = type(sam2_processor).__name__
print(f"Loaded processor type: {processor_type} for model {model_path}")
except Exception as e:
raise RuntimeError(f"Failed to load processor from {model_path}. Error: {str(e)}")
def _get_bounding_boxes(self) -> list[list[list[float]]] | None:
if self.bounding_boxes is None:
return None
return [[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]]
return SegmentAnything2Pipeline(sam2_model=sam2_model, sam2_processor=sam2_processor)
else:
# Load SAM model
sam_model = SamModel.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
def _get_sam_points_list(self) -> list[list[list[float]]] | None:
if not self.point_lists:
return None
return [p.to_list() for p in self.point_lists]
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
def _get_sam2_points_list_and_labels(
self,
) -> tuple[list[list[list[list[float]]]], list[list[list[int]]]] | tuple[None, None]:
if not self.point_lists:
return None, None
point_lists: list[list[list[list[float]]]] = []
point_labels: list[list[list[int]]] = []
for point_list in self.point_lists:
object_points: list[list[float]] = []
object_labels: list[int] = []
for point in point_list.points:
object_points.append([point.x, point.y])
object_labels.append(point.label.value)
point_lists.append([object_points])
point_labels.append([object_labels])
return point_lists, point_labels
def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]:
"""Use Segment Anything (SAM or SAM2) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the input format.
bounding_boxes = (
[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None
)
# Convert points to the format expected by the specific model
# We'll determine the format based on the actual pipeline type after loading
if self.point_lists:
# Prepare both formats - we'll use the appropriate one based on pipeline type
# SAM2 format: [[[[x, y]]]] and [[[label]]]
sam2_point_lists = []
sam2_point_labels = []
for point_list in self.point_lists:
object_points = []
object_labels = []
for point in point_list.points:
object_points.append([point.x, point.y])
object_labels.append(point.label.value)
sam2_point_lists.append([object_points])
sam2_point_labels.append([object_labels])
is_sam_2 = "segment-anything-2" in self.model
# SAM format: [[x, y, label]]
sam_point_lists = [p.to_list() for p in self.point_lists]
else:
sam2_point_lists = None
sam2_point_labels = None
sam_point_lists = None
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
) as pipeline,
):
# Check pipeline type dynamically and use appropriate point format
if isinstance(pipeline, SegmentAnything2Pipeline):
masks = pipeline.segment(
image=image,
bounding_boxes=bounding_boxes,
point_lists=sam2_point_lists,
point_labels=sam2_point_labels,
if is_sam_2:
source = SEGMENT_ANYTHING_MODEL_IDS[self.model]
loader = SegmentAnythingInvocation._load_sam_2_model
inputs: list[SAMInput] = []
for bbox_field, point_field in zip_longest(
self.bounding_boxes or [], self.point_lists or [], fillvalue=None
):
inputs.append(
SAMInput(
bounding_box=bbox_field,
points=point_field.points if point_field else None,
)
)
elif isinstance(pipeline, SegmentAnythingPipeline):
masks = pipeline.segment(image=image, bounding_boxes=bounding_boxes, point_lists=sam_point_lists)
else:
raise RuntimeError(f"Unknown pipeline type: {type(pipeline)}")
with context.models.load_remote_model(source=source, loader=loader) as pipeline:
assert isinstance(pipeline, SegmentAnything2Pipeline)
masks = pipeline.segment(image=image, inputs=inputs)
else:
source = SEGMENT_ANYTHING_MODEL_IDS[self.model]
loader = SegmentAnythingInvocation._load_sam_model
inputs: list[SAMInput] = []
for bbox_field, point_field in zip_longest(
self.bounding_boxes or [], self.point_lists or [], fillvalue=None
):
inputs.append(
SAMInput(
bounding_box=bbox_field,
points=point_field.points if point_field else None,
)
)
with context.models.load_remote_model(source=source, loader=loader) as pipeline:
assert isinstance(pipeline, SegmentAnythingPipeline)
masks = pipeline.segment(image=image, inputs=inputs)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:

View File

@@ -1,26 +1,20 @@
from typing import Optional, TypeAlias
from typing import Optional
import torch
from PIL import Image
# Import SAM2 components - these should be available in transformers 4.56.0+
from transformers.models.sam2 import Sam2Model
from transformers.models.sam2.processing_sam2 import Sam2Processor
from invokeai.backend.image_util.segment_anything.shared import SAMInput
from invokeai.backend.raw_model import RawModel
# Type aliases for the inputs to the SAM2 model.
ListOfBoundingBoxes: TypeAlias = list[list[int]]
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
ListOfPoints: TypeAlias = list[list[list[list[int]]]]
"""A list of points in SAM2 4D format: [[[[x, y]]]] (image_dim, object_dim, point_per_object_dim, coordinates)"""
ListOfPointLabels: TypeAlias = list[list[list[int]]]
"""A list of SAM2 point labels in 3D format: [[[label]]] (image_dim, object_dim, point_label)"""
class SegmentAnything2Pipeline(RawModel):
"""A wrapper class for the transformers SAM2 model and processor that makes it compatible with the model manager."""
def __init__(self, sam2_model: Sam2Model, sam2_processor):
def __init__(self, sam2_model: Sam2Model, sam2_processor: Sam2Processor):
"""Initialize the SAM2 pipeline.
Args:
@@ -45,9 +39,7 @@ class SegmentAnything2Pipeline(RawModel):
def segment(
self,
image: Image.Image,
bounding_boxes: list[list[int]] | None = None,
point_lists: list[list[list[int]]] | None = None,
point_labels: list[list[int]] | None = None,
inputs: list[SAMInput],
) -> torch.Tensor:
"""Segment an image using the SAM2 model.
@@ -65,37 +57,57 @@ class SegmentAnything2Pipeline(RawModel):
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
# Prep the inputs for SAM2:
# - SAM2 expects 4D input format: [[[[x, y]]]] for points and [[[label]]] for labels
# - Bounding boxes remain in 2D format: [[x_min, y_min, x_max, y_max]]
if bounding_boxes:
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
input_points: list[ListOfPoints] | None = None
input_labels: list[ListOfPointLabels] | None = None
elif point_lists and point_labels:
input_boxes: list[ListOfBoundingBoxes] | None = None
input_points = point_lists
input_labels = point_labels
else:
raise ValueError("Either bounding_boxes or (point_lists AND point_labels) must be provided.")
input_points: list[list[list[float]]] = []
input_labels: list[list[int]] = []
input_boxes: list[list[float]] = []
# Process the inputs using the SAM2 processor
inputs = self._sam2_processor(
for i in inputs:
box: list[float] | None = None
points: list[list[float]] | None = None
labels: list[int] | None = None
if i.bounding_box is not None:
box: list[float] | None = [
i.bounding_box.x_min,
i.bounding_box.y_min,
i.bounding_box.x_max,
i.bounding_box.y_max,
]
if i.points is not None:
points = []
labels = []
for point in i.points:
points.append([point.x, point.y])
labels.append(point.label.value)
if box is not None:
input_boxes.append(box)
if points is not None:
input_points.append(points)
if labels is not None:
input_labels.append(labels)
batched_input_boxes = [input_boxes] if input_boxes else None
batched_input_points = [input_points] if input_points else None
batched_input_labels = [input_labels] if input_labels else None
processed_inputs = self._sam2_processor(
images=image,
input_boxes=input_boxes,
input_points=input_points,
input_labels=input_labels,
input_boxes=batched_input_boxes,
input_points=batched_input_points,
input_labels=batched_input_labels,
return_tensors="pt",
).to(self._sam2_model.device)
# Generate masks using the SAM2 model
outputs = self._sam2_model(**inputs)
outputs = self._sam2_model(**processed_inputs)
# Post-process the masks to get the final segmentation
masks = self._sam2_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
original_sizes=processed_inputs.original_sizes,
reshaped_input_sizes=processed_inputs.reshaped_input_sizes,
)
# There should be only one batch.

View File

@@ -1,20 +1,13 @@
from typing import Optional, TypeAlias
from typing import Optional
import torch
from PIL import Image
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.image_util.segment_anything.shared import SAMInput
from invokeai.backend.raw_model import RawModel
# Type aliases for the inputs to the SAM model.
ListOfBoundingBoxes: TypeAlias = list[list[int]]
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
ListOfPoints: TypeAlias = list[list[int]]
"""A list of points. Each point is in the format [x, y]."""
ListOfPointLabels: TypeAlias = list[int]
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
@@ -38,8 +31,7 @@ class SegmentAnythingPipeline(RawModel):
def segment(
self,
image: Image.Image,
bounding_boxes: list[list[int]] | None = None,
point_lists: list[list[list[int]]] | None = None,
inputs: list[SAMInput],
) -> torch.Tensor:
"""Run the SAM model.
@@ -57,36 +49,49 @@ class SegmentAnythingPipeline(RawModel):
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
# Prep the inputs:
# - Create a list of bounding boxes or points and labels.
# - Add a batch dimension of 1 to the inputs.
if bounding_boxes:
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
input_points: list[ListOfPoints] | None = None
input_labels: list[ListOfPointLabels] | None = None
elif point_lists:
input_boxes: list[ListOfBoundingBoxes] | None = None
input_points: list[ListOfPoints] | None = []
input_labels: list[ListOfPointLabels] | None = []
for point_list in point_lists:
input_points.append([[p[0], p[1]] for p in point_list])
input_labels.append([p[2] for p in point_list])
input_boxes: list[list[list[float]]] = []
input_points: list[list[list[float]]] = []
input_labels: list[list[int]] = []
else:
raise ValueError("Either bounding_boxes or points and labels must be provided.")
for i in inputs:
box: list[float] | None = None
points: list[list[float]] | None = None
labels: list[int] | None = None
inputs = self._sam_processor(
if i.bounding_box is not None:
box: list[float] | None = [
i.bounding_box.x_min,
i.bounding_box.y_min,
i.bounding_box.x_max,
i.bounding_box.y_max,
]
if i.points is not None:
points = []
labels = []
for point in i.points:
points.append([point.x, point.y])
labels.append(point.label.value)
if box is not None:
input_boxes.append([box])
if points is not None:
input_points.append(points)
if labels is not None:
input_labels.append(labels)
processed_inputs = self._sam_processor(
images=image,
input_boxes=input_boxes,
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes if input_boxes else None,
input_points=input_points if input_points else None,
input_labels=input_labels if input_labels else None,
return_tensors="pt",
).to(self._sam_model.device)
outputs = self._sam_model(**inputs)
outputs = self._sam_model(**processed_inputs)
masks = self._sam_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
original_sizes=processed_inputs.original_sizes,
reshaped_input_sizes=processed_inputs.reshaped_input_sizes,
)
# There should be only one batch.

View File

@@ -0,0 +1,49 @@
from enum import Enum
from pydantic import BaseModel, model_validator
from pydantic.fields import Field
class BoundingBox(BaseModel):
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
@model_validator(mode="after")
def check_coords(self):
if self.x_min > self.x_max:
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
if self.y_min > self.y_max:
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
def tuple(self) -> tuple[int, int, int, int]:
"""
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
"""
return (self.x_min, self.y_min, self.x_max, self.y_max)
class SAMPointLabel(Enum):
negative = -1
neutral = 0
positive = 1
class SAMPoint(BaseModel):
x: int = Field(..., description="The x-coordinate of the point")
y: int = Field(..., description="The y-coordinate of the point")
label: SAMPointLabel = Field(..., description="The label of the point")
class SAMInput(BaseModel):
bounding_box: BoundingBox | None = Field(None, description="The bounding box to use for segmentation")
points: list[SAMPoint] | None = Field(None, description="The points to use for segmentation")
@model_validator(mode="after")
def check_input(self):
if not self.bounding_box and not self.points:
raise ValueError("Either bounding_box or points must be provided")
return self