mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Bug fixes to get LoRA sidecar patching working for the first time.
This commit is contained in:
@@ -192,6 +192,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||
# LoraPatcher.apply_lora_sidecar_patches(
|
||||
# model=transformer,
|
||||
# patches=self._lora_iterator(context),
|
||||
# prefix="",
|
||||
# ),
|
||||
LoraPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
|
||||
@@ -177,7 +177,9 @@ class LoraPatcher:
|
||||
# Initialize the LoRA sidecar layer.
|
||||
lora_sidecar_layer = LoraPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
|
||||
|
||||
# TODO(ryand): Should we move the LoRA sidecar layer to the same device/dtype as the orig module?
|
||||
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
|
||||
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
|
||||
lora_sidecar_layer.to(device=module.weight.device, dtype=module.weight.dtype)
|
||||
|
||||
if module_key in original_modules:
|
||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
||||
|
||||
@@ -71,10 +71,13 @@ class LoRAConvSidecarLayer(torch.nn.Module):
|
||||
)
|
||||
|
||||
# Inject weight into the LoRA layer.
|
||||
assert model._up.weight.shape == lora_layer.up.shape
|
||||
assert model._down.weight.shape == lora_layer.down.shape
|
||||
model._up.weight.data = lora_layer.up
|
||||
model._down.weight.data = lora_layer.down
|
||||
if lora_layer.mid is not None:
|
||||
assert model._mid is not None
|
||||
assert model._mid.weight.shape == lora_layer.mid.shape
|
||||
model._mid.weight.data = lora_layer.mid
|
||||
|
||||
return model
|
||||
|
||||
@@ -52,10 +52,13 @@ class LoRALinearSidecarLayer(torch.nn.Module):
|
||||
# TODO(ryand): Are there cases where we need to reshape the weight matrices to match the conv layers?
|
||||
|
||||
# Inject weight into the LoRA layer.
|
||||
assert model._up.weight.shape == lora_layer.up.shape
|
||||
assert model._down.weight.shape == lora_layer.down.shape
|
||||
model._up.weight.data = lora_layer.up
|
||||
model._down.weight.data = lora_layer.down
|
||||
if lora_layer.mid is not None:
|
||||
assert model._mid is not None
|
||||
assert model._mid.weight.shape == lora_layer.mid.shape
|
||||
model._mid.weight.data = lora_layer.mid
|
||||
|
||||
return model
|
||||
|
||||
@@ -10,8 +10,8 @@ class LoRASidecarModule(torch.nn.Module):
|
||||
def add_lora_layer(self, lora_layer: torch.nn.Module):
|
||||
self._lora_layers.append(lora_layer)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._orig_module(x)
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x = self._orig_module(input)
|
||||
for lora_layer in self._lora_layers:
|
||||
x += lora_layer(x)
|
||||
x += lora_layer(input)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user