mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-03 01:55:07 -05:00
26 lines
972 B
Python
26 lines
972 B
Python
import torch
|
|
|
|
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
|
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
|
|
|
|
|
class SetParameterLayer(BaseLayerPatch):
|
|
"""A layer that sets a single parameter to a new target value.
|
|
(The diff between the target value and current value is calculated internally.)
|
|
"""
|
|
|
|
def __init__(self, param_name: str, weight: torch.Tensor):
|
|
super().__init__()
|
|
self.weight = weight
|
|
self.param_name = param_name
|
|
|
|
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
|
diff = self.weight - orig_module.get_parameter(self.param_name)
|
|
return {self.param_name: diff * weight}
|
|
|
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
|
|
|
def calc_size(self) -> int:
|
|
return calc_tensor_size(self.weight)
|