mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Rename LoRAModelRaw to ModelPatchRaw.
This commit is contained in:
@@ -20,7 +20,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@@ -66,10 +66,10 @@ class CompelInvocation(BaseInvocation):
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@@ -162,11 +162,11 @@ class SDXLPromptInvocationBase:
|
||||
c_pooled = None
|
||||
return c, c_pooled
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
assert isinstance(lora_model, ModelPatchRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
@@ -39,7 +39,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
@@ -987,10 +987,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
@@ -49,7 +49,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
@@ -715,7 +715,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
|
||||
if self.control_lora:
|
||||
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
|
||||
@@ -723,7 +723,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
loras.append(self.control_lora)
|
||||
for lora in loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
@@ -130,9 +130,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
@@ -193,9 +193,9 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
|
||||
def _clip_lora_iterator(
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -22,7 +22,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
@@ -194,10 +194,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from invokeai.backend.image_util.segment_anything.segment_anything_pipeline impo
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
@@ -43,7 +43,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
(
|
||||
TextualInversionModelRaw,
|
||||
IPAdapter,
|
||||
LoRAModelRaw,
|
||||
ModelPatchRaw,
|
||||
SpandrelImageToImageModel,
|
||||
GroundingDinoPipeline,
|
||||
SegmentAnythingPipeline,
|
||||
|
||||
@@ -8,7 +8,7 @@ from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlL
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||
# Example keys:
|
||||
@@ -43,7 +43,7 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
@@ -81,4 +81,4 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
else:
|
||||
raise ValueError(f"{layer_key} not expected")
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
@@ -30,7 +30,9 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
|
||||
def lora_model_from_flux_diffusers_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor], alpha: float | None
|
||||
) -> ModelPatchRaw:
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
|
||||
This function is based on:
|
||||
@@ -215,7 +217,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
|
||||
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
|
||||
|
||||
return LoRAModelRaw(layers=layers_with_prefix)
|
||||
return ModelPatchRaw(layers=layers_with_prefix)
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
|
||||
FLUX_LORA_CLIP_PREFIX,
|
||||
FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
)
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
@@ -39,7 +39,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
@@ -71,7 +71,7 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -4,17 +4,17 @@ import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(values)
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
|
||||
@@ -7,7 +7,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel):
|
||||
class ModelPatchRaw(RawModel):
|
||||
def __init__(self, layers: Mapping[str, BaseLayerPatch]):
|
||||
self.layers = layers
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
|
||||
@@ -19,7 +19,7 @@ class ModelPatcher:
|
||||
@contextmanager
|
||||
def apply_model_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
@@ -57,7 +57,7 @@ class ModelPatcher:
|
||||
def apply_model_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
@@ -148,7 +148,7 @@ class ModelPatcher:
|
||||
@contextmanager
|
||||
def apply_model_sidecar_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
@@ -189,7 +189,7 @@ class ModelPatcher:
|
||||
@staticmethod
|
||||
def _apply_model_sidecar_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: LoRAModelRaw,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
prefix: str,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
@@ -30,7 +30,7 @@ class LoRAExt(ExtensionBase):
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
assert isinstance(lora_model, ModelPatchRaw)
|
||||
ModelPatcher.apply_model_patch(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
@@ -47,7 +47,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
@@ -89,7 +89,7 @@ def test_apply_lora_patches_change_device():
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
|
||||
@@ -128,7 +128,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
@@ -138,7 +138,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
# Run inference before patching the model.
|
||||
@@ -171,7 +171,7 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
@@ -181,7 +181,7 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user