mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(backend): simplify segment anything APIs
There was a really confusing aspect of the SAM pipeline classes where they accepted deeply nested lists of different dimensions (bbox, points, and labels). The lengths of the lists are related; each point must have a corresponding label, and if bboxes are provided with points, they must be same length. I've refactored the backend API to take a single list of SAMInput objects. This class has a bbox and/or a list of points, making it much simpler to provide the right shape of inputs. Internally, the pipeline classes take rejigger these input classes to have the correct nesting. The Nodes still have an awkward API where you can provide both bboxes and points of different lengths, so I added a pydantic validator that enforces correct lenghts.
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
@@ -331,14 +332,9 @@ class ConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxField(BaseModel):
|
||||
class BoundingBoxField(BoundingBox):
|
||||
"""A bounding box primitive value."""
|
||||
|
||||
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||
|
||||
score: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
@@ -347,21 +343,6 @@ class BoundingBoxField(BaseModel):
|
||||
"when the bounding box was produced by a detector and has an associated confidence score.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_coords(self):
|
||||
if self.x_min > self.x_max:
|
||||
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
|
||||
if self.y_min > self.y_max:
|
||||
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
|
||||
return self
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
|
||||
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
|
||||
"""
|
||||
return (self.x_min, self.y_min, self.x_max, self.y_max)
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from enum import Enum
|
||||
from itertools import zip_longest
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoProcessor
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
from transformers.models.sam2 import Sam2Model
|
||||
from transformers.models.sam2.processing_sam2 import Sam2Processor
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
||||
@@ -18,6 +18,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_2_pipeline import SegmentAnything2Pipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.image_util.segment_anything.shared import SAMInput, SAMPoint
|
||||
|
||||
SegmentAnythingModelKey = Literal[
|
||||
"segment-anything-base",
|
||||
@@ -39,22 +40,10 @@ SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
|
||||
}
|
||||
|
||||
|
||||
class SAMPointLabel(Enum):
|
||||
negative = -1
|
||||
neutral = 0
|
||||
positive = 1
|
||||
|
||||
|
||||
class SAMPoint(BaseModel):
|
||||
x: int = Field(..., description="The x-coordinate of the point")
|
||||
y: int = Field(..., description="The y-coordinate of the point")
|
||||
label: SAMPointLabel = Field(..., description="The label of the point")
|
||||
|
||||
|
||||
class SAMPointsField(BaseModel):
|
||||
points: list[SAMPoint] = Field(..., description="The points of the object")
|
||||
points: list[SAMPoint] = Field(..., description="The points of the object", min_length=1)
|
||||
|
||||
def to_list(self) -> list[list[int]]:
|
||||
def to_list(self) -> list[list[float]]:
|
||||
return [[point.x, point.y, point.label.value] for point in self.points]
|
||||
|
||||
|
||||
@@ -91,14 +80,18 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
default="all",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_points_and_boxes_len(self):
|
||||
if self.point_lists is not None and self.bounding_boxes is not None:
|
||||
if len(self.point_lists) != len(self.bounding_boxes):
|
||||
raise ValueError("If both point_lists and bounding_boxes are provided, they must have the same length.")
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
if self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
|
||||
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
|
||||
not self.point_lists or len(self.point_lists) == 0
|
||||
):
|
||||
@@ -118,91 +111,86 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
|
||||
@staticmethod
|
||||
def _load_sam_model(model_path: Path):
|
||||
"""Load either SAM or SAM2 model based on the model path."""
|
||||
model_path_str = str(model_path).lower()
|
||||
sam_model = SamModel.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
sam_processor = SamProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
if "sam2" in model_path_str:
|
||||
# Load SAM2 model
|
||||
try:
|
||||
sam2_model = Sam2Model.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO: Investigate whether fp16 is supported by SAM2
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load SAM2 model from {model_path}. Error: {str(e)}")
|
||||
@staticmethod
|
||||
def _load_sam_2_model(model_path: Path):
|
||||
sam2_model = Sam2Model.from_pretrained(model_path, local_files_only=True)
|
||||
sam2_processor = Sam2Processor.from_pretrained(model_path, local_files_only=True)
|
||||
return SegmentAnything2Pipeline(sam2_model=sam2_model, sam2_processor=sam2_processor)
|
||||
|
||||
try:
|
||||
sam2_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
# Log what type of processor we got for debugging
|
||||
processor_type = type(sam2_processor).__name__
|
||||
print(f"Loaded processor type: {processor_type} for model {model_path}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load processor from {model_path}. Error: {str(e)}")
|
||||
def _get_bounding_boxes(self) -> list[list[list[float]]] | None:
|
||||
if self.bounding_boxes is None:
|
||||
return None
|
||||
return [[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]]
|
||||
|
||||
return SegmentAnything2Pipeline(sam2_model=sam2_model, sam2_processor=sam2_processor)
|
||||
else:
|
||||
# Load SAM model
|
||||
sam_model = SamModel.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
def _get_sam_points_list(self) -> list[list[list[float]]] | None:
|
||||
if not self.point_lists:
|
||||
return None
|
||||
return [p.to_list() for p in self.point_lists]
|
||||
|
||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||
def _get_sam2_points_list_and_labels(
|
||||
self,
|
||||
) -> tuple[list[list[list[list[float]]]], list[list[list[int]]]] | tuple[None, None]:
|
||||
if not self.point_lists:
|
||||
return None, None
|
||||
point_lists: list[list[list[list[float]]]] = []
|
||||
point_labels: list[list[list[int]]] = []
|
||||
for point_list in self.point_lists:
|
||||
object_points: list[list[float]] = []
|
||||
object_labels: list[int] = []
|
||||
for point in point_list.points:
|
||||
object_points.append([point.x, point.y])
|
||||
object_labels.append(point.label.value)
|
||||
point_lists.append([object_points])
|
||||
point_labels.append([object_labels])
|
||||
return point_lists, point_labels
|
||||
|
||||
def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]:
|
||||
"""Use Segment Anything (SAM or SAM2) to generate masks given an image + a set of bounding boxes."""
|
||||
# Convert the bounding boxes to the input format.
|
||||
bounding_boxes = (
|
||||
[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None
|
||||
)
|
||||
|
||||
# Convert points to the format expected by the specific model
|
||||
# We'll determine the format based on the actual pipeline type after loading
|
||||
if self.point_lists:
|
||||
# Prepare both formats - we'll use the appropriate one based on pipeline type
|
||||
# SAM2 format: [[[[x, y]]]] and [[[label]]]
|
||||
sam2_point_lists = []
|
||||
sam2_point_labels = []
|
||||
for point_list in self.point_lists:
|
||||
object_points = []
|
||||
object_labels = []
|
||||
for point in point_list.points:
|
||||
object_points.append([point.x, point.y])
|
||||
object_labels.append(point.label.value)
|
||||
sam2_point_lists.append([object_points])
|
||||
sam2_point_labels.append([object_labels])
|
||||
is_sam_2 = "segment-anything-2" in self.model
|
||||
|
||||
# SAM format: [[x, y, label]]
|
||||
sam_point_lists = [p.to_list() for p in self.point_lists]
|
||||
else:
|
||||
sam2_point_lists = None
|
||||
sam2_point_labels = None
|
||||
sam_point_lists = None
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||
) as pipeline,
|
||||
):
|
||||
# Check pipeline type dynamically and use appropriate point format
|
||||
if isinstance(pipeline, SegmentAnything2Pipeline):
|
||||
masks = pipeline.segment(
|
||||
image=image,
|
||||
bounding_boxes=bounding_boxes,
|
||||
point_lists=sam2_point_lists,
|
||||
point_labels=sam2_point_labels,
|
||||
if is_sam_2:
|
||||
source = SEGMENT_ANYTHING_MODEL_IDS[self.model]
|
||||
loader = SegmentAnythingInvocation._load_sam_2_model
|
||||
inputs: list[SAMInput] = []
|
||||
for bbox_field, point_field in zip_longest(
|
||||
self.bounding_boxes or [], self.point_lists or [], fillvalue=None
|
||||
):
|
||||
inputs.append(
|
||||
SAMInput(
|
||||
bounding_box=bbox_field,
|
||||
points=point_field.points if point_field else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(pipeline, SegmentAnythingPipeline):
|
||||
masks = pipeline.segment(image=image, bounding_boxes=bounding_boxes, point_lists=sam_point_lists)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown pipeline type: {type(pipeline)}")
|
||||
with context.models.load_remote_model(source=source, loader=loader) as pipeline:
|
||||
assert isinstance(pipeline, SegmentAnything2Pipeline)
|
||||
masks = pipeline.segment(image=image, inputs=inputs)
|
||||
else:
|
||||
source = SEGMENT_ANYTHING_MODEL_IDS[self.model]
|
||||
loader = SegmentAnythingInvocation._load_sam_model
|
||||
inputs: list[SAMInput] = []
|
||||
for bbox_field, point_field in zip_longest(
|
||||
self.bounding_boxes or [], self.point_lists or [], fillvalue=None
|
||||
):
|
||||
inputs.append(
|
||||
SAMInput(
|
||||
bounding_box=bbox_field,
|
||||
points=point_field.points if point_field else None,
|
||||
)
|
||||
)
|
||||
with context.models.load_remote_model(source=source, loader=loader) as pipeline:
|
||||
assert isinstance(pipeline, SegmentAnythingPipeline)
|
||||
masks = pipeline.segment(image=image, inputs=inputs)
|
||||
|
||||
masks = self._process_masks(masks)
|
||||
if self.apply_polygon_refinement:
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
from typing import Optional, TypeAlias
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
# Import SAM2 components - these should be available in transformers 4.56.0+
|
||||
from transformers.models.sam2 import Sam2Model
|
||||
from transformers.models.sam2.processing_sam2 import Sam2Processor
|
||||
|
||||
from invokeai.backend.image_util.segment_anything.shared import SAMInput
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
# Type aliases for the inputs to the SAM2 model.
|
||||
ListOfBoundingBoxes: TypeAlias = list[list[int]]
|
||||
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
|
||||
ListOfPoints: TypeAlias = list[list[list[list[int]]]]
|
||||
"""A list of points in SAM2 4D format: [[[[x, y]]]] (image_dim, object_dim, point_per_object_dim, coordinates)"""
|
||||
ListOfPointLabels: TypeAlias = list[list[list[int]]]
|
||||
"""A list of SAM2 point labels in 3D format: [[[label]]] (image_dim, object_dim, point_label)"""
|
||||
|
||||
|
||||
class SegmentAnything2Pipeline(RawModel):
|
||||
"""A wrapper class for the transformers SAM2 model and processor that makes it compatible with the model manager."""
|
||||
|
||||
def __init__(self, sam2_model: Sam2Model, sam2_processor):
|
||||
def __init__(self, sam2_model: Sam2Model, sam2_processor: Sam2Processor):
|
||||
"""Initialize the SAM2 pipeline.
|
||||
|
||||
Args:
|
||||
@@ -45,9 +39,7 @@ class SegmentAnything2Pipeline(RawModel):
|
||||
def segment(
|
||||
self,
|
||||
image: Image.Image,
|
||||
bounding_boxes: list[list[int]] | None = None,
|
||||
point_lists: list[list[list[int]]] | None = None,
|
||||
point_labels: list[list[int]] | None = None,
|
||||
inputs: list[SAMInput],
|
||||
) -> torch.Tensor:
|
||||
"""Segment an image using the SAM2 model.
|
||||
|
||||
@@ -65,37 +57,57 @@ class SegmentAnything2Pipeline(RawModel):
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
"""
|
||||
|
||||
# Prep the inputs for SAM2:
|
||||
# - SAM2 expects 4D input format: [[[[x, y]]]] for points and [[[label]]] for labels
|
||||
# - Bounding boxes remain in 2D format: [[x_min, y_min, x_max, y_max]]
|
||||
if bounding_boxes:
|
||||
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
|
||||
input_points: list[ListOfPoints] | None = None
|
||||
input_labels: list[ListOfPointLabels] | None = None
|
||||
elif point_lists and point_labels:
|
||||
input_boxes: list[ListOfBoundingBoxes] | None = None
|
||||
input_points = point_lists
|
||||
input_labels = point_labels
|
||||
else:
|
||||
raise ValueError("Either bounding_boxes or (point_lists AND point_labels) must be provided.")
|
||||
input_points: list[list[list[float]]] = []
|
||||
input_labels: list[list[int]] = []
|
||||
input_boxes: list[list[float]] = []
|
||||
|
||||
# Process the inputs using the SAM2 processor
|
||||
inputs = self._sam2_processor(
|
||||
for i in inputs:
|
||||
box: list[float] | None = None
|
||||
points: list[list[float]] | None = None
|
||||
labels: list[int] | None = None
|
||||
|
||||
if i.bounding_box is not None:
|
||||
box: list[float] | None = [
|
||||
i.bounding_box.x_min,
|
||||
i.bounding_box.y_min,
|
||||
i.bounding_box.x_max,
|
||||
i.bounding_box.y_max,
|
||||
]
|
||||
|
||||
if i.points is not None:
|
||||
points = []
|
||||
labels = []
|
||||
for point in i.points:
|
||||
points.append([point.x, point.y])
|
||||
labels.append(point.label.value)
|
||||
|
||||
if box is not None:
|
||||
input_boxes.append(box)
|
||||
if points is not None:
|
||||
input_points.append(points)
|
||||
if labels is not None:
|
||||
input_labels.append(labels)
|
||||
|
||||
batched_input_boxes = [input_boxes] if input_boxes else None
|
||||
batched_input_points = [input_points] if input_points else None
|
||||
batched_input_labels = [input_labels] if input_labels else None
|
||||
|
||||
processed_inputs = self._sam2_processor(
|
||||
images=image,
|
||||
input_boxes=input_boxes,
|
||||
input_points=input_points,
|
||||
input_labels=input_labels,
|
||||
input_boxes=batched_input_boxes,
|
||||
input_points=batched_input_points,
|
||||
input_labels=batched_input_labels,
|
||||
return_tensors="pt",
|
||||
).to(self._sam2_model.device)
|
||||
|
||||
# Generate masks using the SAM2 model
|
||||
outputs = self._sam2_model(**inputs)
|
||||
outputs = self._sam2_model(**processed_inputs)
|
||||
|
||||
# Post-process the masks to get the final segmentation
|
||||
masks = self._sam2_processor.post_process_masks(
|
||||
masks=outputs.pred_masks,
|
||||
original_sizes=inputs.original_sizes,
|
||||
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
||||
original_sizes=processed_inputs.original_sizes,
|
||||
reshaped_input_sizes=processed_inputs.reshaped_input_sizes,
|
||||
)
|
||||
|
||||
# There should be only one batch.
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
from typing import Optional, TypeAlias
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.backend.image_util.segment_anything.shared import SAMInput
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
# Type aliases for the inputs to the SAM model.
|
||||
ListOfBoundingBoxes: TypeAlias = list[list[int]]
|
||||
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
|
||||
ListOfPoints: TypeAlias = list[list[int]]
|
||||
"""A list of points. Each point is in the format [x, y]."""
|
||||
ListOfPointLabels: TypeAlias = list[int]
|
||||
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
|
||||
|
||||
|
||||
class SegmentAnythingPipeline(RawModel):
|
||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||
@@ -38,8 +31,7 @@ class SegmentAnythingPipeline(RawModel):
|
||||
def segment(
|
||||
self,
|
||||
image: Image.Image,
|
||||
bounding_boxes: list[list[int]] | None = None,
|
||||
point_lists: list[list[list[int]]] | None = None,
|
||||
inputs: list[SAMInput],
|
||||
) -> torch.Tensor:
|
||||
"""Run the SAM model.
|
||||
|
||||
@@ -57,36 +49,49 @@ class SegmentAnythingPipeline(RawModel):
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
"""
|
||||
|
||||
# Prep the inputs:
|
||||
# - Create a list of bounding boxes or points and labels.
|
||||
# - Add a batch dimension of 1 to the inputs.
|
||||
if bounding_boxes:
|
||||
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
|
||||
input_points: list[ListOfPoints] | None = None
|
||||
input_labels: list[ListOfPointLabels] | None = None
|
||||
elif point_lists:
|
||||
input_boxes: list[ListOfBoundingBoxes] | None = None
|
||||
input_points: list[ListOfPoints] | None = []
|
||||
input_labels: list[ListOfPointLabels] | None = []
|
||||
for point_list in point_lists:
|
||||
input_points.append([[p[0], p[1]] for p in point_list])
|
||||
input_labels.append([p[2] for p in point_list])
|
||||
input_boxes: list[list[list[float]]] = []
|
||||
input_points: list[list[list[float]]] = []
|
||||
input_labels: list[list[int]] = []
|
||||
|
||||
else:
|
||||
raise ValueError("Either bounding_boxes or points and labels must be provided.")
|
||||
for i in inputs:
|
||||
box: list[float] | None = None
|
||||
points: list[list[float]] | None = None
|
||||
labels: list[int] | None = None
|
||||
|
||||
inputs = self._sam_processor(
|
||||
if i.bounding_box is not None:
|
||||
box: list[float] | None = [
|
||||
i.bounding_box.x_min,
|
||||
i.bounding_box.y_min,
|
||||
i.bounding_box.x_max,
|
||||
i.bounding_box.y_max,
|
||||
]
|
||||
|
||||
if i.points is not None:
|
||||
points = []
|
||||
labels = []
|
||||
for point in i.points:
|
||||
points.append([point.x, point.y])
|
||||
labels.append(point.label.value)
|
||||
|
||||
if box is not None:
|
||||
input_boxes.append([box])
|
||||
if points is not None:
|
||||
input_points.append(points)
|
||||
if labels is not None:
|
||||
input_labels.append(labels)
|
||||
|
||||
processed_inputs = self._sam_processor(
|
||||
images=image,
|
||||
input_boxes=input_boxes,
|
||||
input_points=input_points,
|
||||
input_labels=input_labels,
|
||||
input_boxes=input_boxes if input_boxes else None,
|
||||
input_points=input_points if input_points else None,
|
||||
input_labels=input_labels if input_labels else None,
|
||||
return_tensors="pt",
|
||||
).to(self._sam_model.device)
|
||||
outputs = self._sam_model(**inputs)
|
||||
outputs = self._sam_model(**processed_inputs)
|
||||
masks = self._sam_processor.post_process_masks(
|
||||
masks=outputs.pred_masks,
|
||||
original_sizes=inputs.original_sizes,
|
||||
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
||||
original_sizes=processed_inputs.original_sizes,
|
||||
reshaped_input_sizes=processed_inputs.reshaped_input_sizes,
|
||||
)
|
||||
|
||||
# There should be only one batch.
|
||||
|
||||
49
invokeai/backend/image_util/segment_anything/shared.py
Normal file
49
invokeai/backend/image_util/segment_anything/shared.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_coords(self):
|
||||
if self.x_min > self.x_max:
|
||||
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
|
||||
if self.y_min > self.y_max:
|
||||
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
|
||||
return self
|
||||
|
||||
def tuple(self) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
|
||||
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
|
||||
"""
|
||||
return (self.x_min, self.y_min, self.x_max, self.y_max)
|
||||
|
||||
|
||||
class SAMPointLabel(Enum):
|
||||
negative = -1
|
||||
neutral = 0
|
||||
positive = 1
|
||||
|
||||
|
||||
class SAMPoint(BaseModel):
|
||||
x: int = Field(..., description="The x-coordinate of the point")
|
||||
y: int = Field(..., description="The y-coordinate of the point")
|
||||
label: SAMPointLabel = Field(..., description="The label of the point")
|
||||
|
||||
|
||||
class SAMInput(BaseModel):
|
||||
bounding_box: BoundingBox | None = Field(None, description="The bounding box to use for segmentation")
|
||||
points: list[SAMPoint] | None = Field(None, description="The points to use for segmentation")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_input(self):
|
||||
if not self.bounding_box and not self.points:
|
||||
raise ValueError("Either bounding_box or points must be provided")
|
||||
return self
|
||||
Reference in New Issue
Block a user