mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 08:38:00 -05:00
21 lines
772 B
Python
21 lines
772 B
Python
import itertools
|
|
|
|
import torch
|
|
|
|
|
|
def get_effective_device(model: torch.nn.Module) -> torch.device:
|
|
"""A utility to infer the 'effective' device of a model.
|
|
|
|
This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling:
|
|
`next(iter(model.parameters())).device`.
|
|
|
|
In the worst case, this utility has to check all model parameters, so if you already know the intended model device,
|
|
then it is better to avoid calling this function.
|
|
"""
|
|
# If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device.
|
|
for p in itertools.chain(model.parameters(), model.buffers()):
|
|
if p.device.type != "cpu":
|
|
return p.device
|
|
|
|
return torch.device("cpu")
|