Update MPS cache limit logic.

This commit is contained in:
Ryan Dick
2024-12-17 23:44:17 -05:00
parent 79a4d0890f
commit 7a5dd084ad

View File

@@ -263,11 +263,10 @@ class ModelCache:
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
vram_available_to_process = vram_free + vram_reserved
elif self._execution_device.type == "mps":
# TODO(ryand): Would it be better to use psutil.virtual_memory().total here? I haven't looked into the
# behaviors of some of these functions when multiple processes are using MPS memory.
vram_reserved = torch.mps.driver_allocated_memory()
vram_total: int = torch.mps.recommended_max_memory()
vram_available_to_process = vram_total
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
vram_free = psutil.virtual_memory().available
vram_available_to_process = vram_free + vram_reserved
else:
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")