add invocation_context.load_ckpt_from_url() method

This commit is contained in:
Lincoln Stein
2024-04-12 00:55:21 -04:00
committed by Lincoln Stein
parent 34438ce1af
commit c140d3b1df
8 changed files with 131 additions and 25 deletions

View File

@@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal
import cv2
@@ -11,7 +10,6 @@ from pydantic import ConfigDict
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
@@ -56,7 +54,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
rrdbnet_model = None
netscale = None
esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
@@ -99,16 +96,13 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg)
raise ValueError(msg)
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
# Downloads the ESRGAN model if it doesn't already exist
download_with_progress_bar(
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
loadnet = context.models.load_ckpt_from_url(
source=ESRGAN_MODEL_URLS[self.model_name],
)
upscaler = RealESRGAN(
scale=netscale,
model_path=esrgan_model_path,
loadnet=loadnet.model,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
@@ -118,6 +112,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()

View File

@@ -1,11 +1,14 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from picklescan.scanner import scan_file_path
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from safetensors.torch import load_file as safetensors_load_file
from torch import Tensor
from torch import load as torch_load
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
@@ -488,13 +491,14 @@ class ModelsInterface(InvocationContextInterface):
key: str = job.config_out.key
return key
def download_and_cache_model(
def download_and_cache_ckpt(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
) -> Path:
"""Download the model file located at source to the models cache and return its Path.
"""
Download the model file located at source to the models cache and return its Path.
This can be used to single-file install models and other resources of arbitrary types
which should not get registered with the database. If the model is already
@@ -522,10 +526,65 @@ class ModelsInterface(InvocationContextInterface):
)
return path
def load_ckpt_from_url(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
) -> LoadedModel:
"""
Load and cache the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_model().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
ram_cache = self._services.model_manager.load.ram_cache
try:
return LoadedModel(_locker=ram_cache.get(key=str(source)))
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
return torch_load(path, map_location="cpu")
path = self.download_and_cache_ckpt(source, access_token, timeout)
if loader is None:
loader = (
torch_load_file
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
raw_model = loader(path)
ram_cache.put(key=str(source), model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=str(source)))
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
"""Gets the app's config.
"""
Gets the app's config.
Returns:
The app's config.