import torch def calc_tensor_size(t: torch.Tensor) -> int: """Calculate the size of a tensor in bytes.""" return t.nelement() * t.element_size() def calc_tensors_size(tensors: list[torch.Tensor | None]) -> int: """Calculate the size of a list of tensors in bytes.""" return sum(calc_tensor_size(t) for t in tensors if t is not None)