Ensure that patches are on the correct device when used in sidecar wrappers.

This commit is contained in:
Ryan Dick
2024-12-13 21:24:32 +00:00
parent 606d58d7db
commit e7e3f7e144

View File

@@ -260,7 +260,9 @@ class LoRAPatcher:
wrapped_module = module_to_patch
# Move the LoRA layer to the same device/dtype as the orig module.
patch.to(device=wrapped_module.orig_module.weight.device, dtype=dtype)
first_param = next(module_to_patch.parameters())
device = first_param.device
patch.to(device=device, dtype=dtype)
# Add the patch to the sidecar wrapper.
wrapped_module.add_patch(patch, patch_weight)