diff --git a/invokeai/backend/model_manager/load/model_cache/cache_record.py b/invokeai/backend/model_manager/load/model_cache/cache_record.py new file mode 100644 index 0000000000..2398eb9ab6 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/cache_record.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Dict, Generic, Optional, TypeVar + +import torch + +T = TypeVar("T") + + +@dataclass +class CacheRecord(Generic[T]): + """ + Elements of the cache: + + key: Unique key for each model, same as used in the models database. + model: Model in memory. + state_dict: A read-only copy of the model's state dict in RAM. It will be + used as a template for creating a copy in the VRAM. + size: Size of the model + loaded: True if the model's state dict is currently in VRAM + + Before a model is executed, the state_dict template is copied into VRAM, + and then injected into the model. When the model is finished, the VRAM + copy of the state dict is deleted, and the RAM version is reinjected + into the model. + + The state_dict should be treated as a read-only attribute. Do not attempt + to patch or otherwise modify it. Instead, patch the copy of the state_dict + after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel` + context manager call `model_on_device()`. + """ + + key: str + model: T + device: torch.device + state_dict: Optional[Dict[str, torch.Tensor]] + size: int + loaded: bool = False + _locks: int = 0 + + def lock(self) -> None: + """Lock this record.""" + self._locks += 1 + + def unlock(self) -> None: + """Unlock this record.""" + self._locks -= 1 + assert self._locks >= 0 + + @property + def locked(self) -> bool: + """Return true if record is locked.""" + return self._locks > 0 diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index cfc886a06a..1109296d3a 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -15,56 +15,9 @@ from typing import Dict, Generic, Optional, TypeVar import torch from invokeai.backend.model_manager.config import AnyModel, SubModelType +from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker -T = TypeVar("T") - - -@dataclass -class CacheRecord(Generic[T]): - """ - Elements of the cache: - - key: Unique key for each model, same as used in the models database. - model: Model in memory. - state_dict: A read-only copy of the model's state dict in RAM. It will be - used as a template for creating a copy in the VRAM. - size: Size of the model - loaded: True if the model's state dict is currently in VRAM - - Before a model is executed, the state_dict template is copied into VRAM, - and then injected into the model. When the model is finished, the VRAM - copy of the state dict is deleted, and the RAM version is reinjected - into the model. - - The state_dict should be treated as a read-only attribute. Do not attempt - to patch or otherwise modify it. Instead, patch the copy of the state_dict - after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel` - context manager call `model_on_device()`. - """ - - key: str - model: T - device: torch.device - state_dict: Optional[Dict[str, torch.Tensor]] - size: int - loaded: bool = False - _locks: int = 0 - - def lock(self) -> None: - """Lock this record.""" - self._locks += 1 - - def unlock(self) -> None: - """Unlock this record.""" - self._locks -= 1 - assert self._locks >= 0 - - @property - def locked(self) -> bool: - """Return true if record is locked.""" - return self._locks > 0 - @dataclass class CacheStats(object): @@ -79,6 +32,9 @@ class CacheStats(object): loaded_model_sizes: Dict[str, int] = field(default_factory=dict) +T = TypeVar("T") + + class ModelCacheBase(ABC, Generic[T]): """Virtual base class for RAM model cache.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 3f8ed0d7fc..62475b878d 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -13,8 +13,8 @@ import torch from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache_base import ( - CacheRecord, CacheStats, ModelCacheBase, ) diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 5c0b0e8cc2..6b1dd0ed42 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -7,7 +7,8 @@ from typing import Dict, Optional import torch from invokeai.backend.model_manager import AnyModel -from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheRecord, ModelCacheBase +from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase class ModelLocker: