Memoize frequently accessed values in CachedModelWithPartialLoad.

This commit is contained in:
Ryan Dick
2024-12-06 20:39:05 +00:00
parent b50dd8502f
commit 7a002e1b05
2 changed files with 17 additions and 6 deletions

View File

@@ -23,6 +23,9 @@ class CachedModelWithPartialLoad:
# TODO(ryand): Manage a read-only CPU copy of the model state dict.
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
self._total_bytes = sum(calc_tensor_size(p) for p in self._model.parameters())
self._cur_vram_bytes: int | None = None
@property
def model(self) -> torch.nn.Module:
return self._model
@@ -34,11 +37,15 @@ class CachedModelWithPartialLoad:
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return sum(calc_tensor_size(p) for p in self._model.parameters())
return self._total_bytes
def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
return sum(calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type)
if self._cur_vram_bytes is None:
self._cur_vram_bytes = sum(
calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes
def full_load_to_vram(self) -> int:
"""Load all weights into VRAM."""
@@ -77,6 +84,9 @@ class CachedModelWithPartialLoad:
self._model._apply(to_vram)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
return vram_bytes_loaded
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
@@ -101,4 +111,7 @@ class CachedModelWithPartialLoad:
self._model._apply(from_vram)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed
return vram_bytes_freed

View File

@@ -29,9 +29,6 @@ def test_cached_model_total_bytes(device: str):
linear_numel = 10 * 10 + 10
assert cached_model.total_bytes() == linear_numel * 4 * 2
cached_model.model.to(dtype=torch.float16)
assert cached_model.total_bytes() == linear_numel * 2 * 2
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str):
@@ -39,7 +36,8 @@ def test_cached_model_cur_vram_bytes(device: str):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
assert cached_model.cur_vram_bytes() == 0
cached_model.model.to(device=torch.device(device))
cached_model.full_load_to_vram()
assert cached_model.cur_vram_bytes() > 0
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()