feat(nodes): update pidinet node

Human-readable field names.
This commit is contained in:
psychedelicious
2024-09-11 17:43:55 +10:00
committed by Kent Keirsey
parent a4250e3ff2
commit ee4c0efbf7
2 changed files with 4 additions and 4 deletions

View File

@@ -17,7 +17,7 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an edge map using PiDiNet.""" """Generates an edge map using PiDiNet."""
image: ImageField = InputField(description="The image to process") image: ImageField = InputField(description="The image to process")
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) quantize_edges: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -27,7 +27,7 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
with loaded_model as model: with loaded_model as model:
assert isinstance(model, PiDiNet) assert isinstance(model, PiDiNet)
detector = PIDINetDetector(model) detector = PIDINetDetector(model)
edge_map = detector.run(image=image, safe=self.safe, scribble=self.scribble) edge_map = detector.run(image=image, quantize_edges=self.quantize_edges, scribble=self.scribble)
image_dto = context.images.save(image=edge_map) image_dto = context.images.save(image=edge_map)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)

View File

@@ -41,7 +41,7 @@ class PIDINetDetector:
return self return self
def run( def run(
self, image: Image.Image, safe: bool = False, scribble: bool = False, apply_filter: bool = False self, image: Image.Image, quantize_edges: bool = False, scribble: bool = False, apply_filter: bool = False
) -> Image.Image: ) -> Image.Image:
"""Processes an image and returns the detected edges.""" """Processes an image and returns the detected edges."""
@@ -62,7 +62,7 @@ class PIDINetDetector:
edge = edge.cpu().numpy() edge = edge.cpu().numpy()
if apply_filter: if apply_filter:
edge = edge > 0.5 edge = edge > 0.5
if safe: if quantize_edges:
edge = safe_step(edge) edge = safe_step(edge)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8) edge = (edge * 255.0).clip(0, 255).astype(np.uint8)