Compare commits

...

9 Commits

12 changed files with 200 additions and 106 deletions

View File

@@ -1607,7 +1607,7 @@ model configuration to `load_model_by_config()`. It may raise a
Within invocations, the following methods are available from the
`InvocationContext` object:
### context.download_and_cache_model(source) -> Path
### context.download_and_cache_model(source, [preserve_subfolders=False]) -> Path
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
@@ -1626,6 +1626,16 @@ directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
When requesting a huggingface repo, if the requested file(s) live in a
nested subfolder, the nesting information will be discarded and the
file(s) will be placed in the top level of the returned
directory. Thus, when requesting
`stabilityai/stable-diffusion-v4::vae`, the contents of `vae` will be
found at the top level of the returned path and not in a subdirectory.
This behavior can be changed by passing `preserve_subfolders=True`,
which will preserve the subfolder structure and return the path to the
subdirectory.
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method loads a local model from the indicated path, returning a

View File

@@ -3,7 +3,7 @@
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union
from typing import ClassVar, Dict, List, Literal, Union
import cv2
import numpy as np
@@ -46,10 +46,11 @@ from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNE
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.hed import HED_MODEL, HEDProcessor
from invokeai.backend.image_util.lineart import COARSE_MODEL, LINEART_MODEL, LineartProcessor
from invokeai.backend.image_util.lineart_anime import LINEART_ANIME_MODEL, LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
from invokeai.backend.util.devices import TorchDevice
@@ -137,6 +138,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
CONTROLNET_PROCESSORS: ClassVar[Dict[type, str]] = {
MidasDetector: "lllyasviel/Annotators::/dpt_hybrid-midas-501f0c75.pt",
MLSDdetector: "lllyasviel/Annotators::/mlsd_large_512_fp32.pth",
PidiNetDetector: "lllyasviel/Annotators::/table5_pidinet.pth",
ZoeDetector: "lllyasviel/Annotators::/ZoeD_M12_N.pt",
}
def run_processor(self, image: Image.Image) -> Image.Image:
# superclass just passes through image without processing
return image
@@ -145,6 +154,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
# allows override for any special formatting specific to the preprocessor
return context.images.get_pil(self.image.image_name, "RGB")
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
remote_source = self.CONTROLNET_PROCESSORS[processor]
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
model = self._context.models.load_remote_model(
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)
)
return model
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context)
@@ -218,15 +235,18 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
hed_processor = HEDProcessor()
processed_image = hed_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
hed_weights = self._context.models.load_remote_model(HED_MODEL)
with hed_weights as weights:
assert isinstance(weights, dict)
hed_processor = HEDProcessor(weights)
processed_image = hed_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
return processed_image
@@ -245,10 +265,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image: Image.Image) -> Image.Image:
lineart_processor = LineartProcessor()
processed_image = lineart_processor.run(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
)
model_info = self._context.models.load_remote_model(LINEART_MODEL)
model_coarse_info = self._context.models.load_remote_model(COARSE_MODEL)
with model_info as model_sd, model_coarse_info as coarse_sd:
lineart_processor = LineartProcessor(model_sd=model_sd, coarse_sd=coarse_sd)
processed_image = lineart_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
coarse=self.coarse,
)
return processed_image
@@ -266,12 +292,13 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
processor = LineartAnimeProcessor()
processed_image = processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
with self._context.models.load_remote_model(LINEART_ANIME_MODEL) as model_sd:
processor = LineartAnimeProcessor(model_sd)
processed_image = processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
@@ -293,17 +320,17 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
with self.load_processor(MidasDetector) as midas_processor:
assert isinstance(midas_processor, MidasDetector)
processed_image = midas_processor(
image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
return processed_image
@@ -321,10 +348,11 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
)
with self.load_processor(NormalBaeDetector) as normalbae_processor:
assert isinstance(normalbae_processor, NormalBaeDetector)
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
)
return processed_image
@@ -340,14 +368,15 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d,
)
with self.load_processor(MLSDdetector) as mlsd_processor:
assert isinstance(mlsd_processor, MLSDdetector)
processed_image = mlsd_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d,
)
return processed_image
@@ -363,14 +392,15 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble,
)
with self.load_processor(PidiNetDetector) as pidi_processor:
assert isinstance(pidi_processor, PidiNetDetector)
processed_image = pidi_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble,
)
return processed_image
@@ -415,8 +445,9 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
with self.load_processor(ZoeDetector) as zoe_depth_processor:
assert isinstance(zoe_depth_processor, ZoeDetector)
processed_image: Image.Image = zoe_depth_processor(image)
return processed_image
@@ -464,6 +495,8 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
# LeresDetector requires two hard-coded models, which breaks the load_processor() pattern.
# TODO (LS): Modify download_and_cache() to accept multiple downloaded checkpoint files.
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
@@ -530,14 +563,17 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
)
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
model_path = self._context.models.download_and_cache_model(
source="ybelkada/segment-anything::/checkpoints/sam_vit_h_4b8939.pth", preserve_subfolders=True
)
with self._context.models.load_local_model(
model_path, loader=lambda x: SamDetectorReproducibleColors.from_pretrained(x)
) as segment_anything_processor:
assert isinstance(segment_anything_processor, SamDetectorReproducibleColors)
np_img = np.array(image, dtype=np.uint8)
processed_image: Image.Image = segment_anything_processor(
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
)
return processed_image

View File

@@ -243,11 +243,17 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
:param source: A string representing a URL or repo_id.
:param preserve_subfolders: (optional) If True, the subfolder hierarchy will be preserved;
otherwise flattened.
The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache

View File

@@ -372,8 +372,10 @@ class ModelInstallService(ModelInstallServiceBase):
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_source = self._guess_source(str(source))
model_path = self._download_cache_path(str(source), self._app_config)
# We expect the cache directory to contain one and only one downloaded file or directory.
@@ -385,12 +387,12 @@ class ModelInstallService(ModelInstallServiceBase):
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download(
dest=model_path,
remote_files=remote_files,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
preserve_subfolders=preserve_subfolders,
)
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
@@ -772,15 +774,22 @@ class ModelInstallService(ModelInstallServiceBase):
subfolder: Optional[Path] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
preserve_subfolders: bool = False,
) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
# The exception is when preserve_subfolders is true, in which case we keep the hierarchy
# of subfolders and return the path to the last enclosing subfolder.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
if preserve_subfolders:
path_to_remove = remote_files[0].path.parts[0]
path_to_add = Path(".")
else:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")

View File

@@ -447,6 +447,7 @@ class ModelsInterface(InvocationContextInterface):
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
@@ -457,11 +458,14 @@ class ModelsInterface(InvocationContextInterface):
Args:
source: A URL that points to the model, or a huggingface repo_id.
preserve_subfolders: (optional, False) If True, then preserve subfolder structure.
Returns:
Path to the downloaded model
Path to the downloaded model (file or directory)
"""
return self._services.model_manager.install.download_and_cache_model(source=source)
return self._services.model_manager.install.download_and_cache_model(
source=source, preserve_subfolders=preserve_subfolders
)
def load_local_model(
self,

View File

@@ -1,10 +1,11 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
from typing import Dict
import cv2
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.image_util.util import (
@@ -16,6 +17,8 @@ from invokeai.backend.image_util.util import (
safe_step,
)
HED_MODEL = "lllyasviel/Annotators::/ControlNetHED.pth"
class DoubleConvBlock(torch.nn.Module):
def __init__(self, input_channel, output_channel, layer_number):
@@ -76,16 +79,11 @@ class HEDProcessor:
On instantiation, loads the HED model from the HuggingFace Hub.
"""
def __init__(self):
model_path = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
def __init__(self, state_dict: Dict[str, torch.Tensor]):
self.network = ControlNetHED_Apache2()
self.network.load_state_dict(torch.load(model_path, map_location="cpu"))
self.network.load_state_dict(state_dict)
self.network.float().eval()
def to(self, device: torch.device):
self.network.to(device)
return self
def run(
self,
input_image: Image.Image,

View File

@@ -1,11 +1,12 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
from typing import Dict
import cv2
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.image_util.util import (
@@ -15,6 +16,9 @@ from invokeai.backend.image_util.util import (
resize_image_to_resolution,
)
LINEART_MODEL = "lllyasviel/Annotators::/sk_model.pth"
COARSE_MODEL = "lllyasviel/Annotators::/sk_model2.pth"
class ResidualBlock(nn.Module):
def __init__(self, in_features):
@@ -97,22 +101,15 @@ class Generator(nn.Module):
class LineartProcessor:
"""Processor for lineart detection."""
def __init__(self):
model_path = hf_hub_download("lllyasviel/Annotators", "sk_model.pth")
def __init__(self, model_sd: Dict[str, torch.Tensor], coarse_sd: Dict[str, torch.Tensor]):
self.model = Generator(3, 1, 3)
self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
self.model.load_state_dict(model_sd)
self.model.eval()
coarse_model_path = hf_hub_download("lllyasviel/Annotators", "sk_model2.pth")
self.model_coarse = Generator(3, 1, 3)
self.model_coarse.load_state_dict(torch.load(coarse_model_path, map_location=torch.device("cpu")))
self.model_coarse.load_state_dict(coarse_sd)
self.model_coarse.eval()
def to(self, device: torch.device):
self.model.to(device)
self.model_coarse.to(device)
return self
def run(
self, input_image: Image.Image, coarse: bool = False, detect_resolution: int = 512, image_resolution: int = 512
) -> Image.Image:

View File

@@ -1,14 +1,13 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
import functools
from typing import Optional
from typing import Dict, Optional
import cv2
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.image_util.util import (
@@ -18,6 +17,8 @@ from invokeai.backend.image_util.util import (
resize_image_to_resolution,
)
LINEART_ANIME_MODEL = "lllyasviel/Annotators::/netG.pth"
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
@@ -142,16 +143,14 @@ class UnetSkipConnectionBlock(nn.Module):
class LineartAnimeProcessor:
"""Processes an image to detect lineart."""
def __init__(self):
model_path = hf_hub_download("lllyasviel/Annotators", "netG.pth")
def __init__(self, model_sd: Dict[str, torch.Tensor]):
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
self.model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
ckpt = torch.load(model_path)
for key in list(ckpt.keys()):
for key in list(model_sd.keys()):
if "module." in key:
ckpt[key.replace("module.", "")] = ckpt[key]
del ckpt[key]
self.model.load_state_dict(ckpt)
model_sd[key.replace("module.", "")] = model_sd[key]
del model_sd[key]
self.model.load_state_dict(model_sd)
self.model.eval()
def to(self, device: torch.device):

View File

@@ -292,7 +292,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(target_device, copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
try:
cache_entry.model.to(target_device)
except TypeError as e:
if "got an unexpected keyword argument 'non_blocking'" in str(e):
cache_entry.model.to(target_device)
else:
raise e
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)

View File

@@ -3,8 +3,9 @@
import json
import logging
import sys
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
@@ -30,12 +31,14 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
elif isinstance(model, SchedulerMixin):
return 0
assert hasattr(model, "config") # size is dominated by config
return sys.getsizeof(model.config)
elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0
return sys.getsizeof(model.get_vocab()) # size is dominated by the vocab dict
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
return model.calc_size()
elif isinstance(model, dict):
return _calc_size_from_dict(model, logger)
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.
@@ -70,6 +73,19 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
return mem
def _calc_size_from_dict(model: dict[str, Any] | torch.Tensor | torch.nn.Module, logger: logging.Logger) -> int:
total = sys.getsizeof(model) # get python overhead for object
if isinstance(model, dict):
total += sum(_calc_size_from_dict(model[x], logger) for x in model.keys())
elif isinstance(model, torch.Tensor):
total += model.element_size() * model.nelement()
elif isinstance(model, torch.nn.Module):
total += calc_module_size(model)
else:
logger.warning(f"Failed to calculate model size for unexpected model type: {type(model)}.")
return total
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():

View File

@@ -26,7 +26,7 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[/:]+", "_", value.lower())
value = re.sub(r"[^.\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")

View File

@@ -81,8 +81,21 @@ def test_download_diffusers(mock_context: InvocationContext) -> None:
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::/vae")
assert model_path.is_dir()
assert model_path.name != "vae" # will not create the vae subfolder with preserve_subfolders False
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
model_path / "diffusion_pytorch_model.safetensors"
).exists()
def test_download_diffusers_preserve_subfolders(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model(
"stabilityai/sdxl-turbo::/vae",
preserve_subfolders=True,
)
assert model_path.is_dir()
assert model_path.name == "vae" # will create the vae subfolder with preserve_subfolders True
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
model_path / "diffusion_pytorch_model.safetensors"
).exists()