mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
mostly ported to new manager API; needs testing
This commit is contained in:
@@ -2,4 +2,4 @@
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager
|
||||
from .model_cache import ModelCache, ModelStatus
|
||||
from .model_cache import ModelCache, ModelStatus, SDModelType
|
||||
|
||||
@@ -78,6 +78,10 @@ class UnscannableModelException(Exception):
|
||||
"Raised when picklescan is unable to scan a legacy model file"
|
||||
pass
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -112,8 +116,6 @@ class ModelCache(object):
|
||||
self.loaded_models: set = set() # set of model keys loaded in GPU
|
||||
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_model(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
@@ -124,7 +126,7 @@ class ModelCache(object):
|
||||
legacy_info: LegacyInfo=None,
|
||||
attach_model_part: Tuple[SDModelType, str] = (None,None),
|
||||
gpu_load: bool=True,
|
||||
)->Generator[ModelClass, None, None]:
|
||||
)->ModelLocker: # ?? what does it return
|
||||
'''
|
||||
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
|
||||
Use like this:
|
||||
@@ -188,29 +190,45 @@ class ModelCache(object):
|
||||
if submodel:
|
||||
model = getattr(model, submodel.name)
|
||||
|
||||
if gpu_load and hasattr(model,'to'):
|
||||
try:
|
||||
self.loaded_models.add(key)
|
||||
self.locked_models[key] += 1
|
||||
if self.lazy_offloading:
|
||||
self._offload_unlocked_models()
|
||||
self.logger.debug(f'Loading {key} into {self.execution_device}')
|
||||
model.to(self.execution_device) # move into GPU
|
||||
self._print_cuda_stats()
|
||||
yield model
|
||||
finally:
|
||||
self.locked_models[key] -= 1
|
||||
if not self.lazy_offloading:
|
||||
self._offload_unlocked_models()
|
||||
self._print_cuda_stats()
|
||||
else:
|
||||
# in the event that the caller wants the model in RAM, we
|
||||
# move it into CPU if it is in GPU and not locked
|
||||
if hasattr(model,'to') and (key in self.loaded_models
|
||||
and self.locked_models[key] == 0):
|
||||
model.to(self.storage_device)
|
||||
self.loaded_models.remove(key)
|
||||
yield model
|
||||
return self.ModelLocker(self, key, model, gpu_load)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load):
|
||||
self.gpu_load = gpu_load
|
||||
self.cache = cache
|
||||
self.key = key
|
||||
# This will keep a copy of the model in RAM until the locker
|
||||
# is garbage collected. Needs testing!
|
||||
self.model = model
|
||||
|
||||
def __enter__(self)->ModelClass:
|
||||
cache = self.cache
|
||||
key = self.key
|
||||
model = self.model
|
||||
if self.gpu_load and hasattr(model,'to'):
|
||||
cache.loaded_models.add(key)
|
||||
cache.locked_models[key] += 1
|
||||
if cache.lazy_offloading:
|
||||
cache._offload_unlocked_models()
|
||||
cache.logger.debug(f'Loading {key} into {cache.execution_device}')
|
||||
model.to(cache.execution_device) # move into GPU
|
||||
cache._print_cuda_stats()
|
||||
else:
|
||||
# in the event that the caller wants the model in RAM, we
|
||||
# move it into CPU if it is in GPU and not locked
|
||||
if hasattr(model,'to') and (key in cache.loaded_models
|
||||
and cache.locked_models[key] == 0):
|
||||
model.to(cache.storage_device)
|
||||
cache.loaded_models.remove(key)
|
||||
return model
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
key = self.key
|
||||
cache = self.cache
|
||||
cache.locked_models[key] -= 1
|
||||
if not cache.lazy_offloading:
|
||||
cache._offload_unlocked_models()
|
||||
cache._print_cuda_stats()
|
||||
|
||||
def attach_part(self,
|
||||
diffusers_model: StableDiffusionPipeline,
|
||||
@@ -381,10 +399,11 @@ class ModelCache(object):
|
||||
revisions = [revision] if revision \
|
||||
else ['fp16','main'] if self.precision==torch.float16 \
|
||||
else ['main']
|
||||
extra_args = {'precision': self.precision} \
|
||||
if model_class in DiffusionClasses \
|
||||
else {}
|
||||
|
||||
extra_args = {'torch_dtype': self.precision,
|
||||
'safety_checker': None}\
|
||||
if model_class in DiffusionClasses\
|
||||
else {}
|
||||
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
for rev in revisions:
|
||||
|
||||
@@ -69,7 +69,7 @@ class SDModelInfo():
|
||||
revision: str = None
|
||||
_cache: ModelCache = None
|
||||
|
||||
|
||||
@property
|
||||
def status(self)->ModelStatus:
|
||||
'''Return load status of this model as a model_cache.ModelStatus enum'''
|
||||
if not self._cache:
|
||||
@@ -106,7 +106,7 @@ class ModelManager(object):
|
||||
config_path: Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: torch.dtype = torch.float16,
|
||||
max_models=DEFAULT_MAX_MODELS,
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
@@ -119,7 +119,7 @@ class ModelManager(object):
|
||||
self.config_path = config_path
|
||||
self.config = OmegaConf.load(self.config_path)
|
||||
self.cache = ModelCache(
|
||||
max_models=max_models,
|
||||
max_models=max_loaded_models,
|
||||
execution_device = device_type,
|
||||
precision = precision,
|
||||
sequential_offload = sequential_offload,
|
||||
@@ -164,7 +164,7 @@ class ModelManager(object):
|
||||
if mconfig.get('vae'):
|
||||
legacy.vae_file = global_resolve_path(mconfig.vae)
|
||||
elif format=='diffusers':
|
||||
location = mconfig.repo_id
|
||||
location = mconfig.get('repo_id') or mconfig.get('path')
|
||||
revision = mconfig.get('revision')
|
||||
else:
|
||||
raise InvalidModelError(
|
||||
|
||||
Reference in New Issue
Block a user