mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
adjusting back to hooks, forcing to be last in execution
This commit is contained in:
@@ -38,25 +38,19 @@ LORA_PREFIX_UNET = 'lora_unet'
|
||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||
|
||||
|
||||
def lora_forward(module, input_h, output):
|
||||
if len(loaded_loras) == 0:
|
||||
def lora_forward_hook(name):
|
||||
def lora_forward(module, input_h, output):
|
||||
if len(loaded_loras) == 0:
|
||||
return output
|
||||
|
||||
for lora in applied_loras.values():
|
||||
layer = lora.layers.get(name, None)
|
||||
if layer is None:
|
||||
continue
|
||||
output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale
|
||||
return output
|
||||
|
||||
lora_name = getattr(module, 'lora_name', None)
|
||||
for lora in applied_loras.values():
|
||||
layer = lora.layers.get(lora_name, None)
|
||||
if layer is None:
|
||||
continue
|
||||
output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale
|
||||
return output
|
||||
|
||||
|
||||
def lora_linear_forward(self, input_h):
|
||||
return lora_forward(self, input_h, torch.nn.Linear.forward_before_lora(self, input_h))
|
||||
|
||||
|
||||
def lora_conv2d_forward(self, input_h):
|
||||
return lora_forward(self, input_h, torch.nn.Conv2d.forward_before_lora(self, input_h))
|
||||
return lora_forward
|
||||
|
||||
|
||||
def load_lora(
|
||||
@@ -141,6 +135,7 @@ def load_lora(
|
||||
|
||||
class LoraManager:
|
||||
loras_to_load: dict[str, float]
|
||||
hooks: list[RemovableHandle]
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
@@ -149,26 +144,19 @@ class LoraManager:
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.dtype = pipe.unet.dtype
|
||||
self.loras_to_load = {}
|
||||
|
||||
if not hasattr(torch.nn.Linear, 'forward_before_lora'):
|
||||
torch.nn.Linear.forward_before_lora = torch.nn.Linear.forward
|
||||
|
||||
if not hasattr(torch.nn.Conv2d, 'forward_before_lora'):
|
||||
torch.nn.Conv2d.forward_before_lora = torch.nn.Conv2d.forward
|
||||
|
||||
torch.nn.Linear.forward = lora_linear_forward
|
||||
torch.nn.Conv2d.forward = lora_conv2d_forward
|
||||
self.hooks = []
|
||||
|
||||
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
|
||||
mapping = {}
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
||||
layer_type = child_module.__class__.__name__
|
||||
if layer_type == "Linear" or (layer_type == "Conv2d" and child_module.kernel_size == (1, 1)):
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
mapping[lora_name] = child_module
|
||||
module.lora_name = lora_name
|
||||
self.apply_module_forward(child_module, lora_name)
|
||||
return mapping
|
||||
|
||||
self.text_modules = find_modules(
|
||||
@@ -182,6 +170,12 @@ class LoraManager:
|
||||
loaded_loras[name] = lora
|
||||
return lora
|
||||
|
||||
def apply_module_forward(self, module, lora_name):
|
||||
handle = RemovableHandle(module._forward_hooks)
|
||||
handle.id = 9000
|
||||
module._forward_hooks[handle.id] = lora_forward_hook(lora_name)
|
||||
self.hooks.append(handle)
|
||||
|
||||
def apply_lora_model(self, name, mult: float = 1.0):
|
||||
path = Path(self.lora_path, name)
|
||||
file = Path(path, "pytorch_lora_weights.bin")
|
||||
@@ -258,16 +252,14 @@ class LoraManager:
|
||||
clear_applied_loras()
|
||||
self.loras_to_load = {}
|
||||
|
||||
def clear_hooks(self):
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
|
||||
self.hooks.clear()
|
||||
|
||||
def __del__(self):
|
||||
# cleanup overrides
|
||||
if hasattr(torch.nn.Linear, 'forward_before_lora'):
|
||||
torch.nn.Linear.forward = torch.nn.Linear.forward_before_lora
|
||||
del torch.nn.Linear.forward_before_lora
|
||||
|
||||
if hasattr(torch.nn.Conv2d, 'forward_before_lora'):
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d.forward_before_lora
|
||||
del torch.nn.Conv2d.forward_before_lora
|
||||
|
||||
self.clear_hooks()
|
||||
clear_applied_loras()
|
||||
clear_loaded_loras()
|
||||
del self.text_modules
|
||||
|
||||
Reference in New Issue
Block a user