mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 11:55:17 -05:00
replace load_and_cache_model() with load_remote_model() and load_local_odel()
This commit is contained in:
committed by
psychedelicious
parent
9f9379682e
commit
dc134935c8
@@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
with self._context.models.load_and_cache_model(
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||
) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
|
||||
@@ -134,7 +134,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
with self._context.models.load_and_cache_model(
|
||||
with self._context.models.load_remote_model(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
|
||||
@@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
loadnet = context.models.load_and_cache_model(
|
||||
loadnet = context.models.load_remote_model(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@@ -241,7 +243,7 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache_model(self, source: str) -> Path:
|
||||
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import torch
|
||||
import yaml
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from pydantic_core import Url
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@@ -374,7 +375,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str,
|
||||
source: str | AnyHttpUrl,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_path = self._download_cache_path(str(source), self._app_config)
|
||||
@@ -388,7 +389,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model_source = self._guess_source(source)
|
||||
model_source = self._guess_source(str(source))
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
job = self._multifile_download(
|
||||
dest=model_path,
|
||||
@@ -447,7 +448,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
url=Url(source),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
@@ -37,7 +35,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file or directory located at the indicated Path.
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from torch import Tensor
|
||||
from torch import load as torch_load
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@@ -86,7 +85,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
@@ -95,11 +94,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]:
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||
@@ -109,18 +108,16 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self.convert_cache,
|
||||
).get_hf_load_class(directory)
|
||||
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||
return result
|
||||
|
||||
if loader is None:
|
||||
loader = (
|
||||
diffusers_load_directory
|
||||
if model_path.is_dir()
|
||||
else torch_load_file
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||
|
||||
loader = loader or (
|
||||
diffusers_load_directory
|
||||
if model_path.is_dir()
|
||||
else torch_load_file
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
|
||||
@@ -15,7 +15,14 @@ from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@@ -449,21 +456,42 @@ class ModelsInterface(InvocationContextInterface):
|
||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||
|
||||
Args:
|
||||
source: A model path, URL or repo_id.
|
||||
source: A URL that points to the model, or a huggingface repo_id.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_and_cache_model(
|
||||
def load_local_model(
|
||||
self,
|
||||
source: Path | str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
|
||||
model_path: Path,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL.
|
||||
Load the model file located at the indicated path
|
||||
|
||||
If a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
path: A model Path
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL or repo_id.
|
||||
|
||||
If the model is already downloaded, it will be loaded from the cache.
|
||||
|
||||
@@ -473,18 +501,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
source: A model Path, URL, or repoid.
|
||||
source: A URL or huggingface repoid.
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
|
||||
if isinstance(source, Path):
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||
else:
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
|
||||
@@ -59,14 +59,12 @@ class Migration11Callback:
|
||||
|
||||
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 9 to 10.
|
||||
Build the migration from database version 10 to 11.
|
||||
|
||||
This migration does the following:
|
||||
- Moves "core" models previously downloaded with download_with_progress_bar() into new
|
||||
"models/.download_cache" directory.
|
||||
- Renames "models/.cache" to "models/.convert_cache".
|
||||
- Adds `error_type` and `error_message` columns to the session queue table.
|
||||
- Renames the `error` column to `error_traceback`.
|
||||
"""
|
||||
migration_11 = Migration(
|
||||
from_version=10,
|
||||
|
||||
Reference in New Issue
Block a user