fix(backend): issue w/ multiple bbox and sam1

This commit is contained in:
psychedelicious
2025-09-10 19:21:59 +10:00
parent 9e4d441e2e
commit ec1a058dbe

View File

@@ -49,7 +49,7 @@ class SegmentAnythingPipeline(RawModel):
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
input_boxes: list[list[list[float]]] = []
input_boxes: list[list[float]] = []
input_points: list[list[list[float]]] = []
input_labels: list[list[int]] = []
@@ -74,7 +74,7 @@ class SegmentAnythingPipeline(RawModel):
labels.append(point.label.value)
if box is not None:
input_boxes.append([box])
input_boxes.append(box)
if points is not None:
input_points.append(points)
if labels is not None:
@@ -82,7 +82,7 @@ class SegmentAnythingPipeline(RawModel):
processed_inputs = self._sam_processor(
images=image,
input_boxes=input_boxes if input_boxes else None,
input_boxes=[input_boxes] if input_boxes else None,
input_points=input_points if input_points else None,
input_labels=input_labels if input_labels else None,
return_tensors="pt",