Compare commits

...

1 Commits

Author SHA1 Message Date
psychedelicious
d729b37b36 feat(mm): support customizable cache key for load_model_from_path() 2025-04-22 14:52:22 +10:00
3 changed files with 27 additions and 5 deletions

View File

@@ -30,7 +30,10 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
self,
model_path: Path,
loader: Optional[Callable[[Path], AnyModel]] = None,
cache_key_extra: Optional[str] = None,
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.
@@ -46,6 +49,8 @@ class ModelLoadServiceBase(ABC):
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
cache_key_extra: A string to append to the cache key. This is useful for
differentiating an instances of the same model with different parameters.
Returns:
A LoadedModel object.

View File

@@ -76,9 +76,12 @@ class ModelLoadService(ModelLoadServiceBase):
return loaded_model
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
self,
model_path: Path,
loader: Optional[Callable[[Path], AnyModel]] = None,
cache_key_extra: Optional[str] = None,
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
cache_key = f"{model_path}:{cache_key_extra}" if cache_key_extra else str(model_path)
try:
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
except IndexError:

View File

@@ -497,6 +497,7 @@ class ModelsInterface(InvocationContextInterface):
self,
model_path: Path,
loader: Optional[Callable[[Path], AnyModel]] = None,
cache_key_extra: Optional[str] = None,
) -> LoadedModelWithoutConfig:
"""
Load the model file located at the indicated path
@@ -509,18 +510,25 @@ class ModelsInterface(InvocationContextInterface):
Args:
path: A model Path
loader: A Callable that expects a Path and returns a dict[str|int, Any]
cache_key_extra: A string to append to the cache key. This is useful for
differentiating an instances of the same model with different parameters.
Returns:
A LoadedModelWithoutConfig object.
"""
self._util.signal_progress(f"Loading model {model_path.name}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
return self._services.model_manager.load.load_model_from_path(
model_path=model_path,
loader=loader,
cache_key_extra=cache_key_extra,
)
def load_remote_model(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], AnyModel]] = None,
cache_key_extra: Optional[str] = None,
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL or repo_id.
@@ -535,6 +543,8 @@ class ModelsInterface(InvocationContextInterface):
Args:
source: A URL or huggingface repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any]
cache_key_extra: A string to append to the cache key. This is useful for
differentiating an instances of the same model with different parameters.
Returns:
A LoadedModelWithoutConfig object.
@@ -542,7 +552,11 @@ class ModelsInterface(InvocationContextInterface):
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
return self._services.model_manager.load.load_model_from_path(
model_path=model_path,
loader=loader,
cache_key_extra=cache_key_extra,
)
def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path:
"""Gets the absolute path for a given model config or path.