From 27195b167280966f75463c2b1ab9ad2ea0c489a8 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 21 Jun 2024 15:36:37 -0400 Subject: [PATCH] code cleanup after @ryand review --- invokeai/app/api/routers/model_manager.py | 25 +++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index ffb7e909e5..a0c5ad8017 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -847,24 +847,37 @@ async def set_cache_size( """Set the current RAM or VRAM cache size setting (in GB). .""" cache = ApiDependencies.invoker.services.model_manager.load.ram_cache app_config = get_config() + vram_bak, ram_bak = (app_config.vram, app_config.ram) + if cache_type == CacheType.RAM: cache.max_cache_size = value app_config.ram = value elif cache_type == CacheType.VRAM: cache.max_vram_cache_size = value app_config.vram = value + else: + raise ValueError(f"Unexpected {cache_type=}.") if persist: config_path = app_config.config_file_path - print(f"DEBUG: config_path = {config_path}") + new_config_path = config_path.with_suffix(".yaml.new") + backup_config_path = config_path.with_suffix(".yaml.bak") + shutil.copy(config_path, backup_config_path) try: - shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) - app_config.write_file(config_path) + app_config.write_file(new_config_path) + shutil.move(new_config_path, config_path) except Exception as e: - shutil.move(config_path.with_suffix(".yaml.bak"), config_path) - raise RuntimeError(f"Failed to write modified configuration to {config_path}: {e}") from e + shutil.move(backup_config_path, config_path) + app_config.max_vram_cache_size = vram_bak + app_config.max_cache_size = ram_bak + raise RuntimeError(f"Failed to save configuration to {config_path}: {e}") from e - return cache.max_vram_cache_size if cache_type == CacheType.VRAM else cache.max_cache_size + if cache_type == CacheType.VRAM: + return cache.max_vram_cache_size + elif cache_type == CacheType.RAM: + return cache.max_cache_size + else: + raise ValueError(f"Unexpected {cache_type=}.") @model_manager_router.get(