mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-01 15:54:58 -05:00
* introduce new abstraction layer for GPU devices * add unit test for device abstraction * fix ruff * convert TorchDeviceSelect into a stateless class * move logic to select context-specific execution device into context API * add mock hardware environments to pytest * remove dangling mocker fixture * fix unit test for running on non-CUDA systems * remove unimplemented get_execution_device() call * remove autocast precision * Multiple changes: 1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to context.models.get_execution_device(). 2. Rename TorchDeviceSelect to TorchDevice 3. Added back the legacy public API defined in `invocation_api`, including choose_precision(). 4. Added a config file migration script to accommodate removal of precision=autocast. * add deprecation warnings to choose_torch_device() and choose_precision() * fix test crash * remove app_config argument from choose_torch_device() and choose_torch_dtype() --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
103 lines
3.6 KiB
Python
103 lines
3.6 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
|
"""Implementation of ModelManagerServiceBase."""
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from typing_extensions import Self
|
|
|
|
from invokeai.app.services.invoker import Invoker
|
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
from ..config import InvokeAIAppConfig
|
|
from ..download import DownloadQueueServiceBase
|
|
from ..events.events_base import EventServiceBase
|
|
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
|
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
|
from ..model_records import ModelRecordServiceBase
|
|
from .model_manager_base import ModelManagerServiceBase
|
|
|
|
|
|
class ModelManagerService(ModelManagerServiceBase):
|
|
"""
|
|
The ModelManagerService handles various aspects of model installation, maintenance and loading.
|
|
|
|
It bundles three distinct services:
|
|
model_manager.store -- Routines to manage the database of model configuration records.
|
|
model_manager.install -- Routines to install, move and delete models.
|
|
model_manager.load -- Routines to load models into memory.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
store: ModelRecordServiceBase,
|
|
install: ModelInstallServiceBase,
|
|
load: ModelLoadServiceBase,
|
|
):
|
|
self._store = store
|
|
self._install = install
|
|
self._load = load
|
|
|
|
@property
|
|
def store(self) -> ModelRecordServiceBase:
|
|
return self._store
|
|
|
|
@property
|
|
def install(self) -> ModelInstallServiceBase:
|
|
return self._install
|
|
|
|
@property
|
|
def load(self) -> ModelLoadServiceBase:
|
|
return self._load
|
|
|
|
def start(self, invoker: Invoker) -> None:
|
|
for service in [self._store, self._install, self._load]:
|
|
if hasattr(service, "start"):
|
|
service.start(invoker)
|
|
|
|
def stop(self, invoker: Invoker) -> None:
|
|
for service in [self._store, self._install, self._load]:
|
|
if hasattr(service, "stop"):
|
|
service.stop(invoker)
|
|
|
|
@classmethod
|
|
def build_model_manager(
|
|
cls,
|
|
app_config: InvokeAIAppConfig,
|
|
model_record_service: ModelRecordServiceBase,
|
|
download_queue: DownloadQueueServiceBase,
|
|
events: EventServiceBase,
|
|
execution_device: Optional[torch.device] = None,
|
|
) -> Self:
|
|
"""
|
|
Construct the model manager service instance.
|
|
|
|
For simplicity, use this class method rather than the __init__ constructor.
|
|
"""
|
|
logger = InvokeAILogger.get_logger(cls.__name__)
|
|
logger.setLevel(app_config.log_level.upper())
|
|
|
|
ram_cache = ModelCache(
|
|
max_cache_size=app_config.ram,
|
|
max_vram_cache_size=app_config.vram,
|
|
lazy_offloading=app_config.lazy_offload,
|
|
logger=logger,
|
|
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
|
)
|
|
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
|
loader = ModelLoadService(
|
|
app_config=app_config,
|
|
ram_cache=ram_cache,
|
|
convert_cache=convert_cache,
|
|
registry=ModelLoaderRegistry,
|
|
)
|
|
installer = ModelInstallService(
|
|
app_config=app_config,
|
|
record_store=model_record_service,
|
|
download_queue=download_queue,
|
|
event_bus=events,
|
|
)
|
|
return cls(store=model_record_service, install=installer, load=loader)
|