Files
InvokeAI/invokeai/backend/image_util/pidi/__init__.py

81 lines
2.5 KiB
Python

# Adapted from https://github.com/huggingface/controlnet_aux
import pathlib
import cv2
import huggingface_hub
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class PIDINetDetector:
"""Simple wrapper around a PiDiNet model for edge detection."""
hf_repo_id = "lllyasviel/Annotators"
hf_filename = "table5_pidinet.pth"
@classmethod
def get_model_url(cls) -> str:
"""Get the URL to download the model from the Hugging Face Hub."""
return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename)
@classmethod
def load_model(cls, model_path: pathlib.Path) -> PiDiNet:
"""Load the model from a file."""
model = pidinet()
model.load_state_dict({k.replace("module.", ""): v for k, v in torch.load(model_path)["state_dict"].items()})
model.eval()
return model
def __init__(self, model: PiDiNet) -> None:
self.model = model
def to(self, device: torch.device):
self.model.to(device)
return self
def run(
self, image: Image.Image, quantize_edges: bool = False, scribble: bool = False, apply_filter: bool = False
) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = get_effective_device(self.model)
np_img = pil_to_np(image)
np_img = normalize_image_channel_count(np_img)
assert np_img.ndim == 3
bgr_img = np_img[:, :, ::-1].copy()
with torch.no_grad():
image_pidi = torch.from_numpy(bgr_img).float().to(device)
image_pidi = image_pidi / 255.0
image_pidi = rearrange(image_pidi, "h w c -> 1 c h w")
edge = self.model(image_pidi)[-1]
edge = edge.cpu().numpy()
if apply_filter:
edge = edge > 0.5
if quantize_edges:
edge = safe_step(edge)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
detected_map = edge[0, 0]
if scribble:
detected_map = nms(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0
output_img = np_to_pil(detected_map)
return output_img