mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Improve handling of cases when application code modifies the size of a model after registering it with the model cache.
This commit is contained in:
@@ -21,9 +21,14 @@ class CachedModelWithPartialLoad:
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
|
||||
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
|
||||
# A dictionary of the size of each tensor in the state dict.
|
||||
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
|
||||
# consistency in case the application code has modified the model's size (e.g. by casting to a different
|
||||
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
|
||||
# data that may not be fully accurate.
|
||||
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
|
||||
|
||||
self._total_bytes = sum(self._state_dict_bytes.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
|
||||
@@ -79,7 +84,9 @@ class CachedModelWithPartialLoad:
|
||||
if self._cur_vram_bytes is None:
|
||||
cur_state_dict = self._model.state_dict()
|
||||
self._cur_vram_bytes = sum(
|
||||
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
|
||||
self._state_dict_bytes[k]
|
||||
for k, v in cur_state_dict.items()
|
||||
if v.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
@@ -111,7 +118,7 @@ class CachedModelWithPartialLoad:
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = calc_tensor_size(param)
|
||||
param_size = self._state_dict_bytes[key]
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
@@ -128,7 +135,7 @@ class CachedModelWithPartialLoad:
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = calc_tensor_size(param)
|
||||
param_size = self._state_dict_bytes[key]
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
@@ -149,7 +156,6 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
if fully_loaded:
|
||||
self._set_autocast_enabled_in_all_modules(False)
|
||||
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
|
||||
else:
|
||||
self._set_autocast_enabled_in_all_modules(True)
|
||||
|
||||
@@ -178,7 +184,7 @@ class CachedModelWithPartialLoad:
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
vram_bytes_freed += calc_tensor_size(param)
|
||||
vram_bytes_freed += self._state_dict_bytes[key]
|
||||
|
||||
if vram_bytes_freed > 0:
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
|
||||
Reference in New Issue
Block a user