Files

21 lines
772 B
Python

import itertools
import torch
def get_effective_device(model: torch.nn.Module) -> torch.device:
"""A utility to infer the 'effective' device of a model.
This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling:
`next(iter(model.parameters())).device`.
In the worst case, this utility has to check all model parameters, so if you already know the intended model device,
then it is better to avoid calling this function.
"""
# If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device.
for p in itertools.chain(model.parameters(), model.buffers()):
if p.device.type != "cpu":
return p.device
return torch.device("cpu")