mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Improve MaskOutput dimension consistency (#7591)
## Summary This PR fixes an issue with mask dimension consistency. Prior to this change, the following workflow would fail with `tuple out of range` error: <img width="1072" alt="image" src="https://github.com/user-attachments/assets/d0a9e658-1d64-4db4-adee-973bbdaca745" /> ### Before this PR Dimension compatibility for invocations that take a mask input: - `ApplyMaskTensorToImageInvocation`: 2 or 3 - `MaskTensorToImageInvocation`: 2 or 3 - `InvertTensorMaskInvocation`: 3 Mask dimension for invocations that produce a MaskOutput: - `RectangleMaskInvocation`: 3 - `AlphaMaskToTensorInvocation`: 3 - `InvertTensorMaskInvocation`: 3 - `ImageMaskToTensorInvocation`: 3 - `SegmentAnythingInvocation`: 2 ### After this PR (changes in bold) Dimension compatibility for invocations that take a mask input: - `ApplyMaskTensorToImageInvocation`: 2 or 3 - `MaskTensorToImageInvocation`: 2 or 3 - `InvertTensorMaskInvocation`: **2 or 3** <---------------- Mask dimension for invocations that produce a MaskOutput: - `RectangleMaskInvocation`: 3 - `AlphaMaskToTensorInvocation`: 3 - `InvertTensorMaskInvocation`: 3 - `ImageMaskToTensorInvocation`: 3 - `SegmentAnythingInvocation`: **3** <------------------- ## QA Instructions I tested the workflow in the PR description and this workflow: <img width="872" alt="image" src="https://github.com/user-attachments/assets/20496860-ce81-47c0-a46a-a611b73faa22" /> ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
@@ -86,7 +86,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
title="Invert Tensor Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class InvertTensorMaskInvocation(BaseInvocation):
|
||||
@@ -96,6 +96,15 @@ class InvertTensorMaskInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
|
||||
# Verify dtype and shape.
|
||||
assert mask.dtype == torch.bool
|
||||
assert mask.dim() in [2, 3]
|
||||
|
||||
# Unsqueeze the channel dimension if it is missing. The MaskOutput type expects a single channel.
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
inverted = ~mask
|
||||
|
||||
return MaskOutput(
|
||||
|
||||
@@ -416,6 +416,7 @@ class ColorInvocation(BaseInvocation):
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""A torch mask tensor."""
|
||||
|
||||
# shape: [1, H, W], dtype: bool
|
||||
mask: TensorField = OutputField(description="The mask.")
|
||||
width: int = OutputField(description="The width of the mask in pixels.")
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
@@ -49,7 +49,7 @@ class SAMPointsField(BaseModel):
|
||||
title="Segment Anything",
|
||||
tags=["prompt", "segmentation"],
|
||||
category="segmentation",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class SegmentAnythingInvocation(BaseInvocation):
|
||||
"""Runs a Segment Anything Model."""
|
||||
@@ -96,8 +96,10 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
# masks contains bool values, so we merge them via max-reduce.
|
||||
combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||
|
||||
# Unsqueeze the channel dimension.
|
||||
combined_mask = combined_mask.unsqueeze(0)
|
||||
mask_tensor_name = context.tensors.save(combined_mask)
|
||||
height, width = combined_mask.shape
|
||||
_, height, width = combined_mask.shape
|
||||
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -7716,11 +7716,6 @@ export type components = {
|
||||
* @description Gets the bounding box of the given mask image.
|
||||
*/
|
||||
GetMaskBoundingBoxInvocation: {
|
||||
/**
|
||||
* @description Optional metadata to be saved with the image
|
||||
* @default null
|
||||
*/
|
||||
metadata?: components["schemas"]["MetadataField"] | null;
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this instance of an invocation. Must be unique among all instances of invocations.
|
||||
|
||||
Reference in New Issue
Block a user