Move CacheRecord out to its own file.

This commit is contained in:
Ryan Dick
2024-12-04 21:53:19 +00:00
parent e48dee4c4a
commit ce11a1952e
4 changed files with 59 additions and 50 deletions

View File

@@ -0,0 +1,52 @@
from dataclasses import dataclass
from typing import Dict, Generic, Optional, TypeVar
import torch
T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0

View File

@@ -15,56 +15,9 @@ from typing import Dict, Generic, Optional, TypeVar
import torch
from invokeai.backend.model_manager.config import AnyModel, SubModelType
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
@dataclass
class CacheStats(object):
@@ -79,6 +32,9 @@ class CacheStats(object):
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
T = TypeVar("T")
class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache."""

View File

@@ -13,8 +13,8 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
CacheRecord,
CacheStats,
ModelCacheBase,
)

View File

@@ -7,7 +7,8 @@ from typing import Dict, Optional
import torch
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheRecord, ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
class ModelLocker: