mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-03 20:05:23 -05:00
95 lines
3.9 KiB
Python
95 lines
3.9 KiB
Python
from typing import Optional, TypeAlias
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from transformers.models.sam import SamModel
|
|
from transformers.models.sam.processing_sam import SamProcessor
|
|
|
|
from invokeai.backend.raw_model import RawModel
|
|
|
|
# Type aliases for the inputs to the SAM model.
|
|
ListOfBoundingBoxes: TypeAlias = list[list[int]]
|
|
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
|
|
ListOfPoints: TypeAlias = list[list[int]]
|
|
"""A list of points. Each point is in the format [x, y]."""
|
|
ListOfPointLabels: TypeAlias = list[int]
|
|
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
|
|
|
|
|
|
class SegmentAnythingPipeline(RawModel):
|
|
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
|
|
|
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
|
self._sam_model = sam_model
|
|
self._sam_processor = sam_processor
|
|
|
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
|
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
|
device = None
|
|
self._sam_model.to(device=device, dtype=dtype)
|
|
|
|
def calc_size(self) -> int:
|
|
# HACK(ryand): Fix the circular import issue.
|
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
|
|
|
return calc_module_size(self._sam_model)
|
|
|
|
def segment(
|
|
self,
|
|
image: Image.Image,
|
|
bounding_boxes: list[list[int]] | None = None,
|
|
point_lists: list[list[list[int]]] | None = None,
|
|
) -> torch.Tensor:
|
|
"""Run the SAM model.
|
|
|
|
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
|
|
point_lists will be ignored.
|
|
|
|
Args:
|
|
image (Image.Image): The image to segment.
|
|
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
|
[xmin, ymin, xmax, ymax].
|
|
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
|
|
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
|
|
|
|
Returns:
|
|
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
|
"""
|
|
|
|
# Prep the inputs:
|
|
# - Create a list of bounding boxes or points and labels.
|
|
# - Add a batch dimension of 1 to the inputs.
|
|
if bounding_boxes:
|
|
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
|
|
input_points: list[ListOfPoints] | None = None
|
|
input_labels: list[ListOfPointLabels] | None = None
|
|
elif point_lists:
|
|
input_boxes: list[ListOfBoundingBoxes] | None = None
|
|
input_points: list[ListOfPoints] | None = []
|
|
input_labels: list[ListOfPointLabels] | None = []
|
|
for point_list in point_lists:
|
|
input_points.append([[p[0], p[1]] for p in point_list])
|
|
input_labels.append([p[2] for p in point_list])
|
|
|
|
else:
|
|
raise ValueError("Either bounding_boxes or points and labels must be provided.")
|
|
|
|
inputs = self._sam_processor(
|
|
images=image,
|
|
input_boxes=input_boxes,
|
|
input_points=input_points,
|
|
input_labels=input_labels,
|
|
return_tensors="pt",
|
|
).to(self._sam_model.device)
|
|
outputs = self._sam_model(**inputs)
|
|
masks = self._sam_processor.post_process_masks(
|
|
masks=outputs.pred_masks,
|
|
original_sizes=inputs.original_sizes,
|
|
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
|
)
|
|
|
|
# There should be only one batch.
|
|
assert len(masks) == 1
|
|
return masks[0]
|