Files
InvokeAI/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py

95 lines
3.9 KiB
Python

from typing import Optional, TypeAlias
import torch
from PIL import Image
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.raw_model import RawModel
# Type aliases for the inputs to the SAM model.
ListOfBoundingBoxes: TypeAlias = list[list[int]]
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
ListOfPoints: TypeAlias = list[list[int]]
"""A list of points. Each point is in the format [x, y]."""
ListOfPointLabels: TypeAlias = list[int]
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
self._sam_model = sam_model
self._sam_processor = sam_processor
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._sam_model.to(device=device, dtype=dtype)
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._sam_model)
def segment(
self,
image: Image.Image,
bounding_boxes: list[list[int]] | None = None,
point_lists: list[list[list[int]]] | None = None,
) -> torch.Tensor:
"""Run the SAM model.
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
point_lists will be ignored.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
# Prep the inputs:
# - Create a list of bounding boxes or points and labels.
# - Add a batch dimension of 1 to the inputs.
if bounding_boxes:
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
input_points: list[ListOfPoints] | None = None
input_labels: list[ListOfPointLabels] | None = None
elif point_lists:
input_boxes: list[ListOfBoundingBoxes] | None = None
input_points: list[ListOfPoints] | None = []
input_labels: list[ListOfPointLabels] | None = []
for point_list in point_lists:
input_points.append([[p[0], p[1]] for p in point_list])
input_labels.append([p[2] for p in point_list])
else:
raise ValueError("Either bounding_boxes or points and labels must be provided.")
inputs = self._sam_processor(
images=image,
input_boxes=input_boxes,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt",
).to(self._sam_model.device)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)
# There should be only one batch.
assert len(masks) == 1
return masks[0]