tidy(backend) cleanup sam pipelines

This commit is contained in:
psychedelicious
2025-09-10 19:25:53 +10:00
parent 03ae78bc7c
commit f8ad62b5eb
2 changed files with 14 additions and 22 deletions

View File

@@ -41,25 +41,19 @@ class SegmentAnything2Pipeline(RawModel):
image: Image.Image,
inputs: list[SAMInput],
) -> torch.Tensor:
"""Segment an image using the SAM2 model.
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
point_lists will be ignored.
"""Segment the image using the provided inputs.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
image: The image to segment.
inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
input_boxes: list[list[float]] = []
input_points: list[list[list[float]]] = []
input_labels: list[list[int]] = []
input_boxes: list[list[float]] = []
for i in inputs:
box: list[float] | None = None

View File

@@ -33,17 +33,11 @@ class SegmentAnythingPipeline(RawModel):
image: Image.Image,
inputs: list[SAMInput],
) -> torch.Tensor:
"""Run the SAM model.
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
point_lists will be ignored.
"""Segment the image using the provided inputs.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
image: The image to segment.
inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
@@ -80,11 +74,15 @@ class SegmentAnythingPipeline(RawModel):
if labels is not None:
input_labels.append(labels)
batched_input_boxes = [input_boxes] if input_boxes else None
batched_input_points = input_points if input_points else None
batched_input_labels = input_labels if input_labels else None
processed_inputs = self._sam_processor(
images=image,
input_boxes=[input_boxes] if input_boxes else None,
input_points=input_points if input_points else None,
input_labels=input_labels if input_labels else None,
input_boxes=batched_input_boxes,
input_points=batched_input_points,
input_labels=batched_input_labels,
return_tensors="pt",
).to(self._sam_model.device)
outputs = self._sam_model(**processed_inputs)