Bug fixes to get LoRA sidecar patching working for the first time.

This commit is contained in:
Ryan Dick
2024-09-11 14:43:43 +00:00
parent bf9a661303
commit ef6507d9bb
5 changed files with 17 additions and 4 deletions

View File

@@ -192,6 +192,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
with (
transformer_info.model_on_device() as (cached_weights, transformer),
# Apply the LoRA after transformer has been moved to its target device for faster patching.
# LoraPatcher.apply_lora_sidecar_patches(
# model=transformer,
# patches=self._lora_iterator(context),
# prefix="",
# ),
LoraPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),

View File

@@ -177,7 +177,9 @@ class LoraPatcher:
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoraPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
# TODO(ryand): Should we move the LoRA sidecar layer to the same device/dtype as the orig module?
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
lora_sidecar_layer.to(device=module.weight.device, dtype=module.weight.dtype)
if module_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.

View File

@@ -71,10 +71,13 @@ class LoRAConvSidecarLayer(torch.nn.Module):
)
# Inject weight into the LoRA layer.
assert model._up.weight.shape == lora_layer.up.shape
assert model._down.weight.shape == lora_layer.down.shape
model._up.weight.data = lora_layer.up
model._down.weight.data = lora_layer.down
if lora_layer.mid is not None:
assert model._mid is not None
assert model._mid.weight.shape == lora_layer.mid.shape
model._mid.weight.data = lora_layer.mid
return model

View File

@@ -52,10 +52,13 @@ class LoRALinearSidecarLayer(torch.nn.Module):
# TODO(ryand): Are there cases where we need to reshape the weight matrices to match the conv layers?
# Inject weight into the LoRA layer.
assert model._up.weight.shape == lora_layer.up.shape
assert model._down.weight.shape == lora_layer.down.shape
model._up.weight.data = lora_layer.up
model._down.weight.data = lora_layer.down
if lora_layer.mid is not None:
assert model._mid is not None
assert model._mid.weight.shape == lora_layer.mid.shape
model._mid.weight.data = lora_layer.mid
return model

View File

@@ -10,8 +10,8 @@ class LoRASidecarModule(torch.nn.Module):
def add_lora_layer(self, lora_layer: torch.nn.Module):
self._lora_layers.append(lora_layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._orig_module(x)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = self._orig_module(input)
for lora_layer in self._lora_layers:
x += lora_layer(x)
x += lora_layer(input)
return x