mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 01:58:01 -05:00
Compare commits
9 Commits
controlnet
...
lstein/cho
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3564c8f28c | ||
|
|
ba8f06c285 | ||
|
|
a40fa8e83b | ||
|
|
1e357bd21b | ||
|
|
cac36d9327 | ||
|
|
08d7bd2a0b | ||
|
|
c6dcbce043 | ||
|
|
af274bedc1 | ||
|
|
b000bc2f58 |
@@ -1607,7 +1607,7 @@ model configuration to `load_model_by_config()`. It may raise a
|
||||
Within invocations, the following methods are available from the
|
||||
`InvocationContext` object:
|
||||
|
||||
### context.download_and_cache_model(source) -> Path
|
||||
### context.download_and_cache_model(source, [preserve_subfolders=False]) -> Path
|
||||
|
||||
This method accepts a `source` of a remote model, downloads and caches
|
||||
it locally, and then returns a Path to the local model. The source can
|
||||
@@ -1626,6 +1626,16 @@ directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
When requesting a huggingface repo, if the requested file(s) live in a
|
||||
nested subfolder, the nesting information will be discarded and the
|
||||
file(s) will be placed in the top level of the returned
|
||||
directory. Thus, when requesting
|
||||
`stabilityai/stable-diffusion-v4::vae`, the contents of `vae` will be
|
||||
found at the top level of the returned path and not in a subdirectory.
|
||||
This behavior can be changed by passing `preserve_subfolders=True`,
|
||||
which will preserve the subfolder structure and return the path to the
|
||||
subdirectory.
|
||||
|
||||
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||
|
||||
This method loads a local model from the indicated path, returning a
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
from typing import ClassVar, Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -46,10 +46,11 @@ from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNE
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.hed import HED_MODEL, HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import COARSE_MODEL, LINEART_MODEL, LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LINEART_ANIME_MODEL, LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@@ -137,6 +138,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
|
||||
CONTROLNET_PROCESSORS: ClassVar[Dict[type, str]] = {
|
||||
MidasDetector: "lllyasviel/Annotators::/dpt_hybrid-midas-501f0c75.pt",
|
||||
MLSDdetector: "lllyasviel/Annotators::/mlsd_large_512_fp32.pth",
|
||||
PidiNetDetector: "lllyasviel/Annotators::/table5_pidinet.pth",
|
||||
ZoeDetector: "lllyasviel/Annotators::/ZoeD_M12_N.pt",
|
||||
}
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
@@ -145,6 +154,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# allows override for any special formatting specific to the preprocessor
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
|
||||
remote_source = self.CONTROLNET_PROCESSORS[processor]
|
||||
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
|
||||
model = self._context.models.load_remote_model(
|
||||
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)
|
||||
)
|
||||
return model
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
@@ -218,15 +235,18 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
hed_processor = HEDProcessor()
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
hed_weights = self._context.models.load_remote_model(HED_MODEL)
|
||||
with hed_weights as weights:
|
||||
assert isinstance(weights, dict)
|
||||
hed_processor = HEDProcessor(weights)
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@@ -245,10 +265,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
lineart_processor = LineartProcessor()
|
||||
processed_image = lineart_processor.run(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
)
|
||||
model_info = self._context.models.load_remote_model(LINEART_MODEL)
|
||||
model_coarse_info = self._context.models.load_remote_model(COARSE_MODEL)
|
||||
with model_info as model_sd, model_coarse_info as coarse_sd:
|
||||
lineart_processor = LineartProcessor(model_sd=model_sd, coarse_sd=coarse_sd)
|
||||
processed_image = lineart_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
coarse=self.coarse,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@@ -266,12 +292,13 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processor = LineartAnimeProcessor()
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
with self._context.models.load_remote_model(LINEART_ANIME_MODEL) as model_sd:
|
||||
processor = LineartAnimeProcessor(model_sd)
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@@ -293,17 +320,17 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
with self.load_processor(MidasDetector) as midas_processor:
|
||||
assert isinstance(midas_processor, MidasDetector)
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@@ -321,10 +348,11 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
with self.load_processor(NormalBaeDetector) as normalbae_processor:
|
||||
assert isinstance(normalbae_processor, NormalBaeDetector)
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@@ -340,14 +368,15 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
with self.load_processor(MLSDdetector) as mlsd_processor:
|
||||
assert isinstance(mlsd_processor, MLSDdetector)
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@@ -363,14 +392,15 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
with self.load_processor(PidiNetDetector) as pidi_processor:
|
||||
assert isinstance(pidi_processor, PidiNetDetector)
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@@ -415,8 +445,9 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
with self.load_processor(ZoeDetector) as zoe_depth_processor:
|
||||
assert isinstance(zoe_depth_processor, ZoeDetector)
|
||||
processed_image: Image.Image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@@ -464,6 +495,8 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# LeresDetector requires two hard-coded models, which breaks the load_processor() pattern.
|
||||
# TODO (LS): Modify download_and_cache() to accept multiple downloaded checkpoint files.
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
@@ -530,14 +563,17 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
model_path = self._context.models.download_and_cache_model(
|
||||
source="ybelkada/segment-anything::/checkpoints/sam_vit_h_4b8939.pth", preserve_subfolders=True
|
||||
)
|
||||
with self._context.models.load_local_model(
|
||||
model_path, loader=lambda x: SamDetectorReproducibleColors.from_pretrained(x)
|
||||
) as segment_anything_processor:
|
||||
assert isinstance(segment_anything_processor, SamDetectorReproducibleColors)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image: Image.Image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
|
||||
@@ -243,11 +243,17 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
preserve_subfolders: bool = False,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
:param source: A string representing a URL or repo_id.
|
||||
:param preserve_subfolders: (optional) If True, the subfolder hierarchy will be preserved;
|
||||
otherwise flattened.
|
||||
|
||||
The model file will be downloaded into the system-wide model cache
|
||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||
|
||||
@@ -372,8 +372,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
preserve_subfolders: bool = False,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_source = self._guess_source(str(source))
|
||||
model_path = self._download_cache_path(str(source), self._app_config)
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||
@@ -385,12 +387,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model_source = self._guess_source(str(source))
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
job = self._multifile_download(
|
||||
dest=model_path,
|
||||
remote_files=remote_files,
|
||||
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||
preserve_subfolders=preserve_subfolders,
|
||||
)
|
||||
files_string = "file" if len(remote_files) == 1 else "files"
|
||||
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||
@@ -772,15 +774,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
subfolder: Optional[Path] = None,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
preserve_subfolders: bool = False,
|
||||
) -> MultiFileDownloadJob:
|
||||
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
# The exception is when preserve_subfolders is true, in which case we keep the hierarchy
|
||||
# of subfolders and return the path to the last enclosing subfolder.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
if preserve_subfolders:
|
||||
path_to_remove = remote_files[0].path.parts[0]
|
||||
path_to_add = Path(".")
|
||||
else:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
||||
@@ -447,6 +447,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
preserve_subfolders: bool = False,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
@@ -457,11 +458,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
Args:
|
||||
source: A URL that points to the model, or a huggingface repo_id.
|
||||
preserve_subfolders: (optional, False) If True, then preserve subfolder structure.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
Path to the downloaded model (file or directory)
|
||||
"""
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
return self._services.model_manager.install.download_and_cache_model(
|
||||
source=source, preserve_subfolders=preserve_subfolders
|
||||
)
|
||||
|
||||
def load_local_model(
|
||||
self,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.util import (
|
||||
@@ -16,6 +17,8 @@ from invokeai.backend.image_util.util import (
|
||||
safe_step,
|
||||
)
|
||||
|
||||
HED_MODEL = "lllyasviel/Annotators::/ControlNetHED.pth"
|
||||
|
||||
|
||||
class DoubleConvBlock(torch.nn.Module):
|
||||
def __init__(self, input_channel, output_channel, layer_number):
|
||||
@@ -76,16 +79,11 @@ class HEDProcessor:
|
||||
On instantiation, loads the HED model from the HuggingFace Hub.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
model_path = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
self.network = ControlNetHED_Apache2()
|
||||
self.network.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
self.network.load_state_dict(state_dict)
|
||||
self.network.float().eval()
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.network.to(device)
|
||||
return self
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_image: Image.Image,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.util import (
|
||||
@@ -15,6 +16,9 @@ from invokeai.backend.image_util.util import (
|
||||
resize_image_to_resolution,
|
||||
)
|
||||
|
||||
LINEART_MODEL = "lllyasviel/Annotators::/sk_model.pth"
|
||||
COARSE_MODEL = "lllyasviel/Annotators::/sk_model2.pth"
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_features):
|
||||
@@ -97,22 +101,15 @@ class Generator(nn.Module):
|
||||
class LineartProcessor:
|
||||
"""Processor for lineart detection."""
|
||||
|
||||
def __init__(self):
|
||||
model_path = hf_hub_download("lllyasviel/Annotators", "sk_model.pth")
|
||||
def __init__(self, model_sd: Dict[str, torch.Tensor], coarse_sd: Dict[str, torch.Tensor]):
|
||||
self.model = Generator(3, 1, 3)
|
||||
self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
||||
self.model.load_state_dict(model_sd)
|
||||
self.model.eval()
|
||||
|
||||
coarse_model_path = hf_hub_download("lllyasviel/Annotators", "sk_model2.pth")
|
||||
self.model_coarse = Generator(3, 1, 3)
|
||||
self.model_coarse.load_state_dict(torch.load(coarse_model_path, map_location=torch.device("cpu")))
|
||||
self.model_coarse.load_state_dict(coarse_sd)
|
||||
self.model_coarse.eval()
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.model.to(device)
|
||||
self.model_coarse.to(device)
|
||||
return self
|
||||
|
||||
def run(
|
||||
self, input_image: Image.Image, coarse: bool = False, detect_resolution: int = 512, image_resolution: int = 512
|
||||
) -> Image.Image:
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
|
||||
import functools
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.util import (
|
||||
@@ -18,6 +17,8 @@ from invokeai.backend.image_util.util import (
|
||||
resize_image_to_resolution,
|
||||
)
|
||||
|
||||
LINEART_ANIME_MODEL = "lllyasviel/Annotators::/netG.pth"
|
||||
|
||||
|
||||
class UnetGenerator(nn.Module):
|
||||
"""Create a Unet-based generator"""
|
||||
@@ -142,16 +143,14 @@ class UnetSkipConnectionBlock(nn.Module):
|
||||
class LineartAnimeProcessor:
|
||||
"""Processes an image to detect lineart."""
|
||||
|
||||
def __init__(self):
|
||||
model_path = hf_hub_download("lllyasviel/Annotators", "netG.pth")
|
||||
def __init__(self, model_sd: Dict[str, torch.Tensor]):
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
self.model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
||||
ckpt = torch.load(model_path)
|
||||
for key in list(ckpt.keys()):
|
||||
for key in list(model_sd.keys()):
|
||||
if "module." in key:
|
||||
ckpt[key.replace("module.", "")] = ckpt[key]
|
||||
del ckpt[key]
|
||||
self.model.load_state_dict(ckpt)
|
||||
model_sd[key.replace("module.", "")] = model_sd[key]
|
||||
del model_sd[key]
|
||||
self.model.load_state_dict(model_sd)
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device: torch.device):
|
||||
|
||||
@@ -292,7 +292,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(target_device, copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
try:
|
||||
cache_entry.model.to(target_device)
|
||||
except TypeError as e:
|
||||
if "got an unexpected keyword argument 'non_blocking'" in str(e):
|
||||
cache_entry.model.to(target_device)
|
||||
else:
|
||||
raise e
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
@@ -30,12 +31,14 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||
return _calc_onnx_model_by_data(model)
|
||||
elif isinstance(model, SchedulerMixin):
|
||||
return 0
|
||||
assert hasattr(model, "config") # size is dominated by config
|
||||
return sys.getsizeof(model.config)
|
||||
elif isinstance(model, CLIPTokenizer):
|
||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||
return 0
|
||||
return sys.getsizeof(model.get_vocab()) # size is dominated by the vocab dict
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
||||
return model.calc_size()
|
||||
elif isinstance(model, dict):
|
||||
return _calc_size_from_dict(model, logger)
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
# supported model types.
|
||||
@@ -70,6 +73,19 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
|
||||
return mem
|
||||
|
||||
|
||||
def _calc_size_from_dict(model: dict[str, Any] | torch.Tensor | torch.nn.Module, logger: logging.Logger) -> int:
|
||||
total = sys.getsizeof(model) # get python overhead for object
|
||||
if isinstance(model, dict):
|
||||
total += sum(_calc_size_from_dict(model[x], logger) for x in model.keys())
|
||||
elif isinstance(model, torch.Tensor):
|
||||
total += model.element_size() * model.nelement()
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
total += calc_module_size(model)
|
||||
else:
|
||||
logger.warning(f"Failed to calculate model size for unexpected model type: {type(model)}.")
|
||||
return total
|
||||
|
||||
|
||||
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
|
||||
"""Estimate the size of a model on disk in bytes."""
|
||||
if model_path.is_file():
|
||||
|
||||
@@ -26,7 +26,7 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
value = unicodedata.normalize("NFKC", value)
|
||||
else:
|
||||
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
||||
value = re.sub(r"[/]", "_", value.lower())
|
||||
value = re.sub(r"[/:]+", "_", value.lower())
|
||||
value = re.sub(r"[^.\w\s-]", "", value.lower())
|
||||
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
||||
|
||||
|
||||
@@ -81,8 +81,21 @@ def test_download_diffusers(mock_context: InvocationContext) -> None:
|
||||
|
||||
|
||||
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
|
||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::/vae")
|
||||
assert model_path.is_dir()
|
||||
assert model_path.name != "vae" # will not create the vae subfolder with preserve_subfolders False
|
||||
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
||||
model_path / "diffusion_pytorch_model.safetensors"
|
||||
).exists()
|
||||
|
||||
|
||||
def test_download_diffusers_preserve_subfolders(mock_context: InvocationContext) -> None:
|
||||
model_path = mock_context.models.download_and_cache_model(
|
||||
"stabilityai/sdxl-turbo::/vae",
|
||||
preserve_subfolders=True,
|
||||
)
|
||||
assert model_path.is_dir()
|
||||
assert model_path.name == "vae" # will create the vae subfolder with preserve_subfolders True
|
||||
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
||||
model_path / "diffusion_pytorch_model.safetensors"
|
||||
).exists()
|
||||
|
||||
Reference in New Issue
Block a user