Improve handling of cases when application code modifies the size of a model after registering it with the model cache.

This commit is contained in:
Ryan Dick
2024-12-30 17:57:04 -05:00
parent 402dd840a1
commit 1b7bb70bde

View File

@@ -21,9 +21,14 @@ class CachedModelWithPartialLoad:
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
# A dictionary of the size of each tensor in the state dict.
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
# consistency in case the application code has modified the model's size (e.g. by casting to a different
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
# data that may not be fully accurate.
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
self._total_bytes = sum(self._state_dict_bytes.values())
self._cur_vram_bytes: int | None = None
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
@@ -79,7 +84,9 @@ class CachedModelWithPartialLoad:
if self._cur_vram_bytes is None:
cur_state_dict = self._model.state_dict()
self._cur_vram_bytes = sum(
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
self._state_dict_bytes[k]
for k, v in cur_state_dict.items()
if v.device.type == self._compute_device.type
)
return self._cur_vram_bytes
@@ -111,7 +118,7 @@ class CachedModelWithPartialLoad:
if param.device.type == self._compute_device.type:
continue
param_size = calc_tensor_size(param)
param_size = self._state_dict_bytes[key]
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size
@@ -128,7 +135,7 @@ class CachedModelWithPartialLoad:
if param.device.type == self._compute_device.type:
continue
param_size = calc_tensor_size(param)
param_size = self._state_dict_bytes[key]
if vram_bytes_loaded + param_size > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
@@ -149,7 +156,6 @@ class CachedModelWithPartialLoad:
if fully_loaded:
self._set_autocast_enabled_in_all_modules(False)
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
else:
self._set_autocast_enabled_in_all_modules(True)
@@ -178,7 +184,7 @@ class CachedModelWithPartialLoad:
continue
cur_state_dict[key] = self._cpu_state_dict[key]
vram_bytes_freed += calc_tensor_size(param)
vram_bytes_freed += self._state_dict_bytes[key]
if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)