mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-11 05:55:23 -05:00
75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
from typing import Dict, Optional, Set
|
|
|
|
import torch
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
|
|
|
|
class LoRALayerBase:
|
|
# rank: Optional[int]
|
|
# alpha: Optional[float]
|
|
# bias: Optional[torch.Tensor]
|
|
# layer_key: str
|
|
|
|
# @property
|
|
# def scale(self):
|
|
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
|
|
def __init__(
|
|
self,
|
|
layer_key: str,
|
|
values: Dict[str, torch.Tensor],
|
|
):
|
|
if "alpha" in values:
|
|
self.alpha = values["alpha"].item()
|
|
else:
|
|
self.alpha = None
|
|
|
|
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
|
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
|
values["bias_indices"],
|
|
values["bias_values"],
|
|
tuple(values["bias_size"]),
|
|
)
|
|
|
|
else:
|
|
self.bias = None
|
|
|
|
self.rank = None # set in layer implementation
|
|
self.layer_key = layer_key
|
|
|
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
|
return self.bias
|
|
|
|
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
|
params = {"weight": self.get_weight(orig_module.weight)}
|
|
bias = self.get_bias(orig_module.bias)
|
|
if bias is not None:
|
|
params["bias"] = bias
|
|
return params
|
|
|
|
def calc_size(self) -> int:
|
|
model_size = 0
|
|
for val in [self.bias]:
|
|
if val is not None:
|
|
model_size += val.nelement() * val.element_size()
|
|
return model_size
|
|
|
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
if self.bias is not None:
|
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
|
|
|
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
|
"""Log a warning if values contains unhandled keys."""
|
|
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
|
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
|
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
|
unknown_keys = set(values.keys()) - all_known_keys
|
|
if unknown_keys:
|
|
logger.warning(
|
|
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
|
)
|