mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tidy(backend) cleanup sam pipelines
This commit is contained in:
@@ -41,25 +41,19 @@ class SegmentAnything2Pipeline(RawModel):
|
||||
image: Image.Image,
|
||||
inputs: list[SAMInput],
|
||||
) -> torch.Tensor:
|
||||
"""Segment an image using the SAM2 model.
|
||||
|
||||
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
|
||||
point_lists will be ignored.
|
||||
"""Segment the image using the provided inputs.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to segment.
|
||||
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||
[xmin, ymin, xmax, ymax].
|
||||
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
|
||||
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
|
||||
image: The image to segment.
|
||||
inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
"""
|
||||
|
||||
input_boxes: list[list[float]] = []
|
||||
input_points: list[list[list[float]]] = []
|
||||
input_labels: list[list[int]] = []
|
||||
input_boxes: list[list[float]] = []
|
||||
|
||||
for i in inputs:
|
||||
box: list[float] | None = None
|
||||
|
||||
@@ -33,17 +33,11 @@ class SegmentAnythingPipeline(RawModel):
|
||||
image: Image.Image,
|
||||
inputs: list[SAMInput],
|
||||
) -> torch.Tensor:
|
||||
"""Run the SAM model.
|
||||
|
||||
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
|
||||
point_lists will be ignored.
|
||||
"""Segment the image using the provided inputs.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to segment.
|
||||
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||
[xmin, ymin, xmax, ymax].
|
||||
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
|
||||
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
|
||||
image: The image to segment.
|
||||
inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
@@ -80,11 +74,15 @@ class SegmentAnythingPipeline(RawModel):
|
||||
if labels is not None:
|
||||
input_labels.append(labels)
|
||||
|
||||
batched_input_boxes = [input_boxes] if input_boxes else None
|
||||
batched_input_points = input_points if input_points else None
|
||||
batched_input_labels = input_labels if input_labels else None
|
||||
|
||||
processed_inputs = self._sam_processor(
|
||||
images=image,
|
||||
input_boxes=[input_boxes] if input_boxes else None,
|
||||
input_points=input_points if input_points else None,
|
||||
input_labels=input_labels if input_labels else None,
|
||||
input_boxes=batched_input_boxes,
|
||||
input_points=batched_input_points,
|
||||
input_labels=batched_input_labels,
|
||||
return_tensors="pt",
|
||||
).to(self._sam_model.device)
|
||||
outputs = self._sam_model(**processed_inputs)
|
||||
|
||||
Reference in New Issue
Block a user