mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 01:37:56 -05:00
Compare commits
1 Commits
controlnet
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d729b37b36 |
@@ -30,7 +30,10 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
self,
|
||||
model_path: Path,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
cache_key_extra: Optional[str] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file or directory located at the indicated Path.
|
||||
@@ -46,6 +49,8 @@ class ModelLoadServiceBase(ABC):
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
|
||||
cache_key_extra: A string to append to the cache key. This is useful for
|
||||
differentiating an instances of the same model with different parameters.
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
|
||||
@@ -76,9 +76,12 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
self,
|
||||
model_path: Path,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
cache_key_extra: Optional[str] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
cache_key = f"{model_path}:{cache_key_extra}" if cache_key_extra else str(model_path)
|
||||
try:
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
except IndexError:
|
||||
|
||||
@@ -497,6 +497,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
self,
|
||||
model_path: Path,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
cache_key_extra: Optional[str] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file located at the indicated path
|
||||
@@ -509,18 +510,25 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Args:
|
||||
path: A model Path
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
cache_key_extra: A string to append to the cache key. This is useful for
|
||||
differentiating an instances of the same model with different parameters.
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
|
||||
self._util.signal_progress(f"Loading model {model_path.name}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
return self._services.model_manager.load.load_model_from_path(
|
||||
model_path=model_path,
|
||||
loader=loader,
|
||||
cache_key_extra=cache_key_extra,
|
||||
)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
cache_key_extra: Optional[str] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL or repo_id.
|
||||
@@ -535,6 +543,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Args:
|
||||
source: A URL or huggingface repoid.
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
cache_key_extra: A string to append to the cache key. This is useful for
|
||||
differentiating an instances of the same model with different parameters.
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
@@ -542,7 +552,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
|
||||
self._util.signal_progress(f"Loading model {source}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
return self._services.model_manager.load.load_model_from_path(
|
||||
model_path=model_path,
|
||||
loader=loader,
|
||||
cache_key_extra=cache_key_extra,
|
||||
)
|
||||
|
||||
def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path:
|
||||
"""Gets the absolute path for a given model config or path.
|
||||
|
||||
Reference in New Issue
Block a user