Compare commits

...

4 Commits

Author SHA1 Message Date
Ryan Dick
85661b9f78 WIP - LoRA sidecar layers. 2024-09-11 13:33:32 +00:00
Ryan Dick
0aea4e6800 WIP - adding LoRA sidecar layers 2024-09-10 21:45:18 +00:00
Ryan Dick
e21a66e7bf WIP - tidy LoRA layer initialization code. 2024-09-10 16:33:33 +00:00
Ryan Dick
c69e272fb3 Add util functions calc_tensor_size(...) and calc_tensors_size(...). 2024-09-10 16:02:21 +00:00
18 changed files with 457 additions and 69 deletions

View File

@@ -3,6 +3,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class FullLayer(LoRALayerBase):
@@ -26,9 +27,7 @@ class FullLayer(LoRALayerBase):
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
return calc_tensor_size(self.weight) + super().calc_size()
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)

View File

@@ -3,6 +3,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class IA3Layer(LoRALayerBase):
@@ -30,8 +31,7 @@ class IA3Layer(LoRALayerBase):
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
model_size += calc_tensors_size([self.weight, self.on_input])
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):

View File

@@ -3,6 +3,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoHALayer(LoRALayerBase):
@@ -49,9 +50,7 @@ class LoHALayer(LoRALayerBase):
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
model_size += calc_tensors_size([self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2])
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:

View File

@@ -3,6 +3,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoKRLayer(LoRALayerBase):
@@ -85,9 +86,7 @@ class LoKRLayer(LoRALayerBase):
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
model_size += calc_tensors_size([self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2])
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:

View File

@@ -3,34 +3,61 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
up: torch.Tensor,
down: torch.Tensor,
mid: Optional[torch.Tensor],
alpha: float | None,
bias: torch.Tensor | None,
):
super().__init__(alpha=alpha, bias=bias)
self.up = up
self.down = down
self.mid = mid
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
cls(
up=values["lora_up.weight"],
down=values["lora_down.weight"],
mid=values.get("lora_mid.weight", None),
alpha=alpha,
bias=bias,
)
self.rank = self.down.shape[0]
self.check_keys(
cls.warn_on_unhandled_keys(
values,
{
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
@property
def rank(self) -> int:
return self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
@@ -43,9 +70,7 @@ class LoRALayer(LoRALayerBase):
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
model_size += calc_tensors_size([self.up, self.mid, self.down])
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:

View File

@@ -3,37 +3,37 @@ from typing import Dict, Optional, Set
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
"""Base class for all LoRA-like patching layers."""
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(self, alpha: float | None, bias: torch.Tensor | None):
self.alpha = alpha
self.bias = bias
def __init__(
self,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
@classmethod
def _parse_bias(
cls, bias_indices: torch.Tensor | None, bias_values: torch.Tensor | None, bias_size: torch.Tensor | None
) -> torch.Tensor | None:
assert (bias_indices is None) == (bias_values is None) == (bias_size is None)
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
bias = None
if bias_indices is not None:
bias = torch.sparse_coo_tensor(bias_indices, bias_values, tuple(bias_size))
return bias
else:
self.bias = None
@classmethod
def _parse_alpha(
cls,
alpha: torch.Tensor | None,
) -> float | None:
return alpha.item() if alpha is not None else None
self.rank = None # set in layer implementation
@property
def rank(self) -> int | None:
raise NotImplementedError()
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@@ -49,23 +49,17 @@ class LoRALayerBase:
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
return calc_tensors_size([self.bias])
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
@classmethod
def warn_on_unhandled_keys(cls, values: Dict[str, torch.Tensor], handled_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
unknown_keys = set(values.keys()) - handled_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Unexpected keys: {unknown_keys}"
)

View File

@@ -3,6 +3,7 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class NormLayer(LoRALayerBase):
@@ -27,7 +28,7 @@ class NormLayer(LoRALayerBase):
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += calc_tensor_size(self.weight)
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:

View File

@@ -3,7 +3,16 @@ from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.lora.lora_conv_sidecar_layer import (
LoRAConv1dSidecarLayer,
LoRAConv2dSidecarLayer,
LoRAConv3dSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@@ -110,6 +119,113 @@ class LoraPatcher:
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
):
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoraPatcher._apply_lora_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_modules=original_modules,
)
yield
finally:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = module_key.rsplit(".", 1)
parent_module = model.get_submodule(module_parent_key)
LoraPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_lora_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
):
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoraPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoraPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
# TODO(ryand): Should we move the LoRA sidecar layer to the same device/dtype as the orig module?
if module_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.
assert isinstance(module, LoRASidecarModule)
module.add_lora_layer(lora_sidecar_layer)
else:
# The module has not yet been patched with a LoRASidecarModule. Create one.
lora_sidecar_module = LoRASidecarModule(module, [lora_sidecar_layer])
original_modules[module_key] = module
module_parent_key, module_name = module_key.rsplit(".", 1)
module_parent = model.get_submodule(module_parent_key)
LoraPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
if isinstance(orig_layer, torch.nn.Linear):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv1d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv1dSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight)
else:
raise ValueError(f"Unsupported Conv1D LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv2d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv2dSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight)
else:
raise ValueError(f"Unsupported Conv2D LoRA layer type: {type(lora_layer)}")
elif isinstance(orig_layer, torch.nn.Conv3d):
if isinstance(lora_layer, LoRALayer):
return LoRAConv3dSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight)
else:
raise ValueError(f"Unsupported Conv3D LoRA layer type: {type(lora_layer)}")
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
@staticmethod
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
try:
submodule_index = int(module_name)
# If the module name is an integer, then we use the __setitem__ method to set the submodule.
parent_module[submodule_index] = submodule
except ValueError:
# If the module name is not an integer, then we use the setattr method to set the submodule.
setattr(parent_module, module_name, submodule)
@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool

View File

@@ -0,0 +1,135 @@
import typing
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRAConvSidecarLayer(torch.nn.Module):
"""An implementation of a conv LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
(https://arxiv.org/pdf/2106.09685.pdf)
"""
@property
def conv_module(self) -> type[torch.nn.Conv1d | torch.nn.Conv2d | torch.nn.Conv3d]:
"""The conv module to be set by child classes. One of torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d."""
raise NotImplementedError(
"LoRAConvLayer cannot be used directly. Use LoRAConv1dLayer, LoRAConv2dLayer, or LoRAConv3dLayer instead."
)
def __init__(
self,
in_channels: int,
out_channels: int,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
kernel_size: typing.Union[int, tuple[int]] = 1,
stride: typing.Union[int, tuple[int]] = 1,
padding: typing.Union[str, int, tuple[int]] = 0,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
if rank > min(in_channels, out_channels):
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}")
self._down = self.conv_module(
in_channels,
rank,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
device=device,
dtype=dtype,
)
self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)
self._mid = None
if include_mid:
self._mid = self.conv_module(rank, rank, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))
self._weight = weight
self._rank = rank
@classmethod
def from_layers(cls, orig_layer: torch.nn.Module, lora_layer: LoRALayer, weight: float):
# Initialize the LoRA layer.
with torch.device("meta"):
model = cls.from_orig_layer(
orig_layer,
include_mid=lora_layer.mid is not None,
rank=lora_layer.rank,
# TODO(ryand): Is this the right default in case of missing alpha?
alpha=lora_layer.alpha if lora_layer.alpha is not None else lora_layer.rank,
weight=weight,
)
# Inject weight into the LoRA layer.
model._up.weight.data = lora_layer.up
model._down.weight.data = lora_layer.down
if lora_layer.mid is not None:
assert model._mid is not None
model._mid.weight.data = lora_layer.mid
return model
@classmethod
def from_orig_layer(
cls,
layer: torch.nn.Module,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
if not isinstance(layer, cls.conv_module):
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.")
return cls(
in_channels=layer.in_channels,
out_channels=layer.out_channels,
include_mid=include_mid,
weight=weight,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
rank=rank,
alpha=alpha,
device=layer.weight.device if device is None else device,
dtype=layer.weight.dtype if dtype is None else dtype,
)
def forward(self, x: torch.Tensor):
x = self._down(x)
if self._mid is not None:
x = self._mid(x)
x = self._up(x)
x *= self._weight * self.alpha / self._rank
return x
class LoRAConv1dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv1d
class LoRAConv2dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv2d
class LoRAConv3dSidecarLayer(LoRAConvSidecarLayer):
@property
def conv_module(self):
return torch.nn.Conv3d

View File

@@ -0,0 +1,95 @@
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRALinearSidecarLayer(torch.nn.Module):
"""An implementation of a linear LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
(https://arxiv.org/pdf/2106.09685.pdf)
"""
def __init__(
self,
in_features: int,
out_features: int,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_features, out_features)}")
self._down = torch.nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self._up = torch.nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
self._mid = None
if include_mid:
self._mid = torch.nn.Linear(rank, rank, bias=False, device=device, dtype=dtype)
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))
self._weight = weight
self._rank = rank
@classmethod
def from_layers(cls, orig_layer: torch.nn.Module, lora_layer: LoRALayer, weight: float):
# Initialize the LoRA layer.
with torch.device("meta"):
model = cls.from_orig_layer(
orig_layer,
include_mid=lora_layer.mid is not None,
rank=lora_layer.rank,
# TODO(ryand): Is this the right default in case of missing alpha?
alpha=lora_layer.alpha if lora_layer.alpha is not None else lora_layer.rank,
weight=weight,
)
# TODO(ryand): Are there cases where we need to reshape the weight matrices to match the conv layers?
# Inject weight into the LoRA layer.
model._up.weight.data = lora_layer.up
model._down.weight.data = lora_layer.down
if lora_layer.mid is not None:
assert model._mid is not None
model._mid.weight.data = lora_layer.mid
return model
@classmethod
def from_orig_layer(
cls,
layer: torch.nn.Module,
include_mid: bool,
rank: int,
alpha: float,
weight: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
if not isinstance(layer, torch.nn.Linear):
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.")
return cls(
in_features=layer.in_features,
out_features=layer.out_features,
include_mid=include_mid,
rank=rank,
alpha=alpha,
weight=weight,
device=layer.weight.device if device is None else device,
dtype=layer.weight.dtype if dtype is None else dtype,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._down(x)
if self._mid is not None:
x = self._mid(x)
x = self._up(x)
x *= self._weight * self.alpha / self._rank
return x

View File

@@ -0,0 +1,17 @@
import torch
class LoRASidecarModule(torch.nn.Module):
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
super().__init__()
self._orig_module = orig_module
self._lora_layers = lora_layers
def add_lora_layer(self, lora_layer: torch.nn.Module):
self._lora_layers.append(lora_layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._orig_module(x)
for lora_layer in self._lora_layers:
x += lora_layer(x)
return x

View File

@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
@@ -83,10 +84,9 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
def calc_module_size(model: torch.nn.Module) -> int:
"""Calculate the size (in bytes) of a torch.nn.Module."""
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
mem_params = sum([calc_tensor_size(param) for param in model.parameters()])
mem_bufs = sum([calc_tensor_size(buf) for buf in model.buffers()])
return mem_params + mem_bufs
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:

View File

@@ -10,6 +10,7 @@ from transformers import CLIPTokenizer
from typing_extensions import Self
from invokeai.backend.raw_model import RawModel
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class TextualInversionModelRaw(RawModel):
@@ -74,11 +75,7 @@ class TextualInversionModelRaw(RawModel):
def calc_size(self) -> int:
"""Get the size of this model in bytes."""
embedding_size = self.embedding.element_size() * self.embedding.nelement()
embedding_2_size = 0
if self.embedding_2 is not None:
embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
return embedding_size + embedding_2_size
return calc_tensors_size([self.embedding, self.embedding_2])
class TextualInversionManager(BaseTextualInversionManager):

View File

@@ -0,0 +1,11 @@
import torch
def calc_tensor_size(t: torch.Tensor) -> int:
"""Calculate the size of a tensor in bytes."""
return t.nelement() * t.element_size()
def calc_tensors_size(tensors: list[torch.Tensor | None]) -> int:
"""Calculate the size of a list of tensors in bytes."""
return sum(calc_tensor_size(t) for t in tensors if t is not None)